不管是GraphSAGE,还是GCN,它们的核心思想其实与朴素的GNN思想一致,都是每个节点根据图的连接结构,通过聚合邻居信息来更新自身节点的信息,再把更新后的节点向量送入神经网络层做进一步的学习或预测。GraphSAGE与GCN的区别有两个:一个是结合节点自身信息的方式不同,第二个是泛化聚合邻居点时所采用的聚合函数不同。
GraphSAGE节点信息的更新过程主要分三步:
(1)聚合邻居节点的信息,这个聚合函数有三种;
(2)将聚合后的信息与自身的节点信息进行拼接;
(3)送入神经网络模型中进行映射,得到更新后的节点信息。
举例:图数据如下图所示,现在使用GraphSAGE对节点1进行更新。
![图片[1]-Graph Sample and Aggregate Network(GraphSAGE)-点头深度学习网站](https://venusai-1311496010.cos.ap-beijing.myqcloud.com/wp-content/upload-images/2024/03/20240307194121473.png)
(1) 聚合邻居节点: \(h_{N(1)}^1 \leftarrow \text{AGGREGATE}\left(h_3^0, h_4^0, h_5^0, h_6^0\right)\);
(2) 拼接自身信息: \(h_1^1 \leftarrow \text{CONCAT}\left(h_1^0, h_{\mathcal{N}(1)}^0\right)\);
(3) 经过神经网络映射: \(h_1^1 \leftarrow \sigma\left(\boldsymbol{W}^1 \cdot \text{CONCAT}\left(h_1^0, h_{N(1)}^0\right)\right)\) 。
假设聚合函数 AGGREGATE 是 Mean 函数, 则代数得:
$$
h_{N(1)}^1 \leftarrow \text{AGGREGATE}\left(h_3^0, h_4^0, h_5^0, h_6^0\right)=\text{Mean}([0.3,0.4],[0.2,0.2],[0.7,0.8],[0.5,0.6])
$$
另外,在这个计算流程中有两个地方需要额外注意。第一,GraphSAGE在聚合某节点邻居信息的时候,并不是聚合全部的邻居,而是聚合K个邻居,K是一个超参数。举例,在图9-45中,若K等于3,则在聚合节点1的周围邻居时,随机从节点3、4、5、6中选择3个进行聚合。若K等于5,则除了选择节点1的周围4个邻居以外,再重复从这4个邻居中抽样一个节点。这样做的好处是,当图数据非常庞大时,选取某节点的全部邻居做聚合是非常耗时耗力的,若只选择其中的K个邻居,可以更快的进行计算。超参数K本质上是计算精度和计算速度之间的一种权衡。
第二个需要注意的是GraphSAGE定义了三种不同的聚合函数:
(1) Mean: \(A G G=\sum_{u \in N(v)} \frac{h_u^{(l)}}{|N(v)|}\)
(2) Pool: \(A G G=\gamma\left(\left\{\mathrm{MLP}\left(h_u^{(l)}\right), \forall u \in N(v)\right\}\right)\)
(3) \(\text{LSTM}: \quad A G G=\text{LSTM}\left(\left[h_u^{(l)}, \forall u \in \pi(N(v))\right]\right)\)
Mean操作就是简单的对节点的邻居信息做平均。Pool操作就是先把节点的邻居节点向量送入一个MLP中,对MLP的输出结果做 γ 操作得到聚合后的节点向量,这个 γ 就是池化的算子,可以是mean,也可以是max,分别对应平均池化和最大池化,它们的实际使用效果都差不多。至于第三种LSTM的聚合方式与第二种Pool聚合类似,区别在于把MLP换成了LSTM模型。
通过这些设计,GraphSAGE提供了一种灵活且高效的方法来学习图数据中节点的表示,特别适用于处理大规模和动态变化的图。
暂无评论内容