Tail call optimization in Scala

Summary

We analyze a Scala function with a recursive tail call, and show that the compiler rewrites it as a nonrecursive loop. We use the @tailrec annotation to verify the optimization.Printing the stack at run time confirms the optimization. Bytecode analysis shows optimization implementation details, and how the java virtual machine manages the stack for the recursive and iterative cases. A bytecode decompiler reverse engineers the bytecode into possible source codes. While Scala supports tail call elimination, Java does not.

Recursion

Recursion is when a function calls itself, usually for the purpose of divide and conquer, meaning dividing a large problem into smaller pieces, and solving the smaller problems. Quicksort is a great example of divide and conquer, and is a problem that can be solved using a recursive algorithm.

Imperative languages make judicious use of recursion. We might use imperative languages without realizing that is what they are called. Imperative programming means executing a series of steps in order. For example:

restore database;
recover database;
alter database open resetlogs;

On the other hand, pure functional languages express actions as functions. Claimed advantages are that side-effect-free functions and immutable data lead to improved readability, maintainability, and concurrency.

In a pure functional language, iteration is implemented as recursion. Usually, recursion works by placing the caller’s return address and the calling arguments on the stack, and then jumping to the same function. In the call, the function accesses the values on the stack as local variables. In the return, the function result is loaded into the caller’s operand stack, and the current frame is discarded. Execution resumes in the caller. Ordinarily, stack resources would limit such an approach.

Enter the tail call. If the recursive call comes just before the return, then the recursion can be optimized as iteration. No stack frame is allocated. The return is implemented as a goto.

The objective of this blog post is to see how tail call optimization works.

Recursion Example

This example calculates the nth Fibonacci number, conventionally indexed from 0 as 0, 1, 1, 2, 3 etc. or from 1 as 1, 1, 2, 3, ,etc. I have divided the program into two files for later analysis.

object FibTailRec contains fib, the inner, recursive function. The arguments are i, p, and f

  • i: the counter, counting down from n
  • p: the previous Fibonacci number
  • f: the Fibonacci number
package fibo

object FibTailRec {

  def fib(i: Int, p: Int, f: Int): Int = i match {
    case 0 => p
    case _ => fib(i - 1, f, p + f)
  }
}

Here is the driver. It has two functions. main calls function fibTailRec with a value (“6”), and prints the result. fibTailResult calls fib with n and seed values for p and f.

import fibo.FibTailRec.fib

object RunFibTailRec {

  def main(args: Array[String]): Unit = {
    println(fibTailRec(6))
  }

  def fibTailRec(n: Int): Int =
    fib(n, 0, 1)

}

The output:

8

The @tailrec annotation

To make sure that the compiler optmized the tail call, we make two changes. This import

import scala.annotation.tailrec

and the @tailrec annotation.

package fibo
import scala.annotation.tailrec

object FibTailRec {

  @tailrec def fib(i: Int, p: Int, f: Int): Int = i match {
    case 0 => p
    case _ => fib(i - 1, f, p + f)
  }
}

The code compiles, so we know that the compiler optimized the tail call.

Displaying the stack

We can also display the stack at the innermost recursion depth by replacing

    case 0 => p

with

   case 0 => {
      new Exception().printStackTrace()
      p
    }

The modified function looks like this:

package fibo
import scala.annotation.tailrec

object FibTailRec {

  @tailrec def fib(i: Int, p: Int, f: Int): Int = i match {
    case 0 => {
      new Exception().printStackTrace()
      p
    }
    case _ => fib(i - 1, f, p + f)
  }
}

The output looks is below. Scala produces two class files for each of my source files. The code runs mainly in classes with “$” added to the original name. Notice that function fib appears only once, i.e., no stack frames were allocated as i counted down to 0.

java.lang.Exception
	at fibo.FibTailRec$.fib(FibTailRec.scala:8)
	at RunFibTailRec$.fibTailRec(RunFibTailRec.scala:10)
	at RunFibTailRec$.main(RunFibTailRec.scala:6)
	at RunFibTailRec.main(RunFibTailRec.scala)
8

Notice that we displayed the stack without the use of a debugger or external tracing tool, and without throwing an exception. Execution finished, and the value, 8, was displayed.

Code that cannot be tail-call optimized

To demonstrate code ineligible for tail call optimization we introduce a coding error. Replace

case _ => fib(i - 1, f, p + f)

with

case _ => {
      var r: Int = fib(i - 1, f, p + f)
      r
    }

The function now looks like this:

package fibo

object FibTailBad {

  def fib(i: Int, p: Int, f: Int): Int = i match {
    case 0 => p
    case _ => {
      var r: Int = fib(i - 1, f, p + f)
      r
    }
  }
}

The recursive call (fib) is no longer the last in the brace-delimited code block. If you try to use the @tailrec annotation:

package fibo
import scala.annotation.tailrec

object FibTailBad {

  @tailrec def fib(i: Int, p: Int, f: Int): Int = i match {
    case 0 => p
    case _ => {
      var r: Int = fib(i - 1, f, p + f)
      r
    }
  }
}

the compiler throws this error:

could not optimize @tailrec annotated method fib: it contains 
a recursive call not in tail position

No class file is produced. If you display the stack at the recursion termination condition, you see “fib” 7 times, 6 for the case when i > 0 and 1 for the case when i = 0.

java.lang.Exception
	at fibo.FibTailBad$.fib(FibTailBad.scala:8)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at fibo.FibTailBad$.fib(FibTailBad.scala:12)
	at RunFibTailBad$.fibTailBad(RunFibTailBad.scala:10)
	at RunFibTailBad$.main(RunFibTailBad.scala:6)
	at RunFibTailBad.main(RunFibTailBad.scala)
8

Bytecode listing

It is interesting to compare and contrast the bytecode of the optimized code vs. the non-optimized code. The 9-line and 12-line Scala files (respectively) expand to 71-line and 78-line bytecode listings. If I had not isolated function fib, the listing would have been much longer. The optimized code:

public final class fibo/FibTailRec$ {
     <ClassVersion=52>
     <SourceFile=FibTailRec.scala>

     public static fibo.FibTailRec$ MODULE$;

     public static  { //  //()V
             new fibo/FibTailRec$
             invokespecial fibo/FibTailRec$.()V
             return
     }

     public fib(int arg0, int arg1, int arg2) { //(III)I
         <localVar:index=0 , name=this , desc=Lfibo/FibTailRec$;, sig=null, start=L1, end=L2>
         <localVar:index=1 , name=i , desc=I, sig=null, start=L1, end=L2>
         <localVar:index=2 , name=p , desc=I, sig=null, start=L1, end=L2>
         <localVar:index=3 , name=f , desc=I, sig=null, start=L1, end=L2>

         L1 {
             f_new (Locals[4]: fibo/FibTailRec$, 1, 1, 1) (Stack[0]: null)
             iload1 // reference to arg0
             istore5
             iload5
             tableswitch 
                val: 0 -> L3
                default -> L4
         }
         L3 {
             f_new (Locals[6]: fibo/FibTailRec$, 1, 1, 1, 0, 1) (Stack[0]: null)
             iload2 // reference to arg1
             goto L5
         }
         L4 {
             f_new (Locals[6]: fibo/FibTailRec$, 1, 1, 1, 0, 1) (Stack[0]: null)
             iload1 // reference to arg0
             iconst_1
             isub
             iload3
             iload2 // reference to arg1
             iload3
             iadd
             istore3
             istore2 // reference to arg1
             istore1 // reference to arg0
             goto L1
         }
         L5 {
             f_new (Locals[6]: fibo/FibTailRec$, 1, 1, 1, 0, 1) (Stack[1]: 1)
             ireturn
         }
         L2 {
         }
     }

     private FibTailRec$() { //  //()V
         <localVar:index=0 , name=this , desc=Lfibo/FibTailRec$;, sig=null, start=L1, end=L2>

         L1 {
             aload0 // reference to self
             invokespecial java/lang/Object.()V
             aload0 // reference to self
             putstatic fibo/FibTailRec$.MODULE$:fibo.FibTailRec$
         }
         L3 {
             return
         }
         L2 {
         }
     }

Scala: [B@2ddd5edcScalaInlineInfo: [B@1a12f6f1}

The unoptimized code:

public final class fibo/FibTailBad$ {
     <ClassVersion=52>
     <SourceFile=FibTailBad.scala>

     public static fibo.FibTailBad$ MODULE$;

     public static  { //  //()V
             new fibo/FibTailBad$
             invokespecial fibo/FibTailBad$.()V
             return
     }

     public fib(int arg0, int arg1, int arg2) { //(III)I
         <localVar:index=5 , name=r , desc=I, sig=null, start=L1, end=L2>
         <localVar:index=0 , name=this , desc=Lfibo/FibTailBad$;, sig=null, start=L3, end=L4>
         <localVar:index=1 , name=i , desc=I, sig=null, start=L3, end=L4>
         <localVar:index=2 , name=p , desc=I, sig=null, start=L3, end=L4>
         <localVar:index=3 , name=f , desc=I, sig=null, start=L3, end=L4>

         L3 {
             iload1 // reference to arg0
             istore4
             iload4
             tableswitch 
                val: 0 -> L5
                default -> L6
         }
         L5 {
             f_new (Locals[5]: fibo/FibTailBad$, 1, 1, 1, 1) (Stack[0]: null)
             iload2 // reference to arg1
             goto L7
         }
         L6 {
             f_new (Locals[5]: fibo/FibTailBad$, 1, 1, 1, 1) (Stack[0]: null)
             aload0 // reference to self
             iload1 // reference to arg0
             iconst_1
             isub
             iload3
             iload2 // reference to arg1
             iload3
             iadd
             invokevirtual fibo/FibTailBad$.fib(III)I
         }
         L1 {
             istore5
         }
         L8 {
             iload5
         }
         L2 {
             goto L7
         }
         L7 {
             f_new (Locals[5]: fibo/FibTailBad$, 1, 1, 1, 1) (Stack[1]: 1)
             ireturn
         }
         L4 {
         }
     }

     private FibTailBad$() { //  //()V
         <localVar:index=0 , name=this , desc=Lfibo/FibTailBad$;, sig=null, start=L1, end=L2>

         L1 {
             aload0 // reference to self
             invokespecial java/lang/Object.()V
             aload0 // reference to self
             putstatic fibo/FibTailBad$.MODULE$:fibo.FibTailBad$
         }
         L3 {
             return
         }
         L2 {
         }
     }

Scala: [B@6cec5b09ScalaInlineInfo: [B@7eea59ce}

Comparison of stacks

The Fibonacci calculation is in block L4 (optimized) and block L6 (not optimized). The optimized code breakdown follows:

bytecode description depth
iload1 // reference to arg0 push i 1
iconst_1 push 1 2
isub pop I, pop 1, push 1 + 1 1
iload3 push f 2
iload2 // reference to arg1 push p 3
iload3 push f 4
iadd pop p, pop f, push p + f 3
istore3 pop and store p + f 2
istore2 // reference to arg1 pop and store p 1
istore1 // reference to arg0 pop and store I + 1 0
goto L1

Notice that the optimized code block ends with nothing on the stack and a goto to L1, the loop termination test.

The non-optimized code analysis (block L6) is:

bytecode description depth
aload0 // reference to self push return address on stack 1
iload1 // reference to arg0 push i 2
iconst_1 push 1 3
isub pop i, pop 1, push i – 1 2
iload3 push f 3
iload2 // reference to arg1 push p 4
iload3 push f 5
iadd pop p, pop f, push p + f 4
invokevirtual fibo/FibTailBad$.fib(III)I call fib

Leading up to the recursive call, the stack holds:

  • the return address
  • the decremented counter
  • the previous Fibonacci number
  • the new Fibonacci number

No tail-call optimization in java

If you try the above analysis on java code, you will find that it is not tail-code optimized. Java does not support tail-call optimization. For example:

	private static int fib(int i, int p, int f) {
		switch (i) {
		case 0:
			new Exception().printStackTrace();
			return p;
		default:
			return fib(i - 1, f, p + f);
		}
	}

 

Decompiling

It is impossible to recover the original source code from the bytecode. Furthermore, it is impossible to infer recursion from iteration. Finally, the decompiler returns Java code, not Scala. Bytecode Viewer has several decompilers to choose from. Here is the output from one of the decompilers.

package fibo;

public final class FibTailRec$ {
   public static FibTailRec$ MODULE$;

   static {
      new FibTailRec$();
   }

   public int fib(int i, int p, int f) {
      while(true) {
         switch(i) {
         case 0:
            return p;
         default:
            int var10000 = i - 1;
            int var10001 = f;
            f += p;
            p = var10001;
            i = var10000;
         }
      }
   }

   private FibTailRec$() {
      MODULE$ = this;
   }
}

The switch default block corresponds to bytecode block L4, analyzed earlier. I chose this listing because it shows the block-scope variables var10000 and var10001 that hold temporary results.

Remarks

For various reasons, people decide to code in Scala. When they do, interest in system implementation and application performance can arise. It can be helpful for administrators to be familiar with the technology and the underlying mechanisms.

This posting is outside my field. Comments, suggestions, and corrections are welcome and will be acknowledged.

Summary

  • Functional language programmers tend to implement iteration as recursion.
  • A common case of recursion is tail call recursion.
  • The optimizer can eliminate the tail call.
  • The @tailrec decorator can cause the optimizer to report which routines cannot be tail-call optimized.
  • It is possible to inadvertently defeat tail call optimization.
  • Exception().printStackTrace() can demonstrate tail-call optimized.
  • Like Java, the Scala compiler produces bytecode that is executable by the java virtual machine.
  • Bytecode Viewer, by Konloch, can display the optimization.
  • Hand tracing bytecode shows the internals of optimized (iterative) code compared to recursive code.
  • Bytecode Viewer decompiles the optimized bytecode to iterative java, not the original, recursive Scala.
  • Java does not support tail call optimization.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s