1. 项目概述当大模型实验撞上工程化瓶颈我们到底在解决什么问题“Scaling LLM Experimentation with SageMaker Pipelines and MLflow”——这个标题不是一句技术口号而是我在过去18个月里每天早上打开监控面板时看到的真实压力源。它直指当前大模型研发中最隐蔽、也最消耗团队战力的痛点实验爆炸experiment explosion。你可能刚跑完一个LoRA微调任务发现学习率0.0003比0.00025效果高0.8个点转头又想试试Qwen-1.5B换掉Llama-3-8B基座模型再叠加数据清洗策略A和B的组合接着同事发来PR说他用新的tokenization方式把上下文长度撑到了32K……这些都不是孤立动作而是一张指数级增长的实验网。我带的团队去年平均每人每周启动47次训练作业但其中只有不到12%的实验结果被真正记录、复现或用于决策。其余的要么参数配置散落在Slack消息里要么checkpoint文件名是“model_v3_final_really_final_20240521.pth”要么连用的是哪个commit hash都查不到。这就是SageMaker Pipelines MLflow组合要切的硬骨头把LLM实验从“手工作坊式试错”变成可版本化、可审计、可回滚、可协作的软件工程实践。它不承诺让你的模型指标一夜暴涨但它能确保当你在季度复盘会上被问到“上次那个提升2.3%的RAG优化方案到底是哪次实验跑出来的用的什么数据切片谁改的prompt模板”时你能直接点开MLflow UI3秒内给出带完整血缘图谱的答案。关键词很明确SageMaker Pipelines定义、调度、编排端到端ML工作流、MLflow跟踪实验、管理模型、部署服务、Scaling不是单次跑得快而是支撑百人团队、千级并发实验、万次版本迭代的可持续能力。它面向的不是单打独斗的研究员而是需要交付稳定AI能力的产品团队、MLOps工程师、以及被实验噪音淹没的算法负责人。如果你还在用Jupyter Notebook本地wandb手动scp模型文件的方式管理LLM实验这篇就是为你写的实操手册——不是概念科普而是我把踩过的坑、调过的参数、写废的5版pipeline定义脚本全摊开给你看。2. 整体架构设计与核心选型逻辑为什么是PipelineMLflow而不是其他组合2.1 不是所有“自动化”都能解决LLM实验的规模化问题先说结论单纯用Airflow调度训练脚本或者只用DVC管理数据版本甚至只靠Weights Biases做指标跟踪都无法独立承载LLM实验规模化的真实需求。我见过太多团队在工具选型上走弯路最后不是卡在某个环节就是各系统之间数据孤岛严重。比如用Airflow调度SageMaker训练任务它确实能跑起来但Airflow本身不理解“模型”是什么——它不会自动捕获训练输出的model.tar.gz路径不会关联这次训练用的MLflow Experiment ID更不会把评估报告生成为可交互的HTML artifact。结果就是调度是自动的但实验元数据是割裂的你依然得人工去S3找日志、去MLflow查指标、去ECR确认镜像版本三者之间没有自动链接。SageMaker Pipelines的核心价值在于它原生嵌入了AWS ML生态的血缘感知能力。当你定义一个TrainingStep时Pipeline不只是提交一个训练任务它会自动将该步骤的输入数据URI、输出模型URI、超参字典、甚至SageMaker Training Job的ARN全部作为结构化元数据注入到Pipeline Execution的Execution Graph中。这个Graph不是静态快照而是动态可查询的——你可以用boto3.client(sagemaker).list_pipeline_execution_steps()实时获取任意一次执行中每个步骤的状态、输入输出、耗时、失败原因。这解决了“实验过程不可见”的问题。而MLflow的价值则在于它补上了Pipeline缺失的语义层。SageMaker知道“这个步骤输出了一个模型”但不知道“这个模型是针对金融客服场景微调的Qwen-1.5B使用了包含127个拒答样本的对抗数据集评估时在Banking77测试集上F1达到89.2%”。MLflow的mlflow.start_run()会把所有这些业务语义信息连同指标、参数、代码版本、artifact如tokenizer.json、eval_report.json打包成一个逻辑完整的Run。更重要的是MLflow Tracking Server可以部署在VPC内支持细粒度RBAC权限控制这满足了企业对实验数据合规性的硬性要求——不是所有团队都愿意把模型评估指标上传到公网SaaS服务。所以Pipeline MLflow不是简单拼凑而是分工明确的协同体Pipeline负责“物理流程的确定性执行与可观测性”MLflow负责“逻辑实验的语义化表达与可追溯性”。它们通过一个轻量级胶水层即Pipeline中的ProcessingStep或TrainingStep内嵌的MLflow client调用连接。这个设计规避了两个常见陷阱一是避免了用Lambda函数做异步回调带来的状态同步复杂度二是绕开了在SageMaker Training Image中强行集成Airflow Client导致的镜像臃肿和升级困难。2.2 为什么不用SageMaker内置的Model RegistryMLflow Model Registry有何不可替代性这里必须澄清一个高频误解SageMaker有Model Registry为什么还要用MLflow的答案是Registry的定位不同解决的问题域也不同。SageMaker Model Registry本质是一个模型部署就绪中心Deployment-Ready Hub。它的核心字段是ModelPackageGroupName、ModelApprovalStatusApproved/Pending/Rejected、InferenceSpecification容器镜像、启动命令。它假设你已经有一个经过充分验证、符合上线标准的模型现在要把它推送到生产环境。它的强项是与SageMaker Endpoint、SageMaker Projects深度集成一键部署、A/B测试、影子测试都极其顺畅。而MLflow Model Registry是一个实验成果沉淀中心Experiment Outcome Repository。它的核心字段是Model Name、Version、StageNone/Staging/Production/Archived、Run ID反向链接到完整实验记录。它的设计哲学是“任何一次成功的实验无论是否上线都值得被命名、被归档、被比较”。举个真实案例我们曾为同一份客服对话数据同时运行了三个实验分支——Branch A用LoRA微调Branch B用QLoRA量化微调Branch C用Adapter模块。三者在验证集上指标接近88.1% vs 87.9% vs 88.3%但推理延迟差异巨大120ms vs 45ms vs 85ms。SageMaker Model Registry只会收录最终上线的那个比如Branch B而MLflow Registry则把三个都存为customer-service-qwen-lora、customer-service-qwen-qlora、customer-service-qwen-adapter三个Model并标记为Staging。当我们后续发现某类长尾问题如多轮转账确认在Branch B上表现更差时能立刻切回Branch A的v3版本做对比分析而无需重新训练——因为v3的所有输入数据、代码、评估报告都通过Run ID牢牢绑定。更关键的是MLflow Registry支持跨平台模型格式。我们的部分实验用PyTorch Lightning部分用Hugging Face Transformers还有少量用DeepSpeed。SageMaker Model Registry要求模型必须打包为model.tar.gz并符合特定目录结构code/、model/而MLflow则原生支持mlflow.pytorch、mlflow.transformers、mlflow.huggingface等flavor自动处理序列化/反序列化逻辑。这意味着同一个mlflow.log_model()调用在不同框架下生成的artifact都能被统一注册、统一加载、统一服务化。这种抽象层是SageMaker原生Registry目前不具备的。2.3 架构全景图数据流、控制流与元数据流如何交织整个系统的数据流向可以用三个平行但交织的“流”来理解数据流Data Flow原始数据S3://my-bucket/raw-data/ → Pipeline ProcessingStep清洗、分词、构造instruction格式→ 输出为S3://my-bucket/processed-data/{run_id}/ → TrainingStep读取该路径进行训练 → 输出模型至S3://my-bucket/models/{pipeline_exec_id}/{step_name}/。控制流Control Flow开发者提交Pipeline DefinitionJSON/YAML→ SageMaker Pipelines Service解析依赖关系 → 按DAG顺序触发各StepProcessing/Training/Transform→ 每个Step内部通过boto3.client(sagemaker)调用对应API → 执行完成返回状态与输出URI。元数据流Metadata Flow这是最容易被忽视却最核心的一环。它由两部分构成Pipeline元数据由SageMaker自动生成包括PipelineExecutionArn、StepName、StartTime、EndTime、InputParameters传入Pipeline的参数如--base-model-id、OutputParameters步骤输出如--model-uri。这些可通过describe_pipeline_execution()API实时查询。MLflow元数据由Pipeline Step内嵌的Python代码显式记录包括mlflow.set_experiment(llm-finetuning)、mlflow.log_param(lora_r, 8)、mlflow.log_metric(eval_f1, 0.883)、mlflow.log_artifact(eval_report.json)、mlflow.pytorch.log_model(model, model)。这些数据写入MLflow Tracking Server我们部署在EKS上后端PostgreSQL。这两股元数据流的交汇点就是Run ID的双向绑定。我们在Pipeline的每个Step开始时执行import mlflow mlflow.set_tracking_uri(http://mlflow-svc.mlflow.svc.cluster.local:5000) mlflow.set_experiment(llm-finetuning) # 关键用Pipeline Execution ID Step Name 生成唯一Run ID run_id f{pipeline_execution_id}_{step_name} mlflow.start_run(run_idrun_id)这样当我们在MLflow UI中查看这个Run时就能在Tags里看到pipeline_execution_arn: arn:aws:sagemaker:us-east-1:123456789012:pipeline-execution/abc123。反过来在SageMaker Console的Pipeline Execution详情页我们也在Step的OutputParameters里显式写入mlflow_run_id: abc123_train_step。这种双向锚定让“从Pipeline跳转到MLflow”和“从MLflow跳转回Pipeline”成为可能彻底打通了工程链路与实验链路。提示不要依赖MLflow自动生成的随机Run ID。在Pipeline环境中必须显式传入run_id参数。否则当Step因OOM重试时会创建多个Run导致指标混乱。我们曾因此误判过一次学习率衰减策略的效果花了两天才定位到是重试产生的幽灵Run污染了平均值。3. 核心细节解析与实操要点从零搭建可复现的LLM实验流水线3.1 环境准备最小可行镜像与依赖治理很多团队一上来就想构建一个“全能”镜像把PyTorch、Transformers、DeepSpeed、FlashAttention、vLLM全塞进去。结果是镜像体积超过8GB每次Pipeline更新都要等待15分钟拉取CI/CD流水线频繁超时。我的经验是为不同类型的LLM实验定义专用精简镜像。我们目前维护三个核心镜像镜像名称基础镜像核心依赖典型用途镜像大小llm-train-pytorchpytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtimetransformers4.38.2,datasets2.16.1,peft0.8.2LoRA/QLoRA微调3.2GBllm-infer-vllmvllm/vllm-cu118:0.3.2mlflow2.10.1,boto31.28.59批量推理、评估2.1GBllm-eval-metricspython:3.11-slim-bookwormscikit-learn1.3.0,evaluate0.4.0,mlflow2.10.1独立评估脚本840MB关键操作细节基础镜像选择严格匹配SageMaker Training Instance的CUDA版本。例如ml.p4d.24xlarge实例预装CUDA 11.8那么你的PyTorch镜像就必须是cuda11.8版本。我们曾因使用cuda12.1镜像导致torch.cuda.is_available()返回False错误日志里只有一行Failed to load library: libcudnn.so.8排查了6小时才发现是CUDA小版本不匹配。依赖版本锁定requirements.txt中必须使用精确指定版本禁用。特别是transformers和datasets小版本升级常带来tokenization行为变更。例如transformers4.37.0和4.38.0对|eot_id|特殊token的处理逻辑不同会导致同一份prompt在不同版本下生成结果不一致。我们在requirements.txt顶部加注释# DO NOT UPGRADE: Pinning critical for reproducibility。镜像构建优化使用多阶段构建Multi-stage Build分离构建依赖与运行时依赖。例如在llm-train-pytorch中第一阶段安装flash-attn需要ninja、cmake等编译工具第二阶段只COPY编译好的.so文件和pip install的纯Python包。这使最终镜像体积减少40%且无编译工具残留提升安全性。注意SageMaker Training Job默认以root用户运行但MLflow Tracking Server通常要求非root用户访问。因此在Dockerfile中必须添加USER 1001指令并确保/opt/ml目录权限对UID 1001可写。否则mlflow.log_artifact()会因权限拒绝而静默失败日志里只显示Permission denied: /opt/ml/output/artifacts非常难排查。3.2 Pipeline定义用Python SDK而非YAML掌控每一个执行细节SageMaker Pipelines支持两种定义方式Python SDK推荐和YAML DSL。我强烈建议全程使用Python SDK原因有三一是YAML无法表达动态逻辑如根据数据集大小自动调整per_device_train_batch_size二是Python SDK的类型提示type hints能提前捕获参数错误如把str类型的instance_type误传为int三是调试体验天壤之别——你可以在本地用pipeline.upsert()前打印出完整的pipeline.definition()JSON逐行检查DAG结构。一个典型的LLM微调Pipeline定义核心骨架如下from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import TrainingStep, ProcessingStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.pytorch import PyTorch # 1. 定义参数所有可变输入 base_model_id ParameterString(nameBaseModelId, default_valueQwen/Qwen1.5-1.8B) dataset_version ParameterString(nameDatasetVersion, default_valuev20240501) lora_r ParameterInteger(nameLoraR, default_value8) max_steps ParameterInteger(nameMaxSteps, default_value1000) # 2. 数据预处理Step sklearn_processor SKLearnProcessor( framework_version1.0-1, rolerole, instance_typeml.m5.xlarge, instance_count1, env{MLFLOW_TRACKING_URI: http://mlflow-svc.mlflow.svc.cluster.local:5000} ) processing_step ProcessingStep( namePreprocessData, processorsklearn_processor, inputs[ ProcessingInput(sourcefs3://my-bucket/raw-data/{dataset_version}/, destination/opt/ml/processing/input/), ], outputs[ ProcessingOutput(output_nametrain_data, source/opt/ml/processing/output/train/, destinationfs3://my-bucket/processed-data/{dataset_version}/train/), ProcessingOutput(output_nameeval_data, source/opt/ml/processing/output/eval/, destinationfs3://my-bucket/processed-data/{dataset_version}/eval/), ], codepreprocess.py # 该脚本内会调用mlflow.start_run() ) # 3. 训练Step核心 estimator PyTorch( entry_pointtrain.py, source_dirsrc/, rolerole, instance_count1, instance_typeml.g5.12xlarge, py_versionpy311, framework_version2.1.0, hyperparameters{ model_id: base_model_id, lora_r: lora_r, max_steps: max_steps, mlflow_tracking_uri: http://mlflow-svc.mlflow.svc.cluster.local:5000 } ) training_step TrainingStep( nameTrainLLM, estimatorestimator, inputs{ train: TrainingInput(s3_dataprocessing_step.properties.Outputs[train_data].S3OutputLocation), eval: TrainingInput(s3_dataprocessing_step.properties.Outputs[eval_data].S3OutputLocation), } ) # 4. 组装Pipeline pipeline Pipeline( nameLLM-Finetuning-Pipeline, parameters[base_model_id, dataset_version, lora_r, max_steps], steps[processing_step, training_step], # 关键启用Pipeline Execution日志的详细级别 configuration{LogLevel: All} )这里有几个魔鬼细节ProcessingOutput的destination必须是完整S3路径且以/结尾。如果写成s3://bucket/processed-data/{version}/train无尾部斜杠SageMaker会把整个train/目录当成一个文件名导致下游TrainingStep读取时路径拼接错误。TrainingStep的inputs字典key如train会自动映射为训练脚本的--train命令行参数。因此你的train.py必须能接收--train参数并将其值即S3 URI传递给datasets.load_from_disk()或类似方法。我们曾因参数名不匹配导致训练脚本始终读取默认路径浪费了3次p4d实例的费用。configuration{LogLevel: All}是调试神器。默认日志级别是Error很多Step失败时只显示Failed没有堆栈。开启All后CloudWatch Logs中会输出完整的boto3调用请求/响应能快速定位是IAM权限不足、S3路径不存在还是网络策略阻断。3.3 MLflow集成不止是log_param而是构建实验DNA在Pipeline Step中集成MLflow绝不是简单地在train.py开头加mlflow.start_run()。真正的价值在于把实验的每一个原子要素都转化为可查询、可比较、可复现的结构化数据。我们定义了一套强制性的MLflow Logging规范所有团队成员必须遵守3.3.1 必须记录的5类核心元数据类别记录方式示例为什么重要代码快照mlflow.log_artifact(.git)将整个.git目录作为artifact上传精确还原训练时的代码状态比git commit hash更可靠包含未提交的临时修改数据指纹mlflow.log_dict(data_fingerprint, data_fingerprint.json){train_rows: 12450, eval_rows: 1245, hash: a1b2c3...}避免“数据漂移”误判。当指标下降时先查data_fingerprint是否变更再查模型硬件配置mlflow.log_dict(hardware_info, hardware.json){instance_type: g5.12xlarge, gpu_count: 4, cuda_version: 11.8}GPU型号不同可能导致数值精度差异如A10G的FP16与A100的TF32影响结果可比性训练轨迹mlflow.log_metric(train_loss, value, stepstep)每10步记录一次lossstep参数必填支持在MLflow UI中绘制平滑的loss曲线step是X轴没有它曲线就是一堆离散点模型卡片mlflow.log_text(model_card, model_card.md)包含模型用途、限制、偏见声明、测试集表现的Markdown满足内部AI治理要求也是新成员快速理解模型的入口3.3.2 模型序列化的最佳实践对于Hugging Face模型我们弃用model.save_pretrained()而采用mlflow.transformers.log_model()from transformers import AutoModelForCausalLM, AutoTokenizer import mlflow.transformers model AutoModelForCausalLM.from_pretrained(base_model_id) tokenizer AutoTokenizer.from_pretrained(base_model_id) # 关键传入tokenizermlflow会自动保存其配置 mlflow.transformers.log_model( transformers_model{ model: model, tokenizer: tokenizer, task: text-generation }, artifact_pathmodel, # 这个参数至关重要它告诉mlflow加载时要用transformers_pipeline # 而不是简单的torch.load()从而保证tokenizer和model的兼容性 signaturemlflow.models.infer_signature( model_inputtokenizer(Hello, return_tensorspt), model_outputmodel.generate(**tokenizer(Hello, return_tensorspt), max_new_tokens10) ) )这样做的好处是后续用mlflow.pyfunc.load_model()加载时会自动构建一个transformers.pipeline对象你只需调用predict({inputs: Hello})无需关心tokenizer的pad_token_id、eos_token_id等细节。我们曾因手动保存model.bin和config.json导致加载时tokenizer.pad_token_id为None引发generate()报错排查了整整一天。实操心得在train.py末尾务必添加mlflow.end_run()。我们曾因忘记这行在Pipeline Step重试时新的Run会继承上一次的active_run导致所有log_param都写到旧Run里造成数据污染。现在我们的模板脚本强制在try...finally块中包裹训练主逻辑确保end_run()必然执行。4. 实操过程与核心环节实现一次端到端的LoRA微调全流程详解4.1 从零启动参数配置、数据准备与首次Pipeline提交假设我们要对Qwen-1.5-1.8B模型在自有的客服对话数据集上进行LoRA微调。以下是我在终端中实际执行的每一步命令和背后的思考第一步准备数据# 1. 将原始CSV数据上传到S3 aws s3 cp ./data/customer-dialogs-v20240501.csv s3://my-bucket/raw-data/v20240501/ # 2. 生成数据指纹使用sha256sum但仅对内容哈希排除元数据 # 我们写了一个小脚本data_fingerprint.py它会 # - 读取CSV按行排序消除导出顺序影响 # - 对每一行JSONL格式化确保空格、引号一致 # - 计算整个文件的sha256 python data_fingerprint.py ./data/customer-dialogs-v20240501.csv # 输出a1b2c3d4e5f67890...为什么花时间做数据指纹因为数据集版本管理是LLM实验的基石。我们曾遇到过算法同学A用v20240401数据训练同学B用v20240415新增了200条拒答样本两人指标对比时发现B高0.5%但实际是数据差异而非模型改进。有了指纹v20240401和v20240415的哈希值不同一眼就能识别。第二步配置Pipeline参数# 在pipeline_definition.py中设置参数 base_model_id ParameterString(nameBaseModelId, default_valueQwen/Qwen1.5-1.8B) dataset_version ParameterString(nameDatasetVersion, default_valuev20240501) lora_r ParameterInteger(nameLoraR, default_value8) lora_alpha ParameterInteger(nameLoraAlpha, default_value16) learning_rate ParameterFloat(nameLearningRate, default_value2e-4)参数选择依据lora_r8基于Qwen-1.5-1.8B的层数40层和注意力头数32r8能在参数增量约0.1%和性能损失0.3% F1间取得平衡。我们做过网格搜索r4时收敛慢r16时显存占用接近全参数微调。lora_alpha16alpha/r2是Hugging Face PEFT库的推荐比例能保持LoRA权重的缩放稳定性。learning_rate2e-4不是拍脑袋。我们先用lr_finder在1%数据上跑了100步观察loss下降拐点确定1e-4到5e-4是有效区间最终选中间值2e-4。第三步提交Pipeline# 1. 创建Pipeline首次 pipeline.upsert(role_arnrole) # 2. 启动一次执行 execution pipeline.start( parameters{ BaseModelId: Qwen/Qwen1.5-1.8B, DatasetVersion: v20240501, LoraR: 8, LoraAlpha: 16, LearningRate: 2e-4 } ) # 3. 实时监控我习惯用这个命令比Console更直观 watch -n 5 aws sagemaker describe-pipeline-execution --pipeline-execution-arn $execution.arn --query PipelineExecutionStatus --output textwatch命令每5秒刷新一次状态PipelineExecutionStatus会依次显示Executing→Stopping→Stopped。当看到Stopped时立即去MLflow UI用v20240501作为Search Filter应该能看到一个新Run其params.base_model_id为Qwen/Qwen1.5-1.8Bmetrics.eval_f1约为0.852这是我们预估的基线值。4.2 训练Step深度解析train.py中的关键代码与避坑指南train.py是整个Pipeline的引擎核心。下面是我实际使用的、经过生产验证的简化版去除了日志和异常处理import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq ) import mlflow import argparse from peft import LoraConfig, get_peft_model def parse_args(): parser argparse.ArgumentParser() parser.add_argument(--model_id, typestr, requiredTrue) parser.add_argument(--train, typestr, requiredTrue) # S3 URI parser.add_argument(--eval, typestr, requiredTrue) # S3 URI parser.add_argument(--lora_r, typeint, default8) parser.add_argument(--lora_alpha, typeint, default16) parser.add_argument(--learning_rate, typefloat, default2e-4) parser.add_argument(--mlflow_tracking_uri, typestr, requiredTrue) return parser.parse_args() def main(): args parse_args() # 1. 初始化MLflow Run关键用Pipeline Execution ID生成唯一ID pipeline_exec_id os.getenv(SM_PIPELINE_EXECUTION_ID, local-test) run_id f{pipeline_exec_id}_train mlflow.set_tracking_uri(args.mlflow_tracking_uri) mlflow.set_experiment(llm-finetuning) mlflow.start_run(run_idrun_id) # 2. 记录所有输入参数 mlflow.log_params({ model_id: args.model_id, lora_r: args.lora_r, lora_alpha: args.lora_alpha, learning_rate: args.learning_rate, train_s3_uri: args.train, eval_s3_uri: args.eval }) # 3. 下载数据SageMaker自动挂载S3到本地但需确认路径 # SageMaker Training Job会将S3 URI映射到/opt/ml/input/data/{channel_name}/ # 这里channel_name是TrainingInput的key即train和eval train_dataset load_dataset(json, data_filesf/opt/ml/input/data/train/train.jsonl) eval_dataset load_dataset(json, data_filesf/opt/ml/input/data/eval/eval.jsonl) # 4. 加载模型和tokenizer model AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtypetorch.bfloat16, # 关键bfloat16比float16更稳定尤其对Qwen device_mapauto, # 自动分配到多GPU trust_remote_codeTrue # Qwen需要 ) tokenizer AutoTokenizer.from_pretrained( args.model_id, trust_remote_codeTrue, padding_sideleft # 关键left padding因为causal LM需要eos在末尾 ) tokenizer.pad_token tokenizer.eos_token # 必须设置否则collator报错 # 5. 配置LoRA peft_config LoraConfig( rargs.lora_r, lora_alphaargs.lora_alpha, target_modules[q_proj, k_proj, v_proj, o_proj], # Qwen的注意力层名 lora_dropout0.05, biasnone, task_typeCAUSAL_LM ) model get_peft_model(model, peft_config) # 6. 定义训练参数重点梯度检查点和Flash Attention training_args TrainingArguments( output_dir/opt/ml/model, # SageMaker要求模型输出到此路径 per_device_train_batch_size4, # g5.12xlarge有4个A10G总batch16 per_device_eval_batch_size4, gradient_accumulation_steps4, # 等效总batch64适配Qwen-1.8B learning_rateargs.learning_rate, num_train_epochs1, # LLM微调通常1 epoch足够 warmup_ratio0.03, logging_steps10, save_steps100, evaluation_strategysteps, eval_steps100, save_total_limit2, load_best_model_at_endTrue, metric_for_best_modeleval_f1, greater_is_betterTrue, report_tonone, # 关闭wandb只用mlflow # 关键启用Flash AttentionQwen官方支持 torch_compileTrue, # PyTorch 2.1的graph mode compile # 关键梯度检查点节省显存 fp16False, # 不用fp16用bfloat16 bf16True, gradient_checkpointingTrue, # 关键指定tokenizer让Trainer自动处理padding pad_to_multiple_of8, remove_unused_columnsFalse, ) # 7. 数据整理将对话转为instruction格式 def format_chat(example): # 将[{role:user,content:...},{role:assistant,content:...}]转为 # |im_start|user\n{user}\n|im_end||im_start|assistant\n{assistant}\n|im_end| messages example[messages] formatted for msg in messages: formatted f|im_start|{msg[role]}\n{msg[content]}\n|im_end| return {text: formatted} train_dataset train_dataset.map(format_chat, batchedFalse) eval_dataset eval_dataset.map(format_chat, batchedFalse) # 8. 数据整理器关键使用tokenizer的chat template data_collator DataCollatorForSeq2Seq( tokenizertokenizer, modelmodel, paddingTrue, return_tensorspt ) # 9. 初始化Trainer trainer Trainer( modelmodel, argstraining_args, train_datasettrain_dataset, eval_dataseteval_dataset, tokenizertokenizer, data_collatordata_collator, compute_metricscompute_f1_metric # 自定义F1计算函数 ) #