以下内容将指引你使用 RWKV pip 库 (opens in a new tab)开发基于 RWKV 模型的应用。
RWKV pip 库的原始代码可以在 ChatRWKV (opens in a new tab) 仓库中找到。
API_DEMO_CHAT.py 详解
API_DEMO_CHAT (opens in a new tab) 是一个基于 RWKV pip 库的开发 Demo,用于实现基于命令行的聊天机器人。
下文将以详细的注释,分段介绍这个聊天机器人 DEMO 的代码设计。
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print("RWKV Chat Simple Demo") # 打印一个简单的消息,表明这是 RWKV 聊天的简单演示。
import os, copy, types, gc, sys, re # 导入操作系统、对象复制、类型、垃圾回收、系统、正则表达式等包
import numpy as np # 导入 numpy 库
from prompt_toolkit import prompt # 从 prompt_toolkit 导入 prompt,用于命令行输入
import torch # 导入 pytorch 库
这部分代码是导入一些使用 RWKV 模型推理时需要用到的包,需要注意以下两点:
- torch 版本最低 1.13 ,推荐 2.x+cu121
- 需要先
pip install rwkv
# 优化 PyTorch 设置,允许使用 tf32
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_JIT_ON"] = "1" # 启用 JIT 编译
os.environ["RWKV_CUDA_ON"] = "0" # 禁用原生 CUDA 算子,改成 '1' 表示启用 CUDA 算子(速度更快,但需要 c++ 编译器和 CUDA 库)
这里是一些加快推理速度的 torch 设置和操作环境的优化项。
from rwkv.model import RWKV # 从 RWKV 模型库中导入 RWKV 类,用于加载和操作 RWKV 模型。
from rwkv.utils import PIPELINE # 从 RWKV 工具库中导入 PIPELINE,用于数据的编码和解码
args = types.SimpleNamespace()
args.strategy = "cuda fp16" # 模型推理的设备和精度,使用 CUDA (GPU)并采用 FP16 精度
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6" # 指定 RWKV 模型的路径,建议写绝对路径
这一段引入了 RWKV 工具包中的两个工具类 RWKV 和 PIPELINE ,同时指定了加载 RWKV 模型的设备和精度,以及 RWKV 模型的本地文件路径。
args.strategy
会影响模型的生成效果和生成速度,chatRWKV 支持下表中的 strategy:
下表中 ,fp16i8
指在 fp16 精度基础上进行 int8 量化。
量化可以减少 VRAM 需求,但在精度上略逊于 fp16。因此只要 VRAM 够用,尽量使用 fp16 层。
策略 | VRAM & RAM | 速度 |
---|---|---|
cpu fp32 | 7B 模型需要 32GB 内存 | 使用 CPU fp32 精度加载模型,适合 Intel。对 AMD 非常慢,因为 pytorch 的 cpu gemv 在 AMD 上有问题,并且只会运行在一个单核上。 |
cpu bf16 | 7B 模型需要 16GB 内存 | 使用 CPU bf16 精度加载模型。在支持 bfloat16 的新 Intel CPU(如 Xeon Platinum)上速度较快。 |
cpu fp32i8 | 7B 模型需要 12GB 内存 | 使用 CPU int8 量化精度加载模型。速度较慢(比 cpu fp32 更慢)。 |
cuda fp16 | 7B 模型需要 15GB VRAM | 使用 fp16 精度加载模型所有层,速度最快,但对显存(VRAM)的需求也最高。 |
cuda fp16i8 | 7B 模型需要 9GB VRAM | 使用 int8 量化模型所有层,速度较快。如果设置 os.environ["RWKV_CUDA_ON"] = '1' 来编译 CUDA 内核,可减少 1~2GB VRAM 使用。 |
cuda fp16i8 *20 -> cuda fp16 | VRAM 占用介于 fp16 和 fp16i8 之间 | 将模型的前 20 层(*20 指层数)量化为 fp16i8,其余层使用 fp16 加载。 如果量化后还有较多 VRAM ,则酌情减少 fp16i8 层数(减少 20)。 如果 VRAM 不足则继续增加 fp16i8 量化层数 |
cuda fp16i8 *20+ | 比 fp16i8 使用更少 VRAM | 将模型的前 20 层(*20 指层数)量化为 fp16i8 并固定在 GPU 上,其他层按需动态加载(未固定的层加载速度会慢 3 倍,但节省 VRAM)。 如果 VRAM 不足,减少固定层数(*20 )。 如果 VRAM 充足,增加固定层数。 |
cuda fp16i8 *20 -> cpu fp32 | 比 fp16i8 使用更少 VRAM,但消耗更多内存 | 将模型的前 20 层(*20 )量化为 fp16i8 并固定在 GPU 上,其他层使用 CPU fp32 加载。当 CPU 性能比较强时,此策略比上一个策略(只在 GPU 上固定 20 层)更快。 如果加载 20 层还有剩余 VRAM ,则继续增加 GPU 层数。 如果没有足够 VRAM,减少 GPU 层数。 |
cuda:0 fp16 *20 -> cuda:1 fp16 | 使用双卡驱动模型 | 使用 cuda:0(卡1) fp16 加载模型的前 20 层,然后使用 cuda:1(卡2) fp16 加载剩余的层(自动计算剩余层数)。 建议在最快的 GPU 上运行更多层。 如果某张卡的 VRAM 不够,可以将 fp16 换成 fp16i8 (int8 量化)。 |
# STATE_NAME = None # 不使用 State
# 指定要加载的 State 文件路径。
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048" # 指定要加载的自定义 State 文件路径。
这一段决定是否要加载一个 State 文件,"None"
表示不加载自定义 State ,如需加载请填写 State 文件的绝对路径。
State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)
RWKV State 的介绍和用法可以参照 State 文件介绍和用法 (opens in a new tab)文章。
# 设置模型的解码参数
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
# 判断是否加载了一个 State 文件。如果指定了 State ,则调整生成参数,使回答的效果更好。
if STATE_NAME != None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
CHUNK_LEN = 256 # 对输入进行分块处理
这里主要是设置加载或不加载 State 时, RWKV 模型分别使用哪些解码参数。
有关 RWKV 解码参数的含义和作用,请查看RWKV 解码参数文档 (opens in a new tab)。
指定一个自定义 State 文件后,我们希望模型能更好地遵循 State 中的格式和风格,所以调低了 topp 参数和惩罚参数。
CHUNK_LEN
将输入文本切分成指定大小的块。这个数值越大,模型并行处理的文本越多,但使用的显存也更多。在显存不足时建议调整到 128 或者 64。
print(f"Loading model - {args.MODEL_NAME}")# 打印模型的加载消息
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy) # 加载 RWKV 模型。
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") # 初始化 PIPELINE ,使用 RWKV-World 词表处理输入和输出的编码/解码。
这一段开始使用前面设置的strategy和解码参数加载 RWKV 模型。
如果你希望模型加载完后也有提示,可以在这一段末尾插入:print(f"{args.MODEL_NAME} - 模型加载完毕")
model_tokens = []
model_state = None
# 如果指定了 STATE_NAME,则加载自定义 State 文件,并初始化模型 State
if STATE_NAME != None:
args = model.args # 获取模型参数
state_raw = torch.load(STATE_NAME + '.pth') # 从指定的 State 文件中加载 State 数据
state_init = [None for i in range(args.n_layer * 3)] # 初始化状态列表
for i in range(args.n_layer): #开始循环,遍历每一层。
dd = model.strategy[i] # 获取模型每一层的加载策略
dev = dd.device # 获取每一层的加载设备(如 GPU)
atype = dd.atype # 获取每一层的数据类型(FP32/FP16 或 int8 等)
# 初始化模型的状态
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init) # 复制初始化的状态
这一段代码用于加载自定义的 State 文件,将其写入模型的初始化 State 中。
通常无需修改这部分代码。
def run_rnn(ctx):
# 定义两个全局变量,用于更新 token 和模型状态(state)
global model_tokens, model_state
ctx = ctx.replace("\r\n", "\n") # 将文本中的 CRLF(Windows 系统的换行符)转换为 LF(Linux 系统的换行符)
tokens = pipeline.encode(ctx) # 基于 RWKV 模型的词汇表,将文本编码为 tokens
tokens = [int(x) for x in tokens] # 将 tokens 转换为整数(int)列表,确保类型一致性
model_tokens += tokens # 将 tokens 添加到全局的模型 token 列表中
while len(tokens) > 0: # 使用一个 while 循环执行模型前向传播,直到所有 tokens 处理完毕
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state) # 模型前向传播,处理大小为 CHUNK_LEN 的 token 列表,并更新模型状态
tokens = tokens[CHUNK_LEN:] # 移除已处理的 tokens 块,并继续处理剩余的 tokens
return out # 返回模型的 prefill 结果
这是控制 RWKV 模型使用 RNN 模式进行 prefill 的函数,这个函数会将 ctx(前文)切成长度为 CHUNK_LEN 的段落,一段段送入 RNN 处理,最后得到处理完前文后的 model_state 和 out 。
这个函数接收一个 ctx 参数,通常是文本(string)。然后依次对文本和文本转化的 token 进行了几项处理:
- 使用
replace
方法将文本的换行符统一为\n
,因为 RWKV 模型的训练数据集使用\n
作为标准换行符格式。 - 使用
pipeline.encode
方法,将用户的输入文本按照 RWKV-World 词表转换成对应的 token 。 - 将 tokens 转换为整数(int)列表,确保类型一致性
- 基于当前 token 前向传播,并行处理输入文本,更新模型状态并返回 out
注意,函数返回的 out
不是具体的 token 或文本,它返回的是模型对下一个 token 的原始预测(张量)。
要将 out
转换为实际的 token 或文本,需要通过采样(例如后文中的 pipeline.sample_logits
函数)预测下一个 token ,再从 token decode 成文本。
# 如果没有加载自定义 State ,则使用初始提示进行对话
if STATE_NAME == None:
init_ctx = "User: hi" + "\n\n"
init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
run_rnn(init_ctx) # 运行 RNN 模式对初始提示文本进行 prefill
print(init_ctx, end="") # 打印初始化对话文本
如果未加载任何 State 文件,则使用一段默认的对话文本进行 prefill 。
# 从用户输入中读取消息、循环生成下一个 token
while True:
msg = prompt("User: ") # 从用户输入中读取消息,存到 msg 变量
msg = msg.strip() # 使用 strip 方法去除消息的首尾空格
msg = re.sub(r"\n+", "\n", msg) # 替换多个换行符为单个换行符
if len(msg) > 0: # 如果处理完后,用户输入的消息非空
occurrence = {} # 使用 occurrence 字典这个字典用于记录每个 token 在生成上下文中出现的次数,等会用在实现重复惩罚(Penalty)
out_tokens = [] # 使用 out_tokens 列表记录即将输出的 tokens
out_last = 0 # 用于记录上一次生成的 token 位置
out = run_rnn("User: " + msg + "\n\nAssistant:") # 将用户输入拼接成 RWKV 数据集的对话格式,进行 prefill
print("\nAssistant:", end="") # 打印 "Assistant:" 标签
for i in range(99999):
for n in occurrence:
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # 应用存在惩罚和频率惩罚参数
out[0] -= 1e10 # 禁用 END_OF_TEXT
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P) # 采样生成下一个 token
out, model_state = model.forward([token], model_state) # 模型前向传播
model_tokens += [token]
out_tokens += [token] # 将新生成的 token 添加到输出的 token 列表中
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay # 应用衰减重复惩罚
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) # 更新 token 的出现次数
tmp = pipeline.decode(out_tokens[out_last:]) # 将最新生成的 token 解码成文本
if ("\ufffd" not in tmp) and (not tmp.endswith("\n")): # 当生成的文本是有效 UTF-8 字符串且不以换行符结尾时
print(tmp, end="", flush=True) #实时打印解码得到的文本
out_last = i + 1 #更新输出位置变量 out_last
if "\n\n" in tmp: # 如果生成的文本包含双换行符,表示模型的响应已结束(可以将 \n\n 改成其他停止词)
print(tmp, end="", flush=True) # 实时打印解码得到的文本
break #结束本轮推理
else:
print("!!! Error: please say something !!!") # 如果用户没有输入消息,提示“输入错误,说点啥吧!”
这一段是循环检测用户输入、并使用 RNN 模式进行推理,生成文本的功能代码。
以上代码的主要逻辑如下:
- 接收用户消息,规范空格空行,判断输入文本的内容长度
- 如果规范后用户输入为空,则提示“请说点什么”
- 如果规范后用户的输入非空,则进入步骤 2
- 将用户的输入拼接成聊天格式的 prompt ,然后进行 prefill ,获得 logits
- 预测 token ,并打印解码得到的文本字符
- 应用存在惩罚(GEN_alpha_presence)和频率惩罚(GEN_alpha_frequency)
- 基于 temperature 和 topp 参数对
out
进行采样,获得下一个 token - 使用新 token 前向传播,开启下一轮预测
- 应用惩罚衰减参数(penalty_decay)调整 token 生成的概率
- 把已经生成的 token 列表解码(decode)成字符文本
- 实时输出解码得到的字符文本,判断文本里面有没有 \n\n 停止词。如果出现停止词,则退出本轮推理。
从推理过程可以看出,模型在每个时间步都更新隐藏状态(State),并利用当前的隐藏状态来生成下一个时间步的输出。这符合 RNN 的核心特性: 模型的每次输出依赖于前一步的生成结果。