用DGL和PyTorch复现异构图注意力网络HAN:从IMDB电影分类到DBLP学者研究领域预测
用DGL和PyTorch实战异构图注意力网络HAN从电影分类到学术预测在机器学习领域图神经网络(GNN)正以前所未有的速度改变着我们处理结构化数据的方式。而异构图注意力网络(HAN)作为这一领域的重要突破将注意力机制与异构图分析完美结合为复杂关系数据的建模提供了全新思路。本文将带您深入HAN的实现细节使用DGL和PyTorch框架从IMDB电影分类到DBLP学者研究领域预测手把手完成整个工程实践流程。1. 环境准备与数据加载1.1 安装必要依赖在开始之前我们需要配置好开发环境。推荐使用Python 3.8和CUDA 11.3如果使用GPU加速pip install torch1.12.0cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html pip install scikit-learn pandas numpy tqdm提示如果使用CPU版本可以去掉cu113后缀。DGL的版本需要与PyTorch版本匹配。1.2 数据集概览HAN论文中使用了三个经典异构网络数据集数据集顶点类型边关系元路径任务IMDB电影(M), 演员(A), 导演(D)M-A, M-DMAM, MDM电影类型分类DBLP论文(P), 作者(A), 会议(C), 关键词(T)P-A, P-C, P-TAPA, APCPA, APTPA学者领域预测ACM论文(P), 作者(A), 主题(S)P-A, P-SPAP, PSP论文分类我们将以IMDB和DBLP为例展示完整的数据处理流程。2. 数据预处理与图构建2.1 IMDB数据处理IMDB数据集包含电影、演员和导演三类顶点我们需要先构建异构图结构import dgl import torch def build_imdb_graph(): # 假设已经加载了原始数据 num_movies 5000 num_actors 8000 num_directors 2000 # 构建异构图 graph_data { (movie, ma, actor): (torch.tensor([...]), torch.tensor([...])), (movie, md, director): (torch.tensor([...]), torch.tensor([...])), (actor, am, movie): (torch.tensor([...]), torch.tensor([...])), (director, dm, movie): (torch.tensor([...]), torch.tensor([...])) } g dgl.heterograph(graph_data) # 添加特征 g.nodes[movie].data[feat] torch.randn(num_movies, 100) # 电影特征 g.nodes[actor].data[feat] torch.randn(num_actors, 100) # 演员特征 g.nodes[director].data[feat] torch.randn(num_directors, 100) # 导演特征 return g2.2 元路径邻居提取HAN的核心是基于元路径的邻居聚合。以IMDB的MAM(电影-演员-电影)元路径为例def extract_metapath_neighbors(g, metapath): # 使用DGL的metapath_reachable_graph函数 meta_g dgl.metapath_reachable_graph(g, metapath) return meta_g # 定义IMDB的元路径 mam [movie, ma, actor, am, movie] mdm [movie, md, director, dm, movie] # 提取元路径子图 mam_g extract_metapath_neighbors(g, mam) mdm_g extract_metapath_neighbors(g, mdm)3. HAN模型实现3.1 顶点层次注意力顶点层次注意力与GAT类似但需要处理不同类型顶点的特征import torch.nn as nn import torch.nn.functional as F class NodeLevelAttention(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(NodeLevelAttention, self).__init__() self.num_heads num_heads self.fc nn.Linear(in_dim, out_dim * num_heads, biasFalse) self.attn_fc nn.Linear(2 * out_dim, 1, biasFalse) self.reset_parameters() def reset_parameters(self): gain nn.init.calculate_gain(relu) nn.init.xavier_normal_(self.fc.weight, gaingain) nn.init.xavier_normal_(self.attn_fc.weight, gaingain) def forward(self, feat, adj): # feat: (N, in_dim) # adj: 稀疏邻接矩阵 h self.fc(feat).view(-1, self.num_heads, self.out_dim) # (N, K, out_dim) # 准备注意力计算 N h.size(0) a_input torch.cat([h.repeat(1, 1, N).view(N, self.num_heads, N, -1), h.repeat(N, 1, 1).view(N, self.num_heads, N, -1)], dim-1) # 计算注意力系数 e F.leaky_relu(self.attn_fc(a_input).squeeze(-1), negative_slope0.2) # 应用邻接矩阵掩码 zero_vec -9e15 * torch.ones_like(e) attention torch.where(adj 0, e, zero_vec) attention F.softmax(attention, dim2) # 多头注意力聚合 h_prime torch.matmul(attention, h) return h_prime.mean(dim1) # 平均多头结果3.2 语义层次注意力语义层次注意力负责融合不同元路径的信息class SemanticLevelAttention(nn.Module): def __init__(self, in_dim, hidden_dim128): super(SemanticLevelAttention, self).__init__() self.project nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1, biasFalse) ) def forward(self, z): # z: (M, N, D) M是元路径数量N是节点数D是特征维度 w self.project(z).mean(1) # (M, 1) beta torch.softmax(w, dim0) # (M, 1) beta beta.expand((z.size(0),) z.size()[1:]) # (M, N, 1) return (beta * z).sum(0) # (N, D)3.3 完整HAN模型将两个注意力机制组合成完整模型class HAN(nn.Module): def __init__(self, num_metapaths, in_dim, hidden_dim, out_dim, num_heads, dropout): super(HAN, self).__init__() self.node_attentions nn.ModuleList() for _ in range(num_metapaths): self.node_attentions.append( NodeLevelAttention(in_dim, hidden_dim, num_heads)) self.semantic_attention SemanticLevelAttention(hidden_dim) self.dropout nn.Dropout(dropout) # 分类层 self.classify nn.Linear(hidden_dim, out_dim) def forward(self, g_list, h_list): # g_list: 元路径子图列表 # h_list: 对应特征列表 semantic_embeddings [] for g, h, node_att in zip(g_list, h_list, self.node_attentions): semantic_embeddings.append(node_att(h, g).flatten(1)) semantic_embeddings torch.stack(semantic_embeddings, dim0) # (M, N, D) h self.semantic_attention(semantic_embeddings) h self.dropout(h) return self.classify(h)4. 模型训练与评估4.1 训练流程实现def train(model, g_list, features, labels, train_mask, val_mask, epochs100): optimizer torch.optim.Adam(model.parameters(), lr0.005, weight_decay0.001) loss_fn nn.CrossEntropyLoss() for epoch in range(epochs): model.train() logits model(g_list, features) loss loss_fn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() # 验证集评估 acc evaluate(model, g_list, features, labels, val_mask) print(fEpoch {epoch}: Loss {loss.item():.4f}, Acc {acc:.4f}) return model def evaluate(model, g_list, features, labels, mask): model.eval() with torch.no_grad(): logits model(g_list, features) pred logits.argmax(1)[mask] acc (pred labels[mask]).float().mean() return acc4.2 DBLP数据集实战DBLP数据集处理略有不同需要特别注意元路径定义def prepare_dblp(): # APA: 作者-论文-作者 apa [author, ap, paper, pa, author] # APCPA: 作者-论文-会议-论文-作者 apcpa [author, ap, paper, pc, conference, cp, paper, pa, author] # APTPA: 作者-论文-关键词-论文-作者 aptpa [author, ap, paper, pt, term, tp, paper, pa, author] # 构建元路径子图 apa_g extract_metapath_neighbors(dblp_g, apa) apcpa_g extract_metapath_neighbors(dblp_g, apcpa) aptpa_g extract_metapath_neighbors(dblp_g, aptpa) return [apa_g, apcpa_g, aptpa_g]4.3 超参数调优经验在实际项目中我们发现以下调优策略效果显著学习率0.001-0.01范围内表现较好过大容易震荡注意力头数8头通常足够继续增加收益递减Dropout0.5-0.7有助于防止过拟合嵌入维度64-256之间根据数据集大小选择注意DBLP数据集中APCPA元路径通常获得最高注意力权重这与论文结论一致说明会议信息对学者领域预测最重要。5. 高级技巧与性能优化5.1 邻居采样策略对于大规模图可以使用邻居采样提高训练效率from dgl.dataloading import MultiLayerNeighborSampler sampler MultiLayerNeighborSampler([15, 10]) # 两层采样每层采样15和10个邻居 dataloader dgl.dataloading.NodeDataLoader( g, train_nids, sampler, batch_size1024, shuffleTrue, drop_lastFalse )5.2 混合精度训练使用AMP(自动混合精度)加速训练from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for epoch in range(epochs): optimizer.zero_grad() with autocast(): logits model(g_list, features) loss loss_fn(logits[train_mask], labels[train_mask]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5.3 可视化注意力权重理解模型关注的重点def visualize_attention(model, g_list, node_idx): # 获取顶点层次注意力 node_attentions [] for i, att in enumerate(model.node_attentions): att_weights get_attention_weights(att, g_list[i]) node_attentions.append(att_weights[node_idx]) # 获取语义层次注意力 semantic_weights model.semantic_attention.project.weight # 绘制热力图...6. 实际应用中的挑战与解决方案6.1 数据不平衡问题在IMDB数据集中电影类型分布可能不均衡类型样本数处理策略动作1200类别权重喜剧3000欠采样剧情800过采样在损失函数中添加类别权重class_weight torch.tensor([1.0, 0.4, 1.5]) # 与样本数成反比 loss_fn nn.CrossEntropyLoss(weightclass_weight)6.2 新节点冷启动问题对于图中新加入的节点可以采用以下策略特征传播利用邻居特征生成初始嵌入元学习在小样本上微调模型浅层嵌入结合DeepWalk等传统方法6.3 模型解释性提升通过分析注意力权重我们可以识别重要的元路径如DBLP中的APCPA发现关键邻居节点如对电影分类最重要的关联电影验证业务假设如导演风格对电影类型的影响# 获取DBLP数据集中某学者的重要合作者 author_idx 42 # 目标学者 apa_attention get_attention_weights(model.node_attentions[0], apa_g) top_collaborators torch.topk(apa_attention[author_idx], k5)