图神经网络-GNN

本文最后更新于 2025年2月14日 下午

图神经网络(GNN)数学公式解析

图神经网络(Graph Neural Networks, GNN)是一类专门用于处理图结构数据的深度学习模型。GNN 的核心思想是信息传递(Message Passing):通过迭代地在节点之间传递和聚合邻居的信息,逐步更新节点的表示,最终用于节点分类、链接预测或图分类等任务。


1. 图与初始表示

设图为:
$$
\mathcal{G} = (\mathcal{V}, \mathcal{E})
$$
其中 $\mathcal{V}$ 是节点集合,$\mathcal{E}$ 是边集合。对于每个节点 $v \in \mathcal{V}$,我们有一个初始特征向量:
$$
\mathbf{x}_v \in \mathbb{R}^d
$$
节点的初始隐藏状态:
$$
\mathbf{h}_v^{(0)} = \mathbf{x}_v
$$


2. 信息传递机制

GNN 的核心步骤是信息传递(Message Passing),一般在每一层(或迭代步)进行。第 $k$ 层时,每个节点 $v$ 的隐藏状态 $\mathbf{h}_v^{(k)}$ 根据自身以及邻居节点的信息更新。一般来说,该过程可以拆分为三个步骤:

  1. 消息计算(Message)
    对于每条边 $(u, v) \in \mathcal{E}$,计算从节点 $u$ 到节点 $v$ 的消息:

    $$ \mathbf{m}_{uv}^{(k)} = \text{MESSAGE}^{(k)}\big( \mathbf{h}_u^{(k-1)}, \mathbf{h}_v^{(k-1)}, \mathbf{e}_{uv} \big) $$

    其中 $\mathbf{e}_{uv}$ 表示边的特征(如果有)。

  2. 消息聚合(Aggregate)
    将节点 $v$ 所有来自邻居的消息聚合成一个综合信息:

    $$ \mathbf{m}_v^{(k)} = \text{AGGREGATE}^{(k)}\left( \left\{ \mathbf{m}_{uv}^{(k)} : u \in \mathcal{N}(v) \right\} \right) $$

    这里 $\mathcal{N}(v)$ 表示节点 $v$ 的邻居集合。常用的聚合函数有求和(sum)、平均(mean)和最大值(max)。

  3. 状态更新(Update)
    根据当前节点的表示和聚合的信息更新节点状态:
    $$
    \mathbf{h}_v^{(k)} = \text{UPDATE}^{(k)}\left( \mathbf{h}_v^{(k-1)}, \mathbf{m}_v^{(k)} \right)
    $$

综合上述步骤,节点 $v$ 的更新可以简写为:

$$ \mathbf{h}_v^{(k)} = \text{UPDATE}^{(k)}\left( \mathbf{h}_v^{(k-1)},\, \text{AGGREGATE}^{(k)}\Big( \big\{ \text{MESSAGE}^{(k)}\big( \mathbf{h}_u^{(k-1)},\, \mathbf{h}_v^{(k-1)},\, \mathbf{e}_{uv} \big) : u \in \mathcal{N}(v) \big\} \Big) \right) $$

经过 $K$ 层的迭代后,每个节点获得了包含更丰富结构信息的表示 $\mathbf{h}_v^{(K)}$。

3. Readout 层

在 GNN 任务中,我们通常面临两种主要的预测任务:

  1. 节点级任务(Node-Level Tasks):例如节点分类、节点回归,直接使用最终的节点表示 $\mathbf{h}_v^{(K)}$ 作为特征输入到后续的分类器或回归模型。
  2. 图级任务(Graph-Level Tasks):例如分子分类、社交网络分析等,需要将整个图的信息汇总成一个固定长度的向量 $\mathbf{y}$,然后进行分类或回归。这一过程被称为 Readout

1. Readout 的数学定义

Readout 层的目标是将所有节点的最终表示聚合成一个全局图级表示 $ \mathbf{y}$:

$$ \mathbf{y} = \text{READOUT}\Big( \big\{ \mathbf{h}_v^{(K)} : v \in \mathcal{V} \big\} \Big) $$

其中,READOUT 函数需要满足 排列不变性(Permutation Invariance),即不受节点顺序的影响。这与 GNN 本身的特性一致,因为图的结构不依赖于节点的排列顺序。


2. 常见的 Readout 方法

(1) 全局池化(Global Pooling)

最常见的 Readout 操作是使用池化(Pooling)方法对所有节点表示进行聚合,主要包括:

  • 求和池化(Sum Pooling)
  • 平均池化(Mean Pooling)
  • 最大池化(Max Pooling)
  1. 求和池化
    $$
    \mathbf{y} = \sum_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
    $$
    这种方法适用于节点个数不同的图,并且保留了所有节点的信息。但是,如果节点数目变化较大,可能导致数值尺度不稳定。

  2. 平均池化
    $$
    \mathbf{y} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
    $$
    这种方法可以缓解求和池化的尺度问题,使得不同大小的图得到的向量具有相似的尺度。

  3. 最大池化
    $$
    \mathbf{y} = \max_{v \in \mathcal{V}} \mathbf{h}_v^{(K)}
    $$
    这种方法选取每个维度上最大的值,适用于强调局部重要性信息的场景,但可能会丢失部分全局信息。


(2) 注意力加权池化(Attention Pooling)

全局池化方法对所有节点一视同仁,但在实际应用中,某些节点比其他节点更重要。因此,我们可以引入注意力机制,为不同节点分配不同的权重:

$$ \alpha_v = \frac{\exp\left( \mathbf{w}^T \mathbf{h}_v^{(K)} \right)}{\sum_{u \in \mathcal{V}} \exp\left( \mathbf{w}^T \mathbf{h}_u^{(K)} \right)} $$

$$
\mathbf{y} = \sum_{v \in \mathcal{V}} \alpha_v \mathbf{h}_v^{(K)}
$$

其中:

  • $\mathbf{w}$ 是一个可学习的向量参数。
  • $\alpha_v$ 是归一化的注意力权重(类似于 softmax)。
  • 通过训练,模型可以自动学习哪些节点的重要性较高。

(3) 递归池化(Set2Set)

Set2Set(Vinyals et al., 2015)是一种用于集合数据的聚合方法,基于递归神经网络(RNN)和注意力机制:

$$ \mathbf{q}_t = \text{LSTM}(\mathbf{q}_{t-1}) $$ $$ \mathbf{y}_t = \sum_{v \in \mathcal{V}} \alpha_v^{(t)} \mathbf{h}_v^{(K)} $$ $$ \mathbf{y} = \text{Concat}(\mathbf{y}_T, \mathbf{y}_{T-1}, \dots) $$

其中:

  • $\mathbf{q}_t$ 是一个动态查询向量。
  • LSTM 用于动态更新 $\mathbf{q}_t$。
  • 通过多步计算,Set2Set 可以捕获更复杂的全局信息。

3. Readout 在不同任务中的应用

任务类型Readout 方法适用场景
节点分类无需 Readout直接用 $\mathbf{h}_v^{(K)}$
图分类Sum, Mean, Max Pooling一般场景
大规模图分类Attention Pooling关键节点影响较大
复杂结构图Set2Set, Transformer Readout需要更复杂的全局信息

图神经网络-GNN
https://jimes.cn/2025/02/14/GNN/
作者
Jimes
发布于
2025年2月14日
许可协议