基本信息

背景

为了处理图结构中的数据,需要将图结构中的节点进行降维(embedding),将图的结构信息融入节点的表示中。而当时的主流降维方法难以处理以前没见过的节点。即便能处理,也需要重新进行train,代价高昂。特别是对于很大的图,这更难以实现。

主要思想

一个节点的表示会受到邻居节点的影响。

因此,对于每一个节点,先将其距离为1的节点进行信息聚合。然后将聚合的结果用于影响原始节点。

然后依次是距离为2、3、4…的节点。

最终,每个节点的表示都含有整个图的信息。

实现方法

有两大类可学习参数:WkW^k矩阵、聚合函数Aggregatork\text{Aggregator}^k中的参数

note: 对于不同的k,有不同的参数

计算方式

对于给定的节点vv,其原始的特征为h0h^0

不难得到递推式:

hnk=Aggregator([huk1 for u in neighbors-of(v)])得到一阶邻接的聚合表示hvk=Activator(WkConcat(hvk1,hnk))将邻居的聚合信息与自己进行特征融合\begin{aligned} h^k_n&=\text{Aggregator}([h_u^{k-1}\text{ for u in neighbors-of(v)}]) \qquad \text{得到一阶邻接的聚合表示}\\ h^k_v&=\text{Activator}(W^k\cdot \text{Concat}(h^{k-1}_v,h^k_n))\qquad \text{将邻居的聚合信息与自己进行特征融合} \end{aligned}

为什么每次都只融合一阶邻居的信息就能得到全局信息?对于每个节点,每次特征融合都会将自身的信息流入邻居。经过k次操作,就能将自身的特征流入k hop以内的节点。

loss 函数

let zv=hvkz_v=h_v^k:

J(zv)=log(σ(zvTzu))(1)QEvnNegtive(v)log(σ(zvTzvn))(2)\begin{aligned} J(z_v)=&-\log(\sigma(z_v^T\cdot z_u))\qquad \cdots (1)\\ &-QE_{v_n\sim \text{Negtive}(v)}\log(\sigma(-z_v^T\cdot z_{v_n})) \qquad \cdots(2) \end{aligned}

损失由两部分构成:

正样本损失(1)

我们希望节点v与其邻接节点非常相似,即zvTzuz_v^T\cdot z_u越大越好。

负样本损失(2)

我们希望节点v与非邻接节点(负样本)的相似度越小越好,即zvTzvnz_v^T\cdot z_{v_n}越小越好。

但实际上不可能直接计算所有负样本(计算量过大),因此可以使用蒙特卡洛方法,随机采样Q个负样本,认为其均值能代表这个期望。

为什么要将负样本损失乘上Q?负样本本身得到的期望是比较小的,乘上Q能放大这个差异。等效于i=1Qlog(σ(zvTzvni))-\sum_{i=1}^Q \log(\sigma(-z_v^T\cdot z_{v_{ni}}))

聚合函数的选择

由于一个节点的邻居是没有顺序的,因此,聚合函数必须具有轮换不变性。

Mean

Aggregator=Mean\text{Aggregator}=\text{Mean}

即对邻接节点取均值

hnk=Mean([huk1 for u in neighbors-of(v)])h^k_n=\text{Mean}([h_u^{k-1}\text{ for u in neighbors-of(v)}])

LSTM

值得注意的是,LSTM并不满足轮换对称性(输入的是有序序列)。用在此处仅仅是为了观察LSTM是否有足够的特征提取能力。

Pooling

Aggregator()=max(f([h1,h2]T))\text{Aggregator}(\cdots)=\max(f([h_1,h_2\cdots]^T))

其中f可以是任意复杂的网络,但在此,可以先用一个简单的线性层看看效果:

Aggregator()=max(WT[h1,h2]T+b)\text{Aggregator}(\cdots)=\max(W^T[h_1,h_2\cdots]^T+b)

实验结果

Name Citation Unsup. F1 Citation Sup. F1 Reddit Unsup. F1 Reddit Sup. F1 PPI Unsup. F1 PPI Sup. F1
Random 0.206 0.206 0.043 0.042 0.396 0.396
Raw features 0.575 0.575 0.585 0.585 0.422 0.422
DeepWalk 0.565 0.565 0.324 0.324
DeepWalk + features 0.701 0.701 0.691 0.691
GraphSAGE-GCN 0.742 0.772 0.908 0.930 0.465 0.500
GraphSAGE-mean 0.778 0.820 0.897 0.950 0.486 0.598
GraphSAGE-LSTM 0.788 0.832 0.907 0.954 0.482 0.612
GraphSAGE-pool 0.798 0.839 0.892 0.948 0.502 0.600
% gain over feat. 39% 46% 55% 63% 19% 45%

可以看到,LSTM与池化的效果都很好。并且池化的参数量更少,池化是更优的方法。

推理

在实际使用中,一般不需要将kk设为最长路径,甚至取值1,2都能取得很好的效果。并且,由于一个节点的度可能很大,因此,这个过程也是随机采样的。

在训练阶段,已经算出了WkW^k矩阵以及聚合函数Aggregatork\text{Aggregator}^k中的参数。

对于新的节点,通过递推式即能得到其特征表示。