动手实现和优化BPE Tokenizer的训练——第10部分:使用cython和pypy来加速

Posted by lili on September 24, 2025

本系列文章完成Stanford CS336作业1的一个子任务——实现BPE Tokenizer的高效训练算法。通过一系列优化,我们的算法在OpenWebText上的训练时间从最初的10多个小时优化到小于10分钟。本系列文章解释这一系列优化过程,包括:算法的优化,数据结构的优化,并行(openmp)优化,cython优化,用c++实现关键代码和c++库的cython集成等内容。本文是第十一篇,使用cython和pypy来加速python代码。

目录

1. 问题分析

上一篇文章我们已经对算法做了极致的优化,目前Python版本最快的是bpe_v7_maxheapc_opt_time,它的merge时间是500多秒;C++版本最快的是bpe_train_updater_fine_grained_heap_emhash8_set9_opt,它的merge时间是100秒。后面我们会把c++版本的代码通过cython和前面的python代码集成。不过今天我们想要在python的层面做一些优化,毕竟使用两种语言还是比较麻烦。

同样的代码逻辑,python比c++要慢,这有很多的原因。最主要的原因有:python(更准确的说是CPython)是通过解释执行代码而c++是编译成机器码执行;python是动态类型而c++是静态类型;python的GIL锁使得python多线程无法真正并行;python的内存布局对于cpu不友好。

第三点原因在前面比较python的多进程和c++/openmp并行求max已经分析过了,而且通过算法的优化,我们不需要扫描整个pair_counts来求max,所以这个差异我们可以忽略。

第四点其实也是很重要的,即使是c++,如果使用std::unordered_map这样的链表这种内存布局,它的遍历速度也比absl::flat_hash_map和emhash8::HashMap要慢很多。在python里,所有的一切都是对象,其实从内存布局来说就是void *,所以python的list的用法虽然看起来和std::vector很像,但是它们内存布局完全不同。

当我们创建一个key/value是整数的dict时,比如 {1: 10, 2: 20},Python 在内存中实际做了以下几件事:

  • 为字典本身创建一个哈希表结构。
  • 为键 1 创建一个 PyLongObject。
  • 为值 10 创建一个 PyLongObject。
  • 将键对象和值对象的引用(内存地址)存储到哈希表中的一个条目里。
  • 对键 2 和值 20 重复上述过程。

这意味着,即使你的键和值只是一些小小的整数,每个整数背后都有一个完整的 Python 对象开销(包括类型信息、引用计数等)。这种设计使得 dict 极为灵活,可以存储任何类型的数据,但代价就是会消耗更多的内存。

不过在python层面,我们很难对dict进行优化。这是使用python代价,如果想要优化内存,最好的方法是使用其它的语言(比如c/c++)然后通过Python/C API集成。我们之前的大堆算法就是通过Python/C API集成,不过这个API非常复杂。我们后面会通过cython把c++的代码集成,cython最终也是把我们的c++代码编译成扩展模块。类似的numpy就是通过Python/C集成。

剩下的第一点和第二点就是我们今天优化的主题:一是尝试通过cython来重写部分关键代码;二是尝试使用pypy来替代CPython。

2. cython原理简介

Cython是一种编程语言,它结合了Python的易用性和C语言的高性能。它旨在成为Python的超集,这意味着你可以用它来编写普通的Python代码,同时也能添加额外的C语言特性,比如静态类型声明。Cython的核心思想是将你的代码从Python语言编译成C语言,然后再编译成机器码,形成一个可导入的Python模块。这个过程解决了Python解释器在执行速度上的瓶颈。

关于cython的详细介绍请参考官方文档,另外如果读者想系统学习cython也可以阅读书籍Cython: A Guide for Python Programmers,这本书虽然很久远了,但是大部分内容依然又用。

3. 用cython优化fine_grained_pair_counter_diff函数

我们的优化基于bpe_v7_opt,根据之前的分析,max的时间已经可以忽略,剩下大部分时间在函数_updated_affected_word_count。这个函数主要是调用fine_grained_pair_counter_diff然后更新pair_counts和pair_to_words这两个词典。更新词典很难用cython优化,所以我们先优化fine_grained_pair_counter_diff函数。

我们首先创建一个bpe_updater.pyx文件,它用cython实现了fine_grained_pair_counter_diff。

cpdef void fine_grained_pair_counter_diff(set affected_words, 
                                          word_encodings, 
                                          word_counts, 
                                          tuple merge_pair, 
                                          diff_pairs, 
                                          int new_id, 
                                          pair_to_words, 
                                          set new_pairs):
    cdef str word
    cdef int wc
    cdef int idx
    cdef int first_idx
    cdef int last_idx
    cdef int i
    cdef int tk_len

    for word in affected_words:
        word_tokens = word_encodings[word]
        wc = word_counts[word]

        # find first and last pairs
        idx = 0
        unaffected_pairs = set()
        tk_len = len(word_tokens)
        #first_idx = -1
        while idx < tk_len - 1:
            if word_tokens[idx] == merge_pair[0] and word_tokens[idx+1] == merge_pair[1]:
                first_idx = idx
                break
            idx += 1

        # assert first_idx exists

        idx = tk_len - 2
        while idx > first_idx + 1:
            if word_tokens[idx] == merge_pair[0] and word_tokens[idx+1] == merge_pair[1]:
                last_idx = idx
                break
            idx -= 1
        else:
            last_idx = first_idx

        start_idx = max_int(0, first_idx - 1) # inclusive
        end_idx = min_int(last_idx + 3, tk_len) # exclusive

        # unaffected [0, start_idx)


        for i in range(start_idx):
            pair = word_tokens[i], word_tokens[i + 1]
            unaffected_pairs.add(pair)
        # unaffected [end_idx-1, :-1]
        for i in range(end_idx - 1, tk_len - 1):
            pair = word_tokens[i], word_tokens[i + 1]
            unaffected_pairs.add(pair)                

        # TODO avoid slice copy
        affected_tokens = word_tokens[start_idx: end_idx]
        for i in range(len(affected_tokens) - 1):
            old_pair = (affected_tokens[i], affected_tokens[i + 1])
            diff_pairs[old_pair] -= wc 
            if old_pair not in unaffected_pairs:   
                pair_to_words[old_pair].discard(word)
        

        new_tokens = []
        all_new_tokens = []
        for i in range(start_idx):
            all_new_tokens.append(word_tokens[i])
        
        i = 0
        # account for multiple occurrences of the pair
        while i < len(affected_tokens):
            if i < len(affected_tokens) - 1 and (affected_tokens[i], affected_tokens[i + 1]) == merge_pair:
                new_tokens.append(new_id)
                all_new_tokens.append(new_id)
                # jump past pair
                i += 2
            else:
                new_tokens.append(affected_tokens[i])
                all_new_tokens.append(affected_tokens[i])
                i += 1



        for i in range(end_idx, len(word_tokens)):
            all_new_tokens.append(word_tokens[i])
        
        word_encodings[word] = all_new_tokens

        # add new pairs from the updated word
        for i in range(len(new_tokens) - 1):
            new_pair = (new_tokens[i], new_tokens[i + 1])

            diff_pairs[new_pair] += wc
            pair_to_words[new_pair].add(word)

            new_pairs.add(new_pair)

这个代码和python的版本基本一样,只不过把循环变量都加了静态类型声明,这样cython就可以把它编译成c语言版本的for/while循环,而不需要走python的迭代器协议。

另外我们可以发现上面的代码:

affected_tokens = word_tokens[start_idx: end_idx]

这个slice会复制受影响的tokens,我们在c++实现的代码使用了指针避免复制:

        const int * affected_tokens = word_tokens.data() + start_idx;
        int affected_tokens_len = end_idx - start_idx;

        for(int i = 0; i < affected_tokens_len - 1; ++i){
            std::pair<int, int> old_pair(affected_tokens[i], affected_tokens[i + 1]);
            diff_pairs[old_pair] -= wc;
            if (unaffected_pairs.find(old_pair) == unaffected_pairs.end()) {
                pair_wordids[old_pair].erase(wordid);
            }
        }

python的list没有办法返回一个视图,所以我们只能自己处理繁琐的下标来避免复制,这样我们得到fine_grained_pair_counter_diff_v2。这里就不详细介绍其中的改动了,感兴趣的读者可以阅读fine_grained_pair_counter_diff_v2

此外根据上一篇文章的优化,我们可以把new_pairs去掉,这样在fine_grained_pair_counter_diff_v2的基础上得到fine_grained_pair_counter_diff_v3

然后我们需要修改setup.py:

setup(
    packages=['cs336_basics'],
    name='cs336_basics.bpe_updater',
    ext_modules=cythonize("cs336_basics/bpe_updater.pyx"),
)

然后执行python setup.py build_ext -i,这样编译得到cs336_basics/bpe_updater.cpython-312-x86_64-linux-gnu.so。

最后我们在python里使用bpe_updater这个扩展模块,相应的代码在bpe_v8.pybpe_v8_v2.pybpe_v8_v3.py。它们分别调用fine_grained_pair_counter_diff、fine_grained_pair_counter_diff_v2和fine_grained_pair_counter_diff_v3。

4. 测试结果

因为统计max和update会带来额外的时间开销,比如bpe_v7_time要比bpe_v7慢10多秒,所以我这里只统计整体的合并时间。

版本 数据 总时间(s) 统计词频时间(s) 合并时间(s) 其它
bpe_v7_time open_web 1062/1035/1036 392/397/395 total: 606/573/577 max:6/6/6 update: 599/567/571 num_counter=8, num_merger=1
bpe_v7 open_web 1051/1017/1023 399/389/398 590/568/568 num_counter=8, num_merger=1
bpe_v7_opt open_web 1021/1007/980 399/393/393 558/553/528 num_counter=8, num_merger=1
bpe_v8 open_web 934/930/935 390/393/393 479/472/476 num_counter=8, num_merger=1
bpe_v8_v2 open_web 917/908/965 394/392/395 460/455/505 num_counter=8, num_merger=1
bpe_v8_v3 open_web 897/899/951 395/399/395 442/438/493 num_counter=8, num_merger=1

bpe_v8和bpe_v7的实现逻辑完全相同,我们可以对比得出cython版本的bpe_v8要比bpe_v7的合并时间少17%。bpe_v8_v2避免了复制affected_tokens = word_tokens[start_idx: end_idx],它的前两次运行要比bpe_v8快10多秒,但是第三次运行反而慢了30秒。这个原因不太清楚,也许是服务器的负载波动。bpe_v7_opt对bpe_v7做了new_pairs的删除优化,和它对比的是bpe_v8_v3,但是bpe_v8_v3还多做了bpe_v8_v2的避免复制的优化。bpe_v8_v3的平均时间是457秒,比bpe_v7_opt快了16%。

5. 把整个merge过程用cython实现

更进一步,我们可以把整个merge过程都用cython实现,这能不能加快速度呢?对bpe_v8_v3进行重构我实现了bpe_updater_v2.pyxbpe_v8_v4.py。这里,我把merge过程封装成一个函数:

cpdef void bpe_train_step2(int vocab_size,
                      pair_counts,
                      pair_strings,
                      vocabulary,
                      pair_to_words,
                      word_counts,
                      word_encodings,
                      merges,
                      pair_heap):

    cdef int size = len(vocabulary)
    while size < vocab_size:
        _merge_a_pair(pair_counts, pair_strings, vocabulary,
                                pair_to_words, word_counts, word_encodings,
                                merges, size, pair_heap)
        size += 1

测试结果如下:

版本 数据 总时间(s) 统计词频时间(s) 合并时间(s) 其它
bpe_v7_time open_web 1062/1035/1036 392/397/395 total: 606/573/577 max:6/6/6 update: 599/567/571 num_counter=8, num_merger=1
bpe_v7 open_web 1051/1017/1023 399/389/398 590/568/568 num_counter=8, num_merger=1
bpe_v7_opt open_web 1021/1007/980 399/393/393 558/553/528 num_counter=8, num_merger=1
bpe_v8 open_web 934/930/935 390/393/393 479/472/476 num_counter=8, num_merger=1
bpe_v8_v2 open_web 917/908/965 394/392/395 460/455/505 num_counter=8, num_merger=1
bpe_v8_v3 open_web 897/899/951 395/399/395 442/438/493 num_counter=8, num_merger=1
bpe_v8_v4 open_web 917/915/982 391/404/400 462/448/506 num_counter=8, num_merger=1

bpe_v8_v4比bpe_v8_v3还慢。原因在于fine_grained_pair_counter_diff函数我们通过cython的静态类型声明,把python的迭代器循环变成了c的循环,所以速度能够变快。但是其它部分的代码都是操作python的dict,这些代码用cython实现还是会回到python,这反而更慢。

6. 使用pypy

pypy 是 Python 语言的一个替代性实现。简单来说,它不是一个库或框架,而是一个完整的 Python 解释器,就像我们平时用的 CPython 一样。pypy 的最大亮点在于它内置了 JIT (Just-In-Time) 编译器。这是它和标准 CPython 解释器最大的区别,也是它速度快的原因。CPython将 Python 代码编译成字节码,然后由一个虚拟机逐条解释执行。这个过程比较慢。pypy同样先将 Python 代码编译成字节码。但是,当程序运行时,pypy 的 JIT 编译器会监控哪些代码被频繁执行。它会把这些“热点”代码直接编译成机器码,然后缓存起来。下次再遇到同样的代码时,pypy 就会直接执行高速的机器码,而不是重新解释字节码。pypy和Java的JVM是比较类似的。

要把运行环境从CPython切换到pypy不需要修改任何代码,但是需要重新安装环境。由于我们是使用uv来管理环境的,切换环境非常简单,执行如下命令即可:

deactivate
UV_PROJECT_ENVIRONMENT=.venv_pypy uv sync --python pypy@3.11
source .venv_pypy/bin/activate

第一个是先从当前的虚拟环境中退出。然后用uv sync重新创建一个pypy 3.11的环境,为了不覆盖原来的.venv,我们使用环境变量UV_PROJECT_ENVIRONMENT,告诉uv新创建的虚拟环境存放在.venv_pypy目录下。最后source这个环境。

我首先测试了bpe_v7,测试结果如下:

版本 数据 总时间(s) 统计词频时间(s) 合并时间(s) 其它
bpe_v7 open_web 1028/1053/1045 395/397/393 total: 575/589/590 max: 6/6/6 update: 569/583/583 make heap: 0.01/0.01 heap_push_time: 102/107/122 num_counter=8, num_merger=1
bpe_v7(pypy) open_web 2047/1694/1913 1644/1306/1552 403/388/361 num_counter=8, num_merger=1

结果发现pypy的整体运行时间远远慢于CPython,从1000多秒增加到了2000多秒。不过pypy的merge时间却比CPython快了34%。为什么会出现前后速度不一致的情况呢?通过一系列测试我逐渐把问题锁定到了regex匹配的速度差异上,于是写了一个程序bpe_v2_time_pypy.py。这个程序只执行_pretokenize_and_count这个函数并且统计时间。下面是我在pypy3.11、CPython3.12和CPython3.11之下用单个CPU(num_counter=1)在tinystory数据集上测试的结果:

版本 数据 split_time(s) match_time(s)
bpe_v2_pypy(pypy 3.11) tiny_story 40 888
bpe_v2_pypy(cpython 3.12) tiny_story 3 300  
bpe_v2_pypy(cpython 3.11) tiny_story 3 288  

split_time和match_time分别统计regex.split和regex.finditer的时间:

        for chunk in BPE_Trainer._chunk_documents_streaming(input_path):
            start_time = time.perf_counter()
            blocks = re.split(special_pattern, chunk)
            end_time = time.perf_counter()
            split_time += (end_time - start_time)
            for block in blocks:
                start_time = time.perf_counter()
                for match in re.finditer(pattern, block):
                    text = match.group(0)
                    text_len += len(text)
                end_time = time.perf_counter()
                match_time += (end_time - start_time)

可以看到,pypy的版本要比CPython的慢很多。为了找到原因,我在第三方库regex的github上提交了一个issueregex is much slower in pypy than cpython。根据mattip的回复,pypy慢的原因在于regex使用了CPython的C接口,而pypy只能通过模拟来实现类似CPython的这个接口,所以速度很慢。而且不只是慢,它的结果甚至都可能不正确。根据regex的文档:

This module is targeted at CPython. It expects that all codepoints are the same width, so it won’t behave properly with pypy outside U+0000..U+007F because pypy stores strings as UTF-8.

非常遗憾,pypy的JIT编译确实使得后面的merge过程变快了,但是由于第三方库regex不支持,我们用pypy整体更慢甚至无法实现。我们当然可以把train函数分成两部分,第一部分是_pretokenize_and_count,第二部分是merge。第一个部分用CPython,第二个部分用pypy,然后两个部分之间通过进程间通信来完成。但是进程间通信比较复杂,而且数据复制的开销也很大。

所以pypy的最大问题其实是无法和CPython”兼容”,但是现实问题是很多高性能第三方库都会用到CPython的C接口。比如numpy,pypy无法直接加速 numpy 底层的 C 代码。相反,它通过一个名为 cpyext 的兼容层来与 numpy 的 C 扩展进行通信。这个兼容层是为了让 pypy 能够运行 CPython 的 C 扩展而设计的。这样单纯使用numpy的话pypy比CPython还要慢。

7. 总结

通过pypy虽然能够加快python代码的执行,但是由于第三方库regex不支持,我们只能放弃。而通过cython,我们没有怎么修改python代码(不过为了模块调用进行了一定代码的重构)的情况下提升了速度。当然我们还可以进一步优化cython代码,比如使用c++的unordered_map替代python的dict。不过更好的方法是直接在c++语言里完成这些优化,然后把它封装成动态库供cython调用。这就是下一篇文章我们要研究的内容。

本系列全部文章