图神经网络-GNN
本文最后更新于 2025年2月14日 下午
图神经网络(GNN)数学公式解析
图神经网络(Graph Neural Networks, GNN)是一类专门用于处理图结构数据的深度学习模型。GNN 的核心思想是信息传递(Message Passing):通过迭代地在节点之间传递和聚合邻居的信息,逐步更新节点的表示,最终用于节点分类、链接预测或图分类等任务。
1. 图与初始表示
设图为:
$$
\mathcal{G} = (\mathcal{V}, \mathcal{E})
$$
其中 $\mathcal{V}$ 是节点集合,$\mathcal{E}$ 是边集合。对于每个节点 $v \in \mathcal{V}$,我们有一个初始特征向量:
$$
\mathbf{x}_v \in \mathbb{R}^d
$$
节点的初始隐藏状态:
$$
\mathbf{h}_v^{(0)} = \mathbf{x}_v
$$
2. 信息传递机制
GNN 的核心步骤是信息传递(Message Passing),一般在每一层(或迭代步)进行。第 $k$ 层时,每个节点 $v$ 的隐藏状态 $\mathbf{h}_v^{(k)}$ 根据自身以及邻居节点的信息更新。一般来说,该过程可以拆分为三个步骤:
消息计算(Message)
$$ \mathbf{m}_{uv}^{(k)} = \text{MESSAGE}^{(k)}\big( \mathbf{h}_u^{(k-1)}, \mathbf{h}_v^{(k-1)}, \mathbf{e}_{uv} \big) $$
对于每条边 $(u, v) \in \mathcal{E}$,计算从节点 $u$ 到节点 $v$ 的消息:其中 $\mathbf{e}_{uv}$ 表示边的特征(如果有)。
消息聚合(Aggregate)
$$ \mathbf{m}_v^{(k)} = \text{AGGREGATE}^{(k)}\left( \left\{ \mathbf{m}_{uv}^{(k)} : u \in \mathcal{N}(v) \right\} \right) $$
将节点 $v$ 所有来自邻居的消息聚合成一个综合信息:这里 $\mathcal{N}(v)$ 表示节点 $v$ 的邻居集合。常用的聚合函数有求和(sum)、平均(mean)和最大值(max)。
状态更新(Update)
根据当前节点的表示和聚合的信息更新节点状态:
$$
\mathbf{h}_v^{(k)} = \text{UPDATE}^{(k)}\left( \mathbf{h}_v^{(k-1)}, \mathbf{m}_v^{(k)} \right)
$$
综合上述步骤,节点 $v$ 的更新可以简写为:
$$ \mathbf{h}_v^{(k)} = \text{UPDATE}^{(k)}\left( \mathbf{h}_v^{(k-1)},\, \text{AGGREGATE}^{(k)}\Big( \big\{ \text{MESSAGE}^{(k)}\big( \mathbf{h}_u^{(k-1)},\, \mathbf{h}_v^{(k-1)},\, \mathbf{e}_{uv} \big) : u \in \mathcal{N}(v) \big\} \Big) \right) $$经过 $K$ 层的迭代后,每个节点获得了包含更丰富结构信息的表示 $\mathbf{h}_v^{(K)}$。
3. Readout 层
在 GNN 任务中,我们通常面临两种主要的预测任务:
- 节点级任务(Node-Level Tasks):例如节点分类、节点回归,直接使用最终的节点表示 $\mathbf{h}_v^{(K)}$ 作为特征输入到后续的分类器或回归模型。
- 图级任务(Graph-Level Tasks):例如分子分类、社交网络分析等,需要将整个图的信息汇总成一个固定长度的向量 $\mathbf{y}$,然后进行分类或回归。这一过程被称为 Readout。
1. Readout 的数学定义
Readout 层的目标是将所有节点的最终表示聚合成一个全局图级表示 $ \mathbf{y}$:
$$ \mathbf{y} = \text{READOUT}\Big( \big\{ \mathbf{h}_v^{(K)} : v \in \mathcal{V} \big\} \Big) $$其中,READOUT 函数需要满足 排列不变性(Permutation Invariance),即不受节点顺序的影响。这与 GNN 本身的特性一致,因为图的结构不依赖于节点的排列顺序。
2. 常见的 Readout 方法
(1) 全局池化(Global Pooling)
最常见的 Readout 操作是使用池化(Pooling)方法对所有节点表示进行聚合,主要包括:
- 求和池化(Sum Pooling)
- 平均池化(Mean Pooling)
- 最大池化(Max Pooling)
求和池化
$$
\mathbf{y} = \sum_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
$$
这种方法适用于节点个数不同的图,并且保留了所有节点的信息。但是,如果节点数目变化较大,可能导致数值尺度不稳定。平均池化
$$
\mathbf{y} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
$$
这种方法可以缓解求和池化的尺度问题,使得不同大小的图得到的向量具有相似的尺度。最大池化
$$
\mathbf{y} = \max_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
$$
这种方法选取每个维度上最大的值,适用于强调局部重要性信息的场景,但可能会丢失部分全局信息。
(2) 注意力加权池化(Attention Pooling)
全局池化方法对所有节点一视同仁,但在实际应用中,某些节点比其他节点更重要。因此,我们可以引入注意力机制,为不同节点分配不同的权重:
$$ \alpha_v = \frac{\exp\left( \mathbf{w}^T \mathbf{h}_v^{(K)} \right)}{\sum_{u \in \mathcal{V}} \exp\left( \mathbf{w}^T \mathbf{h}_u^{(K)} \right)} $$$$
\mathbf{y} = \sum_{v \in \mathcal{V}} \alpha_v \mathbf{h}_v^{(K)}
$$
其中:
- $\mathbf{w}$ 是一个可学习的向量参数。
- $\alpha_v$ 是归一化的注意力权重(类似于 softmax)。
- 通过训练,模型可以自动学习哪些节点的重要性较高。
(3) 递归池化(Set2Set)
Set2Set(Vinyals et al., 2015)是一种用于集合数据的聚合方法,基于递归神经网络(RNN)和注意力机制:
$$ \mathbf{q}_t = \text{LSTM}(\mathbf{q}_{t-1}) $$ $$ \mathbf{y}_t = \sum_{v \in \mathcal{V}} \alpha_v^{(t)} \mathbf{h}_v^{(K)} $$ $$ \mathbf{y} = \text{Concat}(\mathbf{y}_T, \mathbf{y}_{T-1}, \dots) $$其中:
- $\mathbf{q}_t$ 是一个动态查询向量。
- LSTM 用于动态更新 $\mathbf{q}_t$。
- 通过多步计算,Set2Set 可以捕获更复杂的全局信息。
3. Readout 在不同任务中的应用
任务类型 | Readout 方法 | 适用场景 |
---|---|---|
节点分类 | 无需 Readout | 直接用 $\mathbf{h}_v^{(K)}$ |
图分类 | Sum, Mean, Max Pooling | 一般场景 |
大规模图分类 | Attention Pooling | 关键节点影响较大 |
复杂结构图 | Set2Set, Transformer Readout | 需要更复杂的全局信息 |