RWKV pip 使用指南
以下内容将指引你使用 RWKV pip 库开发基于 RWKV 模型的应用。
RWKV pip 库的原始代码可以在 ChatRWKV 仓库中找到。
API_DEMO_CHAT.py 详解
API_DEMO_CHAT 是一个基于 RWKV pip 库的开发 Demo,用于实现基于命令行的聊天机器人。
下文将以详细的注释,分段介绍这个聊天机器人 DEMO 的代码设计。
这部分代码是导入一些使用 RWKV 模型推理时需要用到的包,需要注意以下两点:
- torch 版本最低 1.13 ,推荐 2.x+cu121
- 需要先
pip install rwkv
在推理 RWKV-7 模型时,请务必将 os.environ["RWKV_V7_ON"]
设置为 1
。
这里是一些加快推理速度的 torch 设置和操作环境的优化项。
这一段引入了 RWKV 工具包中的两个工具类 RWKV 和 PIPELINE ,同时指定了加载 RWKV 模型的设备和精度,以及 RWKV 模型的本地文件路径。
args.strategy
会影响模型的生成效果和生成速度,chatRWKV 支持下表中的 strategy:
下表中 ,fp16i8
指在 fp16 精度基础上进行 int8 量化。
量化可以减少 VRAM 需求,但在精度上略逊于 fp16。因此只要 VRAM 够用,尽量使用 fp16 层。
策略 | VRAM & RAM | 速度 |
---|---|---|
cpu fp32 | 7B 模型需要 32GB 内存 | 使用 CPU fp32 精度加载模型,适合 Intel。对 AMD 非常慢,因为 pytorch 的 cpu gemv 在 AMD 上有问题,并且只会运行在一个单核上。 |
cpu bf16 | 7B 模型需要 16GB 内存 | 使用 CPU bf16 精度加载模型。在支持 bfloat16 的新 Intel CPU(如 Xeon Platinum)上速度较快。 |
cpu fp32i8 | 7B 模型需要 12GB 内存 | 使用 CPU int8 量化精度加载模型。速度较慢(比 cpu fp32 更慢)。 |
cuda fp16 | 7B 模型需要 15GB VRAM | 使用 fp16 精度加载模型所有层,速度最快,但对显存(VRAM)的需求也最高。 |
cuda fp16i8 | 7B 模型需要 9GB VRAM | 使用 int8 量化模型所有层,速度较快。如果设置 os.environ["RWKV_CUDA_ON"] = '1' 来编译 CUDA 内核,可减少 1~2GB VRAM 使用。 |
cuda fp16i8 *20 -> cuda fp16 | VRAM 占用介于 fp16 和 fp16i8 之间 | 将模型的前 20 层(*20 指层数)量化为 fp16i8,其余层使用 fp16 加载。 如果量化后还有较多 VRAM ,则酌情减少 fp16i8 层数(减少 20)。 如果 VRAM 不足则继续增加 fp16i8 量化层数 |
cuda fp16i8 *20+ | 比 fp16i8 使用更少 VRAM | 将模型的前 20 层(*20 指层数)量化为 fp16i8 并固定在 GPU 上,其他层按需动态加载(未固定的层加载速度会慢 3 倍,但节省 VRAM)。 如果 VRAM 不足,减少固定层数(*20 )。 如果 VRAM 充足,增加固定层数。 |
cuda fp16i8 *20 -> cpu fp32 | 比 fp16i8 使用更少 VRAM,但消耗更多内存 | 将模型的前 20 层(*20 )量化为 fp16i8 并固定在 GPU 上,其他层使用 CPU fp32 加载。当 CPU 性能比较强时,此策略比上一个策略(只在 GPU 上固定 20 层)更快。 如果加载 20 层还有剩余 VRAM ,则继续增加 GPU 层数。 如果没有足够 VRAM,减少 GPU 层数。 |
cuda:0 fp16 *20 -> cuda:1 fp16 | 使用双卡驱动模型 | 使用 cuda:0(卡1) fp16 加载模型的前 20 层,然后使用 cuda:1(卡2) fp16 加载剩余的层(自动计算剩余层数)。 建议在最快的 GPU 上运行更多层。 如果某张卡的 VRAM 不够,可以将 fp16 换成 fp16i8 (int8 量化)。 |
这一段决定是否要加载一个 State 文件,"None"
表示不加载自定义 State ,如需加载请填写 State 文件的绝对路径。
State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)
RWKV State 的介绍和用法可以参照 State 文件介绍和用法文章。
这里主要是设置加载或不加载 State 时, RWKV 模型分别使用哪些解码参数。
有关 RWKV 解码参数的含义和作用,请查看RWKV 解码参数文档。
指定一个自定义 State 文件后,我们希望模型能更好地遵循 State 中的格式和风格,所以调低了 topp 参数和惩罚参数。
CHUNK_LEN
将输入文本切分成指定大小的块。这个数值越大,模型并行处理的文本越多,但使用的显存也更多。在显存不足时建议调整到 128 或者 64。
这一段开始使用前面设置的 strategy 和解码参数加载 RWKV 模型。
如果你希望模型加载完后也有提示,可以在这一段末尾插入:print(f"{args.MODEL_NAME} - 模型加载完毕")
这一段代码用于加载自定义的 State 文件,将其写入模型的初始化 State 中。
通常无需修改这部分代码。
这是控制 RWKV 模型使用 RNN 模式进行 prefill 的函数,这个函数会将 ctx(前文)切成长度为 CHUNK_LEN 的段落,一段段送入 RNN 处理,最后得到处理完前文后的 model_state 和 out 。
这个函数接收一个 ctx 参数,通常是文本(string)。然后依次对文本和文本转化的 token 进行了几项处理:
- 使用
replace
方法将文本的换行符统一为\n
,因为 RWKV 模型的训练数据集使用\n
作为标准换行符格式。 - 使用
pipeline.encode
方法,将用户的输入文本按照 RWKV-World 词表转换成对应的 token 。 - 将 tokens 转换为整数(int)列表,确保类型一致性
- 基于当前 token 前向传播,并行处理输入文本,更新模型状态并返回 out
注意,函数返回的 out
不是具体的 token 或文本,它返回的是模型对下一个 token 的原始预测(张量)。
要将 out
转换为实际的 token 或文本,需要通过采样(例如后文中的 pipeline.sample_logits
函数)预测下一个 token ,再从 token decode 成文本。
如果未加载任何 State 文件,则使用一段默认的对话文本进行 prefill 。
这一段是循环检测用户输入、并使用 RNN 模式进行推理,生成文本的功能代码。
以上代码的主要逻辑如下:
- 接收用户消息,规范空格空行,判断输入文本的内容长度
- 如果规范后用户输入为空,则提示“请说点什么”
- 如果规范后用户的输入非空,则进入步骤 2
- 将用户的输入拼接成聊天格式的 prompt ,然后进行 prefill ,获得 logits
- 预测 token ,并打印解码得到的文本字符
- 应用存在惩罚(GEN_alpha_presence)和频率惩罚(GEN_alpha_frequency)
- 基于 temperature 和 topp 参数对
out
进行采样,获得下一个 token - 使用新 token 前向传播,开启下一轮预测
- 应用惩罚衰减参数(penalty_decay)调整 token 生成的概率
- 把已经生成的 token 列表解码(decode)成字符文本
- 实时输出解码得到的字符文本,判断文本里面有没有 \n\n 停止词。如果出现停止词,则退出本轮推理。
从推理过程可以看出,模型在每个时间步都更新隐藏状态(State),并利用当前的隐藏状态来生成下一个时间步的输出。这符合 RNN 的核心特性: 模型的每次输出依赖于前一步的生成结果。
意见反馈(可选)
联系方式(可选)