RWKV 开发者教程
RWKV pip 库使用指南

以下内容将指引你使用 RWKV pip 库 (opens in a new tab)开发基于 RWKV 模型的应用

RWKV pip 库的原始代码可以在 ChatRWKV (opens in a new tab) 仓库中找到。

API_DEMO_CHAT.py 详解

ℹ️

API_DEMO_CHAT (opens in a new tab) 是一个基于 RWKV pip 库的开发 Demo,用于实现基于命令行的聊天机器人

下文将以详细的注释,分段介绍这个聊天机器人 DEMO 的代码设计。

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
 
print("RWKV Chat Simple Demo")  # 打印一个简单的消息,表明这是 RWKV 聊天的简单演示。
import os, copy, types, gc, sys, re  # 导入操作系统、对象复制、类型、垃圾回收、系统、正则表达式等包
import numpy as np  # 导入 numpy 库
from prompt_toolkit import prompt  # 从 prompt_toolkit 导入 prompt,用于命令行输入
import torch  # 导入 pytorch 库

这部分代码是导入一些使用 RWKV 模型推理时需要用到的包,需要注意以下两点:

  • torch 版本最低 1.13 ,推荐 2.x+cu121
  • 需要先 pip install rwkv

# 优化 PyTorch 设置,允许使用 tf32 
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
 
os.environ["RWKV_JIT_ON"] = "1" # 启用 JIT 编译
os.environ["RWKV_CUDA_ON"] = "0"  # 禁用原生 CUDA 算子,改成 '1' 表示启用 CUDA 算子(速度更快,但需要 c++ 编译器和 CUDA 库)

这里是一些加快推理速度的 torch 设置和操作环境的优化项。


from rwkv.model import RWKV  # 从 RWKV 模型库中导入 RWKV 类,用于加载和操作 RWKV 模型。
from rwkv.utils import PIPELINE  # 从 RWKV 工具库中导入 PIPELINE,用于数据的编码和解码
 
args = types.SimpleNamespace()
 
args.strategy = "cuda fp16"  # 模型推理的设备和精度,使用 CUDA (GPU)并采用 FP16 精度
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6"  # 指定 RWKV 模型的路径,建议写绝对路径

这一段引入了 RWKV 工具包中的两个工具类 RWKV 和 PIPELINE ,同时指定了加载 RWKV 模型的设备精度,以及 RWKV 模型的本地文件路径。

args.strategy 会影响模型的生成效果和生成速度,chatRWKV 支持下表中的 strategy:

ℹ️

下表中 ,fp16i8 指在 fp16 精度基础上进行 int8 量化。

量化可以减少 VRAM 需求,但在精度上略逊于 fp16。因此只要 VRAM 够用,尽量使用 fp16 层。

策略VRAM & RAM速度
cpu fp327B 模型需要 32GB 内存使用 CPU fp32 精度加载模型,适合 Intel。对 AMD 非常慢,因为 pytorch 的 cpu gemv 在 AMD 上有问题,并且只会运行在一个单核上。
cpu bf167B 模型需要 16GB 内存使用 CPU bf16 精度加载模型。在支持 bfloat16 的新 Intel CPU(如 Xeon Platinum)上速度较快。
cpu fp32i87B 模型需要 12GB 内存使用 CPU int8 量化精度加载模型。速度较慢(比 cpu fp32 更慢)。
cuda fp167B 模型需要 15GB VRAM使用 fp16 精度加载模型所有层,速度最快,但对显存(VRAM)的需求也最高。
cuda fp16i87B 模型需要 9GB VRAM使用 int8 量化模型所有层,速度较快。如果设置 os.environ["RWKV_CUDA_ON"] = '1' 来编译 CUDA 内核,可减少 1~2GB VRAM 使用。
cuda fp16i8 *20 -> cuda fp16VRAM 占用介于 fp16 和 fp16i8 之间将模型的前 20 层(*20 指层数)量化为 fp16i8,其余层使用 fp16 加载。 如果量化后还有较多 VRAM ,则酌情减少 fp16i8 层数(减少 20)。 如果 VRAM 不足则继续增加 fp16i8 量化层数
cuda fp16i8 *20+比 fp16i8 使用更少 VRAM将模型的前 20 层(*20 指层数)量化为 fp16i8 并固定在 GPU 上,其他层按需动态加载(未固定的层加载速度会慢 3 倍,但节省 VRAM)。 如果 VRAM 不足,减少固定层数(*20)。 如果 VRAM 充足,增加固定层数。
cuda fp16i8 *20 -> cpu fp32比 fp16i8 使用更少 VRAM,但消耗更多内存将模型的前 20 层(*20)量化为 fp16i8 并固定在 GPU 上,其他层使用 CPU fp32 加载。当 CPU 性能比较强时,此策略比上一个策略(只在 GPU 上固定 20 层)更快。 如果加载 20 层还有剩余 VRAM ,则继续增加 GPU 层数。 如果没有足够 VRAM,减少 GPU 层数。
cuda:0 fp16 *20 -> cuda:1 fp16使用双卡驱动模型使用 cuda:0(卡1) fp16 加载模型的前 20 层,然后使用 cuda:1(卡2) fp16 加载剩余的层(自动计算剩余层数)。 建议在最快的 GPU 上运行更多层。 如果某张卡的 VRAM 不够,可以将 fp16 换成 fp16i8 (int8 量化)。

# STATE_NAME = None # 不使用 State
 
# 指定要加载的 State 文件路径。
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"  # 指定要加载的自定义 State 文件路径。

这一段决定是否要加载一个 State 文件,"None" 表示不加载自定义 State ,如需加载请填写 State 文件的绝对路径。

ℹ️

State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)

RWKV State 的介绍和用法可以参照 State 文件介绍和用法 (opens in a new tab)文章。


# 设置模型的解码参数
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
 
# 判断是否加载了一个 State 文件。如果指定了 State ,则调整生成参数,使回答的效果更好。
if STATE_NAME != None:
    GEN_TOP_P = 0.2
    GEN_alpha_presence = 0.3
    GEN_alpha_frequency = 0.3
 
CHUNK_LEN = 256  # 对输入进行分块处理

这里主要是设置加载或不加载 State 时, RWKV 模型分别使用哪些解码参数。

ℹ️

有关 RWKV 解码参数的含义和作用,请查看RWKV 解码参数文档 (opens in a new tab)

指定一个自定义 State 文件后,我们希望模型能更好地遵循 State 中的格式和风格,所以调低了 topp 参数和惩罚参数

CHUNK_LEN 将输入文本切分成指定大小的块。这个数值越大,模型并行处理的文本越多,但使用的显存也更多。在显存不足时建议调整到 128 或者 64。


print(f"Loading model - {args.MODEL_NAME}")# 打印模型的加载消息
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)  # 加载 RWKV 模型。
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")  # 初始化 PIPELINE ,使用 RWKV-World 词表处理输入和输出的编码/解码。

这一段开始使用前面设置的strategy解码参数加载 RWKV 模型。

如果你希望模型加载完后也有提示,可以在这一段末尾插入:print(f"{args.MODEL_NAME} - 模型加载完毕")


model_tokens = []
model_state = None
 
# 如果指定了 STATE_NAME,则加载自定义 State 文件,并初始化模型 State
if STATE_NAME != None:
    args = model.args  # 获取模型参数
    state_raw = torch.load(STATE_NAME + '.pth')  # 从指定的 State 文件中加载 State 数据
    state_init = [None for i in range(args.n_layer * 3)]  # 初始化状态列表
    for i in range(args.n_layer): #开始循环,遍历每一层。
        dd = model.strategy[i]  # 获取模型每一层的加载策略
        dev = dd.device  # 获取每一层的加载设备(如 GPU)
        atype = dd.atype  # 获取每一层的数据类型(FP32/FP16 或 int8 等)
        # 初始化模型的状态
        state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
        state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
        state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
    model_state = copy.deepcopy(state_init)  # 复制初始化的状态
 

这一段代码用于加载自定义的 State 文件,将其写入模型的初始化 State 中。

通常无需修改这部分代码。


def run_rnn(ctx):
    # 定义两个全局变量,用于更新 token 和模型状态(state)
    global model_tokens, model_state
    ctx = ctx.replace("\r\n", "\n")  # 将文本中的 CRLF(Windows 系统的换行符)转换为 LF(Linux 系统的换行符)
    tokens = pipeline.encode(ctx)  # 基于 RWKV 模型的词汇表,将文本编码为 tokens
    tokens = [int(x) for x in tokens]  # 将 tokens 转换为整数(int)列表,确保类型一致性
    model_tokens += tokens  # 将 tokens 添加到全局的模型 token 列表中
 
    while len(tokens) > 0:  # 使用一个 while 循环执行模型前向传播,直到所有 tokens 处理完毕
        out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)  # 模型前向传播,处理大小为 CHUNK_LEN 的 token 列表,并更新模型状态
        tokens = tokens[CHUNK_LEN:]  # 移除已处理的 tokens 块,并继续处理剩余的 tokens
 
    return out  # 返回模型的 prefill 结果

这是控制 RWKV 模型使用 RNN 模式进行 prefill 的函数,这个函数会将 ctx(前文)切成长度为 CHUNK_LEN 的段落,一段段送入 RNN 处理,最后得到处理完前文后的 model_state 和 out 。

这个函数接收一个 ctx 参数,通常是文本(string)。然后依次对文本和文本转化的 token 进行了几项处理:

  1. 使用 replace 方法将文本的换行符统一为\n ,因为 RWKV 模型的训练数据集使用 \n 作为标准换行符格式。
  2. 使用 pipeline.encode 方法,将用户的输入文本按照 RWKV-World 词表转换成对应的 token 。
  3. 将 tokens 转换为整数(int)列表,确保类型一致性
  4. 基于当前 token 前向传播,并行处理输入文本,更新模型状态并返回 out
⚠️

注意,函数返回的 out 不是具体的 token 或文本,它返回的是模型对下一个 token 的原始预测(张量)。

要将 out 转换为实际的 token 或文本,需要通过采样(例如后文中的 pipeline.sample_logits 函数)预测下一个 token ,再从 token decode 成文本


# 如果没有加载自定义 State ,则使用初始提示进行对话
if STATE_NAME == None:
    init_ctx = "User: hi" + "\n\n"
    init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
    run_rnn(init_ctx)  # 运行 RNN 模式对初始提示文本进行 prefill
    print(init_ctx, end="")  # 打印初始化对话文本

如果未加载任何 State 文件,则使用一段默认的对话文本进行 prefill 。


# 从用户输入中读取消息、循环生成下一个 token 
while True:
    msg = prompt("User: ")  # 从用户输入中读取消息,存到 msg 变量
    msg = msg.strip()  # 使用 strip 方法去除消息的首尾空格
    msg = re.sub(r"\n+", "\n", msg)  # 替换多个换行符为单个换行符
    if len(msg) > 0:  # 如果处理完后,用户输入的消息非空
        occurrence = {}  # 使用 occurrence 字典这个字典用于记录每个 token 在生成上下文中出现的次数,等会用在实现重复惩罚(Penalty)
        out_tokens = []  # 使用 out_tokens 列表记录即将输出的 tokens
        out_last = 0  # 用于记录上一次生成的 token 位置
 
        out = run_rnn("User: " + msg + "\n\nAssistant:")  # 将用户输入拼接成 RWKV 数据集的对话格式,进行 prefill  
        print("\nAssistant:", end="")  # 打印 "Assistant:" 标签
 
        for i in range(99999):  
            for n in occurrence: 
                out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency  # 应用存在惩罚和频率惩罚参数
            out[0] -= 1e10  # 禁用 END_OF_TEXT 
 
            token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)  # 采样生成下一个 token
 
            out, model_state = model.forward([token], model_state)  # 模型前向传播
            model_tokens += [token] 
            out_tokens += [token]  # 将新生成的 token 添加到输出的 token 列表中
 
            for xxx in occurrence:
                occurrence[xxx] *= GEN_penalty_decay  # 应用衰减重复惩罚
            occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)  # 更新 token 的出现次数
 
            tmp = pipeline.decode(out_tokens[out_last:])  # 将最新生成的 token 解码成文本
            if ("\ufffd" not in tmp) and (not tmp.endswith("\n")):  # 当生成的文本是有效 UTF-8 字符串且不以换行符结尾时
                print(tmp, end="", flush=True) #实时打印解码得到的文本
                out_last = i + 1 #更新输出位置变量 out_last 
 
            if "\n\n" in tmp:  # 如果生成的文本包含双换行符,表示模型的响应已结束(可以将 \n\n 改成其他停止词)
                print(tmp, end="", flush=True) # 实时打印解码得到的文本
                break #结束本轮推理
    else:
        print("!!! Error: please say something !!!")  # 如果用户没有输入消息,提示“输入错误,说点啥吧!”

这一段是循环检测用户输入、并使用 RNN 模式进行推理,生成文本的功能代码。

以上代码的主要逻辑如下:

  1. 接收用户消息,规范空格空行,判断输入文本的内容长度
    • 如果规范后用户输入为空,则提示“请说点什么”
    • 如果规范后用户的输入非空,则进入步骤 2
  2. 将用户的输入拼接成聊天格式的 prompt ,然后进行 prefill ,获得 logits
  3. 预测 token ,并打印解码得到的文本字符
    • 应用存在惩罚(GEN_alpha_presence)和频率惩罚(GEN_alpha_frequency)
    • 基于 temperature 和 topp 参数对 out 进行采样,获得下一个 token
    • 使用新 token 前向传播,开启下一轮预测
    • 应用惩罚衰减参数(penalty_decay)调整 token 生成的概率
    • 把已经生成的 token 列表解码(decode)成字符文本
    • 实时输出解码得到的字符文本,判断文本里面有没有 \n\n 停止词。如果出现停止词,则退出本轮推理。

从推理过程可以看出,模型在每个时间步都更新隐藏状态(State),并利用当前的隐藏状态来生成下一个时间步的输出。这符合 RNN 的核心特性: 模型的每次输出依赖于前一步的生成结果