从零实现AG_NEWS新闻分类PyTorchKeras避坑实战指南你是否曾在网上搜索新闻分类代码时发现要么环境配置不全要么预处理步骤缺失甚至模型根本无法运行本文将带你完整复现AG_NEWS新闻分类任务从数据集下载到模型测试每个环节都包含可执行的代码片段和关键避坑点。不同于零散的教程这里提供的是一套开箱即用的解决方案。1. 环境准备与数据获取首先确保已安装Python 3.8环境建议使用conda创建虚拟环境conda create -n news_classify python3.8 conda activate news_classify安装必要的依赖库pip install torch keras pandas numpy scikit-learnAG_NEWS数据集包含四个文件classes.txt4个新闻类别标签train.csv120,000条训练数据test.csv7,600条测试数据常见问题原始数据集中的标签从1开始编号1-4而PyTorch的交叉熵损失期望从0开始0-3。我们会在预处理阶段进行修正。2. 数据预处理全流程2.1 数据加载与清洗使用pandas读取CSV文件时需要注意原始数据没有表头import pandas as pd def load_agnews(filepath): df pd.read_csv(filepath, headerNone) texts [] labels [] for _, row in df.iterrows(): # 合并标题和内容 text f{row[1]} {row[2]}.lower() # 统一转为小写 label row[0] - 1 # 关键步骤标签减1 texts.append(text) labels.append(label) return texts, labels2.2 文本向量化处理结合Keras的Tokenizer和PyTorch的数据加载from keras.preprocessing.text import Tokenizer from keras.utils import pad_sequences from torch.utils.data import Dataset, DataLoader import torch class AGNewsDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_len64): self.sequences tokenizer.texts_to_sequences(texts) self.padded pad_sequences(self.sequences, maxlenmax_len) self.labels labels def __len__(self): return len(self.labels) def __getitem__(self, idx): return torch.LongTensor(self.padded[idx]), torch.tensor(self.labels[idx])创建词汇表的技巧tokenizer Tokenizer(oov_tokenUNK) tokenizer.fit_on_texts(train_texts test_texts) # 合并训练测试集构建词汇表 vocab_size len(tokenizer.word_index) 1 # 加1保留0给padding3. 模型架构设计3.1 自定义文本分类模型使用PyTorch实现带Embedding层的文本分类器import torch.nn as nn import torch.nn.functional as F class NewsClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.fc nn.Linear(embed_dim, num_classes) self.init_weights() def init_weights(self): initrange 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, x): embedded self.embedding(x) # [batch, seq_len, embed_dim] pooled F.avg_pool1d(embedded.transpose(1, 2), kernel_sizeembedded.size(1)) return self.fc(pooled.squeeze(2))3.2 批处理函数实现自定义collate_fn处理变长序列def collate_batch(batch): texts, labels zip(*batch) texts torch.stack(texts) labels torch.tensor(labels) return texts, labels4. 训练与评估实战4.1 训练循环配置设置优化器和学习率调度from torch.optim import SGD from torch.optim.lr_scheduler import StepLR model NewsClassifier(vocab_size, 128, 4).to(device) criterion nn.CrossEntropyLoss() optimizer SGD(model.parameters(), lr4.0) scheduler StepLR(optimizer, 1, gamma0.9)4.2 训练与验证函数完整训练流程实现def train_epoch(model, train_loader, optimizer, criterion, device): model.train() total_loss, total_acc 0, 0 for texts, labels in train_loader: texts, labels texts.to(device), labels.to(device) optimizer.zero_grad() outputs model(texts) loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() total_acc (outputs.argmax(1) labels).sum().item() return total_loss / len(train_loader.dataset), total_acc / len(train_loader.dataset)验证函数需要注意关闭梯度计算def evaluate(model, data_loader, criterion, device): model.eval() total_loss, total_acc 0, 0 with torch.no_grad(): for texts, labels in data_loader: texts, labels texts.to(device), labels.to(device) outputs model(texts) loss criterion(outputs, labels) total_loss loss.item() total_acc (outputs.argmax(1) labels).sum().item() return total_loss / len(data_loader.dataset), total_acc / len(data_loader.dataset)5. 模型优化与部署5.1 超参数调优建议经过多次实验验证的有效参数组合参数推荐值说明学习率4.0初始学习率batch_size32平衡内存和性能embed_dim128词向量维度max_len64文本截断长度epochs20训练轮数5.2 模型保存与加载训练完成后保存最佳模型torch.save({ model_state_dict: model.state_dict(), tokenizer: tokenizer, vocab_size: vocab_size }, ag_news_classifier.pth)加载模型进行预测def predict(text, model, tokenizer, max_len64): sequence tokenizer.texts_to_sequences([text]) padded pad_sequences(sequence, maxlenmax_len) tensor torch.LongTensor(padded).to(device) with torch.no_grad(): output model(tensor) pred output.argmax(1).item() return pred在实际项目中这个模型可以达到约90%的测试准确率。值得注意的是当遇到OOV未登录词时由于我们在Tokenizer中设置了oov_token模型仍能进行合理预测。