llama.cpp 推理方法
ℹ️

Llama.cpp 是一个轻量化的大语言模型运行框架,专门优化了在 CPU 上运行模型的性能。

随着 RWKV 社区成员 @MollySophia (opens in a new tab) 的工作,llama.cpp 现已适配 RWKV-6 模型。

本章节介绍如何在 llama.cpp 中使用 RWKV-6 模型进行推理。

llama.cpp 推理 RWKV 模型

本地构建 llama.cpp

可以选择从 llama.cpp 的 release 页面 (opens in a new tab)下载已编译的 llama.cpp 程序。

也可以参照 llama.cpp 官方构建文档 (opens in a new tab),选择适合的方法本地编译构建。

获取 gguf 格式模型

llama.cpp 支持 .gguf 格式的模型,但 RWKV 官方仅发布了 .pth 格式模型。因此,我们需要使用以下两种方法之一获取 .gguf 格式的 RWKV 模型。

可以从 RWKV - GGUF 仓库 (opens in a new tab) 下载已量化并转化成 gguf 格式的 RWKV 模型。

⚠️

建议使用 Q5_1Q8_0 两种量化精度的 gguf 模型,更低的量化可能会使模型的回答变得很差。

运行 RWKV 模型推理

在 llama.cpp 目录运行以下命令,可驱动 RWKV 模型基于 prompt 生成文本:

./build/bin/llama-cli -m ./v6-Finch-1B6-HF/v6-Finch-1.6B-HF-F16.gguf --no-warmup -p "User: Write me a poem\n\nAssistant:" -t 8 -ngl 25 -n 500

这条命令通过 llama-cli 运行 RWKV 模型,使用 8 个线程、跳过预热、并根据给定的 prompt 生成最多 500 个 token。

RWKV 模型推理

参数解释:

  • ./build/bin/llama-cli:编译好的 llama-cli 程序,打开命令化界面
  • -m ./v6-Finch-1B6-HF/v6-Finch-1.6B-HF-F16.gguf:模型的路径参数
  • --no-warmup:跳过“预热”阶段,直接开始生成文本(以少量性能换取速度)
  • -p "User: Write me a poem\n\nAssistant:": prompt 参数,模型根据该提示词生成文本。"User: Write me a poem\n\nAssistant:" 是符合 RWKV 模型格式的 prompt,更多 RWKV prompt 格式请在RWKV文档-提示词指南 (opens in a new tab)中查看。
  • -t 8:-t 指定线程数,建议根据可用的物理 CPU 核心数调整
  • - ngl:指定模型使用的 n-gpu-layers ,25 是在 GPU 上运行 25 层(1.6B 的 24层 + head 算一层)。可以无脑设定 -ngl 99,使 llama.cpp 加载 RWKV 模型所有层
  • -n 500:-n 参数表示生成的最大 token 数
ℹ️

完整的参数列表可以在 llama.cpp 参数文档 (opens in a new tab)中查看。

附加功能(可选)

批量推理生成

使用以下命令,以进行批量推理:

ℹ️

使用 \n 隔开不同的 prompt

./build/bin/llama-parallel -ns 4 -np 4 -m v6-Finch-1B6-HF/v6-Finch-1.6B-HF-F16.gguf --no-warmup -ngl 25 -n 500 -p "Who are you?\nWhat do we have for dinner?\nWhat's the meaning of life\nHello\nWhat is the end of the universe?"

批量推理生成

参数解释:

  • -ns 4: n_sequence,推理序列的数量
  • -np 4: n_parallel,并行推理的数量

gguf 模型量化方法

运行以下命令,对 .gguf 模型进行量化:

./build-cuda-rel/bin/llama-quantize v6-Finch-1B6-HF/v6-Finch-1.6B-HF-F16.gguf(量化前的 gguf 模型路径) ./v6-Finch-1B6-HF/v6-Finch-1.6B-HF-Q5_1.gguf(量化后的 gguf 模型路径) Q5_1(量化精度)

所有可选的量化精度:

可选的量化精度

ℹ️

建议使用 Q5_1Q8_0 两种量化精度。

启动 Web 服务

使用以下命令,以启动 llama.cpp 的 Web 服务:

./build/bin/llama-server -m v6-Finch-1B6-HF/v6-Finch-1.6B-HF-F16.gguf --no-warmup -ngl 25

llama.cpp 的 Web 服务

启动后,可以访问 http://127.0.0.1:8080 以检查 Web 页面:

WebUI

chatUI

点击右上方的 New Ul按钮,或者直接访问 http://127.0.0.1:8080/index-new.html ,可以打开新版本的 WebUI

新版本的 WebUI

新版本的 chatUI

搭载 RWKV State 文件

ℹ️

State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)

RWKV State 的介绍可参照 State 文件介绍和用法 (opens in a new tab)文章。

遵循以下步骤,可在 llama.cpp 推理 RWKV 模型时搭载 State 文件:

下载 .pth 格式 State 文件

从以下链接下载 RWKV State 文件:

使用脚本转换 State 文件

由于 llama.cpp for RWKV 只接受 .bin 格式的 State 文件,因此我们需要使用附录中的 convert_rwkv_state_to_llamacpp.py 脚本,将 .pth 格式的 State 文件转换为 .bin 格式。

在脚本目录,运行以下转换命令:

python convert_rwkv_state_to_llamacpp.py rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048.pth rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048-llamacpp.bin

参数解释:

  • rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048.pth:需要转换的 .pth State 文件路径
  • rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048-llamacpp.bin:转换后的 .bin State 文件路径

推理时搭载 State 文件

在 llama.cpp 推理 RWKV 模型时,添加额外的-prompt-cache--prompt-cache-ro 参数,以搭载 RWKV State :

./build/bin/llama-cli -m ../models/rwkv-6-world-3b-v2.1/rwkv-6-world-3B-v2.1-F16.gguf --no-warmup -p "User: Hello\n\nAssistant:" --prompt-cache rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048-llamacpp.bin --prompt-cache-ro`

参数解释:

  • --prompt-cache :加载 prompt cache 文件 ,填写 State 文件路径,如 rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048-llamacpp.bin
  • --prompt-cache-ro : 以只读方模式加载,生成时不要覆盖原 State 文件

附录代码

.pth 模型转 HF 格式

以下代码可将 .pth 模型转化为 HF 格式模型(.bin):

# Convert the model for the pytoch_model.bin
import torch
 
SOURCE_MODEL="./v6-FinchX-14B-pth/rwkv-14b-final.pth"
TARGET_MODEL="./v6-Finch-14B-HF/pytorch_model.bin"
 
# delete target model
import os
if os.path.exists(TARGET_MODEL):
    os.remove(TARGET_MODEL)
 
model = torch.load(SOURCE_MODEL, mmap=True, map_location='cpu')
# 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
# Rename all the keys, to include "rwkv."
new_model = {}
for key in model.keys():
 
    # If the keys start with "blocks"
    if key.startswith("blocks."):
        new_key = "rwkv." + key
        # Replace .att. with .attention.
        new_key = new_key.replace(".att.", ".attention.")
        # Replace .ffn. with .feed_forward.
        new_key = new_key.replace(".ffn.", ".feed_forward.")
        # Replace `0.ln0.` with `0.pre_ln.`
        new_key = new_key.replace("0.ln0.", "0.pre_ln.")
    else:
        # No rename needed
        new_key = key
 
        # Rename `emb.weight` to `rwkv.embeddings.weight`
        if key == "emb.weight":
            new_key = "rwkv.embeddings.weight"
 
        # Rename the `ln_out.x` to `rwkv.ln_out.x
        if key.startswith("ln_out."):
            new_key = "rwkv." + key
 
    print("Renaming key:", key, "--to-->", new_key)
    new_model[new_key] = model[key]
 
# Save the new model
print("Saving the new model to:", TARGET_MODEL)
torch.save(new_model, TARGET_MODEL)

.pth State 文件转成 .bin

以下代码可将 RWKV 官方发布的.pth State 文件,转换为 llama.cpp 可使用的 .bin格式 State 文件:

convert_rwkv_state_to_llamacpp.py
import torch
import argparse
import struct
import array
from pathlib import Path
 
mt19937_state = "0 1 1812433255 1900727105 1208447044 2481403966 4042607538 337614300 3232553940 1018809052 3202401494 1775180719 3192392114 594215549 184016991 829906058 610491522 3879932251 3139825610 297902587 4075895579 2943625357 3530655617 1423771745 2135928312 2891506774 1066338622 135451537 933040465 2759011858 2273819758 3545703099 2516396728 1272276355 3172048492 3267256201 2332199830 1975469449 392443598 1132453229 2900699076 1998300999 3847713992 512669506 1227792182 1629110240 112303347 2142631694 3647635483 1715036585 2508091258 1355887243 1884998310 3906360088 952450269 3647883368 3962623343 3077504981 2023096077 3791588343 3937487744 3455116780 1218485897 1374508007 2815569918 1367263917 472908318 2263147545 1461547499 4126813079 2383504810 64750479 2963140275 1709368606 4143643781 835933993 1881494649 674663333 2076403047 858036109 1667579889 1706666497 607785554 1995775149 1941986352 3448871082 2109910019 1474883361 1623095288 1831376534 2612738285 81681830 2204289242 1365038485 251164610 4268495337 1805601714 1262528768 1442526919 1675006593 965627108 646339161 499795587 840887574 380522518 3023789847 1457635507 1947093157 2600365344 2729853143 1550618999 1390905853 3021294812 882647559 838872117 1663880796 4222103589 2754172275 3844026123 3199260319 4176064873 3591027019 2690294242 2978135515 3172796441 3263669796 1451257057 1427035359 4174826006 2171992010 1537002090 3122405306 4162452508 3271954368 3794310005 3240514581 1270412086 3030475836 2281945856 2644171349 3109139423 4253563838 1289926431 1396919653 733220100 2753316645 1196225013 3699575255 3569440056 2675979228 2624079148 3463113149 863430286 623703199 2113837653 2656425919 175981357 4271478366 4238022735 1665483419 86880610 2963435083 1830392943 847801865 3237296945 332143967 3973606945 2671879697 2236330279 2360127810 3283955434 203240344 4048139172 13189264 2263058814 247241371 1566765783 3084408095 3719371299 1958375251 1985924622 1712739232 1861691451 2644502937 2337807839 784993770 2962208780 2190810177 1523122731 714888527 578678761 3698481324 1801168075 534650483 3390213921 3923356461 3586009066 2059432114 52511333 1969897376 3630122061 524661135 3513619765 563070233 501359785 477489274 658768624 938973567 1548584683 1345287459 2488691004 3441144905 3849305094 2430000078 855172178 614463281 2092744749 176381493 1655802051 2273888101 2474494847 3471978030 2138918303 575352373 1658230985 1675972553 2946663114 915579339 284981499 53939948 3022598146 1861218535 3403620774 4203516930 2360471119 3134536268 1383448498 1307602316 3847663247 3027225131 3597251613 3186237127 725127595 1928526954 1843386923 3560410503 54688266 1791983849 2519860352 4256389699 2328812602 486464275 3578698363 301279829 1303654791 4181868765 971794070 1933885487 3996807464 2144053754 4079903755 3775774765 3481760044 1212862354 1067356423 3764189132 1609862325 2209601551 2565747501 161962392 4045451782 2605574664 2520953090 3490240017 1082791980 44474324 101811128 4268650669 4171338684 772375154 3920460306 2319139534 599033750 2950874441 3373922995 1496848525 4095253594 1271943484 1498723121 3097453329 3698082465 281869581 3148270661 3591477288 747441437 2809508504 3896107498 303747862 2368081624 1844217645 886825352 287949781 1444561207 2512101757 2062331723 741720931 1383797313 3876746355 2041045348 2627599118 1124169970 200524822 3484820454 55883666 1135054804 669498692 2677215504 3097911127 1509628615 617580381 2229022193 85601568 3243896546 3715672328 912168347 2359163500 1180347564 4243175048 2092067103 880183327 4000664709 2045044777 3500474644 1515175520 1862207123 186628841 3337252925 708933575 4015964629 3136815297 3314919747 2891909013 3316567785 3944275369 3608506218 2884839110 3054055598 2707439927 1381111877 3275487281 4292456216 2639563270 3327301876 3576924628 721056309 2002808140 748967365 52380958 2200261692 763456477 1708381337 2038446433 2682979402 1526413779 2211263302 3879771969 75966584 3645059271 2985763524 4085690255 82390958 1883631385 1647521260 1598026998 3038041577 2501913134 3279302868 1738888524 805035483 756399074 3863810982 1097797270 1505792529 898904527 583561003 717152376 3333867738 1099456544 1663473545 1242141229 3828627682 1966201676 1713552361 3852160017 1584965284 21695908 1013262144 145341901 3995441263 3462066219 2239637848 1214086163 2428868268 1650037305 1545513388 1621198806 4232947817 1823092073 256414624 1745018809 1357102386 2055139770 3280958307 2482431613 1664870585 859130423 4097751123 3079768369 2470211009 2984880786 2808568948 2877071923 1984903163 302768457 1866396789 869566317 3746415787 4169433075 3025005404 3980733379 3539207278 3953071536 876960847 2548872156 800507464 1865466907 1273317878 3754712872 1757188269 3229950355 3731640200 2283390608 2204990292 411873449 447423849 1852437802 472825525 3044219944 2913114194 1859709265 4053786194 574820536 2104496732 865469814 2438352724 4208743605 4215067542 1364015250 4139974345 3838747005 1818502786 2914274940 1402365828 1751123528 2302578077 2463168652 1968705496 1730700144 3023943273 1139096844 2658667767 2063547264 705791165 1444775274 2415454225 1575664730 921044163 648101324 1212387162 4191962054 1787702169 1888718041 1518218010 3398792842 4079359729 149721439 750400353 2661036076 3802767886 520152586 951852508 2939585975 1375969109 385733137 3523607459 1902438415 4250996086 2712727066 484493674 3932107461 1428488210 1764242548 3424801055 4004904451 2226862072 2393366939 3609584727 3614444319 317349896 3826527525 204023804 981902443 3356042039 3051207045 1869902661 561831895 3706675415 1527687593 1227610446 2596341042 3191717368 3269246891 557877074 4062070629 3052520266 3772487029 400039836 3195205275 4085394797 1655557239 1345770144 2864727192 449281238 73189507 528365765 2727400656 247880434 2408277395 777039183 2210179398 1088433648 2124356402 1555630141 604790219 195012151 3312518356 923728373 3999251660 3313059535 3478133921 3395026960 383464614 3425869222 2446885186 4032184426 157195416 3158909476 1663750443 2046427584 1658453076 1784483001 3146546889 1238739785 2297306523 3472330897 2953326031 2421672215 1221694592 1588568605 2546987845 3375168573 2137961649 3056565164 330165219 235900365 1000384800 2697255904 579122283 3050664825 73426122 1232986102 2940571064 3076486824 1708182873 2796363264 292154131 4280019913 1102652157 1185393592 1494991690 4270076389 2384840717 425785147 2385321880 317514772 3926962743 392176856 3465421709 1878853468 122662664 2958252160 1858961315 2244939588 2361884409 2860936803 683833250 3291277128 1686857206 1112632275 1200680507 3342928196 2677058150 939442136 3407104669 2906783932 3668048733 2030009470 1910839172 1234925283 3575831445 123595418 2362440495 3048484911 1796872496 624"
mt19937_state = bytes(array.array('B', [ord(c) for c in mt19937_state]))
 
def convert_rwkv_state_to_llamacpp_cache(path_input, path_output):
    state = torch.load(path_input, map_location='cpu')
    n_head, head_size, _ = state['blocks.0.att.time_state'].shape
    n_embd = n_head * head_size
    n_layer = len(state.keys())
    print(f'n_layer: {n_layer}, n_embd: {n_embd}')
    
    LLAMA_SESSION_MAGIC = 0x6767736e
    LLAMA_SESSION_VERSION = 8
 
    with open(path_output, 'wb') as file:
        # magic, version, n_token_count
        file.write(struct.pack('3I', LLAMA_SESSION_MAGIC, LLAMA_SESSION_VERSION, 0))
        # model arch
        file.write(struct.pack('I5s', 5, b'rwkv6'))
        # rng seed (may need a PR in llama.cpp to be able to skip this too)
        file.write(struct.pack(f'I{len(mt19937_state)}s', len(mt19937_state), mt19937_state))
        # session output ids, logits, embeddings (skip)
        file.write(struct.pack('=IQQ', 0, 0, 0))
        # cell_count, pos, n_seq_id, v_trans, n_layer
        file.write(struct.pack('=5I', 1, 0, 0, 0, n_layer))
        for _ in range(n_layer):
            # k_type = GGML_TYPE_F32, k_size = 2 * n_embd * sizeof(float)
            file.write(struct.pack('=iq', 0, 2 * n_embd * 4))
            for _ in range(2*n_embd):
                file.write(struct.pack('=f', 0.0))
            
        for i in range(n_layer):
            # v_type = GGML_TYPE_F32, v_size = n_head * head_size * head_size * sizeof(float)
            file.write(struct.pack('=iq', 0, n_head * head_size * head_size * 4))
            data = state[f'blocks.{i}.att.time_state'].float().transpose(1, 2).flatten().tolist()
            file.write(struct.pack(f'{n_head * head_size * head_size}f', *data))
 
def main():
    parser = argparse.ArgumentParser(description='Convert RWKV state to llama.cpp prompt cache')
    parser.add_argument('input', type=Path, help='Path to RWKV state')
    parser.add_argument('output', type=Path, help='File name for llama.cpp prompt cache')
    args = parser.parse_args()
 
    convert_rwkv_state_to_llamacpp_cache(args.input, args.output)
 
if __name__ == '__main__':
    main()