Removing a recursion in Python, part 1

For the last two decades or so I’ve admired the simplicity and power of the Python language without ever actually doing any work in it or learning about the details. I’ve been taking a closer look lately and having a lot of fun with it; it’s a very pleasant language indeed.

A recent question on Stack Overflow got me thinking about how to turn a recursive algorithm into an iterative one, and it turns out that Python is a pretty decent language for this. The problem that the poster faced was:

• A player is on a positive-integer numbered “space”.
• The aim is to get back to space 1.
• If you’re on an even numbered space, you must pay one coin and jump backwards half that many spaces.
• If you’re on an odd numbered space, you have two choices: pay five coins to go directly to space 1, or pay one coin and jump back one space (to an even numbered space, obviously) and go from there.

The problem is: given the space the player is on, what is the lowest cost in coins to get home? The recursive solution is straightforward:

``````def cost(s):
if s <= 1:
return 0
if s % 2 == 0:
return 1 + cost(s // 2)
return min(1 + cost(s - 1), 5)``````

However, apparently the user was experimenting with values so large that the program was crashing from exceeding the maximum recursion depth! Those must have been some big numbers. The question then is: how do we transform this recursive algorithm into an iterative algorithm in Python?

Before we get into it, of course there are fast ways of solving this specific problem; I’m not interested in this problem per se. Rather, it was just a jumping-off point for the question of how to in general remove a single recursion from a Python program. The point is that we can refactor any simple recursive method to remove the recursion; this is just the example that was at hand.

Of course, the technique that I’m going to show you is not necessarily “Pythonic”. There are probably more Pythonic solutions using generators and so on. What I’d like to show here is that to remove this sort of recursion, you can do so by re-organizing the code using a series of small, careful refactorings until the program is in a form where removing the recursion is easy. Let’s see first how to get the program into that form.

The first step of our transformation is: I want the thing that precedes every recursive call to be a computation of the argument, and the thing that follows every recursive call to be a return of a method call that takes the recursive result:

``````def add_one(n):
return n + 1
def get_min(n):
return min(n + 1, 5)
def cost(s):
if s <= 1:
return 0
if s % 2 == 0:
argument = s // 2
result = cost(argument)
argument = s - 1
result = cost(argument)
return get_min(result)
``````

The second step is: I want to compute the argument in a helper function:

``````...
def get_argument(s):
if s % 2 == 0:
return s // 2
return s - 1
def cost(s):
if s <= 1:
return 0
argument = get_argument(s)
result = cost(argument)
if s % 2 == 0:
return get_min(result)
``````

The third step is: I want to decide which function to call afterwards in a helper function. Notice that we have a function which returns a function here!

``````...
def get_after(s):
if s % 2 == 0:
return get_min
def cost(s):
if s <= 1:
return 0
argument = get_argument(s)
after = get_after(s) # after is a function!
result = cost(argument)
return after(result) ``````

And you know, let’s make this a little bit more general, and make the recursive case a bit more concise:

``````...
def is_base_case(s):
return s <= 1
def base_case_value(s):
return 0
def cost(s):
if is_base_case(s):
return base_case_value(s)
argument = get_argument(s)
after = get_after(s)
return after(cost(argument)) ``````

I hope it is clear that at every small refactoring we have maintained the meaning of the program. We’re now doing a small amount of work twice; we have two tests for even space number per recursion, whereas before we had just one, but we could solve that problem if we wanted by combining our two helpers into one function that returned a tuple. Let’s not worry about it for the sake of this exercise.

We’ve reduced our recursive method to an extremely general form:

• If we are in the base case:
• compute the base case value to be returned
• return it
• If we are not in the base case:
• get the recursive argument
• make the recursive call
• compute the value to be returned
• return it

Something important to notice at this stage is that none of the “afters” must themselves contain any calls to cost; the technique that I’m showing today only removes a single recursion. If you’re recursing two or more times, well then, we’ll need more special techniques for that.

Once we’ve got our recursive algorithm in this form, turning it into an iterative algorithm is straightforward. The trick is to think about what happens in the recursive program. As we do the recursive descent, we call get_argument before every recursion, and we call whatever is in after, well, after every recursion. That is, all of the calls to get_argument happen before all of the calls to every after. Therefore we can turn that into two loops: one calls all the get_arguments and makes a list of all the afters, and the other calls all the afters:

``````...
def cost(s):
# Let's make a stack of afters. Remember, these are functions
# that take the value returned by the "recursive" call, and
# return the value to be returned by the "recursive" method.
afters = [ ]
while not is_base_case(s):
argument = get_argument(s)
after = get_after(s)
afters.append(after)
s = argument
# Now we have a stack of afters:
result = base_case_value(s)
while len(afters) != 0:
after = afters.pop()
result = after(result)
return result``````

No more recursion! It looks like magic, but of course all we are doing here is exactly what the recursive version of the program did, in the same order.

This illustrates a point I make often about the call stack: its purpose is to tell you what is coming next, not what happened before! The only relevant information on the call stack in the recursive version of our program is what the value of after is, since that is the function that is going to be called next; nothing else matters. Instead of using the call stack as an inefficient and bulky mechanism for storing a stack of afters, we can just, you know, store a stack of afters.

Next time on FAIC, we’ll look at a more advanced technique for removing a recursion in Python that very-long-time readers of my blog will recognize.

17 thoughts on “Removing a recursion in Python, part 1”

1. While it’s interesting to consider translating the call stack to a literal stack, what about the DP approach that permits recognition of the “bottom-up” tabulation, at the cost of some extra space?

• My goal here was to solve the general problem of removing a single recursion, but you are certainly right that in some specific problems you can get big wins with dynamic programming; can you describe in more detail the sort of transformation you’re thinking about?

• D. Ben Knoble on said:

I’m not sure my brain was quite turned on all the way when I posted that comment; in the end, I think the transformation I considered is similar to yours. The difference is that, rather than a stack, I would have created a table of the values and computed up to n.

This has, to me, the advantage that the structure of the recursion is still visible, but the cost of its computation is a constant-time lookup.

Plus, it can be memoized/cached between calls, so that subsequent runs can take advantage of previous computations.

2. mr mindor on said:

Not a complete removal of recursion, but if we look at the interesting part of the original problem (the choice when you are on an odd numbered space) and recognize that 5 is the best cost you will pay for any odd number beyond 10, and the worst you will pay for any odd number below 10. You can control the amount of recursion needed to reach an answer…
I apologize if this is not pythonic, (or not actually valid python, I have very little experience with the language and just typed it in here on the fly)

```def cost(s):
accumulatedCost = 0
#make the even space path iterative
while s % 2 == 0:
s = s // 2
accumulatedCost = accumulatedCost+1
#at this point we have an odd number. if it is greater than our threshold, it costs 5 to get to 1.
if s > 10:
return accumulatedCost + 5
#if we find ourselves at 1, we've already calculated the cost
if s <= 1:
if s % 2 == 0:
s = s // 2
accumulatedCost = accumulatedCost+1
elif s > 10:
s = 1
accumulatedCost = accumulatedCost+5
else:
s = s-1
accumulatedCost = accumulatedCost+1

return accumulatedCost
```

Here is where I acknowledge that I’m kinda dodging the point of article which is presenting a general pattern for removing recursion that could be applied to many recursive problems by replacing the helper functions. This will also fail if the cost to jump to 1 were changed from 5. In fact if it were changed to 7 or 8 there is no longer a simple threshold before which we can’t do worse than and beyond which we can’t do better than it.

• mr mindor on said:

I seem to have lost all the formatting and the less than symbols. Definitely not valid python as is ;). I also seem to have lost a few whole lines in the middle which held the actual recursion, and started a second version with all the recursion removed.

• Yeah WordPress is bad about that. I mostly fixed it up.

• Ah, good analysis. In the original SO posting the constant was actually given as a parameter to the problem, but I decided to simplify the problem for presentation reasons without analyzing whether that made the problem easier to solve or not! I’m glad someone did that work so I didn’t have to. 🙂

3. Svetlana Rosemond on said:

You should do a blog post explaining what implementation details you would have done differently.

Out of curiosity, what resources did you use to learn Python? I used Python Crash Course and Python Module of the Week. The book is available, but you can view the content on the authors website. https://pymotw.com/3/. You might find the chapter on Concurrency interesting. https://pymotw.com/3/concurrency.html

• Re: resources: Mostly I read the source code of the CPython compiler, read lots of bug reports and PEPs, and got meetings with Python core team developers.

4. Yeah, you’re definitely approaching the problem with more generality than I did:

```def cost(s):
r = 0
for c in bin(s)[3:]:
if c == '0':
r += 1
else:
r += 2
r = min(r, 5)
return r
```

For sake of generality I will point out that I find a lot of these problems become simpler if you start looking for the hidden monoids. Many not-tail-recursive recursive functions boil down to “concatenate a bunch of monoid terms, presented here from right to left” and making the function iterative boils down to evaluating the same monoid concatenation from left to right.

• Oh yeah, wordpress. Bloody indents…

• Right, the original problem is trivial and irrelevant; it was just a convenient jumping-off place to talk about the general problem of turning a call stack into an explicit stack.

5. Kelly on said:

Nice writeup! I recently hit the recursion limit on a backjumping algorithm (well I’ve been around the block enough; my code didn’t hit that limit, I just knew it was going to be a problem because my use case would be several thousand calls deep…). I labored for hours to find a readable, performant, and general solution to the problem; so I was excited to see somebody else addressing this.

I do have a couple of issues with your approach; which are both essentially the same issue.

The first is readability. You’ve separated the logic into two places — to understand the code, a newcomer needs to reference two functions to find (a) how the recursive parameter evolves, and (b) how the returned value evolves. Maintaining such code is a nightmare, especially with more complex recursion.

The second is performance. By separating the logic into two places, the interpreter has two copies of the branching decision — it evaluates once to evolve the recursive parameter, and then once to record the value subroutine.

Here’s the solution I came up with. It’s just as general, but the conditional only evaluates once and it’s almost as readable as the original.

```def cost(n):
callstack = None

while n > 1:
if n%2==0:
n/= 2
f = lambda val: val+1
else:
n-=1
f = lambda val: min(1+val, 5)
callstack = f, callstack

val = 0
while callstack is not None:
f, callstack = callstack
val = f(val)

return val
```

The difference between your use of append/len/pop and my use of tuples/None is familiar to performance-oriented Pythoners, but nothing that I’d criticize; moreover it’s probably clearer to the novice reader.

I did come up with an even more general solution, using co-routines (next/yield/send)… but it ended up an unreadable mess so I won’t reproduce it here.