Optimizing tail recursion in Python using bytecode manipulations. Allison Kaptur Paul Tagliamonte Liuda Nikolaeva (all errors are my own)
Optimizing tail recursion in Python using bytecode
manipulations.
Allison KapturPaul TagliamonteLiuda Nikolaeva(all errors are my own)
Problem:
Python has a limit on recursion depth:
def factorial(n, accum):
if n <= 1:
return accum
else:
return factorial(n-1, accum*n)
>>> tail-factorial(1000)
RuntimeError: maximum recursion depth exceeded
Challenge:
• Optimize recursive function calls so that they don’t create new frames, thus avoiding stack overflow.
• What we want: eliminate the recursive call; instead, reset the variables and jump to the beginning of the function.
Problem:
How do you change the insides of a function?
Bytecode!
Solution:
(obviously)
Quick intro to bytecode.def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> f.__code__.co_code
'|\x00\x00d\x01\x00k\x01\x00r\x10\x00|\x01\x00St\x00\x00|\x00\x00d\x01\x00\x18|\x01\x00|\x00\x00\x14\x83\x02\x00Sd\x00\x00S‘
>>> print [ord(b) for b in f.__code__.co_code]
[124, 0, 0, 100, 1, 0, 107, 1, 0, 114, 16, 0, 124, 1, 0, 83, 116, 0, 0, 124, 0, 0, 100, 1, 0, 24, 124, 1, 0, 124, 0, 0, 20, 131, 2, 0, 83, 100, 0, 0, 83]
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
Before optimization:0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)
6 COMPARE_OP 1 (<=)
9 POP_JUMP_IF_FALSE 16
12 LOAD_FAST 1 (accum)
15 RETURN_VALUE
>> 16 LOAD_GLOBAL 0 (f)
19 LOAD_FAST 0 (n)
22 LOAD_CONST 1 (1)
25 BINARY_SUBTRACT
26 LOAD_FAST 1 (accum)
29 LOAD_FAST 0 (n)
32 BINARY_MULTIPLY
33 CALL_FUNCTION 2
36 RETURN_VALUE
After optimization:>> 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)
6 COMPARE_OP 1 (<=)
9 POP_JUMP_IF_FALSE 16
12 LOAD_FAST 1 (accum)
15 RETURN_VALUE
>> 16 LOAD_FAST 0 (n)
19 LOAD_CONST 1 (1)
22 BINARY_SUBTRACT
23 LOAD_FAST 1 (accum)
26 LOAD_FAST 0 (n)
39 BINARY_MULTIPLY
30 STORE_FAST 1 (accum)
33 STORE_FAST 0 (n)
36 JUMP_ABSOLUTE 0
39 RETURN_VALUE
Simplified algorithm.def recursion_optimizer(f):
new_bytecode = ‘’
for byte in f.__code__.co_code:
if instruction[byte] == ‘LOAD_GLOBAL f’:
get rid of this instruction
elif instruction[byte] == ‘CALL_FUNCTION’:
#replace it with resetting variables and jumping to 0
for arg in *args:
new_bytecode.add_instr(store_new_val(arg))
new_bytecode.add_instr(jump_to_0)
else: #regular byte
new_bytecode.add(byte)
f.__code__.co_code = new_bytecode
return f
Not only does it work, it works FASTER than the original function:
• Timed 10000 calls to fact(450).
Original fact: 1.7009999752
Optimized fact: 1.6970000267
• And faster than other ways of optimizing this.
Here is the most interesting so far:
If our function calls another function…
def sq(x): return x*x
@tailbytes_v1def sum_squares(n, accum):
if n < 1:return accum
else:return sum_squares(n-1, accum+sq(n))
• Our initial algorithm was removing all calls to a function, not only the recursive calls, so this would break.
How do you battle this?
• We need to keep track of function calls and remove only the recursive calls.
• Unfortunately, bytecode doesn’t know which function it’s calling: it just calls whatever is on the stack:
29 CALL_FUNCTION 2
So we just need to keep track of the stack…
• When we hit ‘LOAD_GLOBAL self’, we start keeping track of the stack size (stack_size = 0).
• Now, with every byte, we update the stack size.
• Once we hit stack_size = 0, it means this byte was the recursive call, so we remove it.
• It allows us to not get rid of calls to other functions (e.g., identity).
Road ahead:
• Make it harder to break.
• Translate “normal” (non-tail) recursion into tail-recursion (possibly with ASTs)
• Handle mutual recursion
…And some crasy ideas:
https://github.com/lohmataja/recursion
Or: http://tinyurl.com/tailbytes
Liuda Nikolaeva