【爆火的图神经网络模型】GCN/GraphSAGE/GAT
Graph Network1. GCN节点特征的更新公式:H(l+1)=σ(D~−12A~D~−12H(l)W(l))H^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right)H(l+1)=σ(D~−21A~D~−21H(l)W(l))其中 D~
Graph Network
1. GCN
-
节点特征的更新公式:
H(l+1)=σ(D~−12A~D~−12H(l)W(l)) H^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right) H(l+1)=σ(D~−21A~D~−21H(l)W(l))
其中 D~\tilde{D}D~相当于加上了自环后的度矩阵,A~\tilde{A}A~相当于加上了自环后的邻接矩阵 -
半监督的节点分类任务:
-
网络架构:
-
模型公式:
Z=f(X,A)=softmax(A^ReLU(A^XW(0))W(1)) Z=f(X, A)=\operatorname{softmax}\left(\hat{A} \operatorname{ReLU}\left(\hat{A} X W^{(0)}\right) W^{(1)}\right) Z=f(X,A)=softmax(A^ReLU(A^XW(0))W(1))
其中A^=D~−12A~D~−12\hat{A}=\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}}A^=D~−21A~D~−21,W(0)∈RC×HW^{(0)}\in R^{C \times H}W(0)∈RC×H,W(1)∈RH×FW^{(1)}\in R^{H \times F}W(1)∈RH×F维度变换:(N,C) * (C,H) = (N,H) (N,H) * (H,F) = (N,F)
最后一层GCN的隐层特征数等于类别数,直接使用softmax输出概率
-
损失函数:
L=−∑l∈YL∑f=1FYlflnZlf \mathcal{L}=-\sum_{l \in \mathcal{Y}_{L}} \sum_{f=1}^{F} Y_{l f} \ln Z_{l f} L=−l∈YL∑f=1∑FYlflnZlf其中 YL\mathcal{Y}_{L}YL表示有标签的节点索引的集合; 例如,当节点1属于第5类时,那么Y15Y_{15}Y15等于1,其余的Y1?Y_{1?}Y1?都等于0
-
数据集介绍:
- Cora数据集:
- 其中包含两个文件,cora.cites 和 cora.content
- cora.cites
- 每行都由两个编号组成。
- 例如:“论文1 论文2” 表示 论文2 -> 论文1,就是说论文2引用了论文1
- 共有5429行,表示有5429条边
- cora.content
- 每行第一个条目为 论文编号,中间条目为二进制数据,其中1表示该单词在这篇论文中出现,最后一个条目为 论文类别(共7个类别)
- 中间条目也就是代表了论文的feature,为1433维
- 共有 2708行,表示有2708个节点
- cora.cites
- 其中包含两个文件,cora.cites 和 cora.content
- Cora数据集:
-
关键代码如下:
地址:https://github.com/tkipf/pygcn
class GraphConvolution(Module): def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() # 参数初始化 def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) # 随机化参数 if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input, adj): support = torch.mm(input, self.weight) # X * W output = torch.spmm(adj, support) # 稀疏矩阵的相乘,和torch.mm一样的效果 A*X*W if self.bias is not None: return output + self.bias else: return output class GCN(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout): super(GCN, self).__init__() self.gc1 = GraphConvolution(nfeat, nhid) # 构建第一层 GCN self.gc2 = GraphConvolution(nhid, nclass) # 构建第二层 GCN self.dropout = dropout def forward(self, x, adj): x = F.relu(self.gc1(x, adj)) # 第一层输出 x = F.dropout(x, self.dropout) x = self.gc2(x, adj) # 第二层输出 2708*7 return F.log_softmax(x, dim=1)
-
2. GraphSAGE
-
和GCN的不同点:
- GCN是直推式学习模型,也就是说只能在一张固定的图上进行表示学习,这种模型既不能够对那些在训练中未见的节点进行有效的向量表示,也不能够跨图进行节点表示学习。
- GraphSage是归纳式学习模型
-
核心思想:通过学习一个对邻居节点进行聚合表示的函数来产生中心节点的特征表示,而不是学习节点本身的embedding。
-
模型框架:
-
嵌入生成(前向传播)的算法:
-
细节部分:
-
邻居采样:
- 采用的是固定大小的采样方式。
- 策略:如果邻居数量多于固定值,那么随机不重复采样固定值的邻居;相反,如果邻居数量少于固定值,那么随机重复采样固定值的邻居。
- 实验效果最好的是:第一次采样的数目和第二次采样的数目乘积小于等于500
-
聚合方式
-
Mean aggregator:
-
公式:
hvk←σ(W⋅MEAN({hvk−1}∪{huk−1,∀u∈N(v)}) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right. hvk←σ(W⋅MEAN({hvk−1}∪{huk−1,∀u∈N(v)}) -
就是对中心节点的邻居节点的特征向量进行求均值操作,然后和中心节点特征向量进行拼接。
-
-
LSTM aggregator
- 将中心节点的邻居节点随机打乱作为输入序列,将所得向量表示与中心节点的向量表示分别经过非线性变换后拼接得到中心节点在该层的向量表示。LSTM本身是用于序列数据,而邻居节点没有明显的序列关系,因此输入到LSTM中的邻居节点需要随机打乱顺序。
-
Pooling aggregator
-
公式:
AGGREGATE kpool =max({σ(Wpool huik+b),∀ui∈N(v)}) \text { AGGREGATE }_{k}^{\text {pool }}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right) AGGREGATE kpool =max({σ(Wpool huik+b),∀ui∈N(v)}) -
先对中心节点的邻居节点表示向量进行一次非线性变换,然后对变换后的邻居表示向量进行池化操作(mean pooling或者max pooling),最后将pooling所得结果与中心节点的特征表示分别进行非线性变换,并将所得结果进行拼接或者相加从而得到中心节点在该层的向量表示。
-
-
-
-
-
损失函数:
- 为了在完全无监督的设置下学习表示,我们将基于图的损失函数应用于输出表示zu\mathbf{z}_uzu
JG(zu)=−log(σ(zu⊤zv))−Q⋅Evn∼Pn(v)log(σ(−zu⊤zvn)) J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right) JG(zu)=−log(σ(zu⊤zv))−Q⋅Evn∼Pn(v)log(σ(−zu⊤zvn))
其中zu\mathbf{z}_uzu是uuu输出的表示,vvv是在固定长度随机游走时在uuu附近共同出现的节点,PnP_nPn是一个负采样的分布,QQQ表示负样本的数量
-
代码地址:https://github.com/twjiang/graphSAGE-pytorch
3. GAT
-
简介:利用注意力机制来对邻居节点特征加权求和,从而聚合邻域信息,完全摆脱了图结构的束缚,是一种归纳式学习模型。
-
模型架构:
-
计算注意力系数:
- 公式:
αij=exp( LeakyReLU (a→T[Wh⃗i∥Wh⃗j]))∑k∈Niexp(LeakyReLU(a→T[Wh⃗i∥Wh⃗k])) \alpha_{i j}=\frac{\exp \left(\text { LeakyReLU }\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{k}\right]\right)\right)} αij=∑k∈Niexp(LeakyReLU(aT[Whi∥Whk]))exp( LeakyReLU (aT[Whi∥Whj]))
- 公式:
-
聚合节点表示:
- 公式:
h⃗i′=σ(∑j∈NiαijWh⃗j) \vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} \mathbf{W} \vec{h}_{j}\right) hi′=σ⎝⎛j∈Ni∑αijWhj⎠⎞
- 公式:
-
-
为了稳定自注意力的学习过程,GAT也采用多头注意力机制(multi-head attention)来捕获邻居节点在不同的方面对中心节点影响力的强弱,将K 个head分别提取的节点特征表示进行拼接作为最终的节点表示:
h⃗i′=∥k=1Kσ(∑j∈NiαijkWkh⃗j) \vec{h}_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right) hi′=∥k=1Kσ⎝⎛j∈Ni∑αijkWkhj⎠⎞ -
代码地址:https://github.com/Diego999/pyGAT
更多推荐
所有评论(0)