Working on the type level

0. Introduction

Scala has a sophisticated type system that allows to catch many logic errors. Let's make it to good use.

For instance: we try to avoid getting a DivisionByZeroError. This would be doable if we had a type that represents numbers different from zero. Or we may want to zip two lists together only if we are guaranteed that they have the same length.

Dependent type systems allow to define types in term of values. We will not make this precise, but show some examples in the Scala type system.

Applications will include:

  • fixed-size sequences
  • modular arithmetic
  • balanced trees

Note: in general, dependent type systems are associated with a proof system. The reason is that we may need to be able to help the compiler do its type inference.

For instance, say we have type NonZeroInt that represents integers that are not 0. We know that the function x => x^2 + 1 sends Int to NonZeroInt, but the compiler may not be able to infer this. Example of dependently-typed languages include:

  • Coq
  • Agda
  • Idris
  • ATS

Scala has some of the needed features, in particular objects can have type members, and two such types are the same exactly when they are the same members of the same value. This will be the basis for all our constructions.

1. Preliminaries

We start with a few functions and types that are defined in the Scala prequel, since we will need to understand how they work.

Since we will need a few complex signatures in the sequel, a nice first step is exploiting the compiler to summon implicit instances and do some work for us.

In [ ]:
def implicitly[A](implicit a: A) = a

Let us check that this works as expected

In [ ]:
case class Foo(x: Int)

implicit val foo = Foo(5)
implicitly[Foo]

Can we use the same method to check whether some types are the same?

In [ ]:
class =:=[A, B]
    
implicit def equalTypeInstance[A] = new =:=[A, A]

We use this definition, together with implicitly, to check for type equality

In [ ]:
//# this does not compile
//# implicitly[Int =:= String]
implicitly[Int =:= Int]

It is a little annoying to work with stuff that does not compile. If we provide a dummy default value, we can turn this check into a value-level boolean.

In [ ]:
def type_==[A, B](implicit ev: A =:= B = null) = ev != null
In [ ]:
type_==[Int, Int]
In [ ]:
type_==[Int, String]

Great, next we do the same for subtyping

In [ ]:
class <:<[-A, +B]
    
implicit def subTypeInstance[A] = new <:<[A, A]

This variance trick is all that is needed to make this work, see:

In [ ]:
//# this does not compile
//# implicitly[Int <:< String]
implicitly[List[Int] <:< Seq[Int]]
In [ ]:
def type_<[A, B](implicit ev: A <:< B = null) = ev != null
In [ ]:
type_<[Int, String]
In [ ]:
type_<[List[String], Seq[String]]

2. Encoding some values as types

Booleans

It is not difficult to encode some types, and the relative operations, at the type level. We start with booleans

In [ ]:
sealed trait Bool
sealed trait True extends Bool
sealed trait False extends Bool

This is a start, but we do not have any operations yet. Let us do better

In [ ]:
sealed trait Bool {
    type Not
}
sealed trait True extends Bool {
    type Not = False
}
sealed trait False extends Bool {
    type Not = True
}
In [ ]:
type_==[True, False#Not]

We can even encode the if operation:

In [ ]:
sealed trait Bool {
    type Not
    type If[T <: Up, F <: Up, Up] <: Up
}
sealed trait True extends Bool {
    type Not = False
    type If[T <: Up, F <: Up, Up] = T
}
sealed trait False extends Bool {
    type Not = True
    type If[T <: Up, F <: Up, Up] = F
}

Note that we have three parameters: Up is there to guarantee that the result of the operation is a supertype of both branch types.

We can then use If to define Not, and we will add some more operations as well

In [ ]:
sealed trait Bool {
    type If[T <: Up, F <: Up, Up] <: Up
}
sealed trait True extends Bool {
    type If[T <: Up, F <: Up, Up] = T
}
sealed trait False extends Bool {
    type If[T <: Up, F <: Up, Up] = F
}

object Bool {
    type &&[A <: Bool, B <: Bool] = A#If[B, False, Bool]
    type || [A <: Bool, B <: Bool] = A#If[True, B, Bool]
    type Not[A <: Bool] = A#If[False, True, Bool]
}

Let us check that everything works as expected

In [ ]:
import Bool._

type_==[True, Not[False]]
In [ ]:
implicitly[False =:= &&[True, False]]
implicitly[True =:= ||[True, False]]

We can use our type test function to convert Bool types to Boolean values

In [ ]:
def bool2val[A <: Bool]: Boolean = type_==[A, True]
In [ ]:
bool2val[True]

Ops... this does not work! The reason is that each function creates its own scope, and the implicit resolution happens there. This means that in the scope of bool2val the implicit for A =:= True is not found and the function always return false.

Let us summon the implicit instance in scope

In [ ]:
def bool2val[A <: Bool](implicit ev: A =:= True = null) = type_==[A, True]
In [ ]:
bool2val[True]
In [ ]:
bool2val[False]

Much better

Natural numbers

We can construct naturals the Peano way

In [ ]:
sealed trait Nat
sealed trait _0 extends Nat
sealed trait Succ[N <: Nat] extends Nat

and then define

In [ ]:
type _1 = Succ[_0]

and so on. Of course, we can add some operations too

In [ ]:
sealed trait Nat {
    type IfZero[T <: Up, F <: Up, Up] <: Up
}
sealed trait _0 extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = T
}
sealed trait Succ[N <: Nat] extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = F
}
type _1 = Succ[_0]

We check that these work as expected

In [ ]:
type_==[_0#IfZero[True, False, Bool], True]

We can extend this to define, say, the precedessor function, which of course has to be defined so that Pred[_0] is _0

In [ ]:
sealed trait Nat {
    type IfZero[T <: Up, F <: Up, Up] <: Up
    type Pred <: Nat
}
sealed trait _0 extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = T
    type Pred = _0
}
sealed trait Succ[N <: Nat] extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = F
    type Pred = N
}

object Nat {
    type Pred[N <: Nat] = N#Pred
}
type _1 = Succ[_0]
In [ ]:
import Nat._

type_==[_0, Pred[_1]]
In [ ]:
type_==[_1, Pred[_0]]

One would think this is enough to define addition recursively. Let's see

In [ ]:
object NatOps {
    type Plus[A <: Nat, B <: Nat] = A#IfZero[B, Succ[Plus[Pred[A], B]], Nat]
}

The compiler will not allow recursive type definitions. We can work around this by adding one more operation to our Nat

In [ ]:
sealed trait Nat {
    type IfZero[T <: Up, F <: Up, Up] <: Up
    type Pred <: Nat
    type Plus[B <: Nat] <: Nat
}
sealed trait _0 extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = T
    type Pred = _0
    type Plus[B <: Nat] = B
}
sealed trait Succ[N <: Nat] extends Nat {
    type IfZero[T <: Up, F <: Up, Up] = F
    type Pred = N
    type Plus[B <: Nat] = Succ[N#Plus[B]]
}

object Nat {
    type Pred[N <: Nat] = N#Pred
    type Plus[A <: Nat, B <: Nat] = A#Plus[B]
}
type _1 = Succ[_0]
type _2 = Succ[_1]
type _3 = Succ[_2]
type _4 = Succ[_3]

Let's see this in action!

In [ ]:
import Nat._
type_==[Plus[_2, _2], _4]
In [ ]:
type_==[Plus[_2, _2], _3]

As before, we can convert these types to values.

In [ ]:
class NatConverter[N <: Nat](val n: Int)

implicit val zeroConverter: NatConverter[_0] = new NatConverter(0)
implicit def succConverter[N <: Nat](implicit ev: NatConverter[N]): NatConverter[Succ[N]] = new NatConverter(ev.n + 1)

def nat2value[N <: Nat](implicit ev: NatConverter[N]) = ev.n
In [ ]:
nat2value[_2]

Well, we now will want to add multiplication as well. But, just as above, we will need to modify our definition of Nat. What if, instead, we add a general function to do recursion?

To do so, we will define a FoldR operation on Nat. We first need the type-level analogue of a function (A, B) => B. Such a function is equivalent to the trait

In [ ]:
trait FoldOp[A, B] {
    def apply(a: A, b: B): B
}

We will then define the type-level analogue

In [ ]:
trait Fold[A, B] {
    type Apply[X <: A, Y <: B] <: B
}

The FoldR operation can then be defined on Nat as follows

In [ ]:
sealed trait Nat {
    type Pred <: Nat
    type FoldR[Init <: Up, Op <: Fold[Nat, Up], Up] <: Up
}
sealed trait _0 extends Nat {
    type Pred = _0
    type FoldR[Init <: Up, Op <: Fold[Nat, Up], Up] = Init
}
sealed trait Succ[N <: Nat] extends Nat {
    type Pred = N
    type FoldR[Init <: Up, Op <: Fold[Nat, Up], Up] = Op#Apply[Succ[N], N#FoldR[Init, Op, Up]]
}

Phew... that was quite complex! Anyway, we can now define a few operations

In [ ]:
object NatOps {
    type Incr = Fold[Nat, Nat] {
       type Apply[N <: Nat, Acc <: Nat] = Succ[Acc]
    }
    type Plus[A <: Nat, B <: Nat] = A#FoldR[B, Incr, Nat]
}
In [ ]:
class NatConverter[N <: Nat](val n: Int)

implicit val zeroConverter: NatConverter[_0] = new NatConverter(0)
implicit def succConverter[N <: Nat](implicit ev: NatConverter[N]): NatConverter[Succ[N]] = new NatConverter(ev.n + 1)

def nat2value[N <: Nat](implicit ev: NatConverter[N]) = ev.n
In [ ]:
import NatOps._
type _1 = Succ[_0]
type _2 = Succ[_1]
type _3 = Succ[_2]
type _4 = Succ[_3]
type _5 = Succ[_4]
type _6 = Succ[_5]
type _7 = Succ[_6]
type _8 = Succ[_7]

type_==[Plus[_3, _4], _7]

By the same method we can add multiplication, and (bounded) difference

In [ ]:
object NatOps {
    type Incr = Fold[Nat, Nat] {
       type Apply[N <: Nat, Acc <: Nat] = Succ[Acc]
    }
    type Plus[A <: Nat, B <: Nat] = A#FoldR[B, Incr, Nat]
    type Sum[B <: Nat] = Fold[Nat, Nat] {
       type Apply[N <: Nat, Acc <: Nat] = Plus[Acc, B]
    }
    type Times[A <: Nat, B <: Nat] = A#FoldR[_0, Sum[B], Nat]
    type Decr = Fold[Nat, Nat] {
       type Apply[N <: Nat, Acc <: Nat] = Acc#Pred
    }
    type Minus[A <: Nat, B <: Nat] = B#FoldR[A, Decr, Nat]
}
In [ ]:
import NatOps._

type_==[Plus[_2, Times[_2, _3]], _8]
In [ ]:
type_==[Minus[_5, _3], _2]
In [ ]:
type_==[Minus[_3, _5], _0]

In fact, we can define the IfZero operations as well, and use that to implement comparison

In [ ]:
object MoreNatOps {
    type NonZero[Up, F <: Up] = Fold[Nat, Up] {
        type Apply[N <: Nat, Acc <: Up] = F
    }
    type IfZero[A <: Nat, T <: Up, F <: Up, Up] = A#FoldR[T, NonZero[Up, F], Up]
    type Pred[N <: Nat] = N#Pred
    type Leq[A <: Nat, B <: Nat] = IfZero[Minus[A, B], True, False, Bool]
    type Less[A <: Nat, B <: Nat] = Leq[Succ[A], B]
}
In [ ]:
import MoreNatOps._
type_==[IfZero[_0, True, False, Bool], True]
In [ ]:
type_==[IfZero[_3, True, False, Bool], False]
In [ ]:
type_==[Succ[Pred[_3]], _3]
In [ ]:
bool2val[_3 Leq _4]
In [ ]:
bool2val[_4 Leq _3]

Notice that we have never used N in the FoldR operation. We could use that to define, for instance, the factorial, but we will stop here

Dependent types

Scala is not known to be a dependently-typed language, but has a few features that allow us to simulate some structures typical of languages with dependent types, such as:

  • path-dependent types - encoded as type members;
  • value-dependent return types in methods;
  • implicits, to make all the necessary wiring less verbose.

Let us see some example in action

Fixed-length lists

The poster child for dependent types are range types and fixed-length lists. We shall start from the latter.

We start by recalling the usual definition of List

In [ ]:
sealed trait List[+A] {
    def ::[B >: A](b: B): List[B] = Cons[B](b, this)
    def head: A
    def tail: List[A]
}
case object Nil extends List[Nothing] {
    def head = sys.error("empty list")
    def tail = sys.error("empty list")
}
case class Cons[A](val head: A, val tail: List[A]) extends List[A]

We can use it as customary

In [ ]:
val list = 1 :: 2 :: 3 :: 4 :: Nil
In [ ]:
list.tail

Of course, this definition is unsafe, as we cannot take the head or tail of an empty list. We could remove head and tail from the List interface, leaving them as operation on Cons. But we would not fare much better: the compiler does not keep track of lengths, hence we have no way to know statically whether we are dealing with a List or a Cons.

Let us parametrize lists by their length, which is a natural number

In [ ]:
sealed trait NList[N <: Nat, +A] {
    def ::[B >: A](b: B) = NCons[N, B](b, this)
}
case object NNil extends NList[_0, Nothing]
case class NCons[N <: Nat, A](val head: A, val tail: NList[N, A]) extends NList[Succ[N], A]
In [ ]:
val nlist: NCons[_2, Int] = 1 :: 2 :: 3 :: NNil
In [ ]:
nlist.tail
In [ ]:
%type nlist.tail
In [ ]:
nlist.tail.tail

Too bad! The compiler keeps track of the length, but does not yet know that when we have a list of positive length, we can keep taking the tail. Let us try to write an append operation

In [ ]:
def append[M <: Nat, N <: Nat, A](xs: NList[M, A], ys: NList[N, A]): NList[Plus[M, N], A] = xs match {
    case NNil => ys
    case NCons(h, t) => h :: append(t, ys)
}

Looks like the naive definition we used is no good at all. The compiler is not able to prove much about our natural numbers.

We will try a different approach: wrap an unsafe data structure and only provide operations when we know they can be carried out

In [ ]:
class Sized[N <: Nat, A](underlying: Seq[A]) {
    def apply[M <: Nat](implicit ev: Less[M, N] =:= True, ev1: NatConverter[M]) = underlying(nat2value[M])
}
In [ ]:
val seq = new Sized[_3, Int](List(1, 2, 3))

seq[_1]
In [ ]:
seq[_4]

We still have the problem that the following compiles, and fails at runtime

In [ ]:
val seq1 = new Sized[_8, Int](List(1, 2, 3))

seq1[_5]

Symmetrically

In [ ]:
val seq2 = new Sized[_3, Int](List(1, 2, 3, 4, 5))

seq2[_3]

To avoid this problem, we use a package-private constructor and we expose a safe API

In [ ]:
class Sized[N <: Nat, A]private(underlying: Seq[A]) {
    def apply[M <: Nat](implicit ev: Less[M, N] =:= True, ev1: NatConverter[M]) = underlying(nat2value[M])
}
object Sized {
    def sized[N <: Nat, A](xs: Seq[A])(implicit ev: NatConverter[N]): Option[Sized[N, A]] =
        if (xs.length == nat2value[N]) Some(new Sized(xs)) else None
}
In [ ]:
import Sized._

val Some(s) = sized[_3, Int](List(1, 2, 3))
In [ ]:
sized[_8, Int](List(1, 2, 3))
In [ ]:
new Sized[_8, Int](List(1, 2, 3))
In [ ]:
s[_2]

We can also add some other convenience methods. In all cases, we ask the compiler for a check that what we are doing is well-typed. For simplicity, we omit variance in the A parameter

In [ ]:
class Sized[N <: Nat, A]private(val underlying: Seq[A])(implicit ev: NatConverter[N]) {
    def apply[M <: Nat](implicit ev: Less[M, N] =:= True, ev1: NatConverter[M]) = underlying(nat2value[M])

    def +:(x: A)(implicit ev: NatConverter[Succ[N]]) = new Sized[Succ[N], A](x +: underlying)
    def :+(x: A)(implicit ev: NatConverter[Succ[N]]) = new Sized[Succ[N], A](underlying :+ x)
    def ++[M <: Nat](xs: Sized[M, A])(implicit ev: NatConverter[Plus[N, M]]) = new Sized[Plus[N, M], A](underlying ++ xs.underlying)
    def head(implicit ev: Less[_0, N] =:= True) = underlying.head
    def tail(implicit ev: Less[_0, N] =:= True, ev1: NatConverter[Pred[N]]) = new Sized[Pred[N], A](underlying.tail)

    override def toString = s"Sized[${ nat2value[N] }]: ${ underlying }"
}
object Sized {
    def sized[N <: Nat, A](xs: Seq[A])(implicit ev: NatConverter[N]): Option[Sized[N, A]] =
        if (xs.length == nat2value[N]) Some(new Sized(xs)) else None
    def single[A](x: A) = new Sized[_1, A](List(x))
    def SNil[A] = new Sized[_0, A](Seq.empty[A])
}
In [ ]:
import Sized._

val Some(s) = sized[_3, Int](List(1, 2, 3))
In [ ]:
s :+ 5
In [ ]:
s ++ s
In [ ]:
s.head
In [ ]:
s.tail

So far, everything is going on as expected. Will the compiler prevent us to take too many tails?

In [ ]:
s.tail.tail.tail.tail

Some more operations show that this is quite a simple to use data structure

In [ ]:
single(1) :+ 2 :+ 3
In [ ]:
1 +: 2 +: 3 +: SNil[Int]

Modular arithmetic

Another nice example of type-level encoding is to use Nat to define classes of integers modulo N, and implement their operations in a type-safe way.

We will implement modular reduction for a natural N greater than _0.

In [ ]:
type IsLeq[A <: Nat, B <: Nat] = Leq[A, B] =:= True
In [ ]:
class Mod[N <: Nat](x: Int)(implicit ev: _1 IsLeq N, ev1: NatConverter[N]) {
    lazy val n = x % nat2value[N]

    def +(y: Mod[N]) = new Mod[N](n + y.n)
    def -(y: Mod[N]) = new Mod[N](n - y.n)
    def *(y: Mod[N]) = new Mod[N](n * y.n)
    override def toString = s"$n mod ${ nat2value[N] }"
}
object Mod {
    def apply[N <: Nat](x: Int)(implicit ev: _1 IsLeq N, ev1: NatConverter[N]) = new Mod[N](x)
}
In [ ]:
val a = Mod[_4](10)
In [ ]:
a * a

Notice that every operation is defined in terms of the lazy val n, which is a reduced representative of the class of x modulo N. This guarantees that we can keep performing operations without overflowing

In [ ]:
def power[N <: Nat](x: Mod[N], y: Int): Mod[N] = if (y == 1) x else x * power(x, y - 1)
In [ ]:
val b = Mod[_7](6)
In [ ]:
power(b, 1000)
In [ ]:
b + Mod[_5](3)

Compile guards

Another neat application of dependent types allows us to choose one of two branches based on which one compiles. Let us define a trait

In [ ]:
trait Application[A, F, G]{
  type Out

  def apply(a: A, f: F, g: G): Out
}

This represent the application of either the function f or g to the value a, with a result of type Out. None of the types is determined yet. We now define our choice method

In [ ]:
def choose[A, F, G](a: A, f: F, g: G)(implicit app: Application[A, F, G]): app.Out = app(a, f, g)

This does nothing but let the compiler summon an appropriate implicit. Notice that the return type is app.Out, hence it depends on the implicit chosen.

We will need to provide suitable implicits to make the above work. The trick is that we will provide two implicits: the first one will need F to be a function of A, and the second one will require G to be a function of A. Whichever compiles will be chosen

In [ ]:
implicit def applyf[A, B, G] = new Application[A, A => B, G] {
    type Out = B
    def apply(a: A, f: A => B, g: G) = f(a)
}
implicit def applyg[A, B, F] = new Application[A, F, A => B] {
    type Out = B
    def apply(a: A, f: F, g: A => B) = g(a)
}

See it in action:

In [ ]:
choose(3, { x: Int => x * x }, { x: String => "hello " + x })
In [ ]:
choose("world", { x: Int => x * x }, { x: String => "hello " + x })

Balanced trees

As a final example, we will build a balanced binary tree type, so that unbalanced trees will be a compile time error.

Since binary trees with a number of leaves that is not a power of two cannot be perfectly balanced, we will require that for each subtree, starting from the root:

  • the left branch is at least as high as the right one
  • the left branch is at most high one more than the right one

We will do so by adding a type member that represents the tree height, and providing tree constructors only when we have implicits that certify that the height conditions are satisfied

In [ ]:
sealed trait Tree[+T]{
  type N <: Nat
}
object Leaf extends Tree[Nothing]{
  type N = _0
}
trait Branch[T] extends Tree[T]{
  def value: T
  def left: Tree[T]
  def right: Tree[T]
}
In [ ]:
object Branch {
  type HBranch[T, N0 <: Nat] = Branch[T] { type N = N0 }

  def apply[T, D <: Nat](value0: T, left0: Tree[T], right0: Tree[T])
    (implicit ev: right0.N IsLeq left0.N, ev1: left0.N IsLeq Succ[right0.N]): HBranch[T, Succ[left0.N]] =
      new Branch[T] {
        val value = value0
        val left = left0
        val right = right0
        type N = Succ[left0.N]
      }
}
In [ ]:
val b1 = Branch(1, Leaf, Leaf) //# N = 1
val b2 = Branch(2, Leaf, Leaf) //# N = 1
val b3 = Branch(3, b1, Leaf) //# N = 2
val b4 = Branch(4, b2, b1) //# N = 2

Let us check that the height are as expected

In [ ]:
type_==[b1.N, _1]
In [ ]:
type_==[b3.N, _2]
In [ ]:
val b5 = Branch(5, b3, b4)

Look what happens as soon as we try to define unbalanced trees..

In [ ]:
Branch(5, b1, b3) //# left branch too small
In [ ]:
Branch(5, b4, Leaf) //# left branch too high
In [ ]:
Branch(5, b5, b1) //# left branch too high

References