Tuesday, April 13, 2010

My solution to Matrix-Chain-Order problem (Chapter 15 of CLRS)

For practice, I tried solving the problem "Matrix-Chain-Order" that appears on section 15.2 of Introduction to Algorithms (that's the chapter on Dynamic Programming). In short, one gets a list of (compatible) matrices, A1, A2, A3... An, and the problem is to compute the optimal order for creating their product. 


I liked my solution, so I'm posting this just for the record. This should be helpful if one is reading that section and looks for a Java implementation of that problem (in order to compare with his own solution of course :)). To make it tougher, I tried it writing this first on paper, then in my IDE, to see how many errors would slip in the paper version. There were two: a single off-by-one error in the inner loop (damn!), and that I forgot some components of the cost of a particular multiplication expression.


It creates optimal multiplication expressions of increasing width. Initially width == 1, and we just have the leaves, the matrixes themselves. Next, we find the optimal expressions of width == 2, but there is only one order to multiple 2 arrays, so nothing special happens. After that, things get more interesting, since we get to create various possible trees for each sequence of matrixes, and retain the best.


Amazingly, it is exactly one year (to the day!) ago that I posted a same exercise in tree building (enumerating binary trees, also via dynamic programming): http://code-o-matic.blogspot.com/2009/04/wonderful-programming-exercise.html
I think I need more diversity :)


Anyway. The main method is this:



    public static void main(String[] args) {
        Op op = matrixChainOrder(Arrays.asList(
                new Matrix(30, 35),
                new Matrix(35, 15),
                new Matrix(15, 5),
                new Matrix(5, 10),
                new Matrix(10, 20),
                new Matrix(20, 25)));
        System.out.println(op);
        System.out.println("Cost: " + op.cost());
    }

And it prints:

(([30 X 35] * ([35 X 15] * [15 X 5])) * (([5 X 10] * [10 X 20]) * [20 X 25]))
Cost: 15125

(This mimicks the example and solution of the book, but it also creates the expression of the multiplication, easy to pretty-print and ready to be computed).

The full code:

import java.util.*;

public class Matrixes {
    public static void main(String[] args) {
        Op op = matrixChainOrder(Arrays.asList(
                new Matrix(30, 35),
                new Matrix(35, 15),
                new Matrix(15, 5),
                new Matrix(5, 10),
                new Matrix(10, 20),
                new Matrix(20, 25)));
        System.out.println(op);
        System.out.println("Cost: " + op.cost());
    }

    static Op matrixChainOrder(List matrixes) {
        Map optima = new HashMap();
        for (int i = 0; i < matrixes.size(); i++) {
            optima.put(new Interval(i, i), new Leaf(matrixes.get(i)));
        }
        for (int width = 1; width < matrixes.size(); width++) {
            for (int offset = 0; offset < matrixes.size() - width; offset++) {
                Op best = DUMMY;
                for (int cut = 0; cut < width; cut++) {
                    Op left = optima.get(new Interval(offset, offset + cut));
                    Op right = optima.get(new Interval(offset + cut + 1, offset + width));
                    Op mul = new Mul(left, right);
                    if (mul.cost() < best.cost()) {
                        best = mul;
                    }
                }
                optima.put(new Interval(offset, offset + width), best);
            }
        }
        return optima.get(new Interval(0, matrixes.size() - 1));
    }
    private static final Op DUMMY = new Op() {
        public int cost() { return Integer.MAX_VALUE; }
        public Matrix compute() { throw new AssertionError(); }
        public int rows() { throw new AssertionError(); }
        public int columns() { throw new AssertionError(); }
    };

    private static class Interval {
        final int begin;
        final int end;
        Interval(int begin, int end) { this.begin = begin; this.end = end; }
        public boolean equals(Object o) {
            if (!(o instanceof Interval)) return false;
            Interval that = (Interval)o;
            return this.begin == that.begin && this.end == that.end;
        }
        public int hashCode() { return 31 * begin * (17 + end * 31); }
        public String toString() { return "[" + begin + ".." + end + "]"; }
    }
}

class Matrix {
    final int rows; final int columns;
    Matrix(int rows, int columns) { this.rows = rows; this.columns = columns; }
    int rows() { return rows; }
    int columns() { return columns; }
    public String toString() { return "[" + rows + " X " + columns + "]"; }
}

interface Op {
    int cost();
    Matrix compute();
    int rows();
    int columns();
}

class Leaf implements Op {
    final Matrix matrix;
    Leaf(Matrix matrix) { this.matrix = matrix; }
    public int cost() { return 0; }
    public Matrix compute() { return matrix; }
    public int rows() { return matrix.rows(); }
    public int columns() { return matrix.columns(); }
    public String toString() { return matrix.toString(); }
}

class Mul implements Op {
    final Op left;
    final Op right;
    Mul(Op left, Op right) {
        this.left = left; this.right = right;
    }
    public int rows() {
        return left.rows();
    }
    public int columns() {
        return right.columns();
    }
    public int cost() {
        return left.rows() * left.columns() * right.columns() + left.cost() + right.cost();
    }
    public Matrix compute() { throw new UnsupportedOperationException("later"); }
    public String toString() { return "(" + left + " * " + right + ")"; }
}