RWKV 微调教程
State tuning 微调教程
ℹ️

State 微调是什么?

RWKV 是纯 RNN,因此可以做 transformer 难以做到的事情。例如,作为 RNN 有固定大小的 state,所以,微调 RWKV 的初始 state,就相当于最彻底的 prompt tuning,甚至可以用于 alignment,因为迁移能力很强。

本文的 State tuning 方法来自 RWKV 社区微调项目 RWKV-PEFT (opens in a new tab)

开始之前,请确保你拥有一个 Linux 工作区,以及支持 CUDA 的 NVIDIA 显卡。

State tuning 的显存需求可参考下表:

模型参数fp16int8nf4
RWKV6-1.6B5.8GB GPU4.5GB GPU3.9GB GPU
RWKV6-3B8.7GB GPU6.2GB GPU4.9GB GPU
RWKV6-7B17.8GB GPU11.9GB GPU8.5GB GPU

上表的数据基于以下测试参数:

  • ctxlen=1024
  • micro_bsz=1
  • strategy=deepspeed_stage_1

整理训练数据

收集 jsonl 格式训练数据

要 state tuning 微调 RWKV 模型,需要使用收集适合训练 RWKV 的数据(jsonl 格式),具体方法可参考准备微调数据集 (opens in a new tab)

为了突出 state tuning 的特性,我们为本教程选取了大量带 emoji 的对话数据作为训练数据:

state-tuning-dataset-demo

转化成 binidx 数据

使用 make_data.py (opens in a new tab) 脚本,可以将 jsonl 训练数据乱序重复生成 10 遍,同时转化成可供 RWKV 训练的 binidx 数据。

在 Linux 工作区依次运行以下命令,以使用 RWKV-LM 仓库中的 make_data.py 脚本生成 binidx 数据:

git clone https://github.com/BlinkDL/RWKV-LM.git # 克隆 RWKV-LM 仓库
cd RWKV-LM/RWKV-v5 # 进入 RWKV-v5 目录
python make_data.py /home/rwkv/RWKV-PEFT/data/qaemoji.jsonl 10 512 # 乱序复制数据并生成 binidx 数据

state-tuning-make-data

/home/rwkv/RWKV-PEFT/data/qaemoji.jsonl 需要改成你自己的 jsonl 文件路径,10 是数据重复的次数,512 则是训练数据的 ctx_len(上下文长度)。

ℹ️

在社区的实验中,state tuning 的 ctx_len 应当尽可能小,建议从 512 开始尝试。

配置训练环境

请参考RWKV 微调环境配置 (opens in a new tab)板块,配置 Conda 等训练环境。

克隆仓库并安装依赖

在 Linux 或 WSL 中,使用 git 命令克隆 RWKV-PEFT 仓库​:

git clone https://github.com/JL-er/RWKV-PEFT.git

克隆完成后,使用 cd RWKV-PEFT 命令进入 RWKV-PEFT 目录。并运行以下命令,安装项目所需依赖:

pip install -r requirements.txt

修改训练参数

使用任意文本编辑器(如 vscode)打开 RWKV-PEFT/scripts 目录的 demo-state-tuning.sh 文件,修改训练参数,进而控制微调的训练过程和训练效果:

state-tuning-configs

以下是一次 state tuning 调参过程:

调整路径参数

demo-state-tuning.sh 文件前三行是文件路径参数:

  • load_model: 基底 RWKV 模型的路径
  • proj_dir:训练日志和训练得到的 state 文件输出路径
  • data_file:训练数据集的路径,注意路径中不需要带 bin 和 idx 后缀,仅需文件名称。

调整 n_layer 和 n_embd 参数

基底 RWKV 模型的参数大小不同,训练时使用的 n_layer 和 n_embd 数值也不一样。以下是不同参数的 RWKV 模型和 n_layer 和 n_embd 数值的对照列表:

模型参数n_layern_embd
0.1B12768
0.4B241024
1.5B242048
3B322560
7B324096
14B614096

调整重要训练参数

ℹ️

以下参数建议根据你的微调数据、设备性能进行调整。

参数描述
micro_bsz=1微批次大小,根据显存大小调整,微调时从 1 开始逐渐增大
epoch_save=1每隔多少个训练轮次保存一次 State 文件
epoch_steps=1000每个训练轮次的步数,增加会拉长单个epoch的训练时间
ctx_len=512微调模型的上下文长度,state tuning 建议从短开始尝试,如 512

调整其他训练参数

下面列出了脚本中其他可修改的训练参数,及其修改的效果。

⚠️

注意:微调 state 时,建议 --warmup_steps 10--lr_init 1--lr_final 0.01 ,以及尽可能短的 ctxlen 。(是的,state tuning 需要使用非常高的学习率。)

参数描述
--data_type binidx训练语料的文件格式,支持:"utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"
--vocab_size 65536词表大小,默认为 65536,设置为 0 表示模型自动确定词汇表大小
--epoch_count 5总训练轮次,可根据效果调整
--epoch_begin 0初始训练轮次,即从第 N 个 epoch 开始加载
--pre_ffn 0用 ffn 替换第一个 att 层,有时可能有益
--head_qk 0通常保持默认值 0,即关闭状态
--lr_init 1初始学习率,state tuning 建议为 1,其他微调建议不超过 1e-4
--lr_final 0.01最终学习率,state tuning 建议为 0.01,其他微调建议不超过 1e-4
--warmup_steps 10预热步骤数,state tuning 建议为 10
--beta1 0.9Adam 优化器的 beta1 参数
--beta2 0.99Adam 优化器的 beta2 参数
--adam_eps 1e-8Adam 优化器的 epsilon 参数
--accelerator gpu使用的加速器类型,目前主要支持 gpu, cpu 基本不支持训练
--devices 1单显卡填 1,多卡按实际数量填写
--precision bf16训练精度,默认为 bf16,支持:"fp32", "tf32", "fp16", "bf16"
--strategy deepspeed_stage_1lightning 训练策略参数,微调推荐使用 deepspeed_stage_1
--grad_cp 1梯度累积步数,0 训练更快但需更多显存,1 训练较慢但节省显存
--my_testing "x060"训练的 RWKV 模型版本,v5 选 x052,v6 选 x060
--dataload pad数据加载选项,pad 支持 bsz>1,only 限制 bsz=1
--train_type "state"训练类型 state tuning ,保持默认
--fla是否启用 fla,bsz < 8 时建议开启以减小显存需求
--quant int8/nf4RWKV 默认使用 bf16 训练精度,但支持 int8 和 nf4 两种量化训练类型,推荐使用精度损失较小的 int8
--wandb PEFT-State-tuning是否使用 wandb 可视化记录训练日志,需提前配置 wandb (opens in a new tab) 账号

量化训练可以降低显存需求,但会导致模型精度损失。如果不需要量化训练,可以删除 quant 相关的参数。

⚠️

参数调整完成后,请记得保存 demo-state-tuning.sh 文件。

附录:state tuning 配置参考

demo-state-tuning.sh
load_model='/home/rwkv/models/basemodel/3b.pth'
proj_dir='/home/rwkv/RWKV-PEFT/output_state'
data_file='/home/rwkv/RWKV-PEFT/data/qaemoji'
 
 
n_layer=32
n_embd=2560
 
micro_bsz=2
epoch_save=1
epoch_steps=500
ctx_len=512
 
python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 5 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-2 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
--my_testing "x060" \
--train_type "state"  --dataload pad --wandb PEFT-State-tuning
 
# 以下为可选项
# --fla 
# --quant int8/nf4 (是否量化训练)
# --wandb PEFT-State-tuning (是否使用 wandb 监控训练过程)

开始训练

在 RWKV-PEFT 目录,运行 sh scripts/demo-state-tuning.sh 命令,开启 state tuning 。

正常开始训练后,应当是如下画面:

state-tuning-running

训练完毕后,应当可以在输出文件夹中找到训练好的 state 文件(.pth 格式)和训练日志(.txt 文件):

state-tuning-get-model

如何使用 state 文件

获得 state 文件后,你可以如此使用:

  • 使用 demo-state-merge.sh 工具将 state 文件合并到基底 RWKV 模型中,获得一个完整的 state 微调模型。

  • 可以选择在 RWKV Runner 或 Ai00 等工具中单独挂载 state 文件。(推荐用法)

ℹ️

注意:挂载 state 文件时,必须使用训练此 state 文件的同款 RWKV 模型。

举个例子:这个 state 文件是基于 RWKV-6-World-3B-v2.1 模型微调而来,那么你在 RWKV Runner 或 Ai00 等工具中必须启动 RWKV-6-World-3B-v2.1 模型,挂载的 state 文件才会生效。

state-file-usage

由于我们的示例数据基于大量 emoji 且 ctx 非常短,训练出来的 State 文件效果如下:

state-tuning-examples

ℹ✨

由于 state 文件支持单独挂载,其他用户也可以通过挂载你训练出来的的 state 文件,增强 RWKV 模型的使用体验。

State 文件的用法可以参考 RWKV Runner (opens in a new tab) | Ai00 (opens in a new tab)