Length-indexed lists in C++

Over the weekend, while lamenting the lack of dependent types in Haskell, I thought I'd see what C++ can offer in this space. After all, C++ templates allow value parameters as well as types, so I decided to try implementing a length-indexed list.

The list implementation itself is based on Bartosz Milewski's immutable list class (Functional Data Structures in C++: Lists) with some simplifications.

// Defined externally to List because it needs to be independent from the size parameter.
template <typename T>  
struct Item {  
    Item(T v, std::shared_ptr<const Item<T>> const & tail) :
        _val(v), _next(tail) {}
    T _val;
    std::shared_ptr<const Item<T>> _next;
};

// List type, parameterized over type and length.
template <typename T, int size>  
class List {  
public:  
    List() {}

    List(T v, List<T, size-1> const & tail) :
        _head(std::make_shared<Item<T>>(v, tail._head)) {}

    explicit List(std::shared_ptr<const Item<T>> items) : _head(items) {}

    bool isEmpty() const { return !_head; }

    T front() const {
        assert(!isEmpty());
        return _head->_val;
    }

    List<T, size-1> pop_front() const {
        assert(!isEmpty());
        return List<T, size-1>(_head->_next);
    }

    List<T, size+1> push_front(T v) const {
        return List<T, size+1>(v, *this);
    }

    // May be null.
    std::shared_ptr<const Item<T>> _head;
};

The List constructor isn't intended to be used directly, a "safe" constructor empty should be used instead, which ensures that the size template parameter is set to 0.

// Helper for constructing empty lists.
template <typename T>  
List<T, 0> empty() {  
    return List<T, 0>();
}

Why is this useful? Quite simple, it helps ensure type safety. For example, if you're working on a neural network library, you'll probably need to calculate the dot product of two variable-length vectors. You could of course just use std::vector and just check the lengths are equal at runtime. But we're going to go one better and enforce equal lengths at compile-time!

The function we want to implement will look something like this:

template <typename T, int size>  
T dotProduct(List<T, size> a, List<T, size> b) {  
    ...
}

Because the size augments are the same, trying to take the dot product of two lists with different sizes simply won't compile. Instead you will get a lovely cannot convert argument... error (or template parameter 'size' is ambiguous if you let the compiler infer the types).

We're going to need two other functions first to implement this: zipWith and fold.

The zipWith function takes a function T(U, V) and two lists of of equal length, one containing elements of type U, and the other elements of type V. A new list is generated from the results of the function applied to each pair of elements from the lists.

// (u -> v -> t) -> [u] -> [v] -> [t]
template <typename T, typename U, typename V, int size>  
List<T, size> zipWith(std::function<T(U, V)> f, List<U, size> us, List<V, size> vs) {  
    if (us.isEmpty()) {
        // Use constructor directly instead of empty().
        // The type checker can't infer List<T, size> ~ List<T, 0>.
        return List<T, size>();
    } else {
        return zipWith<T, U, V, size - 1>(
            f,
            us.pop_front(),
            vs.pop_front()).push_front(f(us.front(), vs.front()));
    }
}

There is a slight problem with this implementation: the type checker will fail to terminate. Even though it's clear at the term level that the recursion stops when size == 0, the template system keeps going, infinitely recursing to build types List<T, -1> and so on.

To stop this, it's necessary to specify a specialized template for size == 0. However, we only want to specialize for size, leaving T, U and V as abstract types. Unfortunately, C++ only allows partial specialization of template classes, not functions. What we would like is something like this:

// Invalid C++.
template <typename T, typename U, typename V>  
List<T, 0> zipWith(std::function<T(U, V)>, List<U, 0> us, List<V, 0> vs) {  
    return empty<T>();
}

But until that feature finds its way into the spec, we'll have to make individual complete specializations for each combination of types. For the dot product example, we want one for T, U, V == int:

template <>  
List<int, 0> zipWith(std::function<int(int, int)>, List<int, 0> us, List<int, 0> vs) {  
    return empty<int>();
}

One of the nice things about using specialization is that the isEmpty() check is no longer required in the main definition.

The fold function takes a function, a starting value and a list to produce a summary value. This implementation requires a specialization for the empty list just like zipWith.

// (u -> t -> t) -> t -> [u] -> t
template <typename T, typename U, int size>  
T fold(std::function<T(U, T)> f, T acc, List<U, size> xs) {  
    T nextAcc = f(xs.front(), acc);
    return fold<T, U, size - 1>(f, nextAcc, xs.pop_front());
}

template <>  
int fold(std::function<int(int, int)> f, int acc, List<int, 0> xs) {  
    return acc;
}

So finally we can implement our typesafe dotProduct:

template <typename T, int size>  
T dotProduct(List<T, size> a, List<T, size> b) {  
    // Sum the element-wise products.
    return fold<T, T, size>(
        [](T x, T y) { return x + y; },
        0,
        zipWith<T, T, T, size>(
            [](T x, T y) { return x * y; },
            a,
            b));
}

The full source code for this example can be found in this Gist.