相关文章推荐
大力的核桃  ·  JAVA ...·  1 年前    · 

Open Assistant 全流程训练细节(GPT3+RL)

5 个月前 · 来自专栏 AI 学习专栏

Open Assistant 是 LAION 机构开源的,旨在训练一个 ChatGPT 的小规模替代版本,就像 stable diffusion 相对于 dalle 一样,让普通人都可以跑起来,传播力会比较广。

LAION 机构全称为 Large-scale Artificial Intelligence Open Network,是一个非盈利的机器学习研究机构,致力于为公众提供 AI 模型、数据集和开源代码。Stable diffusion 使用的开源数据就是该机构提供的。

这一篇介绍如何一步一步在 Open Assistant 上训练一个完整的 ChatGPT。

代码: github.com/LAION-AI/Open-Assistant
Star 16.8k,Fork 1.3k,Issue 308open,535closed,代码更新三天前
文档: projects.laion.ai/Open-
在huggingface上面的模型: OpenAssistant (OpenAssistant)
数据格式介绍: github.com/LAION-AI/Ope

整体训练流程

ChatGPT 完整训练包括三个流程:

  1. Supervised FineTune(SFT) :使用人工编写的期望模型如何输出的数据集,对GPT-3进行微调
  2. Reward Model(RM) :使用人工标注的排序数据,训练奖励模型,预测人类更喜欢哪个输出
  3. 强化学习微调 SFT :使用奖励模型作为强化学习优化目标,微调SFT模型

数据收集

对应上面三步的训练过程,OpenAssistant依靠社区希望构建对应内容的数据:

训练 SFT:需要社区成员进行如下内容的标注

  • 随机写初始 Prompt,写一些你希望与机器人进行交互的问题
  • 写回答,模拟AI助理来给第一步的 Prompt合理的回答

训练 RM:需要社区成员进行如下内容的标注

  • 对同一个上下文的回答进行打分

配置环境

git clone https://github.com/LAION-AI/Open-Assistant.git
cd Open-Assistant/model
# 要求python3.8+
pip install -r model_training/requirements.txt
pip install -r reward/instructor/requirements.txt

我们把所有的预训练模型和数据都放在 Open-Assistat/model/.cache 目录下,我们设置一个全局变量 DATA_DIR

# 还是在 Open-Assistant/model 目录下
mkdir -p .cache 
mkdir -p .saved_models
export DATA_PATH=$PWD/.cache # 设置数据目录
export MODEL_PATH=$PWD/.saved_models # 设置模型目录

第一步:训练 SFT

第一步是用更接近用户使用情况的数据来 finetune 已经 pretrain 好的 gpt-3 模型,论文中写这样子第一步的 finetune 过拟合一点,对于后面的强化学习训练有帮助。

SFT是 Supervised FineTune,那么用什么数据进行监督呢?这一步使用的数据具体可以参考: github.com/LAION-AI/Ope ,包括如下数据

datasets:
    - webgpt
    - squad_v2
    - adversarial_qa
    - trivia_qa_nocontext
    - xsum
    - cnn_dailymail
    - prompt_dialogue # TODO: need to fix the url 这个数据目前无法自动下载
    - multi_news
    - scitldr
    - soda
    - joke
    - gsm8k
    - dive_mt
    - wmt2019_zh-en
    - wmt2019_ru-en
    - wmt2019_de-en
    - ted_trans_nl-en
    - ted_trans_de-ja
    - instruct_tuning
    - wmt2019_de-en
    - samsum
    - soda_dialogue

Open Assistant 里面的 OA Private(jsonl格式)数据没有说的很清楚,不知道在哪里下载,所以我们尽量跳过这个数据,用其他的数据进行训练。

怎么跳过 OA Private 数据,使用其他数据训练呢?

首先先进入 model_training 目录

cd Open-Assistant/model/model_training

训练 SFT

python trainer_sft.py --configs defaults galactica-125m --cache_dir $DATA_PATH --output_dir $MODEL_PATH/sft_model

上面 --config 后面的参数代表需要载入的 config 项目,例如上面的 galactica-125m 是 configs/config.yaml 里面的关于模型的指定,具体内容如下面所示:

galactica-125m:
  learning_rate: 5e-5
  model_name: facebook/galactica-125m
  weight_decay: 0.01
  warmup_steps: 600
  gradient_checkpointing: false
  gradient_accumulation_steps: 2
  per_device_train_batch_size: 4