动手实现和优化BPE Tokenizer的训练——第11部分:使用cython封装c++代码

Posted by lili on September 25, 2025

本系列文章完成Stanford CS336作业1的一个子任务——实现BPE Tokenizer的高效训练算法。通过一系列优化,我们的算法在OpenWebText上的训练时间从最初的10多个小时优化到小于10分钟。本系列文章解释这一系列优化过程,包括:算法的优化,数据结构的优化,并行(openmp)优化,cython优化,用c++实现关键代码和c++库的cython集成等内容。本文是第十二篇,也是最后一篇,使用cython把之前的c++代码封装成扩展模块供python调用。

目录

1. 目标

前面的文章我们已经做了很多探索,bpe训练可以分为两个阶段。第一个阶段是分词和统计,我们通过multiprocessing进行并行计算,实现的bpe_v3.py利用32核可以在120秒内完成。而读取bytes的bpe_v3_bytes_time.py可以使用64核在70秒内完成。第二个阶段是合并,python版本的bpe_v8_v3.py可以在500秒内完成,总时间是十多分钟。c++版本的合并最快的是bpe_train_updater_fine_grained_heap_emhash8_set9_opt.cpp,它的合并时间是100秒左右。

今天我们的目标就是实现一个最快的版本,通过cython封装c++代码成扩展模块,这样第二个阶段就可以调用这个模块。

2. 把c++代码封装成一个动态库

我们在cppupdate里做了很多实验,现在我们需要把结果最好的一些版本封装成动态库,这样方便后续cython使用。

我们会新建一个cppstep2的c++项目,它的完整代码在cppstep2

根据cppupdate的实验结果,我选择了bpe_train_updater_fine_grained_emhash8.cppbpe_train_updater_fine_grained_emhash8_set.cppbpe_train_updater_fine_grained_emhash8_set9.cppbpe_train_updater_fine_grained_heap_emhash8_set9.cppbpe_train_updater_fine_grained_heap_emhash8_set9_opt.cpp

选择emhash而不是absl的原因有二:一是它比较快;二是用它比较简单。我之前是git clone了emhash的完整代码,这里为了简化,我只是复制了需要的3个头文件:hash_set8.hpp, hash_set4.hpp和hash_table8.hpp。另外max_heap.h和max_heap.cpp也从cppupdate项目原封不动的复制了过来。

我们的动态库主要就两个文件:头文件bpe_train_step2.h和实现文件bpe_train_step2.cpp

2.1 头文件

#pragma once

#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "emhash/hash_table8.hpp"

struct pair_hash {
    template <class T1, class T2>
    std::size_t operator () (const std::pair<T1,T2> &p) const {
        auto h1 = std::hash<T1>{}(p.first);
        auto h2 = std::hash<T2>{}(p.second);

        return h1 ^ (h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2));
    }
};

void bpe_train_step2(int vocab_size, 
                emhash8::HashMap<std::pair<int, int>, int, pair_hash> & pair_counts, 
                std::unordered_map<std::pair<int, int>, std::vector<std::vector<int>>, pair_hash> & pair_strings, 
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                std::unordered_map<std::pair<int, int>, std::unordered_set<int>, pair_hash> & pair_wordids, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges);

void bpe_train_step2_v2(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges);

void bpe_train_step2_v3(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges);   
                
void bpe_train_step2_v4(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges);  

void bpe_train_step2_v5(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges);  

void bpe_train_step2_v6(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges); 

这个头文件定义了bpe_train_step2_v2、bpe_train_step2_v3、bpe_train_step2_v4、bpe_train_step2_v5和bpe_train_step2_v6这5个函数,它们分别对应bpe_train_updater_fine_grained_emhash8.cppbpe_train_updater_fine_grained_emhash8_set.cppbpe_train_updater_fine_grained_emhash8_set9.cppbpe_train_updater_fine_grained_heap_emhash8_set9.cppbpe_train_updater_fine_grained_heap_emhash8_set9_opt.cpp

bpe_train_step2是完全参考cppudate的实现,它的输入参数是6个。而后面我们会讲到pair_counts是wordid_counts是基于计算出来的,可以从python移植到c++,这样速度更快,而且少传递很多参数。所以bpe_train_step2_v2以及之后的版本都是只有5个参数。

2.2 CMakeLists.txt

我们的目标是编译一个动态库,按照CMake的语法编写如下内容:

cmake_minimum_required(VERSION 3.20)

project(BPE_TRAIN_STEP2 LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)


add_library(bpe_train_step2 SHARED bpe_train_step2.cpp max_heap.cpp)
target_include_directories(bpe_train_step2 PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(bpe_train_step2 PUBLIC MY_LIBRARY_EXPORT)

target_include_directories(bpe_train_step2 PUBLIC
    "${PROJECT_SOURCE_DIR}/emhash"
)

install(TARGETS bpe_train_step2
        EXPORT bpe_train_step2_export_targets
        RUNTIME DESTINATION bin
        LIBRARY DESTINATION lib
        ARCHIVE DESTINATION lib)

install(FILES "bpe_train_step2.h" 
        DESTINATION include
)

我们通过add_library创建bpe_train_step2这个target。接着的target_compile_definitions好像是为了windows的兼容性定义的宏。我不懂windows,也没有windows环境测试,所以不知道这个项目能不能在windows下编译。

然后的target_include_directories把emhash的头文件包含进来。最后是安装的时候把bpe_train_step2这个target(主要是libbpe_train_step2.so)和bpe_train_step2.h复制到合适的位置,后面我会讲到怎么安装。

2.3 bpe_train_step2

首先我们按照cppupdate的方式集成,我们需要传入pair_counts等7个参数。

void bpe_train_step2(int vocab_size, 
                emhash8::HashMap<std::pair<int, int>, int, pair_hash> & pair_counts, 
                std::unordered_map<std::pair<int, int>, std::vector<std::vector<int>>, pair_hash> & pair_strings, 
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                std::unordered_map<std::pair<int, int>, std::unordered_set<int>, pair_hash> & pair_wordids, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges){
    auto start = std::chrono::steady_clock::now();
    while(vocabulary.size() < vocab_size){
        int max_count = -1;
        std::pair<int, int> max_pair;
        std::vector<std::vector<int>> max_strings;
        for(const auto& [pair, count] : pair_counts){
            if(count > max_count){
                max_count = count;
                max_pair = pair;
                max_strings = pair_strings[pair];
            }else if(count == max_count){
                std::vector<std::vector<int>> strings = pair_strings[pair];
                ComparisonResult r1 = three_way_compare(strings[0], max_strings[0]);
                if(r1 == ComparisonResult::Greater){
                    max_count = count;
                    max_pair = pair;
                    max_strings = strings;
                }else if(r1 == ComparisonResult::Equal){
                    ComparisonResult r2 = three_way_compare(strings[1], max_strings[1]);
                    if(r2 == ComparisonResult::Greater){
                        max_count = count;
                        max_pair = pair;
                        max_strings = strings;                        
                    }
                }
            }
        }

        const std::vector<int>& bytes1 = vocabulary[max_pair.first];
        const std::vector<int>& bytes2 = vocabulary[max_pair.second];
        std::vector<int> merge_bytes;
        merge_bytes.reserve(bytes1.size() + bytes2.size());
        merge_bytes.insert(merge_bytes.end(), bytes1.begin(), bytes1.end());
        merge_bytes.insert(merge_bytes.end(), bytes2.begin(), bytes2.end());

        int size = vocabulary.size();
        vocabulary[size] = merge_bytes;

        auto& affected_words = pair_wordids[max_pair];

        
        updated_affected_word_count(max_pair, affected_words, wordid_encodings, wordid_counts,
                                    pair_counts, pair_wordids, size, pair_strings, vocabulary);
  
        merges.push_back({bytes1, bytes2});


    }
    auto end = std::chrono::steady_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
    std::cout << "bpe_train_step2: " << duration.count() << "ms." << std::endl;
}

代码完全是复制bpe_train_updater_fine_grained_emhash8.cpp

2.4 bpe_train_step2_v2

如果读者仔细看过之前的时间统计,会发现总时间会比”统计词频时间” + “合并时间” 要多五六十秒。这个时间就是调用BPE_Trainer._count_pairs的时间:

    @staticmethod    
    def _count_pairs(word_counts, word_encodings, pair_strings, vocabulary, pair_to_words):
        pair_counts = defaultdict(int)
        for word, count in word_counts.items():
            encoding = word_encodings[word]
            for i in range(0, len(encoding) - 1):
                pair = encoding[i], encoding[i + 1]
                pair_counts[pair] += count
                if pair not in pair_strings:
                    pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])

                pair_to_words[pair].add(word)

        return pair_counts

这段代码的作用就是根据word_counts统计pair_counts,构建初始的倒排索引pair_to_words以及pair_strings。它的输入参数是word_counts, word_encodings和vocabulary。这段代码完全可以在c++里实现,这样一来可以加速,二来可以减少python和c++之间的参数传递。

因此bpe_train_step2_v2的输入参数就可以减少三个:

void bpe_train_step2_v2(int vocab_size,          
                std::unordered_map<int, std::vector<int>> & vocabulary, 
                const std::unordered_map<int, long long> & wordid_counts, 
                std::unordered_map<int, std::vector<int>> & wordid_encodings, 
                std::vector<std::pair<std::vector<int>, std::vector<int>>> & merges){
    std::unordered_map<std::pair<int, int>, std::vector<std::vector<int>>, pair_hash> pair_strings;
    emhash8::HashMap<std::pair<int, int>, int, pair_hash>  pair_counts;
    std::unordered_map<std::pair<int, int>, std::unordered_set<int>, pair_hash>  pair_wordids;
    
    std::pair<int, int> pair;
    for(const auto& [wordid, count] : wordid_counts){
        const auto& word_tokens = wordid_encodings[wordid];
        for(int i = 0; i < word_tokens.size() - 1; ++i){
            pair.first = word_tokens[i];
            pair.second = word_tokens[i + 1];
            pair_counts[pair] += count;
            if (pair_strings.find(pair) == pair_strings.end()) {
                pair_strings[pair] = {vocabulary[pair.first], vocabulary[pair.second]};
            }
            pair_wordids[pair].insert(wordid);
        }
    }
    

    bpe_train_step2(vocab_size, pair_counts, pair_strings, vocabulary,
                    pair_wordids, wordid_counts, wordid_encodings, merges);
}

bpe_train_step2_v2这个函数首先根据wordid_counts、wordid_encodings和vocabulary计算得到pair_counts、pair_strings和pair_wordids,最后还是调用bpe_train_step2。

bpe_train_step2_v3和之后的版本接口与bpe_train_step2_v2完全一样,代码也是复制相应的cppupdate里的代码,这里就不赘述了。

2.5 编译安装

使用如下命令编译安装:

cd cppstep2
mkdir build && cd build
cmake -D CMAKE_INSTALL_PREFIX=../../lib_bpe_train_step2/  -D CMAKE_BUILD_TYPE=Release ..
cmake --build . -- -j8
# 如果是gcc的话可以直接make -j8
cmake --install .  
# 或者make install

由于我们不希望把这个库安装到系统的路径,比如/usr/local下,因为这样的安装需要root权限。而且不同系统的路径不相同,也不便于cython的集成。所以我这里把它安装到../../lib_bpe_train_step2/,也就是项目根目录assignment1-basics-bpe之下的lib_bpe_train_step2:

$ ls assignment1-basics-bpe

cppstep2/
cppupdate/
cs336_basics/
data/
lib_bpe_train_step2/

安装后的lib_bpe_train_step2为:

lib_bpe_train_step2$ tree
.
├── include
│   ├── bpe_train_step2.h
│   └── emhash
│       └── hash_table8.hpp
└── lib
    └── libbpe_train_step2.so

除了bpe_train_step2.h,hash_table8.hpp也被包含了进来,因为bpe_train_step2.h要用到它。这主要是因为bpe_train_step2的声明要用到它。如果我们只保留bpe_train_step2_v2及其之后的版本,那么这个头文件也不需要安装(但是我们编译还是需要emhash下的所有头文件)。

3. 用cython封装成扩展模块

对于cython封装c++库不熟悉的读者可以参考Using C++ in Cython

我们需要两个文件:bpe_train_step2_wrapper.pxdbpe_train_step2_wrapper.pyx

3.1 bpe_train_step2_wrapper.pxd

首先来看pxd文件:

# distutils: language = c++

# 导入 C++ 标准库类型
from libcpp.utility cimport pair
from libcpp.vector cimport vector
from libcpp.unordered_map cimport unordered_map
from libcpp.unordered_set cimport unordered_set


cdef extern from "../lib_bpe_train_step2/include/bpe_train_step2.h" :
    cppclass pair_hash:
        pass

cdef extern from "../lib_bpe_train_step2/include/emhash/hash_table8.hpp" namespace "emhash8":
    cppclass HashMap[K, V, H]:
        #ValueT& operator[](const KeyT& key) noexcept
        V& operator[](const K& key)


cdef extern from "../lib_bpe_train_step2/include/bpe_train_step2.h":
    void bpe_train_step2(int vocab_size,
                         HashMap[pair[int, int], int, pair_hash] & pair_counts,
                         unordered_map[pair[int, int], vector[vector[int]], pair_hash] & pair_strings,
                         unordered_map[int, vector[int]] & vocabulary,
                         unordered_map[pair[int, int], unordered_set[int], pair_hash] & pair_wordids,
                         const unordered_map[int, long long] & wordid_counts,
                         unordered_map[int, vector[int]] & wordid_encodings,
                         vector[pair[vector[int], vector[int]]] & merges) except +
    
    void bpe_train_step2_v2(int vocab_size,
                         unordered_map[int, vector[int]] & vocabulary,
                         const unordered_map[int, long long] & wordid_counts,
                         unordered_map[int, vector[int]] & wordid_encodings,
                         vector[pair[vector[int], vector[int]]] & merges) except +


第一行用来告诉 distutils(deprecated)或setuptools 模块:在这个脚本中,你正在构建一个 C++ 扩展模块,而不是默认的 C 语言扩展。

接下来4行是我们要用到c++标准库的std::vector、std::pair、std::unordered_map和std::unordered_set。cython把这些常见的标准库都封装到libcpp下了,我们只需要用cimport(类似python的import,不过这是编译时而不是运行时的import)导入即可使用。要查看哪些c++标准库可以在cython直接使用可以参考这里

接下来用cppclass声明编译时要用到的pair_hash,因为我们使用lib_bpe_train_step2只需要知道pair_hash这个符号即可,它的内容我们不需要,所以它的内容就是一行pass。注意使用cdef extern from的路径是”../lib_bpe_train_step2/include/bpe_train_step2.h”,这就要求lib_bpe_train_step2按照前面的步骤安装到了合适的位置。

接下来是声明emhash8::HashMap,在cdef extern from后有一个”namespace emhash8”,这样cython知道HashMap是在emhash8这个namespace下。cppclass HashMap[K, V, H]说明HashMap是一个模板类,分别代表Key/Value/Hash函数。此外后面的pyx文件里我们会用到operator[],所以我们也需要声明这个重载运算符的原型。

最后就是声明bpe_train_step2这些函数了,cython和c++的语法类似,只不过<>要改成[],看起来有点别扭。

3.2 bpe_train_step2_wrapper.pyx

接下来就是把c++的函数封装成python可以调用的函数,这里主要做的就是参数的转换,比如把python的dict变成c++的std::unordered_map。我们先看最直接的实现。

3.2.1 py_bpe_train_step2

cpdef py_bpe_train_step2(int vocab_size,
                             pair_counts_py,
                             pair_strings_py,
                             vocabulary_py,
                             pair_wordids_py,
                             wordid_counts_py,
                             wordid_encodings_py,
                             merges_py):

    # 声明 C++ 容器
    cdef HashMap[pair[int, int], int, pair_hash] pair_counts_cpp
    cdef unordered_map[pair[int, int], vector[vector[int]], pair_hash] pair_strings_cpp
    cdef unordered_map[int, vector[int]] vocabulary_cpp
    cdef unordered_map[pair[int, int], unordered_set[int], pair_hash] pair_wordids_cpp
    cdef unordered_map[int, long long] wordid_counts_cpp
    cdef unordered_map[int, vector[int]] wordid_encodings_cpp
    cdef vector[pair[vector[int], vector[int]]] merges_cpp

 
    cdef pair[int, int] pair_key
    cdef vector[vector[int]] strings_value
    cdef vector[int] vector_value
    cdef unordered_set[int] set_value




    for p, count in pair_counts_py.items():
        pair_key.first = p[0]
        pair_key.second = p[1]
        pair_counts_cpp[pair_key] = count

    for p, string in pair_strings_py.items():
        pair_key.first = p[0]
        pair_key.second = p[1]
        strings_value.clear()
        value = [list(item) for item in string] 
        vector_value = value[0]
        strings_value.push_back(vector_value)
        vector_value = value[1]
        strings_value.push_back(vector_value)        
        pair_strings_cpp[pair_key] = strings_value     
    
    for k, v in vocabulary_py.items():
        value = list(v)
        vector_value = value
        vocabulary_cpp[k] = vector_value

    for p, wordids in pair_wordids_py.items():
        pair_key.first = p[0]
        pair_key.second = p[1]        
        set_value = wordids
        pair_wordids_cpp[pair_key] = set_value

    for k, v in wordid_counts_py.items():
        wordid_counts_cpp[k] = v

    for k, v in wordid_encodings_py.items():
        vector_value = v
        wordid_encodings_cpp[k] = vector_value

    # 调用 C++ 函数
    bpe_train_step2(vocab_size,
                    pair_counts_cpp,
                    pair_strings_cpp,
                    vocabulary_cpp,
                    pair_wordids_cpp,
                    wordid_counts_cpp,
                    wordid_encodings_cpp,
                    merges_cpp)

    return merges_cpp, vocabulary_cpp

这个函数前面的大部分代码都是把python的变量转换成c++的变量,然后调用bpe_train_step2。我们来看两个典型的例子:

    cdef HashMap[pair[int, int], int, pair_hash] pair_counts_cpp
    cdef pair[int, int] pair_key

    for p, count in pair_counts_py.items():
        pair_key.first = p[0]
        pair_key.second = p[1]
        pair_counts_cpp[pair_key] = count

pair_counts_cpp的key是pair[int,int],也就是std::pair<int,int>,我们可以通过first和second对它赋值。最后通过运算符[]对它进行插入,这也是之前我们需要声明V& operator的原因。我们可以发现上面的代码是类似python的for遍历。

再看一个复杂一点的:

    cdef unordered_map[pair[int, int], vector[vector[int]], pair_hash] pair_strings_cpp

    for p, string in pair_strings_py.items():
        pair_key.first = p[0]
        pair_key.second = p[1]
        strings_value.clear()
        value = [list(item) for item in string] 
        vector_value = value[0]
        strings_value.push_back(vector_value)
        vector_value = value[1]
        strings_value.push_back(vector_value)        
        pair_strings_cpp[pair_key] = strings_value   

pair_strings_py的key是一个tuple,value也是一个tuple,这个tuple有2个元素,每个都是一个bytes。说起来很费劲,我们看一个例子:

(111,110): (b'o', b'n')

我们对应的pair_strings_cpp是unordered_map[pair[int, int], vector[vector[int]], pair_hash],所以需要把bytes变成list[int]。这是通过下面的语句实现:

value = [list(item) for item in string] 
vector_value = value[0]
strings_value.push_back(vector_value)

首先是用列表推导把tuple[bytes]变成list[[list[int]]],然后通过赋值把python的list[int]变成c++的vector[int],最后push_back到strings_value里。

3.2.2 py_bpe_train_step2_v2

cpdef py_bpe_train_step2_v2(int vocab_size,
                             vocabulary_py,
                             wordid_counts_py,
                             wordid_encodings_py,
                             merges_py):

    # 声明 C++ 容器
    cdef unordered_map[int, vector[int]] vocabulary_cpp
    cdef unordered_map[int, long long] wordid_counts_cpp
    cdef unordered_map[int, vector[int]] wordid_encodings_cpp
    cdef vector[pair[vector[int], vector[int]]] merges_cpp

 
    cdef pair[int, int] pair_key
    cdef vector[vector[int]] strings_value
    cdef vector[int] vector_value
    cdef unordered_set[int] set_value
 
    
    for k, v in vocabulary_py.items():
        value = list(v)
        vector_value = value
        vocabulary_cpp[k] = vector_value


    for k, v in wordid_counts_py.items():
        wordid_counts_cpp[k] = v

    for k, v in wordid_encodings_py.items():
        vector_value = v
        wordid_encodings_cpp[k] = vector_value

    # 调用 C++ 函数
    bpe_train_step2_v2(vocab_size,
                    vocabulary_cpp,
                    wordid_counts_cpp,
                    wordid_encodings_cpp,
                    merges_cpp)

    return merges_cpp, vocabulary_cpp

这个版本和之前差不多,只不过少了3个参数。

注意:最后我们返回的是c++的变量merges_cpp, vocabulary_cpp,它们的类型是:

cdef unordered_map[int, vector[int]] vocabulary_cpp
cdef vector[pair[vector[int], vector[int]]] merges_cpp

返回到python时,cython会自动把它转换成dict[int,list[int]]和list[tuple[list[int], list[int]]]。我们后面需要再把list[int]变成bytes。

其实不只是返回值,我们把一个python变量复制给一个c++变量或者把一个c++变量复制给python时cython也会自动的做这些常见的转换:

Python type => C++ type => Python type
bytes std::string bytes
iterable std::vector list
iterable std::list list
iterable std::set set
iterable std::unordered_set set
mapping std::map dict
mapping std::unordered_map dict
iterable (len 2) std::pair tuple (len 2)
complex std::complex complex

我们可以用这个特性来简化变量之间的转换,得到py_bpe_train_step2_opt:

cpdef py_bpe_train_step2_opt(int vocab_size,
                             vocabulary_py,
                             wordid_counts_py,
                             wordid_encodings_py,
                             merges_py):

    # 声明 C++ 容器
    cdef unordered_map[int, vector[int]] vocabulary_cpp
    cdef unordered_map[int, long long] wordid_counts_cpp
    cdef unordered_map[int, vector[int]] wordid_encodings_cpp
    cdef vector[pair[vector[int], vector[int]]] merges_cpp

 
    vocabulary_cpp = vocabulary_py

    wordid_counts_cpp = wordid_counts_py

    wordid_encodings_cpp = wordid_encodings_py
    # 调用 C++ 函数
    bpe_train_step2_v2(vocab_size,
                    vocabulary_cpp,
                    wordid_counts_cpp,
                    wordid_encodings_cpp,
                    merges_cpp)

    return merges_cpp, vocabulary_cpp

这里我们直接用3条复制语句,cython自动就会帮我们在python和c++之间进行转换。其它版本的调用都是和py_bpe_train_step2_opt一样,比如py_bpe_train_step2_v3:

cpdef py_bpe_train_step2_v3(int vocab_size,
                             vocabulary_py,
                             wordid_counts_py,
                             wordid_encodings_py,
                             merges_py):

    # 声明 C++ 容器
    cdef unordered_map[int, vector[int]] vocabulary_cpp
    cdef unordered_map[int, long long] wordid_counts_cpp
    cdef unordered_map[int, vector[int]] wordid_encodings_cpp
    cdef vector[pair[vector[int], vector[int]]] merges_cpp

 
    vocabulary_cpp = vocabulary_py

    wordid_counts_cpp = wordid_counts_py

    wordid_encodings_cpp = wordid_encodings_py
    # 调用 C++ 函数
    bpe_train_step2_v3(vocab_size,
                    vocabulary_cpp,
                    wordid_counts_cpp,
                    wordid_encodings_cpp,
                    merges_cpp)

    return merges_cpp, vocabulary_cpp

3.3 修改setup.py

project_root = os.path.dirname(os.path.abspath(__file__))


ext_modules = [
    Extension(
        name="cs336_basics.bpe_train_step2_wrapper",
        sources=["cs336_basics/bpe_train_step2_wrapper.pyx"],

        language="c++",
        #extra_compile_args=['-std=c++17', '-O3'],
        extra_compile_args=['-std=c++17'],
        libraries=["bpe_train_step2"],

        library_dirs=[f"{project_root}/lib_bpe_train_step2/lib"],
        runtime_library_dirs=[f"{project_root}/lib_bpe_train_step2/lib"],
        include_dirs=[f"{project_root}/lib_bpe_train_step2/include",
                      f"{project_root}/lib_bpe_train_step2/include/emhash"],
    )
]

setup(
    packages=['cs336_basics'],
    name='bpe_train_step2',
    ext_modules=cythonize(ext_modules),
)

我们需要编译bpe_train_step2_wrapper.pyx。

  • name是指定模块的名字,这样python里可以import cs336_basics.bpe_train_step2_wrapper。这里的cs336_basics是package的名字,bpe_train_step2_wrapper是模块名。
  • sources指定要编译的源代码
  • language指定模块是c++语言
  • extra_compile_args指定额外的编译选项,这里指定了’-std=c++17’
  • libraries是依赖的库
  • library_dirs指定编译时库的位置
  • runtime_library_dirs指定运行时库的位置
  • include_dirs指定编译时的头文件位置

为了避免硬编码,project_root为setup.py文件所在的目录。

这些选项最终会使得c++编译器的编译和链接命令为(我的gcc环境,不同环境可能会有差异):

c++ -pthread -fno-strict-overflow -Wsign-compare -Wunreachable-code -DNDEBUG -g -O3 -Wall -fPIC -fPIC -Ics336_basics -I......codes/assignment1-basics-bpe/lib_bpe_train_step2/include -I......codes/assignment1-basics-bpe/lib_bpe_train_step2/include/emhash -I......codes/assignment1-basics-bpe/.venv/include -I.......local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/include/python3.12 -c cs336_basics/bpe_train_step2_wrapper.cpp -o build/temp.linux-x86_64-cpython-312/cs336_basics/bpe_train_step2_wrapper.o -std=c++17
c++ -pthread -fno-strict-overflow -Wsign-compare -Wunreachable-code -DNDEBUG -g -O3 -Wall -fPIC -shared -Wl,--exclude-libs,ALL build/temp.linux-x86_64-cpython-312/cs336_basics/bpe_train_step2_wrapper.o -L......codes/assignment1-basics-bpe/lib_bpe_train_step2/lib -L.......local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib -Wl,--enable-new-dtags,-rpath,......codes/assignment1-basics-bpe/lib_bpe_train_step2/lib -lbpe_train_step2 -o build/lib.linux-x86_64-cpython-312/cs336_basics/bpe_train_step2_wrapper.cpython-312-x86_64-linux-gnu.so

3.4 在python里使用

bpe_v9.py调用的是py_bpe_train_step2函数;bpe_v10.py调用的是py_bpe_train_step2_v2;bpe_v10_v2.py调用的是py_bpe_train_step2_opt;bpe_v11.py调用的是py_bpe_train_step2_v3;bpe_v11_bytes.py调用的是py_bpe_train_step2_v3;bpe_v11_v2.py调用的是py_bpe_train_step2_v4;bpe_v11_v3.py调用的是py_bpe_train_step2_v5;bpe_v11_v3_bytes.py调用的是py_bpe_train_step2_v5;bpe_v11_v4.py调用的是py_bpe_train_step2_v6;bpe_v11_v4_bytes.py调用的是py_bpe_train_step2_v6。

它们的代码基本相同,这里看一下bpe_v11.py


        vocabulary = {i: bytes([i]) for i in range(N_BYTES)} # every byte
        for i, token in enumerate(special_tokens):
            vocabulary[N_BYTES + i] = token.encode('utf-8')
        size = N_BYTES + len(special_tokens)
        merges = []

        # initial word encodings are utf-8
        word_encodings = {}
        for word in word_counts:
            word_encodings[word] = list(word.encode('utf-8'))

        word_ids = {word:id for id, word in enumerate(word_counts)}

        wordid_counts = {word_ids[word]:count for word, count in word_counts.items()}

        wordid_encodings = {word_ids[word]:encoding for word, encoding in word_encodings.items()}      


        merges_cpp, vocabulary_cpp = py_bpe_train_step2_v3(vocab_size, 
                             vocabulary,
                             wordid_counts,
                             wordid_encodings,
                             merges)


        vocabulary = {k:bytes(v) for k, v in vocabulary_cpp.items()}
        merges = [(bytes(arr[0]), bytes(arr[1])) for arr in merges_cpp]

在调用之前,我们需要把str的word变成int的id,这个映射关系保存在word_ids。然后利用word_ids把word_counts变成wordid_counts,把word_encodings变成wordid_encodings。调用后得到的merges_cpp, vocabulary_cpp需要把list[int]转换成bytes。

4. 测试

版本 数据 总时间(s) 统计词频时间(s) 合并时间(s) 其它
bpe_v8_v3 open_web 897/899/951 395/399/395 442/438/493 num_counter=8, num_merger=1
bpe_v9 open_web 831/816/867 400/401/390 prepare & convert: 93/86/94 py_bpe_train_step2: 330/320/374 c++:289/281/326 num_counter=8, num_merger=1
bpe_v10 open_web 816/769/788 390/400/400 prepare & convert: 21/17/19 py_bpe_train_step2: 402/350/367 c++:338/296/309 num_counter=8, num_merger=1
bpe_v10_v2 open_web 767/774/767 400/401/401 prepare & convert: 18/17/17 py_bpe_train_step2: 346/355/347 c++:292/298/294 num_counter=8, num_merger=1
bpe_v10_v2 open_web 498/477/495 120/120/120 prepare & convert: 20/19/18 py_bpe_train_step2: 355/336/354 c++:299/282/298 num_counter=32, num_merger=4
bpe_v11 open_web 350/340/354 120/120/120 prepare & convert: 21/19/19 py_bpe_train_step2: 207/199/212 c++:183/175/190 num_counter=32, num_merger=4
bpe_v11_bytes open_web 311/307/305 80/80/80 prepare & convert: 18/19/18 py_bpe_train_step2: 211/206/204 c++:189/183/182 num_counter=64, num_merger=8, chunk_size 8mb
bpe_v11_v2 open_web 362/350/338 130/120/120 prepare & convert: 19/18/19 py_bpe_train_step2: 210/210/197 c++:189/190/176 num_counter=32, num_merger=4
bpe_v11_v3 open_web 269/274/270 120/120/120 prepare & convert: 18/19/18 py_bpe_train_step2: 129/133/129 c++: 106/109/106 num_counter=32, num_merger=4
bpe_v11_v3_bytes open_web 218/219/215 72/74/69 prepare & convert: 21/21/21 py_bpe_train_step2: 123/122/123 c++: 101/100/101 num_counter=64, num_merger=8, chunk_size 8mb
bpe_v11_v4 open_web 258/256/261 116/117/117 prepare & convert: 19/18/19 py_bpe_train_step2: 121/119/123 c++: 98/97/100 num_counter=32, num_merger=4
bpe_v11_v4_bytes open_web 210/206/207 71/69/70 prepare & convert: 20/18/19 py_bpe_train_step2: 117/117/117 c++: 95/96/95 num_counter=64, num_merger=8, chunk_size 8mb

对比bpe_v10和bpe_v10_v2,它们的差别一是手动在python和c++之间转换参数一是cython自动转换。自动转换不仅更方便,而且更快。

最终我们使用64核bpe_v11_v4_bytes最快的训练时间是200多秒,这比最初的十多个小时快了100多倍!

5. 总结

这就是本系列文章的全部内容。下面是简要的总结和对应文章的链接: