本系列文章完成Stanford CS336作业1的一个子任务——实现BPE Tokenizer的高效训练算法。通过一系列优化,我们的算法在OpenWebText上的训练时间从最初的10多个小时优化到小于10分钟。本系列文章解释这一系列优化过程,包括:算法的优化,数据结构的优化,并行(openmp)优化,cython优化,用c++实现关键代码和c++库的cython集成等内容。本文是第十一篇,使用cython和pypy来加速python代码。
目录
- 1. 问题分析
- 2. cython原理简介
- 3. 用cython优化fine_grained_pair_counter_diff函数
- 4. 测试结果
- 5. 把整个merge过程用cython实现
- 6. 使用pypy
- 7. 总结
- 本系列全部文章
1. 问题分析
上一篇文章我们已经对算法做了极致的优化,目前Python版本最快的是bpe_v7_maxheapc_opt_time,它的merge时间是500多秒;C++版本最快的是bpe_train_updater_fine_grained_heap_emhash8_set9_opt,它的merge时间是100秒。后面我们会把c++版本的代码通过cython和前面的python代码集成。不过今天我们想要在python的层面做一些优化,毕竟使用两种语言还是比较麻烦。
同样的代码逻辑,python比c++要慢,这有很多的原因。最主要的原因有:python(更准确的说是CPython)是通过解释执行代码而c++是编译成机器码执行;python是动态类型而c++是静态类型;python的GIL锁使得python多线程无法真正并行;python的内存布局对于cpu不友好。
第三点原因在前面比较python的多进程和c++/openmp并行求max已经分析过了,而且通过算法的优化,我们不需要扫描整个pair_counts来求max,所以这个差异我们可以忽略。
第四点其实也是很重要的,即使是c++,如果使用std::unordered_map这样的链表这种内存布局,它的遍历速度也比absl::flat_hash_map和emhash8::HashMap要慢很多。在python里,所有的一切都是对象,其实从内存布局来说就是void *,所以python的list的用法虽然看起来和std::vector很像,但是它们内存布局完全不同。
当我们创建一个key/value是整数的dict时,比如 {1: 10, 2: 20},Python 在内存中实际做了以下几件事:
- 为字典本身创建一个哈希表结构。
- 为键 1 创建一个 PyLongObject。
- 为值 10 创建一个 PyLongObject。
- 将键对象和值对象的引用(内存地址)存储到哈希表中的一个条目里。
- 对键 2 和值 20 重复上述过程。
这意味着,即使你的键和值只是一些小小的整数,每个整数背后都有一个完整的 Python 对象开销(包括类型信息、引用计数等)。这种设计使得 dict 极为灵活,可以存储任何类型的数据,但代价就是会消耗更多的内存。
不过在python层面,我们很难对dict进行优化。这是使用python代价,如果想要优化内存,最好的方法是使用其它的语言(比如c/c++)然后通过Python/C API集成。我们之前的大堆算法就是通过Python/C API集成,不过这个API非常复杂。我们后面会通过cython把c++的代码集成,cython最终也是把我们的c++代码编译成扩展模块。类似的numpy就是通过Python/C集成。
剩下的第一点和第二点就是我们今天优化的主题:一是尝试通过cython来重写部分关键代码;二是尝试使用pypy来替代CPython。
2. cython原理简介
Cython是一种编程语言,它结合了Python的易用性和C语言的高性能。它旨在成为Python的超集,这意味着你可以用它来编写普通的Python代码,同时也能添加额外的C语言特性,比如静态类型声明。Cython的核心思想是将你的代码从Python语言编译成C语言,然后再编译成机器码,形成一个可导入的Python模块。这个过程解决了Python解释器在执行速度上的瓶颈。
关于cython的详细介绍请参考官方文档,另外如果读者想系统学习cython也可以阅读书籍Cython: A Guide for Python Programmers,这本书虽然很久远了,但是大部分内容依然又用。
3. 用cython优化fine_grained_pair_counter_diff函数
我们的优化基于bpe_v7_opt,根据之前的分析,max的时间已经可以忽略,剩下大部分时间在函数_updated_affected_word_count。这个函数主要是调用fine_grained_pair_counter_diff然后更新pair_counts和pair_to_words这两个词典。更新词典很难用cython优化,所以我们先优化fine_grained_pair_counter_diff函数。
我们首先创建一个bpe_updater.pyx文件,它用cython实现了fine_grained_pair_counter_diff。
cpdef void fine_grained_pair_counter_diff(set affected_words,
word_encodings,
word_counts,
tuple merge_pair,
diff_pairs,
int new_id,
pair_to_words,
set new_pairs):
cdef str word
cdef int wc
cdef int idx
cdef int first_idx
cdef int last_idx
cdef int i
cdef int tk_len
for word in affected_words:
word_tokens = word_encodings[word]
wc = word_counts[word]
# find first and last pairs
idx = 0
unaffected_pairs = set()
tk_len = len(word_tokens)
#first_idx = -1
while idx < tk_len - 1:
if word_tokens[idx] == merge_pair[0] and word_tokens[idx+1] == merge_pair[1]:
first_idx = idx
break
idx += 1
# assert first_idx exists
idx = tk_len - 2
while idx > first_idx + 1:
if word_tokens[idx] == merge_pair[0] and word_tokens[idx+1] == merge_pair[1]:
last_idx = idx
break
idx -= 1
else:
last_idx = first_idx
start_idx = max_int(0, first_idx - 1) # inclusive
end_idx = min_int(last_idx + 3, tk_len) # exclusive
# unaffected [0, start_idx)
for i in range(start_idx):
pair = word_tokens[i], word_tokens[i + 1]
unaffected_pairs.add(pair)
# unaffected [end_idx-1, :-1]
for i in range(end_idx - 1, tk_len - 1):
pair = word_tokens[i], word_tokens[i + 1]
unaffected_pairs.add(pair)
# TODO avoid slice copy
affected_tokens = word_tokens[start_idx: end_idx]
for i in range(len(affected_tokens) - 1):
old_pair = (affected_tokens[i], affected_tokens[i + 1])
diff_pairs[old_pair] -= wc
if old_pair not in unaffected_pairs:
pair_to_words[old_pair].discard(word)
new_tokens = []
all_new_tokens = []
for i in range(start_idx):
all_new_tokens.append(word_tokens[i])
i = 0
# account for multiple occurrences of the pair
while i < len(affected_tokens):
if i < len(affected_tokens) - 1 and (affected_tokens[i], affected_tokens[i + 1]) == merge_pair:
new_tokens.append(new_id)
all_new_tokens.append(new_id)
# jump past pair
i += 2
else:
new_tokens.append(affected_tokens[i])
all_new_tokens.append(affected_tokens[i])
i += 1
for i in range(end_idx, len(word_tokens)):
all_new_tokens.append(word_tokens[i])
word_encodings[word] = all_new_tokens
# add new pairs from the updated word
for i in range(len(new_tokens) - 1):
new_pair = (new_tokens[i], new_tokens[i + 1])
diff_pairs[new_pair] += wc
pair_to_words[new_pair].add(word)
new_pairs.add(new_pair)
这个代码和python的版本基本一样,只不过把循环变量都加了静态类型声明,这样cython就可以把它编译成c语言版本的for/while循环,而不需要走python的迭代器协议。
另外我们可以发现上面的代码:
affected_tokens = word_tokens[start_idx: end_idx]
这个slice会复制受影响的tokens,我们在c++实现的代码使用了指针避免复制:
const int * affected_tokens = word_tokens.data() + start_idx;
int affected_tokens_len = end_idx - start_idx;
for(int i = 0; i < affected_tokens_len - 1; ++i){
std::pair<int, int> old_pair(affected_tokens[i], affected_tokens[i + 1]);
diff_pairs[old_pair] -= wc;
if (unaffected_pairs.find(old_pair) == unaffected_pairs.end()) {
pair_wordids[old_pair].erase(wordid);
}
}
python的list没有办法返回一个视图,所以我们只能自己处理繁琐的下标来避免复制,这样我们得到fine_grained_pair_counter_diff_v2。这里就不详细介绍其中的改动了,感兴趣的读者可以阅读fine_grained_pair_counter_diff_v2。
此外根据上一篇文章的优化,我们可以把new_pairs去掉,这样在fine_grained_pair_counter_diff_v2的基础上得到fine_grained_pair_counter_diff_v3。
然后我们需要修改setup.py:
setup(
packages=['cs336_basics'],
name='cs336_basics.bpe_updater',
ext_modules=cythonize("cs336_basics/bpe_updater.pyx"),
)
然后执行python setup.py build_ext -i,这样编译得到cs336_basics/bpe_updater.cpython-312-x86_64-linux-gnu.so。
最后我们在python里使用bpe_updater这个扩展模块,相应的代码在bpe_v8.py、bpe_v8_v2.py和bpe_v8_v3.py。它们分别调用fine_grained_pair_counter_diff、fine_grained_pair_counter_diff_v2和fine_grained_pair_counter_diff_v3。
4. 测试结果
因为统计max和update会带来额外的时间开销,比如bpe_v7_time要比bpe_v7慢10多秒,所以我这里只统计整体的合并时间。
版本 | 数据 | 总时间(s) | 统计词频时间(s) | 合并时间(s) | 其它 |
bpe_v7_time | open_web | 1062/1035/1036 | 392/397/395 | total: 606/573/577 max:6/6/6 update: 599/567/571 | num_counter=8, num_merger=1 |
bpe_v7 | open_web | 1051/1017/1023 | 399/389/398 | 590/568/568 | num_counter=8, num_merger=1 |
bpe_v7_opt | open_web | 1021/1007/980 | 399/393/393 | 558/553/528 | num_counter=8, num_merger=1 |
bpe_v8 | open_web | 934/930/935 | 390/393/393 | 479/472/476 | num_counter=8, num_merger=1 |
bpe_v8_v2 | open_web | 917/908/965 | 394/392/395 | 460/455/505 | num_counter=8, num_merger=1 |
bpe_v8_v3 | open_web | 897/899/951 | 395/399/395 | 442/438/493 | num_counter=8, num_merger=1 |
bpe_v8和bpe_v7的实现逻辑完全相同,我们可以对比得出cython版本的bpe_v8要比bpe_v7的合并时间少17%。bpe_v8_v2避免了复制affected_tokens = word_tokens[start_idx: end_idx]
,它的前两次运行要比bpe_v8快10多秒,但是第三次运行反而慢了30秒。这个原因不太清楚,也许是服务器的负载波动。bpe_v7_opt对bpe_v7做了new_pairs的删除优化,和它对比的是bpe_v8_v3,但是bpe_v8_v3还多做了bpe_v8_v2的避免复制的优化。bpe_v8_v3的平均时间是457秒,比bpe_v7_opt快了16%。
5. 把整个merge过程用cython实现
更进一步,我们可以把整个merge过程都用cython实现,这能不能加快速度呢?对bpe_v8_v3进行重构我实现了bpe_updater_v2.pyx和bpe_v8_v4.py。这里,我把merge过程封装成一个函数:
cpdef void bpe_train_step2(int vocab_size,
pair_counts,
pair_strings,
vocabulary,
pair_to_words,
word_counts,
word_encodings,
merges,
pair_heap):
cdef int size = len(vocabulary)
while size < vocab_size:
_merge_a_pair(pair_counts, pair_strings, vocabulary,
pair_to_words, word_counts, word_encodings,
merges, size, pair_heap)
size += 1
测试结果如下:
版本 | 数据 | 总时间(s) | 统计词频时间(s) | 合并时间(s) | 其它 |
bpe_v7_time | open_web | 1062/1035/1036 | 392/397/395 | total: 606/573/577 max:6/6/6 update: 599/567/571 | num_counter=8, num_merger=1 |
bpe_v7 | open_web | 1051/1017/1023 | 399/389/398 | 590/568/568 | num_counter=8, num_merger=1 |
bpe_v7_opt | open_web | 1021/1007/980 | 399/393/393 | 558/553/528 | num_counter=8, num_merger=1 |
bpe_v8 | open_web | 934/930/935 | 390/393/393 | 479/472/476 | num_counter=8, num_merger=1 |
bpe_v8_v2 | open_web | 917/908/965 | 394/392/395 | 460/455/505 | num_counter=8, num_merger=1 |
bpe_v8_v3 | open_web | 897/899/951 | 395/399/395 | 442/438/493 | num_counter=8, num_merger=1 |
bpe_v8_v4 | open_web | 917/915/982 | 391/404/400 | 462/448/506 | num_counter=8, num_merger=1 |
bpe_v8_v4比bpe_v8_v3还慢。原因在于fine_grained_pair_counter_diff函数我们通过cython的静态类型声明,把python的迭代器循环变成了c的循环,所以速度能够变快。但是其它部分的代码都是操作python的dict,这些代码用cython实现还是会回到python,这反而更慢。
6. 使用pypy
pypy 是 Python 语言的一个替代性实现。简单来说,它不是一个库或框架,而是一个完整的 Python 解释器,就像我们平时用的 CPython 一样。pypy 的最大亮点在于它内置了 JIT (Just-In-Time) 编译器。这是它和标准 CPython 解释器最大的区别,也是它速度快的原因。CPython将 Python 代码编译成字节码,然后由一个虚拟机逐条解释执行。这个过程比较慢。pypy同样先将 Python 代码编译成字节码。但是,当程序运行时,pypy 的 JIT 编译器会监控哪些代码被频繁执行。它会把这些“热点”代码直接编译成机器码,然后缓存起来。下次再遇到同样的代码时,pypy 就会直接执行高速的机器码,而不是重新解释字节码。pypy和Java的JVM是比较类似的。
要把运行环境从CPython切换到pypy不需要修改任何代码,但是需要重新安装环境。由于我们是使用uv来管理环境的,切换环境非常简单,执行如下命令即可:
deactivate
UV_PROJECT_ENVIRONMENT=.venv_pypy uv sync --python pypy@3.11
source .venv_pypy/bin/activate
第一个是先从当前的虚拟环境中退出。然后用uv sync重新创建一个pypy 3.11的环境,为了不覆盖原来的.venv,我们使用环境变量UV_PROJECT_ENVIRONMENT,告诉uv新创建的虚拟环境存放在.venv_pypy目录下。最后source这个环境。
我首先测试了bpe_v7,测试结果如下:
版本 | 数据 | 总时间(s) | 统计词频时间(s) | 合并时间(s) | 其它 |
bpe_v7 | open_web | 1028/1053/1045 | 395/397/393 | total: 575/589/590 max: 6/6/6 update: 569/583/583 make heap: 0.01/0.01 heap_push_time: 102/107/122 | num_counter=8, num_merger=1 |
bpe_v7(pypy) | open_web | 2047/1694/1913 | 1644/1306/1552 | 403/388/361 | num_counter=8, num_merger=1 |
结果发现pypy的整体运行时间远远慢于CPython,从1000多秒增加到了2000多秒。不过pypy的merge时间却比CPython快了34%。为什么会出现前后速度不一致的情况呢?通过一系列测试我逐渐把问题锁定到了regex匹配的速度差异上,于是写了一个程序bpe_v2_time_pypy.py。这个程序只执行_pretokenize_and_count这个函数并且统计时间。下面是我在pypy3.11、CPython3.12和CPython3.11之下用单个CPU(num_counter=1)在tinystory数据集上测试的结果:
版本 | 数据 | split_time(s) | match_time(s) |
bpe_v2_pypy(pypy 3.11) | tiny_story | 40 | 888 |
bpe_v2_pypy(cpython 3.12) tiny_story | 3 | 300 | |
bpe_v2_pypy(cpython 3.11) tiny_story | 3 | 288 |
split_time和match_time分别统计regex.split和regex.finditer的时间:
for chunk in BPE_Trainer._chunk_documents_streaming(input_path):
start_time = time.perf_counter()
blocks = re.split(special_pattern, chunk)
end_time = time.perf_counter()
split_time += (end_time - start_time)
for block in blocks:
start_time = time.perf_counter()
for match in re.finditer(pattern, block):
text = match.group(0)
text_len += len(text)
end_time = time.perf_counter()
match_time += (end_time - start_time)
可以看到,pypy的版本要比CPython的慢很多。为了找到原因,我在第三方库regex的github上提交了一个issueregex is much slower in pypy than cpython。根据mattip的回复,pypy慢的原因在于regex使用了CPython的C接口,而pypy只能通过模拟来实现类似CPython的这个接口,所以速度很慢。而且不只是慢,它的结果甚至都可能不正确。根据regex的文档:
This module is targeted at CPython. It expects that all codepoints are the same width, so it won’t behave properly with pypy outside U+0000..U+007F because pypy stores strings as UTF-8.
非常遗憾,pypy的JIT编译确实使得后面的merge过程变快了,但是由于第三方库regex不支持,我们用pypy整体更慢甚至无法实现。我们当然可以把train函数分成两部分,第一部分是_pretokenize_and_count,第二部分是merge。第一个部分用CPython,第二个部分用pypy,然后两个部分之间通过进程间通信来完成。但是进程间通信比较复杂,而且数据复制的开销也很大。
所以pypy的最大问题其实是无法和CPython”兼容”,但是现实问题是很多高性能第三方库都会用到CPython的C接口。比如numpy,pypy无法直接加速 numpy 底层的 C 代码。相反,它通过一个名为 cpyext 的兼容层来与 numpy 的 C 扩展进行通信。这个兼容层是为了让 pypy 能够运行 CPython 的 C 扩展而设计的。这样单纯使用numpy的话pypy比CPython还要慢。
7. 总结
通过pypy虽然能够加快python代码的执行,但是由于第三方库regex不支持,我们只能放弃。而通过cython,我们没有怎么修改python代码(不过为了模块调用进行了一定代码的重构)的情况下提升了速度。当然我们还可以进一步优化cython代码,比如使用c++的unordered_map替代python的dict。不过更好的方法是直接在c++语言里完成这些优化,然后把它封装成动态库供cython调用。这就是下一篇文章我们要研究的内容。
本系列全部文章
- 第0部分:简介 介绍bpe训练的基本算法和相关任务,并且介绍开发环境。
- 第1部分:最简单实现 bpe训练最简单的实现。
- 第2部分:优化算法 实现pair_counts的增量更新。
- 第3部分:并行分词和统计词频 使用multiprocessing实现多进程并行算法。
- 第4部分:一次失败的并行优化 尝试用多进程并行计算max pair。
- 第5部分:用C++实现Merge算法 用C++实现和Python等价的merge算法,并且比较std::unordered_map的两种遍历方式。
- 第6部分:用OpenMP实现并行求最大 用OpenMP并行求pair_counts里最大pair。
- 第7部分:使用flat hashmap替代std::unordered_map 使用flat hashmap来替代std::unordered_map。
- 第8部分:实现细粒度更新 使用倒排索引实现pair_counts的细粒度更新算法。
- 第9部分:使用堆来寻找最大pair 使用堆来求最大pair,提升性能。
- 第10部分:使用cython和pypy来加速 使用cython和pypy来加速python代码。
- 第11部分:使用cython封装c++代码 使用cython封装c++代码。
- 显示Disqus评论(需要科学上网)