Transformative pattern: re-factoring a function to make it tail recursive

A few days ago I talked a little about tail recursion, and I mentioned a pattern I called iterator/builder to transform some simple recursive functions into tail-recursive functions. The transformation looks like this (in Erlang).

% not tail recursive
factorial_a( 0 ) -> 1;
factorial_a( N ) -> N * factorial_a( N - 1 ).

% tail recursive
factorial_b( N ) -> fact_tail( N, 1 ).
fact_tail( 0, Product ) -> Product;
fact_tail( N, Product ) -> fact_tail( N - 1, N * Product ).

In this example of the iterator/builder pattern, N is the iterator and Product is the builder. Here are some more examples.

% count items in a list
% not tail recursive
count_a( [     ] ) -> 0;
count_a( [_ | R] ) -> 1 + count_a( R).

% tail recursive
count_b( List           ) -> count_t( List, 0).
count_t( [     ], Total ) -> Total;
count_t( [_ | R], Total ) -> count_t( R, 1 + Total ).

% triangle number (0 + 1 + 2 + 3 ...)
% not tail recursive
triangle_a( 0 ) -> 0;
triangle_a( N ) -> N + triangle_a( N - 1 ).

% tail recursive
triangle_b( N      ) -> triangle_t( N, 0 ).
triangle_t( 0, Sum ) -> Sum;
triangle_t( N, Sum ) -> triangle_t( N - 1, N + Sum ).

Transforming these functions to be tail-recursive is simple. The pattern looks like this:

% Iterator/Builder pattern to transform a recursive
% function into a tail-recursive function.
% This is a common pattern. It is probably described
% elsewhere with a different a name.
%
% Pattern variables:
%   null_pattern - matches null iter
%   null_value   - initial value
%   iter_pattern - matches non-null iter
%   iter_first   - first element value
%   iter_rest    - iter with first element removed
%   fn_name( iter )
%   combine_expression( iter_first, iter )

% Before (not tail recursive):
fn_name( null_pattern ) ->
  null_value;
fn_name( iter_pattern ) ->
  combine_expression( iter_first, fn_name( iter_rest )).

% After (tail recursive):
fn_name( X ) ->
  fn_name_tail( pre_transform( X ), null_value ).
fn_name_tail( null_pattern, Collect ) ->
  Collect;
fn_name_tail( iter_pattern, Collect ) ->
  fn_name_tail( iter_rest,
    combine_expression( iter_first, Collect )).

This is not perfect however. Consider the triangle function above. triangle_a(3) calculates (3+(2+(1+0))) while triangle_b(3) calculates (1+(2+(3+0))). If you change N+Sum to Sum+N in triangle_b you end up calculating (((0+3)+2)+1) instead. This is the left-fold vs right-fold problem that you sometimes run into when you flatten recursion, and also shows how the initial value for the fold operation (zero in this case) can move around. For an operation like + (plus) it doesn’t matter since + is associative and commutative, but it matters in other cases. Consider the stutter/1 function.

% stutter( [a,b,c] ) -> [a,a,b,b,c,c]

% not tail recursive:
stutter_a( [     ] ) -> [];
stutter_a( [H | R] ) -> glom( H, stutter_a( R )).

% tail recursive
stutter_b( A ) -> lists:reverse( stutter_t( A, [] )).
stutter_t( []     , C ) -> C;
stutter_t( [H | R], C ) -> stutter_t( R, glom( H, C )).

% another tail recursive version
stutter_c( A ) -> stutter_t( lists:reverse( A ), [] ).

% combine expression for stutter
glom( H, R ) -> [H, H | R].

stutter_a([a,b,c]) calculates glom(a,glom(b,glom(c,[]))) while stutter_b([a,b,c]) calculates glom(c,glom(b,glom(a,[]))) and then reverses the result. stutter_c([a,b,c]) starts by reversing the iterator and so calculates glom(a,glom(b,glom(c,[]))) just like stutter_a.

So the simple transformative pattern described above sometimes has to be modified if the combine expression is not associative and commutative. Sometimes you can fix the result, sometimes you can reverse the initial iterator, sometimes you need an extra end-of-iterator parameter when you reverse the iterator, and sometimes you have to change the combine expression.

And of course sometimes you just want to leave the function alone and forget about tail recursing. This kind of code transformation or re-factoring requires some understanding of the problem before being applied.

Exactly when are templates expanded?

C++ templates are only expanded as necessary. The following is error-free even though does_not_exist() does indeed not exist. The method never_called() is never called and thus never expanded, and so does_not_exist() is never needed.

  template< typename X >
  struct
use_non_existant_stuff
  {
    int never_called( )  { does_not_exist( ); }
  };

// no errors or warnings here
use_non_existant_stuff< int > inst_that_does_not_use_methods;

But what about when a template method is expanded. When are the identifiers resolved? Consider the class uses_hidden_fn<..>.

  template< typename X >
  struct
uses_hidden_fn
  {
    X x;
    int call_hidden_fn( )  { return hidden_fn( x); }
  };

namespace hiding_ns {
int hidden_fn( double);
}

// This is OK.
uses_hidden_fn< double > inst_double;

// This fails because hidden_fn( double) is not visible.
int a = inst_double.call_hidden_fn( );

Invoking the method call_hidden_fn() fails because although hidden_fn(double) is defined, it is hidden in the hiding_ns namespace.

You can solve this problem with a simple using hiding_ns::hidden_fn; declaration. And hidden_fn(double) doesn’t have to be exposed before it is needed, only before the end of the compilation unit. The following compiles without error or warning.

  template< typename X >
  struct
uses_hidden_fn
  {
    X x;
    int call_hidden_fn( )  { return hidden_fn( x); }
  };

// This is OK.
uses_hidden_fn< double > inst_double;

// This is OK because hidden_fn( double) is declared below.
int a = inst_double.call_hidden_fn( );

// Declare hidden_fn(double) after we needed it above.
namespace hiding_ns {
int hidden_fn( double);
}
using hiding_ns::hidden_fn;

I’m testing this on the Microsoft (ms9) compiler, but I think it’s standard that templates are not expanded and identifiers are not bound until the end of the compilation unit.

This also compiles without complaint.

  template< typename X >
  struct
uses_hidden_fn
  {
    X x;
    int call_hidden_fn( )  { return hidden_fn( x); }
    int call_hidden_fn2( ) { return hidden_fn2( ); }
  };

namespace hiding_ns {
struct hidden_type { };
}

uses_hidden_fn< hiding_ns::hidden_type > inst;
int a = inst.call_hidden_fn( );
int b = inst.call_hidden_fn2( );

// Declare hidden_fn(..) after we needed it above.
namespace hiding_ns {
int hidden_fn( hidden_type &);
int hidden_fn2( );
}
using hiding_ns::hidden_fn2; // needed
//using hiding_ns::hidden_fn; // not needed!

I was surprised. I thought that we’d still need the using hiding_ns::hidden_fn; declaration at the bottom to make this work, but the (MS) compiler doesn’t complain. Apparently calling hidden_fn(..) with an argument whose type is in the hiding_ns namespace is enough of a hint for the compiler to dig hiding_ns::hidden_fn(hidden_type&) out of that namespace.

If you declare another hidden_fn(hiding_ns::hidden_type&) at top namespace scope the compiler complains that there are two hidden_fn(..)s to choose from. So it does not prefer the exposed declaration over the hidden one.

Although this sort of behavior is interesting, I would recommend strongly against relying on it. Even if it is standard (is it?), it feels like something on the edge. It is clearer to declare hiding_ns::hidden_fn(..) near the top of the file, followed by a using statement. That way hidden_fn(..) is exposed before uses_hidden_fn<..> is expanded.

Tail recursion

In a recent post I showed some Erlang functions that were tail recursive. So I thought I’d talk a little about tail recursion today.

Let’s say you have a function A(..), and the very last thing it does is call another function B(..).

int A( int); // declare function A(..)
int B( int); // declare function B(..)

// The last thing A(..) does is call B(..).
int A( int x)
{
  int y = do_something( x, 2 * x);
  int z = do_something_else( x, y);
  return B( x + y + z);
};

Calling A(..) sets up a stack frame with x on it. Soon y and z are pushed onto the stack frame, and finally a new stack frame for B(..) is created. When B(..) returns, A(..)‘s stack frame is popped and the value is returned.

But once B(..) is called, A(..)‘s stack frame isn’t needed for anything. The compiler could arrange to have A(..)‘s stack frame destroyed before B(..) is called, so that B(..)‘s stack frame could be build in exactly the same place. The return value from B(..) would be returned directly to A(..)‘s caller. In other words, B(..) steps on A(..)‘s tail, which is known as tail-call optimization.

The benefit is that the runtime stack is smaller since it doesn’t have to hold both frames at the same time. In this case it’s a small optimization, perhaps useful in tight embedded situations. But consider the following:

int A( int x)
{
  do_something( x);
  return (x > 1000000000) ? x : A( x + 1);
};

In this example the last thing A(x) does is call A(x+1), so A(x+1) can step on A(x)‘s tail. But it’s no longer just a small optimization since this repeats about a billion times. You’ll need a very big runtime stack if the compiler doesn’t arrange for A(..) to step on it’s own tail.

This is what they call tail recursion. I first read about it in Guy Steele’s famous paper Lambda: The Ultimate GOTO. It’s essential for languages like Scheme and Erlang to optimize tail recursion because they don’t provide a loop, since loops are disguised gotos. In these languages you recurse instead of loop.

But the programmer has to be aware of when he is and isn’t tail recursing. If the recursion is not tail recursion the compiler cannot tail-optimize. Consider this Erlang function.

% Take a list like [a,b,c] and produce [a,a,b,b,c,c].
stutter( [] ) ->
  [];
stutter( [Head | Rest] ) ->
  [Head, Head | stutter( Rest )].

It looks like the last thing stutter/1 does is call stutter(Rest). But really the last thing it does is make a list incorporating the result of stutter(Rest). So despite appearances, this is NOT tail recursive.

Here is the tail-recursive version of stutter/1.

stutter( A ) ->
  lists:reverse( stutter_tail( A, [] )).

stutter_tail( [], Collect ) ->
  Collect;
stutter_tail( [Head | Rest], Collect ) ->
  stutter_tail( Rest, [Head, Head | Collect] ).

In this, stutter/1 calls stutter_tail/2 which is tail recursive. stutter_tail/2 takes two arguments, the iterator and the builder. The iterator is the list that we take apart, peeling the head off at each iteration. And while we use up the iterator, we add to the builder and construct the stuttering list.

Or consider factorial/1. First the intuitive version, which at first glance looks tail recursive even though it’s not.

% not tail recursive
factorial( 0 ) -> 1;
factorial( N ) -> N * factorial( N - 1 ).

And the tail-recursive version, which follows the iterator/builder pattern like stutter_tail/2 above.

factorial( N ) ->
  factorial( N, 1 ).

% tail-recursive:
factorial( 0, Product ) ->
  Product;
factorial( N, Product ) ->
  factorial( N - 1, N * Product ).

You see this iterator/builder pattern a lot in tail-recursive realizations, where one parameter is used up while the other is built up. Here are two more examples.

count( List ) ->
  count( List, 0 ).
count( [], Total ) ->
  Total;
count( [_ | Rest], Total ) ->
  count( Rest, 1 + Total ).

triangle( N ) ->
  triangle( N, 0 ).
triangle( 0, Sum ) ->
  Sum;
triangle( N, Sum ) ->
  triangle( N - 1, N + Sum ).

reverse( A ) ->
  reverse( A, [] ).
reverse( [], Collect ) ->
  Collect;
reverse( [Head | Rest], Collect ) ->
  reverse( Rest, [Head | Collect] ).

When you’re programming in Scheme or Erlang you get used to reaching for a recursive solution whenever you get into a looping situation. And you’re always conscious of whether your implementation is tail recursive or not. And you soon find yourself thinking in recursive patterns like iterator/builder instead of iterative patterns like while(test_this()){do_that();}.

Any recursive algorithm can be expressed as an iterative loop with a stack. If it’s tail recursive, you don’t need the stack to make it into a loop. Some languages, like Scheme and Erlang, will automatically translate tail recursion into in-place looping whenever possible. This allows you to express many algorithms more naturally than you would with a loop without having to worry about stack overflow.

It would be nice if C compilers optimized tail recursion as a loop. It would be even better if C compilers could arrange for a trailing function to step on its caller’s tail whenever possible, even in non-recursive situations. This would allow a more functional coding style in C, and would make it easier for Scheme/Erlang “compilers” to use C as a target language. (I think one of the design goals for C should be to make it a universal target language.)

Tail recursion is more problematic in C++. Usually the last thing a C++ function does is run destructors for local variables. Sometimes this is absolutely essential, such as when you are using a wrapper class to lock/unlock (see Resource Acquisition is Initialization, or RAII). If the C++ compiler optimized tail recursion or tail stepping, the compiler would have to run the destructors before overwriting the caller’s stack frame. In the end the programmer would have to be given a way to control this, thus making C++ even more complex than it already is.

Next Page →