Recursion is a useful technique that lets us implement many algorithms in a more succinct and readable fashion than an iterative approach. Recursion also has a few drawbacks, namely:
- Stack space is an issue in most programming languages and can lead to stack overflow
- Proper tail call optimization is not supported in all programming languages
These limitations may eventually force you to re-implement a recursive solution as an iterative solution. To that end, I want to look at a generic method for implementing any recursive method (regardless of complexity) as an iterative method and then offer some commentary on the subject.
The code in this post will be written in Python but this technique works in any OO language and can be adapted to work in non-OO languages as well.
A Model for Recursive Functions
Lets start with a model for any recursive function and then build up our solution from there. We’ll be looking at a simple binary tree implementation and the method to print it. This code comes from TutorialsPoint
class Node: def __init__(self, data): self.left = None self.right = None self.data = data def insert(self, data): # Compare the new value with the parent node if self.data: if data < self.data: if self.left is None: self.left = Node(data) else: self.left.insert(data) elif data > self.data: if self.right is None: self.right = Node(data) else: self.right.insert(data) else: self.data = data # Print the tree def PrintTree(self): if self.left: self.left.PrintTree() print( self.data), if self.right: self.right.PrintTree()
The method we are concerned with implementing iteratively is the
PrintTree essentially examines all the way down the left side of an arbitrary tree and prints it, then backtracks up, prints the data, and then tracks all the way down the right side. This has the effect of printing the tree in order.
We can start working on a solution to our problem by realizing that all recursion involves a stack. While there are frames on the stack. Each one of those frames contains a number of operations (the literal code in the method). Finally new frames can be added onto the stack via a method call. So our model has to include at least these three things:
- A stack to track where in the recursion we are
- A state tracking mechanism that can say which operations we are executing now
- A way to add new frames to the stack
Implementing a Stack
Our stack implementation is relatively straightforward:
class IterativeProcessor: def __init__(self, initial_state): self._initial_state = initial_state def traverse(self): operator = Operator(self._initial_state) while operator.has_operations: operator.operate() class Operator: def __init__(self, initial_state): self._stack = [initial_state] @property def has_operations(self): return len(self._stack) > 0 def operate(self): top_of_stack = self._stack[-1] try: new_state = top_of_stack.next() if new_state: self._stack.append(new_state) except TerminatedException: self._stack.pop()
We have two classes,
Operator. The job of
IterativeProcessor is to initialize an
Operator with our initial state (we’ll get into this next) and simply run operations as long as there are frames on the stack. This is our looping construct and prevents us from having to write a loop anywhere else.
Operator has the responsibility of maintaining the actual call stack. At each call of operate, we check the top of the stack and call the
next method. All
State objects have only one method
next which may optionally return a
State object which is then added to the stack. Finally, if the frame has concluded and has no next operations (it throws a
TerminatedException) then we can pop it off the stack and allow another frame to start executing. This fulfills two of our requirements:
Operatormaintains knowledge of the current call stack and can tell when there are no frames left to execute
Stateobjects can add new frames to the stack by returning a
The best way to see how the state objects work is to look at an example. Here’s an implementation of
PrintTree using a
State object that’s compatible with our
class TreeTraverseState: def __init__(self, tree): self._tree = tree self._current_state = 0 def next(self): if self._current_state == 0: return self._left_state() elif self._current_state == 1: return self._process_this_node_state() elif self._current_state == 2: return self._right_state() else: raise TerminatedException def _left_state(self): if self._tree.left: self._current_state = 1 return TreeTraverseState(self._tree.left) else: return self._process_this_node_state() def _process_this_node_state(self): self._current_state = 2 print(self._tree.data), def _right_state(self): self._current_state = 3 if self._tree.right: return TreeTraverseState(self._tree.right) else: raise TerminatedException
There’s nothing particularly fancy going on here, and that’s the point. We’ve translated our original
PrintTree code into a series of individual states in a state machine. We then transition through the steps of the state machine with
next, and each call to
next can optionally return another state machine.
Thus our model is recursion as a stack of state machines. You can view a visualization here. Notice how the stack builds up with new
TreeTraverseState instances for every node, and how those instances preserve their state allowing us to implement the algorithm correctly.
Translating recursive functions in this way has a few good features:
- The method is totally generic and can be re-used for any recursive function no matter how complex. In fact neither
IterativeProcessorneed to change. The only class that needs to be re-implemented for different recursive functions is the
- Translating the body of a recursive function into a
Stateclass is simple. The number of possible states in the machine is equal to the number of times the method can recurse + 1. After that you simply need to return a new instance of the
Stateclass with the correct arguments any time you would normally recurse, and raise
TerminatedExceptionfor any base cases or any time you would not recurse.
Stateclasses are easy to unit test. Simply initialize with any arbitrary state and step through with
nextasserting that the correct values are returned
There are also two primary drawbacks:
- This method is memory intensive compared to other possible iterative implementations. Each
Stateobject is a full object and depending on how much you recurse, there can be a lot of them. Other iterative solutions may be more difficult to implement but will use less memory because of the number of objects instantiated using this method.
- This method is much harder to read than the recursive solution and is subtle if you aren’t familiar with state machines. Not only is there more code, you have to keep track of the state of each frame and the stack manually which is normally done by the underlying language for you.
For these reasons, this solution is mostly interesting for academic reasons. It could potentially find some use in replacing a very complex recursion such as in a recursive decent parser. Any simple recursion would be best replaced either by a more straight forward, less mechanical iterative approach, or by simply introducing a maximum recursion depth.
An author by the name of Tom Moertel produced an excellent series along similar lines as this post about re-implementing recursion as iteration with similar motivations as the ones listed here. The mechanics of translation are not as generic but are probably more appropriate for a real world implementation as they have generally lower memory overhead than this generic implementation.
Finally, I want to leave a challenge for the reader. The code below is an implementation of the factorial algorithm using our
class FactorialState: def __init__(self, n, accumulator): self._n = n self._accumulator = accumulator self._current_state = 0 def next(self): if self._current_state == 0: self._current_state = 1 if self._n < 2: raise TerminatedException self._accumulator.mult(self._n) return FactorialState(self._n - 1, self._accumulator) else: raise TerminatedException class Accumulator: def __init__(self, start): self._start = start def mult(self, n): self._start *= n def div(self, n): self._start /= n @property def value(self): return self._start accumulator = Accumulator(1) IterativeProcessor(FactorialState(6, accumulator)).traverse() print(accumulator.value)
If you visualize the code at Python Tutor (or step through it with an IDE debugger), you’ll see that there are many state machines sitting on the stack in
Operator once the final stack frame is hit. These stacks are superfluous all along and basically just raise
TermiantedException, wasting both time and memory.
Operator implemented proper tail call optimization, these stacks would never exist in the first place, rather we would simply pop the top of the stack and push the new stack frame at each step which would give us the same result. How can we implement tail call optimization in our
Operator while still preserving the proper behavior for non-tail call recursion (such as our