GNN

【GNN】- 2.实战 QM7b 数据集

Posted by Orchid on March 8, 2025

数据集介绍

QM7b是一个常用的分子图数据集,包含7211个分子图。可以通过 from torch_geometric.datasets import QM7b 直接获取。

  • 图结构:每个分子图的节点代表原子,边代表原子之间的化学键。
  • 节点特征:在 torch_geometric.datasets 提供的 QM7b 中,Node feature 为 None。
  • 边特征:在 torch_geometric.datasets 提供的 QM7b 中,Edge feature 为一个一维特征。
  • 目标值:在 torch_geometric.datasets 提供的 QM7b 中,Target 是一个 14 维向量。所以这是一个回归任务

代码介绍

这一小节将详细的介绍每一步代码,及其背后的逻辑和思路。

  1. 引用库

    1
    2
    3
    4
    5
    6
    7
    8
    9
    
    import torch
    import torch.nn.functional as F
    from torch_geometric.datasets import QM7b
    from torch_geometric.loader import DataLoader
    from torch_geometric.nn import MessagePassing, global_add_pool
    from torch_geometric.utils import add_self_loops
    from torch.nn import Linear, Sequential, ReLU
    from torch.optim import Adam
    from torch.nn import MSELoss
    
  2. 确定 GNN 的训练设备(CPU 还是 GPU)

    1
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
  3. 加载数据集,并查看数据集相关参数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    dataset = QM7b(root='data/QM7b')	# 如果是第一次运行,将从官方下载数据集到 data/QM7b 路径下
       
    # Print dataset information
    print('Number of graphs:', len(dataset))
    print('Number of node features:', dataset.num_node_features)
    print('Number of edge features:', dataset.num_edge_features)
    print('Number of classes:', dataset.num_classes)
       
    # Print the first graph in the dataset
    data = dataset[0]
    print(data)
    print(f"Node features: {data.x}")
    print(f"Edge indices: {data.edge_index}")
    print(f"Edge features: {data.edge_attr}")
    print(f"Target: {data.y}")
    
  4. 数据集划分

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    # 数据集划分
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
       
    indices = torch.randperm(len(dataset))
    train_dataset = dataset[indices[:train_size]]
    val_dataset = dataset[indices[train_size:train_size + val_size]]
    test_dataset = dataset[indices[train_size + val_size:]]
       
    # 创建数据加载器
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
  5. 模型架构建立

    • 这里可能存在初学者看不懂的情况

    • 首先,EdgeFeatureGNN 继承了 MessagePassing,其 forwardmessageupdate 函数均是继承后的重构函数,其执行的内在顺序是:

      • forward 开启消息传递的过程
      • message生成所有的消息
      • update将所有的消息聚合,并返回更新后的节点信息

      上述过程对应 GNN 的消息传递过程。

    • data.batch

      tensor([0, 0, 0, 1, 1, 2, 2])

      这表示前 3 个节点属于第 0 张图,接下来的 2 个节点属于第 1 张图,最后的 2 个节点属于第 2 张图。

    • global_add_pool()

      将同一张图的向量相叠加

    • 同时因为数据集中,不存在节点特征,这里人工将节点特征设置为 hidden_dim 的零向量,以便后续的消息传播。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    
    # 自定义图卷积层(处理无节点特征的情况)
    class EdgeFeatureGNN(MessagePassing):
        def __init__(self, hidden_dim):
            super().__init__(aggr='sum')  # 消息聚合方式
               
            # 消息网络:处理边特征和邻居信息
            self.msg_mlp = Sequential(
                Linear(2 * hidden_dim + 1, hidden_dim),  # 输入:两个节点嵌入 + 边特征
                ReLU(),
                Linear(hidden_dim, hidden_dim)
            )
               
            # 节点更新网络
            self.node_update = Sequential(
                Linear(hidden_dim + hidden_dim, hidden_dim),  # 原嵌入 + 聚合消息
                ReLU()
            )
       
        def forward(self, edge_index, edge_attr, node_emb):
            """
            Args:
                edge_index: 边连接 [2, E]
                edge_attr: 边特征 [E, 1]
                node_emb: 可学习节点嵌入 [N, hidden_dim]
            """
            # 添加自环边(保留原始嵌入),自环边分配属性值默认值为 1.0
            edge_index, edge_attr = add_self_loops(
                edge_index, 
                edge_attr,
                num_nodes=node_emb.size(0)
            )
               
            # 开始消息传递
            return self.propagate(
                edge_index, 
                x=node_emb, 
                edge_attr=edge_attr
            )
       
        def message(self, x_i, x_j, edge_attr):
            """
            x_i: 目标节点嵌入 [E, hidden_dim]
            x_j: 源节点嵌入 [E, hidden_dim]
            edge_attr: 边特征 [E, 1]
            """
            # 将 edge_attr 从 [num_edges] 转换为 [num_edges, 1]
            edge_attr = edge_attr.unsqueeze(-1)  # 添加一个维度
            # 拼接源节点、目标节点和边特征
            message_input = torch.cat([x_i, x_j, edge_attr], dim=1)
            return self.msg_mlp(message_input)  # [E, hidden_dim]
       
        def update(self, aggr_out, x):
            """
            aggr_out: 聚合后的消息 [N, hidden_dim]
            x: 原始节点嵌入 [N, hidden_dim]
            """
            # 拼接原始嵌入和聚合消息
            update_input = torch.cat([x, aggr_out], dim=1)
            return self.node_update(update_input)  # [N, hidden_dim]
       
    # 完整模型架构
    class EdgeGNN(torch.nn.Module):
        def __init__(self, hidden_dim=64):
            super().__init__()
               
            # 初始化节点嵌入(解决无节点特征问题)
            self.hidden_dim = hidden_dim
               
            # 两个图卷积层
            self.conv1 = EdgeFeatureGNN(hidden_dim)
            self.conv2 = EdgeFeatureGNN(hidden_dim)
               
            # 输出层(预测14个目标)
            self.lin = Sequential(
                Linear(hidden_dim, hidden_dim * 2),
                ReLU(),
                Linear(hidden_dim * 2, 14)  # 输出14个回归值
            )
       
        def forward(self, data):
            # 初始化节点嵌入
            num_nodes = data.num_nodes
            x = torch.zeros(num_nodes, self.hidden_dim).to(device)
               
            # 第一层消息传递
            x = self.conv1(data.edge_index, data.edge_attr, x)
            x = F.relu(x)
               
            # 第二层消息传递
            x = self.conv2(data.edge_index, data.edge_attr, x)
            x = F.relu(x)
               
            # 全局池化(将节点特征聚合为图特征)
            graph_emb = global_add_pool(x, data.batch)  # [batch_size, hidden_dim]
               
            # 最终预测
            return self.lin(graph_emb)  # [batch_size, 14]
    
  6. 初始化模型、优化器和损失函数

    1
    2
    3
    
    model = EdgeGNN(hidden_dim=20).to(device)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = MSELoss().to(device)
    
  7. 定义训练函数、验证函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    
    # 训练函数
    def train():
        model.train()
        total_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)  # 前向传播
            loss = criterion(out, data.y.view(-1, 14))  # 适配形状
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
        return total_loss / len(train_dataset)
       
    # 验证函数
    def evaluate(loader):
        model.eval()
        total_loss = 0
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                out = model(data)
                total_loss += criterion(out, data.y.view(-1, 14)).item() * data.num_graphs
        return total_loss / len(loader.dataset)
    
  8. 训练

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    best_val_loss = float('inf')
    for epoch in range(100):
        train_loss = train()
        val_loss = evaluate(val_loader)
        print(f'Epoch: {epoch + 1:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
           
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model_qm7b.pth')
    
  9. 测试模型

    1
    2
    3
    
    model.load_state_dict(torch.load('best_model_qm7b.pth'))
    test_loss = evaluate(test_loader)
    print(f'Test Loss: {test_loss:.4f}')