论文阅读:GraphSAGE
基本信息
- 论文题目:Inductive Representation Learning on Large Graphs
- 日期: Sep 2018
- 链接:https://arxiv.org/abs/1706.02216
背景
为了处理图结构中的数据,需要将图结构中的节点进行降维(embedding),将图的结构信息融入节点的表示中。而当时的主流降维方法难以处理以前没见过的节点。即便能处理,也需要重新进行train,代价高昂。特别是对于很大的图,这更难以实现。
主要思想
一个节点的表示会受到邻居节点的影响。
因此,对于每一个节点,先将其距离为1的节点进行信息聚合。然后将聚合的结果用于影响原始节点。
然后依次是距离为2、3、4…的节点。
最终,每个节点的表示都含有整个图的信息。
实现方法
有两大类可学习参数:矩阵、聚合函数中的参数
note: 对于不同的k,有不同的参数
计算方式
对于给定的节点,其原始的特征为
不难得到递推式:
为什么每次都只融合一阶邻居的信息就能得到全局信息?对于每个节点,每次特征融合都会将自身的信息流入邻居。经过k次操作,就能将自身的特征流入k hop以内的节点。
loss 函数
let :
损失由两部分构成:
正样本损失(1)
我们希望节点v与其邻接节点非常相似,即越大越好。
负样本损失(2)
我们希望节点v与非邻接节点(负样本)的相似度越小越好,即越小越好。
但实际上不可能直接计算所有负样本(计算量过大),因此可以使用蒙特卡洛方法,随机采样Q个负样本,认为其均值能代表这个期望。
为什么要将负样本损失乘上Q?负样本本身得到的期望是比较小的,乘上Q能放大这个差异。等效于。
聚合函数的选择
由于一个节点的邻居是没有顺序的,因此,聚合函数必须具有轮换不变性。
Mean
即对邻接节点取均值
LSTM
值得注意的是,LSTM并不满足轮换对称性(输入的是有序序列)。用在此处仅仅是为了观察LSTM是否有足够的特征提取能力。
Pooling
其中f可以是任意复杂的网络,但在此,可以先用一个简单的线性层看看效果:
实验结果
Name | Citation Unsup. F1 | Citation Sup. F1 | Reddit Unsup. F1 | Reddit Sup. F1 | PPI Unsup. F1 | PPI Sup. F1 |
---|---|---|---|---|---|---|
Random | 0.206 | 0.206 | 0.043 | 0.042 | 0.396 | 0.396 |
Raw features | 0.575 | 0.575 | 0.585 | 0.585 | 0.422 | 0.422 |
DeepWalk | 0.565 | 0.565 | 0.324 | 0.324 | — | — |
DeepWalk + features | 0.701 | 0.701 | 0.691 | 0.691 | — | — |
GraphSAGE-GCN | 0.742 | 0.772 | 0.908 | 0.930 | 0.465 | 0.500 |
GraphSAGE-mean | 0.778 | 0.820 | 0.897 | 0.950 | 0.486 | 0.598 |
GraphSAGE-LSTM | 0.788 | 0.832 | 0.907 | 0.954 | 0.482 | 0.612 |
GraphSAGE-pool | 0.798 | 0.839 | 0.892 | 0.948 | 0.502 | 0.600 |
% gain over feat. | 39% | 46% | 55% | 63% | 19% | 45% |
可以看到,LSTM与池化的效果都很好。并且池化的参数量更少,池化是更优的方法。
推理
在实际使用中,一般不需要将设为最长路径,甚至取值1,2都能取得很好的效果。并且,由于一个节点的度可能很大,因此,这个过程也是随机采样的。
在训练阶段,已经算出了矩阵以及聚合函数中的参数。
对于新的节点,通过递推式即能得到其特征表示。