Removing a recursion in Python, part 2

Good day all, before we get into the continuation of the previous episode, a few bookkeeping notes.

First, congratulations to the Python core developers on successfully choosing a new governance model. People who follow Python know that the Benevolent Dictator For Life has stepped down after 30 years of service, which led to a great many questions about how to best run an open-source, community-driven programming language. I’m really looking forward to seeing how the new governance model works and evolves.

Second, there were a number of good comments on my last episode. Several people noted that the given problem could be turned into a much simpler iterative form than the one I gave. That’s certainly true! Remember, the problem that inspired this series was not a particularly interesting or important problem, it was just a problem where I realized that I could demonstrate the general method of removing a recursion. The general method is not the best method for all possible specific cases. Also, some commenters noted that there are other general ways to remove a recursion; indeed there are, and some of them do make for more readable code.

Third, have a festive holiday season all, and a happy New Year; we’ll pick up with more fabulous adventures in 2019!


Today I want to look at a rather different technique for removing a recursion, called continuation passing style. Long-time readers will recall that I discussed this technique in Javascript back in the early days of this blog, and that a variation on CPS is the mechanism that underlies async/await in C#.

In our example of last time we made a stack of “afters” — each “after” is a method that takes the return value of the “previous” recursive call, and computes the result of the “current” recursive call, which then becomes the argument to the “next” after, and so on. The “continuation” in “continuation passing style” just means “what happens after”; the code that happens after a bit of code is the continuation of that code in your program.

Put another way: in our original recursive program, we return and that returned value is then passed directly to the continuation of the recursive call. A return is logically the same as a method call to the continuation.

The idea of continuation passing style is, you guessed it, we’re going to pass in the continuation as an argument, and then instead of returning, we’re going to call the continuation.

Let’s suppose we again have this general form of a one-recursion recursive function:

def f(n):
    if is_base_case(n):
        return base_case_value
    argument = get_argument(n)
    after = get_after(n)
    return after(f(argument))  

How might we turn this into the equivalent program in CPS? We have two returns; both of them must turn into calls to the continuation. The first one is straightforward:

#we're now passing in continuation c, a function of one argument
def f(n, c):
    if is_base_case(n):
        c(base_case_value)
        return

Easy peasy. What about the non-base case? Well, what were we going to do in the original program? We were going to (1) get the arguments, (2) get the “after” — the continuation of the recursive call, (3) do the recursive call, (4) call “after”, its continuation, and (5) return the computed value to whatever the continuation of f is.

We’re going to do all the same things, except that when we do step (3) now we need to pass in a continuation that does steps 4 and 5:

    argument = get_argument(n)
    after = get_after(n)
    f(argument, lambda x: c(after(x)))

Hey, that is so easy! What do we do after the recursive call? Well, we call after with the value returned by the recursive call. But now that value is going to be passed to the recursive call’s continuation function, so it just goes into x. What happens after that? Well, whatever was going to happen next, and that’s in c, so it needs to be called, and we’re done.

Let’s try it out. Previously we would have said

print(f(100))

but now we have to pass in what happens after f(100). Well, what happens is, the value gets printed!

f(100, print)

and we’re done.

So… big deal. The function is still recursive. Why is this interesting? Because the function is now tail recursive! A tail-recursive method is a recursive method where the last thing that happens is the recursive call, and those methods are very easy to remove the recursion, because they can easily be turned into a loop:

# This is wrong!
def f(n, c):
    while True:
        if is_base_case(n):
            c(base_case_value)
            return
        argument = get_argument(n)
        after = get_after(n)
        n = argument
        c = lambda x: c(after(x))

The naive transformation of the code above into a loop is wrong; do you see why?  Give it some thought and then read on.

.

.

.

.

.

.

.

Hopefully you figured it out; the lambda is closed over variables c and after, which means that when they change, they change in every implementation of the lambda. In Python, as in C#, VB, JavaScript and many other languages, lambdas close over variables, not the values the variables had when the lambda was created.

We can work around this though by making a factory that closes over the values:

def continuation_factory(c, after)
    return lambda x: c(after(x))

# This is wrong too!
def f(n, c):
    while True:
        if is_base_case(n):
            c(base_case_value)
            return
        argument = get_argument(n)
        after = get_after(n)
        n = argument
        c = continuation_factory(c, after)

It sure looks like we’ve correctly removed the recursion here, by turning it into a loop. But this is not right either! Again, give it some thought; this one is more subtle.

.

.

.

.

.

.

The problem we started with was that a recursive algorithm is blowing the stack. We’ve turned this into an iterative algorithm — there’s no recursive call at all here! We just sit in a loop updating local variables.

The question though is — what happens when the final continuation is called, in the base case? What does that continuation do? Well, it calls its after, and then it calls its continuation. What does that continuation do? Same thing.

All we’ve done here is moved the recursive control flow into a collection of function objects that we’ve built up iteratively, and calling that thing is still going to blow the stack. So we haven’t actually solved the problem.

Or… have we?

What we can do here is add one more level of indirection, and that will solve the problem. (This solves every problem in computer programming except one problem; do you know what that problem is?)

What we’ll do is we’ll change the contract of f so that it is no longer “I am void-returning and will call my continuation when I’m done”. We will change it to “I will return a function that, when it is called, calls my continuation. And furthermore, my continuation will do the same.”

That sounds a little tricky but really its not. Again, let’s reason it through. What does the base case have to do? It has to return a function which, when called, calls my continuation. But my continuation already meets that requirement by assumption:

def f(n, c):
    if is_base_case(n):
        return c(base_case_value)

What about the recursive case? We need to return a function, which when called, executes the recursion. The continuation of that call needs to be a function that takes a value and returns a function that when called executes the continuation on that value. We know how to do that:

    argument = get_argument(n)
    after = get_after(n)
    return lambda : f(argument, lambda x: lambda: c(after(x)))

OK, so how does this help? We can now move the loop into a helper function:

def trampoline(f, n, c):
    t = f(n, c)
    while t != None:
        t = t()

And call it:

trampoline(f, 3, print)

And holy goodness it works.

Follow along what happens here. Here’s the call sequence with indentation showing stack depth:

trampoline(f, 3, print)
  f(3, print)

What does this call return? It effectively returns

lambda:
  f(2, lambda x:
    lambda:
      print(min_distance(x))

so that’s the new value of t.

That’s not None, so we call t(), which calls:

  f(2, lambda x: lambda: print(min_distance(x))

What does that thing do? It immediately returns this mess: (I’ll indent it to make it more clear, but this is not necessarily legal Python.)

lambda: f(1,
  lambda x:
    lambda:
      (lambda x:
        lambda:
          print(min_distance(x)))(add_one(x))

So that’s the new value of t. It’s not None, so we invoke it. That calls:

  f(1,
    lambda x:
      lambda:
        (lambda x:
          lambda:
            print(min_distance(x)))(add_one(x))

Now we’re in the base case, so we call the continuation, substituting 0 for x. It returns:

      lambda: 
        (lambda x:
          lambda:
            print(min_distance(x)))(add_one(0))

So that’s the new value of t. It’s not None, so we invoke it.

That calls add_one(0) and gets 1. It then passes 1 for x in the middle lambda. That thing returns:

lambda: print(min_distance(1))

So that’s the new value of t. It’s not None, so we invoke it. And that calls

  print(min_distance(1))

Which prints out the correct answer, print returns None, and the loop stops.

Notice what happened there. The stack never got more than two deep because every call returned a function that said what to do next to the loop, rather than calling the function.

If this sounds familiar, it should. Basically what we’re doing here is making a very simple work queue. Every time we “enqueue” a job, it is immediately dequeued, and the only thing the job does is enqueues the next job by returning a lambda to the trampoline, which sticks it in its “queue”, the variable t.

We break the problem up into little pieces, and make each piece responsible for saying what the next piece is.

Now, you’ll notice that we end up with arbitrarily deep nested lambdas, just as we ended up in the previous technique with an arbitrarily deep queue. Essentially what we’ve done here is moved the workflow description from an explicit list into a network of nested lambdas, but unlike before, this time we’ve done a little trick to avoid those lambdas ever calling each other in a manner that increases the stack depth.

Once you see this pattern of “break it up into pieces and describe a workflow that coordinates execution of the pieces”, you start to see it everywhere. This is how Windows works; each window has a queue of messages, and messages can represent portions of a workflow. When a portion of a workflow wishes to say what the next portion is, it posts a message to the queue, and it runs later. This is how async await works — again, we break up the workflow into pieces, and each await is the boundary of a piece. It’s how generators work, where each yield is the boundary, and so on. Of course they don’t actually use trampolines like this, but they could.

The key thing to understand here is the notion of continuation. Once you realize that you can treat continuations as objects that can be manipulated by the program, any control flow becomes possible. Want to implement your own try-catch? try-catch is just a workflow where every step has two continuations: the normal continuation and the exceptional continuation. When there’s an exception, you branch to the exceptional continuation instead of the regular continuation. And so on.

The question here was again, how do we eliminate an out-of-stack caused by a deep recursion in general. I’ve shown that any recursive method of the form

def f(n):
    if is_base_case(n):
        return base_case_value
    argument = get_argument(n)
    after = get_after(n)
    return after(f(argument))
...
print(f(10))

can be rewritten as:

def f(n, c):
    if is_base_case(n):
        return c(base_case_value)
    argument = get_argument(n)
    after = get_after(n)
    return lambda : f(argument, lambda x: lambda: c(after(x)))
...
trampoline(f, 10, print)

and that the “recursive” method will now use only a very small, fixed amount of stack.

Of course you probably would not actually make this transformation yourself; there are libraries in Python that can do it for you. But it is certainly character-building to understand how these mechanisms work.

Advertisements

Removing a recursion in Python, part 1

For the last two decades or so I’ve admired the simplicity and power of the Python language without ever actually doing any work in it or learning about the details. I’ve been taking a closer look lately and having a lot of fun with it; it’s a very pleasant language indeed.

A recent question on Stack Overflow got me thinking about how to turn a recursive algorithm into an iterative one, and it turns out that Python is a pretty decent language for this. The problem that the poster faced was:

  • A player is on a positive-integer numbered “space”.
  • The aim is to get back to space 1.
  • If you’re on an even numbered space, you must pay one coin and jump backwards half that many spaces.
  • If you’re on an odd numbered space, you have two choices: pay five coins to go directly to space 1, or pay one coin and jump back one space (to an even numbered space, obviously) and go from there.

The problem is: given the space the player is on, what is the lowest cost in coins to get home? The recursive solution is straightforward:

def cost(s):
  if s <= 1:
    return 0
  if s % 2 == 0:
    return 1 + cost(s // 2) 
  return min(1 + cost(s - 1), 5)

However, apparently the user was experimenting with values so large that the program was crashing from exceeding the maximum recursion depth! Those must have been some big numbers. The question then is: how do we transform this recursive algorithm into an iterative algorithm in Python?

Before we get into it, of course there are fast ways of solving this specific problem; I’m not interested in this problem per se. Rather, it was just a jumping-off point for the question of how to in general remove a single recursion from a Python program. The point is that we can refactor any simple recursive method to remove the recursion; this is just the example that was at hand.

Of course, the technique that I’m going to show you is not necessarily “Pythonic”. There are probably more Pythonic solutions using generators and so on. What I’d like to show here is that to remove this sort of recursion, you can do so by re-organizing the code using a series of small, careful refactorings until the program is in a form where removing the recursion is easy. Let’s see first how to get the program into that form.

The first step of our transformation is: I want the thing that precedes every recursive call to be a computation of the argument, and the thing that follows every recursive call to be a return of a method call that takes the recursive result:

def add_one(n):
  return n + 1
def get_min(n):
  return min(n + 1, 5)
def cost(s):
  if s <= 1:
    return 0
  if s % 2 == 0:
    argument = s // 2
    result = cost(argument)
    return add_one(result) 
  argument = s - 1
  result = cost(argument)
  return get_min(result)

The second step is: I want to compute the argument in a helper function:

...
def get_argument(s):
  if s % 2 == 0:
    return s // 2
  return s - 1
def cost(s):
  if s <= 1:
    return 0
  argument = get_argument(s)
  result = cost(argument)
  if s % 2 == 0:
    return add_one(result) 
  return get_min(result)

The third step is: I want to decide which function to call afterwards in a helper function. Notice that we have a function which returns a function here!

...
def get_after(s):
  if s % 2 == 0:
    return add_one
  return get_min
def cost(s):
  if s <= 1:
    return 0
  argument = get_argument(s)
  after = get_after(s) # after is a function!
  result = cost(argument)
  return after(result) 

And you know, let’s make this a little bit more general, and make the recursive case a bit more concise:

...
def is_base_case(s):
  return s <= 1
def base_case_value(s):
  return 0
def cost(s):
  if is_base_case(s):
    return base_case_value(s)
  argument = get_argument(s)
  after = get_after(s)
  return after(cost(argument)) 

I hope it is clear that at every small refactoring we have maintained the meaning of the program. We’re now doing a small amount of work twice; we have two tests for even space number per recursion, whereas before we had just one, but we could solve that problem if we wanted by combining our two helpers into one function that returned a tuple. Let’s not worry about it for the sake of this exercise.

We’ve reduced our recursive method to an extremely general form:

  • If we are in the base case:
    • compute the base case value to be returned
    • return it
  • If we are not in the base case:
    • get the recursive argument
    • make the recursive call
    • compute the value to be returned
    • return it

Something important to notice at this stage is that none of the “afters” must themselves contain any calls to cost; the technique that I’m showing today only removes a single recursion. If you’re recursing two or more times, well then, we’ll need more special techniques for that.

Once we’ve got our recursive algorithm in this form, turning it into an iterative algorithm is straightforward. The trick is to think about what happens in the recursive program. As we do the recursive descent, we call get_argument before every recursion, and we call whatever is in after, well, after every recursion. That is, all of the calls to get_argument happen before all of the calls to every after. Therefore we can turn that into two loops: one calls all the get_arguments and makes a list of all the afters, and the other calls all the afters:

...
def cost(s):
  # Let's make a stack of afters. Remember, these are functions
  # that take the value returned by the "recursive" call, and
  # return the value to be returned by the "recursive" method.
  afters = [ ]
  while not is_base_case(s):
    argument = get_argument(s)
    after = get_after(s)
    afters.append(after)
    s = argument
  # Now we have a stack of afters:
  result = base_case_value(s)
  while len(afters) != 0:
    after = afters.pop()
    result = after(result)
  return result

No more recursion! It looks like magic, but of course all we are doing here is exactly what the recursive version of the program did, in the same order.

This illustrates a point I make often about the call stack: its purpose is to tell you what is coming next, not what happened before! The only relevant information on the call stack in the recursive version of our program is what the value of after is, since that is the function that is going to be called next; nothing else matters. Instead of using the call stack as an inefficient and bulky mechanism for storing a stack of afters, we can just, you know, store a stack of afters.

Next time on FAIC, we’ll look at a more advanced technique for removing a recursion in Python that very-long-time readers of my blog will recognize.