This series of articles implements a subtask of Stanford’s CS336 Assignment 1: building an efficient training algorithm for a BPE Tokenizer. Through a series of optimizations, our algorithm’s training time on OpenWebText was reduced from over 10 hours to less than 10 minutes. This series explains these optimizations, including algorithmic improvements, data structure enhancements, parallelization with OpenMP, Cython optimization, and implementing key code in C++ along with its integration via Cython. This is the fifth article, documenting a failed attempt at parallel optimization. Through this failure, we can understand the problems that exist with Python’s multiprocessing
module and, in turn, learn when it should be used.
Table of Content
- 1. Algorithm Optimization
- 2. Using
multiprocessing.Process
to Parallelize themax
Operation - 3. Test Results for
multiprocessing.Process
- 4. Analysis of the Results
- 5. Using
multiprocessing.Pool
for Parallelmax
- 6. Test Results for
multiprocessing.Pool
- 7. Analysis of the Results
- 8. Verifying the
multiprocessing.Process
Algorithm withspawn
- 9. Conclusion
1. Algorithm Optimization
After a series of optimizations, our current algorithm (bpe_v3_time
) on OpenWeb now takes only 120 seconds to perform the first step—token frequency counting (on a 32-core machine). However, the second step, merging, still takes over 30,000 seconds. This article focuses on optimizing the second step.
The second step is divided into two parts: finding the pair with the highest frequency using the max
function, and incrementally updating the pair counts. The max
operation alone takes 30,000 seconds, while the incremental update takes only a few hundred seconds. This analysis clearly shows that finding the most frequent pair is the part we should prioritize for optimization.
We use Python’s built-in max
function, whose source code can be found here. It ultimately calls the min_max
function, located here:
PyObject *
min_max(PyObject *iterable, PyObject *args, PyObject *kwds, int op)
{
// ... code ...
}
The code for this function is quite long, and much of it is related to the Python/C API. The core algorithm, however, is very simple and consists of the following snippet:
while (( item = PyIter_Next(it) )) {
/* get the value from the key function */
if (keyfunc != NULL) {
val = PyObject_CallOneArg(keyfunc, item);
if (val == NULL)
goto Fail_it_item;
}
/* no key function; the value is the item */
else {
val = Py_NewRef(item);
}
/* maximum value and item are unset; set them */
if (maxval == NULL) {
maxitem = item;
maxval = val;
}
/* maximum value and item are set; update them as necessary */
else {
int cmp = PyObject_RichCompareBool(val, maxval, op);
if (cmp < 0)
goto Fail_it_item_and_val;
else if (cmp > 0) {
Py_DECREF(maxval);
Py_DECREF(maxitem);
maxval = val;
maxitem = item;
}
else {
Py_DECREF(item);
Py_DECREF(val);
}
}
}
The logic is as follows: It iterates through each item
of the iterator it
. For each item
, it calls a keyfunc
(if a key
argument was passed, which we did), to get val
. If it’s the first element, maxval
is set to val
. Otherwise, val
is compared to the current maxval
using PyObject_RichCompareBool(val, maxval, op)
. The third argument, op
, specifies the comparison operator; we pass Py_GT
to check if val
is greater than maxval
. If the return value is greater than 0, val > maxval
, and maxval
and maxitem
are updated.
The built-in max
function is implemented in C, making it very difficult to create a faster version at the Python level. So, what’s the solution?
The most obvious approach, similar to our previous successful optimization, is to use a parallel algorithm. The max
function is associative (max(max(a,b),c) = max(a,max(b,c))
), which makes it easy to parallelize. For example, with N CPUs, we can split all the pairs into N chunks, find the local maximum in each chunk, and then find the global maximum from those N local maximums.
2. Using multiprocessing.Process
to Parallelize the max
Operation
My first attempt was to use the successful strategy from the previous article: parallelizing the max
operation with multiple processes. The code for this approach is in bpe_v4_mp.py
. Let’s look at the changes relative to bpe_v3.py
.
@staticmethod
def _merge_a_pair(pair_counts, pair_strings, vocabulary, pair_to_words,
word_counts, word_encodings, merges, size, num_processes):
_, _, merge_pair = BPE_Trainer._parallel_max(pair_counts, pair_strings, num_processes)
....
The main change is the call to BPE_Trainer._parallel_max
. Let’s examine this function:
@staticmethod
def _parallel_max(pair_counts, pair_strings, num_processes):
# need a copy
pair_counts = list(pair_counts.items())
chunk_size = len(pair_counts) // num_processes
data_chunks = [pair_counts[i:i + chunk_size] for i in range(0, len(pair_counts), chunk_size)]
if len(data_chunks) > num_processes:
data_chunks[num_processes - 1].extend(data_chunks.pop())
processes = []
queue = mp.Queue()
for i in range(num_processes):
p = mp.Process(target=BPE_Trainer._find_max_pair,
args=(pair_strings, data_chunks[i], queue),
name=f"find_max_pair-{i+1}")
p.start()
processes.append(p)
for p in processes:
p.join()
local_maxes = []
for i in range(num_processes):
local_maxes.append(queue.get())
return max(local_maxes)
To split the tasks, we first need to copy the contents of the pair_counts
dictionary into a randomly accessible list, which requires an extra copy operation. The list is then divided into num_processes
chunks. For example, if pair_counts
has a size of 10 and num_processes
is 4, the tasks are split so that the first three processes get 10 // 4 = 2 tasks each, and the remainder goes to the last process, resulting in a task distribution of [2, 2, 2, 4]. While not the most optimal distribution, it is acceptable since our number of processes is small (typically equal to the number of CPU cores) and pair_counts
is large (hundreds of thousands of items).
To collect the results from each process, we create a multiprocessing.Queue
to store the local maximum values. We then create num_processes
processes to execute BPE_Trainer._find_max_pair
. Once all processes finish, we collect their results into local_maxes
and find the global maximum.
The BPE_Trainer._find_max_pair
code is:
@staticmethod
def _find_max_pair(pair_strings, pair_counts, queue):
max_pair, max_count = max(pair_counts, key = lambda x: (x[1], pair_strings[x[0]]))
queue.put((max_count, pair_strings[max_pair], max_pair))
This is our original max
function, but it puts the result into the queue.
For convenience during testing, a command-line argument was added to set the number of concurrent max
processes:
parser.add_argument("--num_max",
type=int,
default=NUM_MAX_PROCESS,
help="number of processes for max")
3. Test Results for multiprocessing.Process
To measure the performance, I implemented bpe_v4_mp_time.py
. The initial tests showed that it was incredibly slow on OpenWeb, so I switched to the smaller TinyStories dataset. The results are below:
Version | Data | Total Time(s) | Word Count Time(s) | Merge Time(s) | Other |
bpe_v3_time | tinystory | 211/206/208 | 90/90/90 | total:120/115/117 max_time: 117/112/115 update_time: 3/3/3 |
num_counter=8, num_merger=1 |
bpe_v4_mp_time | tinystory | 730/715/740 | 80/90/80 | merge time: 650/624/659 copy_time: 109/127/123 max_time: 459/404/443 compute_time: 221/238/231 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_time | tinystory | 931/943/1007 | 90/90/90 | merge time: 841/852/917 copy_time: 146/157/136 max_time: 600/593/691 compute_time: 125/128/118 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_time | tinystory | 1354/1405/1431 | 80/90/80 | merge time: 1274/1315/1351 copy_time: 143/158/160 max_time: 1039/1055/1090 compute_time: 80/85/86 |
num_counter=8, num_merger=1 num_max=8 |
Here are the time metrics:
merge time
is the total time for all merge operations:
start_time = time.perf_counter()
while size < vocab_size:
BPE_Trainer._merge_a_pair(pair_counts, pair_strings, vocabulary,
pair_to_words, word_counts, word_encodings,
merges, size, num_max_processes, times)
size += 1
end_time = time.perf_counter()
print(f"merge time: {end_time - start_time:.2f}s")
copy_time
is the time to copy dict.items()
to a list:
start_time = time.perf_counter()
pair_counts = list(pair_counts.items())
end_time = time.perf_counter()
times[0] += (end_time - start_time)
max_time
is the total time for all parallel max
operations:
start_time = time.perf_counter()
processes = []
queue = mp.Queue()
for i in range(num_processes):
p = mp.Process(target=BPE_Trainer._find_max_pair,
args=(pair_strings, data_chunks[i], queue),
name=f"find_max_pair-{i+1}")
p.start()
processes.append(p)
for p in processes:
p.join()
end_time = time.perf_counter()
times[1] += (end_time - start_time)
compute_time
is the maximum time taken by a single process to find its local maximum:
local_maxes = []
local_times = []
for i in range(num_processes):
t = queue.get()
local_times.append(t[-1])
local_maxes.append(t[:-1])
times[2] += max(local_times)
The actual time calculation happens within the child process and is returned to the main process as a parameter:
@staticmethod
def _find_max_pair(pair_strings, pair_counts, queue):
start_time = time.perf_counter()
max_pair, max_count = max(pair_counts, key = lambda x: (x[1], pair_strings[x[0]]))
end_time = time.perf_counter()
queue.put((max_count, pair_strings[max_pair], max_pair, (end_time-start_time)))
4. Analysis of the Results
Here are some key takeaways from the results:
- The
merge time
forbpe_v3_time
is only about 120 seconds, but forbpe_v4_mp_time
it’s around 700 seconds with 1 process, over 900 seconds with 4 processes, and more than 1,300 seconds with 8 processes. - The
max_time
forbpe_v3_time
is about 110 seconds, whereasbpe_v4_mp_time
takes over 400 seconds for 1 process, 600+ seconds for 4 processes, and 1000+ seconds for 8 processes. - The
compute_time
forbpe_v4_mp_time
does decrease with more processes, from over 220 seconds to 120+ and 80+ seconds for 1, 4, and 8 processes, respectively. - The
copy_time
forbpe_v4_mp_time
is between 120 and 160+ seconds.
The results clearly show that as the number of processes increases, the total time spent also increases. While the actual computation time (compute_time
) does decrease, the total max_time
increases significantly. This suggests that the overhead of creating and destroying processes is extremely high. Our previous _pretokenize_and_count_mp
function only created and destroyed processes once, with most of the time spent on regex and counting. However, here, _merge_a_pair
is called 10,000 times, which means _parallel_max
is called 10,000 times, leading to 10,000 process creations and destructions. This cost is substantial.
Furthermore, for a parallel algorithm to work, we need to be able to randomly access pair_counts
. Since Python’s dict
only allows sequential access via an iterator, we need to copy it to a list. This copy operation itself can take longer than the max
operation. We saw from the max
source code that it has a time complexity of O(n) and a space complexity of O(1), as it only needs to iterate once to find the maximum value. The copy operation also has O(n) time complexity and O(n) space complexity. To verify this, I wrote bpe_v4_time2.py
to compare the max
and copy times.
start_time = time.perf_counter()
merge_pair, max_count = max(pair_counts.items(), key = lambda x: (x[1], pair_strings[x[0]]))
end_time = time.perf_counter()
times[0] += (end_time - start_time)
start_time = time.perf_counter()
merge_pair3, max_count3 = max(pair_counts.items(), key = lambda x: x[1])
end_time = time.perf_counter()
times[4] += (end_time - start_time)
start_time = time.perf_counter()
pair_counts_list = list(pair_counts.items())
end_time = time.perf_counter()
times[1] += (end_time - start_time)
start_time = time.perf_counter()
pair_strings_list = [pair_strings[pair] for pair in pair_counts]
end_time = time.perf_counter()
times[2] += (end_time - start_time)
The results showed that after only 7,000 merges, max
took 70 seconds, while copying to a list took 55 seconds. When you add the overhead of process creation and destruction, this approach is clearly not viable.
The biggest issue with this approach isn’t just the overhead of process creation/destruction; it’s also that Python’s dictionaries cannot be iterated in parallel. If the computation were very large, we could copy the dictionary to a list and use multiple CPUs to compute, and the parallel speed-up would compensate for the copy overhead. However, our max
algorithm is primarily I/O-bound (iterating) with simple CPU comparisons, so this strategy is not worthwhile.
At this point, we could have abandoned this approach. However, a few questions remain:
- Why is the
compute_time
forbpe_v4_mp_time
with a single process (220+) so much slower thanbpe_v3_time
(110+)? One would expect iterating a list to be faster than iterating a dictionary. - Does using
multiprocessing.Pool
solve the high overhead of process creation and destruction?
To explore these questions, I tried implementing the parallel max
algorithm using multiprocessing.Pool
.
5. Using multiprocessing.Pool
for Parallel max
The full code is in bpe_v4_time.py
.
The train
function
with mp.Pool(processes=num_max_processes) as pool:
times = [0] * 4
start_time = time.perf_counter()
while size < vocab_size:
BPE_Trainer._merge_a_pair(pair_counts, pair_strings, vocabulary,
pair_to_words, word_counts, word_encodings,
merges, size, pool, num_max_processes, times)
size += 1
end_time = time.perf_counter()
print(f"merge time: {end_time - start_time:.2f}s")
print(f"copy_time: {times[0]:.2f}s, max_time: {times[1]:.2f}s, compute_time: {times[2]:.2f}s, copy_trunks_time: {times[3]:.2f}s")
We use mp.Pool(processes=num_max_processes)
to create a fixed-size process pool. This approach avoids the need to create and destroy processes on every call to _merge_a_pair
.
The BPE_Trainer._merge_a_pair
function calls BPE_Trainer._parallel_max
, which we’ll examine next.
BPE_Trainer._parallel_max
@staticmethod
def _parallel_max(pair_counts, pair_strings, pool, num_processes, times):
# need a copy
start_time = time.perf_counter()
pair_counts = list(pair_counts.items())
end_time = time.perf_counter()
times[0] += (end_time - start_time)
start_time = time.perf_counter()
chunk_size = len(pair_counts) // num_processes
data_chunks = [pair_counts[i:i + chunk_size] for i in range(0, len(pair_counts), chunk_size)]
if len(data_chunks) > num_processes:
data_chunks[num_processes - 1].extend(data_chunks.pop())
end_time = time.perf_counter()
times[3] += (end_time - start_time)
start_time = time.perf_counter()
find_max_pair = partial(BPE_Trainer._find_max_pair, pair_strings)
local_maxes = pool.map(find_max_pair, data_chunks)
end_time = time.perf_counter()
times[1] += (end_time - start_time)
computing_time = max([r[-1] for r in local_maxes])
times[2] += computing_time
return max(local_maxes)[:-1]
The first step is still to copy pair_counts.items()
to a list. Next, the list is divided into data_chunks
, just as before. I’ve added timing for this step, though it’s negligible.
Our child process function is _find_max_pair
:
@staticmethod
def _find_max_pair(pair_strings, pair_counts):
start_time = time.perf_counter()
max_pair, max_count = max(pair_counts, key = lambda x: (x[1], pair_strings[x[0]]))
end_time = time.perf_counter()
return max_count, pair_strings[max_pair], max_pair, (end_time - start_time)
This is the actual function that finds the maximum value. To track the time, the run time is included in the return value.
Then, we call Pool.map
to submit the tasks to the process pool. The first argument is the function to be called, and the second is an iterable of arguments for that function. Our BPE_Trainer._find_max_pair
needs two arguments: pair_strings
and data_chunks[i]
. We could pass them as a tuple, but since pair_strings
is the same for all child processes, we can use functools.partial
to create a partial function that “freezes” this argument:
find_max_pair = partial(BPE_Trainer._find_max_pair, pair_strings)
local_maxes = pool.map(find_max_pair, data_chunks)
This makes the code cleaner. When Pool.map
sends a task to a process, the arguments are serialized (using pickle
), sent via inter-process communication (IPC), and then deserialized. While it might seem that using partial
reduces the data being passed, the find_max_pair
function itself, with pair_strings
embedded, still needs to be transferred to each process in the pool.
6. Test Results for multiprocessing.Pool
Version | Data | Total Time(s) | Word Count Time(s) | Merge Time(s) | Other |
bpe_v3_time | tinystory | 211/206/208 | 90/90/90 | total:120/115/117 max_time: 117/112/115 update_time: 3/3/3 |
num_counter=8, num_merger=1 |
bpe_v4_mp_time | tinystory | 730/715/740 | 80/90/80 | merge time: 650/624/659 copy_time: 109/127/123 max_time: 459/404/443 compute_time: 221/238/231 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_time | tinystory | 931/943/1007 | 90/90/90 | merge time: 841/852/917 copy_time: 146/157/136 max_time: 600/593/691 compute_time: 125/128/118 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_time | tinystory | 1354/1405/1431 | 80/90/80 | merge time: 1274/1315/1351 copy_time: 143/158/160 max_time: 1039/1055/1090 compute_time: 80/85/86 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_time | tinystory | 981/967/960 | 90/90/90 | merge time: 890/876/870 copy_time: 91/92/88 max_time: 761/743/744 compute_time: 79/79/81 copy_trunks_time: 2/2/2 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_time | tinystory | 1784/1688/1543 | 90/80/90 | merge time: 1693/1607/1452 copy_time: 149/151/97 max_time: 1483/1393/1311 compute_time: 26/26/25 copy_trunks_time: 3/3/2 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_time | tinystory | 2660/2399/2503 | 90/90/90 | merge time: 2569/2308/2412 copy_time: 165/95/119 max_time: 2336/2169/2238 compute_time: 16/15/16 copy_trunks_time: 3/2/2 |
num_counter=8, num_merger=1 num_max=8 |
7. Analysis of the Results
Here are a few observations:
Pool
is even slower thanProcess
. The total time (taking the median) for 1, 4, and 8 processes is roughly 900/1600/2500 seconds forPool
vs. 700/900/1300 seconds forProcess
.Pool
’scompute_time
decreases as the number of processes increases. Interestingly,Pool
’scompute_time
for a single process is less than bothProcess
andbpe_v3_time
.Pool
’smax_time
is significantly slower thanProcess
.
The first and third results were unexpected. I thought that Pool
, by reducing process creation/destruction overhead, would be faster than Process
.
The second result is as expected, since Pool
’s max
iterates over a list, which should be faster than bpe_v3_time
’s dict
iteration.
However, the question remains: why is Pool
slower than Process
and why is Pool
’s compute_time
faster than Process
’s?
After some research, I found the reason for Pool
being slower than Process
. For Process
, I can only offer a guess.
My guess for Process
’s slower compute_time
is that a new process is created in each loop iteration, possibly on a different CPU core. This leads to cache misses when iterating through pair_counts
as the data needs to be fetched from main memory. In contrast, Pool
reuses processes. Since pair_counts
is updated incrementally, and we always submit tasks in the same order, a given chunk of data is likely always handled by the same process running on the same core, which could improve cache hit rates. This is just a hypothesis, and to verify it, we would need to either shuffle the data chunks before submitting them to Pool
or explicitly bind Process
instances to specific CPUs using something like psutil.Process.cpu_affinity()
.
The reason for the difference in max_time
and total time between Process
and Pool
is that Process
(on Linux by default) uses fork
to create child processes. A child process created via fork
inherits the parent’s address space, avoiding the need for inter-process data copying. Pool
, on the other hand, must go through three steps for data transfer: serialization via pickle
, IPC communication, and deserialization.
I learned this key distinction between Process
and Pool
from this very problem. Many online resources are vague or incorrect about this. Let’s verify this conclusion with some code.
The code to test this is here.
from multiprocessing import Process, set_start_method
class Unpickleable:
def __reduce__(self):
raise TypeError("I hate pickle")
def __str__(self) -> str:
return "Unpickleable"
def func(obj):
print(obj)
if __name__ == '__main__':
#set_start_method('spawn')
set_start_method('fork')
o = Unpickleable()
p = Process(target=func, args=((o,)))
p.start()
p.join()
We define an Unpickleable
class that raises an exception if pickle
is attempted. If we set the start method to fork
, the code runs without error, demonstrating that fork
inherits the parent’s objects without serialization. If we set it to spawn
, it raises an exception, proving that spawn
requires serialization.
Note: fork
is faster, but it has issues in multithreaded environments. As of Python 3.12, it issues a DeprecationWarning
if the main process is multithreaded. The default start method will change to spawn
in Python 3.14 on all platforms. If you are running on a newer Python version, you might need to manually set the start method to fork
.
8. Verifying the multiprocessing.Process
Algorithm with spawn
Based on our analysis, Process
has high overhead from creation/destruction and poor cache utilization. The only reason it was faster than Pool
was fork
’s lack of IPC data transfer. To test this, I modified the Process
-based algorithm to use spawn
. If my reasoning is correct, this version should be the slowest of all. The code is in bpe_v4_mp_spawn_time.py
, with the only change being:
mp.set_start_method('spawn')
The results are below:
Version | Data | Total Time(s) | Word Count Time(s) | Merge Time(s) | Other |
bpe_v3_time | tinystory | 211/206/208 | 90/90/90 | total:120/115/117 max_time: 117/112/115 update_time: 3/3/3 |
num_counter=8, num_merger=1 |
bpe_v4_mp_time | tinystory | 730/715/740 | 80/90/80 | merge time: 650/624/659 copy_time: 109/127/123 max_time: 459/404/443 compute_time: 221/238/231 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_time | tinystory | 931/943/1007 | 90/90/90 | merge time: 841/852/917 copy_time: 146/157/136 max_time: 600/593/691 compute_time: 125/128/118 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_time | tinystory | 1354/1405/1431 | 80/90/80 | merge time: 1274/1315/1351 copy_time: 143/158/160 max_time: 1039/1055/1090 compute_time: 80/85/86 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_time | tinystory | 981/967/960 | 90/90/90 | merge time: 890/876/870 copy_time: 91/92/88 max_time: 761/743/744 compute_time: 79/79/81 copy_trunks_time: 2/2/2 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_time | tinystory | 1784/1688/1543 | 90/80/90 | merge time: 1693/1607/1452 copy_time: 149/151/97 max_time: 1483/1393/1311 compute_time: 26/26/25 copy_trunks_time: 3/3/2 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_time | tinystory | 2660/2399/2503 | 90/90/90 | merge time: 2569/2308/2412 copy_time: 165/95/119 max_time: 2336/2169/2238 compute_time: 16/15/16 copy_trunks_time: 3/2/2 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_mp_spawn_time | tinystory | 2317/2277/2318 | 90/90/90 | merge time: 2227/2186/2227 copy_time: 160/148/157 max_time: 2003/1978/2007 compute_time: 81/81/81 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_spawn_time | tinystory | 6193/6173/6218 | 90/90/90 | merge time: 6102/6083/6127 copy_time: 160/146/161 max_time: 5877/5876/5901 compute_time: 26/26/26 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_spawn_time | tinystory | 11071 | 90 | merge time: 10980 copy_time: 145 max_time: 10774 compute_time: 15 |
num_counter=8, num_merger=1 num_max=8 |
This experiment confirms that spawn
is much slower than fork
, which is what we expected.
However, the compute_time
is similar to Pool
’s, which seems to contradict my cache-related hypothesis for why Process
was slow.
If that’s not the reason, why is iterating the pair_counts
list in a fork
-created child process slower than in a spawn
- or Pool
-created process? The key difference is that a fork
-ed child process reads from the parent’s list(pair_counts.items())
(a shallow copy), while spawn
- and Pool
-created processes get a deep copy of the data via pickle
. This deep copy might result in a more memory-compact list layout, leading to faster iteration.
To test this, I implemented bpe_v4_mp_deepcopy_time.py
, which manually serializes and deserializes the list to mimic the deep copy. The only difference from bpe_v4_mp_time.py
is:
start_time = time.perf_counter()
pair_counts = list(pair_counts.items())
pickled_bytes = pickle.dumps(pair_counts)
pair_counts = pickle.loads(pickled_bytes)
end_time = time.perf_counter()
The results are below:
Version | Data | Total Time(s) | Word Count Time(s) | Merge Time(s) | Other |
bpe_v3_time | tinystory | 211/206/208 | 90/90/90 | total:120/115/117 max_time: 117/112/115 update_time: 3/3/3 |
num_counter=8, num_merger=1 |
bpe_v4_mp_time | tinystory | 730/715/740 | 80/90/80 | merge time: 650/624/659 copy_time: 109/127/123 max_time: 459/404/443 compute_time: 221/238/231 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_time | tinystory | 931/943/1007 | 90/90/90 | merge time: 841/852/917 copy_time: 146/157/136 max_time: 600/593/691 compute_time: 125/128/118 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_time | tinystory | 1354/1405/1431 | 80/90/80 | merge time: 1274/1315/1351 copy_time: 143/158/160 max_time: 1039/1055/1090 compute_time: 80/85/86 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_time | tinystory | 981/967/960 | 90/90/90 | merge time: 890/876/870 copy_time: 91/92/88 max_time: 761/743/744 compute_time: 79/79/81 copy_trunks_time: 2/2/2 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_time | tinystory | 1784/1688/1543 | 90/80/90 | merge time: 1693/1607/1452 copy_time: 149/151/97 max_time: 1483/1393/1311 compute_time: 26/26/25 copy_trunks_time: 3/3/2 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_time | tinystory | 2660/2399/2503 | 90/90/90 | merge time: 2569/2308/2412 copy_time: 165/95/119 max_time: 2336/2169/2238 compute_time: 16/15/16 copy_trunks_time: 3/2/2 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_mp_spawn_time | tinystory | 2317/2277/2318 | 90/90/90 | merge time: 2227/2186/2227 copy_time: 160/148/157 max_time: 2003/1978/2007 compute_time: 81/81/81 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_spawn_time | tinystory | 6193/6173/6218 | 90/90/90 | merge time: 6102/6083/6127 copy_time: 160/146/161 max_time: 5877/5876/5901 compute_time: 26/26/26 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_spawn_time | tinystory | 11071 | 90 | merge time: 10980 copy_time: 145 max_time: 10774 compute_time: 15 |
num_counter=8, num_merger=1 num_max=8 |
bpe_v4_mp_deepcopy_time | tinystory | 1262/1393/1377 | 90/90/90 | merge time: 1171/1302/1287 copy_time: 590/668/644 max_time: 492/538/549 compute_time: 278/303/298 |
num_counter=8, num_merger=1 num_max=1 |
bpe_v4_mp_deepcopy_time | tinystory | 1529/1481/1525 | 80/90/90 | merge time: 1449/1391/1434 copy_time: 661/712/680 max_time: 687/571/651 compute_time: 115/121/116 |
num_counter=8, num_merger=1 num_max=4 |
bpe_v4_mp_deepcopy_time | tinystory | 1924/1872/1983 | 80/90/80 | merge time: 1843/1782/1903 copy_time: 730/682/721 max_time: 1000/994/1070 compute_time: 72/68/70 |
num_counter=8, num_merger=1 num_max=8 |
The compute_time
for bpe_v4_mp_deepcopy_time
is similar to bpe_v4_mp_time
(around 200+ seconds), which again refutes my guess. pickle
serialization/deserialization doesn’t seem to make list iteration faster.
A final question remains: Why is the subprocess created with fork
slower at iterating through the count_pairs
list than one created with spawn
or Pool
? A search on Gemini provided this answer:
A fork
subprocess doesn’t need to copy the arguments passed from the parent process. If the subprocess modifies a corresponding memory page, the operating system will create a private copy of that page. This is known as Copy-on-Write (CoW). However, since our subprocess’s traversal of count_pairs
should be read-only, it shouldn’t, in theory, trigger the CoW mechanism. I searched Gemini, and its reference answer is as follows.
Even if the subprocess only performs a read operation on count_pairs
, it can still trigger Copy-on-Write (CoW). While this seems counterintuitive, the reason is typically related to the Python interpreter and underlying memory management mechanisms.
When a subprocess created with fork
inherits the parent process’s memory space, it gets read-only access to the count_pairs
data. Ideally, if the subprocess only iterates through the data, it shouldn’t trigger CoW. However, several scenarios can break this ideal state:
1. Garbage Collection
This is the most common and subtle reason. When the Python interpreter runs in the subprocess, it performs its own memory management and garbage collection. To ensure correct object reference counts or to mark and clear objects that are no longer in use, the interpreter may make small modifications to certain metadata in memory, such as reference counts.
- Python Objects: In Python, every object has an
ob_refcnt
(reference count) field. - Subprocess Reading: When the subprocess iterates through the tuples in
count_pairs
, the interpreter might internally need to inspect or update the reference counts of these tuples. - Triggering CoW: Even updating a single byte of a reference count is considered a modification by the operating system. To maintain the parent-child processes’ independence, the kernel immediately creates a private copy of the entire memory page containing that object.
2. Memory Alignment and Cache Lines
When a processor accesses memory, it does so in units called cache lines (typically 64 bytes). When a memory page is accessed, the entire cache line is loaded into the processor’s cache. If this cache line contains data that needs to be modified (even if it belongs to other, unrelated objects), the entire memory page may be flagged for copying.
3. JIT Compilers and Other System Calls
If the subprocess uses a Just-In-Time (JIT) compiler or other complex libraries, these libraries may perform internal memory operations that unintentionally trigger CoW. Even a simple memory access can have a complex interaction with the kernel’s memory management.
Conclusion
The read operation on count_pairs
in the subprocess triggers CoW not because of an explicit modification in your code, but because of small, background modifications made by the Python interpreter to maintain its internal memory state (most commonly, reference counts for garbage collection).
The spawn
method doesn’t have this issue because it creates a separate memory copy from the beginning. While this copying process (serialization and deserialization) is time-consuming at startup, once complete, the subprocess has its own fully independent memory space. Subsequent operations will not be affected by or affect the parent process’s memory, so there’s no CoW overhead.
To verify this, I wrote a program to test it:
import time
import multiprocessing as mp
def worker(lst):
t0 = time.perf_counter()
for x in lst:
pass
print(f"traverse time: {time.perf_counter() - t0:.2f}s")
if __name__ == '__main__':
mp.set_start_method('spawn')
lst = list(zip(range(5_000_000), range(5_000_000)))
start_time = time.perf_counter()
p = mp.Process(target=worker, args=(lst,))
p.start()
p.join()
end_time = time.perf_counter()
print(f"total time: {end_time - start_time:.2f}s")
With the fork
method, the total time was 1.07s and the subprocess traversal time was 0.28s. With the spawn
method, the total time was 4.8s and the subprocess traversal time was 0.08s. Therefore, even with the CoW issue, fork
is still faster overall than spawn
.
9. Conclusion
This exploration shows that Python’s multiprocessing
is not suitable for tasks with frequent I/O and CPU overlap. While using multiple CPUs does reduce the compute_time
for the max
operation, the large overhead of inter-process communication makes it counterproductive. This type of problem is better suited for multithreading, where threads within the same process can share memory, thus avoiding IPC overhead. However, due to Python’s Global Interpreter Lock (GIL), we cannot use multithreading to solve this.
Therefore, we will now turn to C++ to attempt to parallelize the max
function using multithreading.
- 显示Disqus评论(需要科学上网)