Open Assistant 全流程训练细节(GPT3+RL)
Open Assistant 是 LAION 机构开源的,旨在训练一个 ChatGPT 的小规模替代版本,就像 stable diffusion 相对于 dalle 一样,让普通人都可以跑起来,传播力会比较广。
LAION 机构全称为 Large-scale Artificial Intelligence Open Network,是一个非盈利的机器学习研究机构,致力于为公众提供 AI 模型、数据集和开源代码。Stable diffusion 使用的开源数据就是该机构提供的。
这一篇介绍如何一步一步在 Open Assistant 上训练一个完整的 ChatGPT。
代码: github.com/LAION-AI/Open-Assistant
Star 16.8k,Fork 1.3k,Issue 308open,535closed,代码更新三天前
文档: https:// projects.laion.ai/Open- Assistant/docs/intro
在huggingface上面的模型: OpenAssistant (OpenAssistant)
数据格式介绍: https:// github.com/LAION-AI/Ope n-Assistant/blob/363a3a124471217e723d57b084122ae1ca41ab2a/notebooks/data-augmentation/stackexchange-builder/README.md
整体训练流程
ChatGPT 完整训练包括三个流程:
- Supervised FineTune(SFT) :使用人工编写的期望模型如何输出的数据集,对GPT-3进行微调
- Reward Model(RM) :使用人工标注的排序数据,训练奖励模型,预测人类更喜欢哪个输出
- 强化学习微调 SFT :使用奖励模型作为强化学习优化目标,微调SFT模型
数据收集
对应上面三步的训练过程,OpenAssistant依靠社区希望构建对应内容的数据:
训练 SFT:需要社区成员进行如下内容的标注
- 随机写初始 Prompt,写一些你希望与机器人进行交互的问题
- 写回答,模拟AI助理来给第一步的 Prompt合理的回答
训练 RM:需要社区成员进行如下内容的标注
- 对同一个上下文的回答进行打分
配置环境
git clone https://github.com/LAION-AI/Open-Assistant.git
cd Open-Assistant/model
# 要求python3.8+
pip install -r model_training/requirements.txt
pip install -r reward/instructor/requirements.txt
我们把所有的预训练模型和数据都放在 Open-Assistat/model/.cache 目录下,我们设置一个全局变量 DATA_DIR
# 还是在 Open-Assistant/model 目录下
mkdir -p .cache
mkdir -p .saved_models
export DATA_PATH=$PWD/.cache # 设置数据目录
export MODEL_PATH=$PWD/.saved_models # 设置模型目录
第一步:训练 SFT
第一步是用更接近用户使用情况的数据来 finetune 已经 pretrain 好的 gpt-3 模型,论文中写这样子第一步的 finetune 过拟合一点,对于后面的强化学习训练有帮助。
SFT是 Supervised FineTune,那么用什么数据进行监督呢?这一步使用的数据具体可以参考: https:// github.com/LAION-AI/Ope n-Assistant/blob/main/model/model_training/configs/config.yaml ,包括如下数据
datasets:
- webgpt
- squad_v2
- adversarial_qa
- trivia_qa_nocontext
- xsum
- cnn_dailymail
- prompt_dialogue # TODO: need to fix the url 这个数据目前无法自动下载
- multi_news
- scitldr
- soda
- joke
- gsm8k
- dive_mt
- wmt2019_zh-en
- wmt2019_ru-en
- wmt2019_de-en
- ted_trans_nl-en
- ted_trans_de-ja
- instruct_tuning
- wmt2019_de-en
- samsum
- soda_dialogue
Open Assistant 里面的 OA Private(jsonl格式)数据没有说的很清楚,不知道在哪里下载,所以我们尽量跳过这个数据,用其他的数据进行训练。
怎么跳过 OA Private 数据,使用其他数据训练呢?
首先先进入 model_training 目录
cd Open-Assistant/model/model_training
训练 SFT
python trainer_sft.py --configs defaults galactica-125m --cache_dir $DATA_PATH --output_dir $MODEL_PATH/sft_model
上面 --config 后面的参数代表需要载入的 config 项目,例如上面的 galactica-125m 是 configs/config.yaml 里面的关于模型的指定,具体内容如下面所示:
galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
weight_decay: 0.01
warmup_steps: 600
gradient_checkpointing: false
gradient_accumulation_steps: 2
per_device_train_batch_size: 4