Template recursion
Today I’m going to compare a recursive template function with a similar non-template function. I’m choosing factorial because it’s simple and yet still interesting.
I’m going to work with 64-bit unsigned integers and put everything in a namespace called factorial. The factorial functions are factorial::get(n) and factorial::get<N>(). The header file looks like this.
// factorial.h
namespace factorial {
// 20! fits in 64 bits, 21! does not
const int max_n = 20;
// uint_n is param type
typedef unsigned int uint_n;
// uint_v is return type (64-bit unsigned)
// we could use uint64_t (boost::uint64_t from
// <boost//cstdint.hpp> instead of ULL.
typedef unsigned long long uint_v;
// get( n) and get< N >( )
// An overflow will assert and return 0 (zero).
extern uint_v get( uint_n);
template< uint_n > extern uint_v get( );
} /* end namespace factorial */
The CPP file starts with the non-template factorial function. It uses a loop instead of recursion because it’s simple and faster. I’d use recursion if the compiler knew how to flatten tail-recursive functions into loops.
// factorial.cpp
# include <cassert>
# include <boost/static_assert.hpp>
# include "factorial.h"
namespace factorial {
// We assume (max_n+1) is not a valid param
// and is not zero several times in this code.
BOOST_STATIC_ASSERT( (max_n + 1) > max_n );
// exported function
uint_v get( uint_n n)
{
// Check for overflow.
if ( n > max_n ) {
assert( false);
return 0; /* overflow */
}
// Calculate factorial with a loop.
uint_v fact = 1;
for ( ; n > 1 ; n -= 1 ) {
fact *= n;
}
return fact;
}
// .. still inside namespace factorial ..
Before I go on let me define a macro for one-time asserts at file or namespace scope (top-level scope). I’ll post more about asserts later. I’m not using BOOST_STATIC_ASSERT(..) because it only works with compile-time constants (or constexpr, due in C++0x). Even with constexpr we still could not use static (compile-time) asserts below because a function that recurses or loops cannot be constexpr.
I define this macro over several lines for clarity, but it’d be better all on one line to ensure the __LINE__ macro doesn’t get confused.
# define ASSERT_top( ID, A) \
namespace { \
const int debug_junk_ ## ID = \
(assert( A ), 0); \
}
It’s too bad the above macro has to have an ID param, but it’s necessary if we want to use the macro more than once. Remember the parent namespace is automatically “using” the child anonymous namespace, and there’s no way to switch that off and bury the names inside.
(Later note: I can get rid of the ID param in ASSERT_top( ID, A). I can hack up a unique ID using __LINE__. Doh!)
The following code comes next. It checks that get(n) is calculating factorial correctly, and that max_n is the right value.
# ifndef NDEBUG
// .. still inside namespace factorial ..
// check a few factorial values
ASSERT_top( check_val_0, (get( 0) == 1) );
ASSERT_top( check_val_1, (get( 1) == 1) );
ASSERT_top( check_val_2, (get( 2) == 2) );
ASSERT_top( check_val_9, (get( 9) == (9*8*7*6*5*4*3*2)) );
// define max_v, a private debug-only value used below
namespace /* anonymous debug-only */ {
const uint_v max_v = get( max_n); /* max factorial allowed */
} /* end anonymous debug-only namespace */
// assert we did not overflow too soon
ASSERT_top( max_not_z, (max_v != 0));
// assert largest factorial is valid
ASSERT_top( max_is_valid, (get( max_n - 1) == (max_v / max_n)));
// assert next factorial would overflow
ASSERT_top( max_is_max,
(((max_v * (max_n + 1)) / (max_n + 1)) < max_v));
# endif
Now we define the raw<N>() template function to calculate factorial. This time we’ll define it recursively. It’s a very intuitive definition, and we use template specialization to define zero factorial. The template language finds the best match for the template parameters, and can match values and not just classes.
// .. still inside namespace factorial ..
namespace /* anonymous */ {
// template version of factorial
template< uint_n N >
uint_v raw ( ) { return raw< N - 1 >( ) * N; }
template<>
uint_v raw< 0 >( ) { return 1; }
// overflow returns zero
template<>
uint_v raw< max_n + 1 >( ) { return 0; }
} /* end anonymous namespace */
Now I’m ready to define factorial::get<N>(). I want it to call raw<N>() after bounds checking. My first try looked like this. Notice the bounds checking makes the raw<max_n+1>() specialization above unnecessary.
// .. still inside namespace factorial ..
// first attempt
template< uint_n N >
uint_v get( )
{
if ( N > max_n ) {
assert( false);
return 0; /* overflow */
}
return raw< N >( );
}
This works great for expressions like get<0>(), get<10>(), and get<25>(). The last expression correctly returns zero after asserting when the program starts. But the expression get<1000000>() chokes the compiler because the expansion is “too complicated”. Which tells us the compiler is not trimming the code but is expanding into 1000000 functions. If the compiler could recognize that get<1000000>() becomes { assert( false); return 0; } after evaluating “if ( 1000000 > max_n ) ..” then this would work fine.
I was a little disappointed when I saw this. I was expecting the compiler to recognize compile-time constants and trim code around ifs, logical ands (&&), ors (||), and (a?b:c) operators.
But this is a lesson into how template expansion works. Even though the generated code is not examined and trimmed, we do know that template params are evaluated. So we can fix our problem by defining get<N>() like this.
// .. still inside namespace factorial ..
// final attempt
template< uint_n N >
uint_v get( )
{
assert( N <= max_n);
return raw< ((N > max_n) ? (max_n + 1) : N) >( );
}
(It’d be cool if we could solve this with template specialization like “template<> get<N when (N > max_n) >() { return 0; }”. This is like a guard clause in Erlang and not part of C++. Yet.)
(Another solution is to punt — abandon template recursion and define “template< uint_n N > get() { return get( N); }”. Always consider the easy way out before wasting too much time!)
Now factorial::get<1000000>() compiles with no complaint. Notice that we now rely on the specialization for get<max_n+1>().
The last bit of code checks the template function and closes the namespace.
// .. still inside namespace factorial .. # ifndef NDEBUG // check that the two versions agree ASSERT_top( agree_0 , (get( 0) == get< 0 >( )) ); ASSERT_top( agree_1 , (get( 1) == get< 1 >( )) ); ASSERT_top( agree_2 , (get( 2) == get< 2 >( )) ); ASSERT_top( agree_3 , (get( 3) == get< 3 >( )) ); ASSERT_top( agree_y , (get( max_n - 1) == get< max_n - 1 >( ))); ASSERT_top( agree_x , (get( max_n) == get< max_n >( ))); // The following works, but the call to get<max_n+1>() // asserts false when the program starts. // ASSERT_top( template_over0, (get< max_n + 1 >( ) == 0)); // The following compiles. At startup it asserts false. // ASSERT_top( template_over1, (get< 1000000 >( ) == 0)); # endif } /* end namespace factorial */
So that’s how to define a recursive template function, and how to use template specialization, in this case to terminate the recursion.
I’d get rid of the template functions if I were releasing this as production code even though the two versions don’t collide. The template functions aren’t needed and add complication. It’s best to keep things simple. And of course the non-template functions are much more useful — you can only use the templates with constants known at compile time.
Comments
Leave a Reply