【A Generalist Agent论文阅读】: 首次展示了单一模型可以执行数百种不同任务
论文信息标题A Generalist Agent会议Transactions on Machine Learning Research (TMLR) 2022单位DeepMind代码无公开官方代码论文https://arxiv.org/pdf/2205.06175引言一个AIN种超能力你能想象一个AI既能打《Pong》《Breakout》等经典Atari游戏又能给图片写标题还能和你聊天甚至控制真实机器人手臂堆积木吗DeepMind在2022年推出的Gato就做到了它用同一套神经网络权重搞定了604个完全不同的任务从文本对话到机器人控制从图像理解到游戏通关打破了传统AI“一个任务一个模型”的范式。就像人类一样Gato可以根据不同的输入输出完全不同类型的结果——看到游戏画面就输出按键看到图片就输出描述看到机器人传感器数据就输出关节力矩。图1 Gato的多模态多任务能力来源论文Figure1从图1可以清晰看到Gato可以无缝处理Atari游戏图像、文本对话、机器人本体感觉数据等完全不同的输入输出对应的动作或文本。这是通用人工智能AGI道路上的重要一步。一、Gato的核心设计把所有东西都变成序列Gato的设计灵感来自大语言模型LLM的成功。既然Transformer可以把文本变成序列来处理那为什么不能把所有东西都变成序列呢Gato的核心思想就是将所有模态的数据图像、文本、传感器数据、动作等都统一序列化然后用一个标准的解码器-only Transformer来处理。1.1 多模态数据的统一Token化这是Gato最关键的一步。不管是什么类型的数据都要先转换成一串整数token然后输入到Transformer中。论文中设计了四种不同的token化方案数据类型Token化方法Token范围文本使用SentencePiece分词器32000个子词[0, 32000)离散值如Atari按键直接展平为整数[0, 1024)连续值如机器人关节角度先进行mu-law编码再离散化为1024个bin[32000, 33024)图像分成16×16的patch用ResNet块嵌入与其他token共享嵌入空间通俗解释这就像把不同语言的书都翻译成同一种语言然后让同一个翻译官来读。不管是中文书、英文书还是图画书都先转换成统一的“机器语言”然后交给Transformer处理。1.2 序列排序规则所有数据token化后还要按照固定的顺序排列成一个长序列文本token按原始顺序排列图像patch token按光栅顺序从左到右从上到下排列张量数据按行优先顺序排列每个时间步的结构是观察token 分隔符 动作token整个episode按时间顺序排列1.3 训练目标与损失函数Gato采用自回归训练方式和大语言模型完全一样。它的损失函数是L(θ,B)−∑b1∣B∣∑l1Lm(b,l)logpθ(sl(b)∣s1(b),...,sl−1(b))\mathcal{L}(\theta, \mathcal{B})-\sum_{b1}^{|\mathcal{B}|} \sum_{l1}^{L} m(b, l) log p_{\theta}\left(s_{l}^{(b)} | s_{1}^{(b)}, ..., s_{l-1}^{(b)}\right)L(θ,B)−b1∑∣B∣l1∑Lm(b,l)logpθ(sl(b)∣s1(b),...,sl−1(b))公式解释L\mathcal{L}L总损失函数θ\thetaθ模型的可训练参数B\mathcal{B}B训练批次batch∣B∣|\mathcal{B}|∣B∣批次大小LLL序列长度Gato中固定为1024m(b,l)m(b,l)m(b,l)掩码函数1表示该位置是文本或动作需要计算损失0表示是观察值如图像、传感器数据不计算损失pθ(sl(b)∣s1(b),...,sl−1(b))p_{\theta}(s_l^{(b)} | s_1^{(b)}, ..., s_{l-1}^{(b)})pθ(sl(b)∣s1(b),...,sl−1(b))模型在给定前l-1个token的情况下预测第l个token的概率通俗解释这个损失函数就是让模型学会“根据前面的内容预测下一个内容”。但我们只让它学习预测文本和动作不用预测输入的图像或传感器数据。就像你学英语时只需要背单词和句子不用背课本上的插图一样。1.4 模型架构Gato使用标准的解码器-only Transformer架构和GPT系列完全一致。最大的版本有1.2B参数具体超参数如下表1 Gato的Transformer超参数来源论文Table5为什么选择1.2B参数论文中明确说明这个规模是为了能在真实机器人上实现20Hz的实时控制。如果模型太大推理速度就会跟不上机器人的控制频率如果太小能力又不够。1.2B是一个很好的平衡点。二、训练数据604个任务的大杂烩Gato的强大能力来自于它海量且多样化的训练数据。它在604个不同的任务上进行了训练涵盖了以下几个大类模拟控制任务包括Atari游戏、DM Control Suite、Meta-World、BabyAI等真实机器人任务RGB堆叠任务真实和模拟视觉语言任务图像字幕、视觉问答、文本对话等纯文本任务MassiveText数据集网页、书籍、新闻、代码等不同数据集的采样权重如下表2 视觉语言数据集的采样权重来源论文Table1有趣的细节论文中提到他们在训练时会过滤掉那些表现不好的episode只保留专家水平80%以上的数据。这就像你学习时只看学霸的笔记不看学渣的作业一样。三、惊人的实验结果一个模型打天下Gato的实验结果可以用“震撼”来形容。它用同一套权重在数百个完全不同的任务上都取得了不错的表现。3.1 模拟控制任务表现Gato在超过450个模拟控制任务上达到了专家水平的50%以上图2 Gato在模拟控制任务上的表现来源论文Figure5亮点在23个Atari游戏上达到了人类平均水平在11个游戏上超过了人类两倍的分数在BabyAI的几乎所有关卡上达到了专家水平的80%以上最难的BossLevel也达到了75%在Meta-World的45个任务中44个达到了专家水平的50%以上35个达到了80%以上3.2 真实机器人任务RGB堆叠这是最令人印象深刻的实验之一。Gato被用来控制一个真实的Sawyer机器人手臂完成堆叠不同形状积木的任务。实验分为两个部分Skill Generalization技能泛化训练时用的积木形状和测试时不同Skill Mastery技能掌握训练和测试用相同的积木形状在Skill Generalization任务中Gato的表现甚至超过了专门的BC-IMP基线表3 Gato在真实机器人RGB堆叠任务上的Skill Generalization表现来源论文Table2通俗解释这意味着Gato学会了“堆叠”这个通用技能而不是只会堆叠特定形状的积木。就像人类学会了搭积木后不管是方形、圆形还是三角形的积木都能搭起来一样。3.3 文本与图像能力Gato还展示了不错的文本和图像理解能力可以生成合理的图像字幕可以进行简单的对话可以回答视觉问题图3 Gato生成的图像字幕示例来源论文Figure6从图3可以看到Gato生成的字幕虽然不是完美的但基本都能准确描述图片的主要内容。3.4 少样本泛化能力Gato最强大的地方在于它的少样本泛化能力。它可以在只看到几个新任务的演示后就学会完成这个新任务。图4 Gato的少样本泛化能力来源论文Figure9从图4可以看到在Cartpole Swingup、Assembly-v2等任务上Gato只需要10-100个演示episode就能达到接近专家的水平。这比从头训练一个模型要高效得多。四、核心代码实现简化版下面是一个极度简化的Gato实现展示了它的核心思想importtorchimporttorch.nnasnnfromtransformersimportGPT2Model,GPT2ConfigimportnumpyasnpclassGato(nn.Module): 简化版Gato通用智能体 核心思想用一个解码器-only Transformer处理所有模态的序列 def__init__(self,num_tokens33024,# 总token数32000文本 1024连续值embed_dim2048,# 嵌入维度num_layers24,# Transformer层数num_heads16,# 注意力头数patch_size16# 图像patch大小):super().__init__()self.patch_sizepatch_size# 1. 配置GPT2模型解码器-only TransformerconfigGPT2Config(vocab_sizenum_tokens,n_embdembed_dim,n_layernum_layers,n_headnum_heads,activation_functiongelu_new,n_positions1024# 上下文长度)self.transformerGPT2Model(config)self.lm_headnn.Linear(embed_dim,num_tokens,biasFalse)# 2. 图像patch嵌入简化版用卷积代替论文中的ResNet块self.patch_embednn.Conv2d(in_channels3,out_channelsembed_dim,kernel_sizepatch_size,stridepatch_size)# 3. 位置编码self.pos_embednn.Embedding(1024,embed_dim)deftokenize_image(self,image): 将图像转换为patch token序列 Args: image: [batch, 3, H, W]归一化到[-1, 1] Returns: patches: [batch, num_patches, embed_dim] # 提取patch并嵌入patchesself.patch_embed(image)# [batch, embed_dim, H/16, W/16]# 展平为序列patchespatches.flatten(2).transpose(1,2)# [batch, num_patches, embed_dim]returnpatchesdefforward(self,tokens,attention_maskNone): 前向传播 Args: tokens: [batch, seq_len]已经token化的序列 attention_mask: [batch, seq_len]注意力掩码 Returns: logits: [batch, seq_len, num_tokens]下一个token的预测logits batch_size,seq_lentokens.shape# 添加位置编码positionstorch.arange(seq_len,devicetokens.device).unsqueeze(0).expand(batch_size,-1)pos_embself.pos_embed(positions)# 嵌入token并加上位置编码token_embself.transformer.wte(tokens)hidden_statestoken_embpos_emb# Transformer前向传播outputsself.transformer(inputs_embedshidden_states,attention_maskattention_mask)# 预测下一个tokenlogitsself.lm_head(outputs.last_hidden_state)returnlogits# 测试代码if__name____main__:modelGato()# 测试文本输入text_tokenstorch.randint(0,32000,(1,10))logitsmodel(text_tokens)print(f文本输入输出形状:{logits.shape})# 应该是 [1, 10, 33024]# 测试图像输入实际使用时需要先tokenizeimagetorch.randn(1,3,64,64)image_patchesmodel.tokenize_image(image)print(f图像patch形状:{image_patches.shape})# 应该是 [1, 16, 2048]代码说明这只是一个概念验证版本实际的Gato还包含更复杂的多模态嵌入、位置编码和动作解码逻辑核心思想就是用一个统一的Transformer处理所有模态的序列图像被分成16×16的patch每个patch被嵌入成一个向量和文本token一起输入到Transformer中五、局限性与未来展望虽然Gato取得了令人瞩目的成就但它仍然有很多局限性上下文长度有限Gato的上下文长度只有1024个token对于需要长序列记忆的任务如长对话、复杂机器人任务来说不够用纯监督学习Gato是纯监督学习训练的没有使用强化学习。这意味着它只能模仿专家的行为不能通过试错来改进表现不如专门模型在大多数任务上Gato的表现都不如专门为该任务训练的模型数据依赖严重Gato的能力完全依赖于训练数据的质量和多样性。如果某个任务没有足够的高质量数据它的表现就会很差未来展望扩大模型规模和训练数据规模进一步提升能力引入强化学习让Gato可以通过试错来学习增加上下文长度支持更长的序列探索更好的多模态融合方法六、总结Gato是通用智能体发展史上的一个重要里程碑。它证明了用一个统一的序列模型处理所有模态、所有任务是完全可行的。虽然Gato还不是真正的通用人工智能但它向我们展示了一条清晰的道路只要我们有足够多的多样化数据和足够大的模型我们就可以训练出一个能完成各种任务的通用智能体。就像论文中所说的“Transformer序列模型作为多任务多载体策略是有效的包括真实世界的文本、视觉和机器人任务。未来这样的模型可以作为学习新行为的默认起点而不是从头开始训练。”这可能就是通用人工智能的未来。