Motivation

In CAD software, we often want to express constraints, e.g.

Drag points in the systems below and see what constraints are obeyed:

In general, constraints can be expressed as systems of equations.

Given a system of equations, e.g.

$$ (2x + 3y) \times (x - y) = 2 $$ $$ 3x + y = 5 $$

How would we go about solving it?

One option is to use a computer algebra system like sage, which performs symbolic manipulation to give exact-form solutions:

sage: solve([(2*x + 3*y) * (x - y) == 2, 3*x + y == 5], x,y)
[[x == -1/56*sqrt(401) + 95/56, y ==  3/56*sqrt(401) - 5/56],
 [x ==  1/56*sqrt(401) + 95/56, y == -3/56*sqrt(401) - 5/56]]

However, this scales poorly with huge systems of many equations and variables. Instead, we'll approach the problem numerically.

Move your mouse around the graph below. The arrow at your cursor points to a nearby solution: see if you can find it.

There are two possible solutions to this system of equations:

(which are equal to the closed-form solutions found above)

You probably used the strategy of following the arrow "downhill" towards a solution. We will formalize this technique (known as gradient descent) as a numeric solver for arbitrary systems of equations.

Representing equations

Intuition

To start, we need some way of representing sets of variables and systems of equations.

We'd like a way to evaluate $f(x,y)$ at a given point $(x,y)$. Since we're using gradient descent, we'd also like to know all of its partial derivatives $\partial f/\partial x$, $\partial f/\partial y$, at that point.

We'll use a technique known as automatic differentiation, which involves keeping track of both the result and its derivatives at every step of a calculation.

Consider the equation $f(x,y) = (x + 5) \times (x + y)$. If we want to evaluate it at $x = 2$, $y = 3$, we'll end up with the following computation tree:

Computation tree

In this tree, derivatives are computed with sum and product rules.

Now that we know what we want, let's work out an implementation. As in most of my recent recreational coding, we'll be using Haskell.

Implementation

We'll represent variables (with associated values) as a Data.Map.

type Vars a = Map.Map a Double

The set of variables-and-values $x=1$, $y=3$ is represented by Map.fromList [('x', 1), ('y', 3)]

An Equation has a single function eval. When called, eval returns a value and first derivatives with respect to each variable:

newtype Equation a = Equation {eval :: Vars a -> (Double, Vars a)}

Note the function-fu here: an Equation only returns a value when activated with eval and a map of variable IDs to values.

The simplest equation is a single variable. When called, its value is found in the Map and its derivative is set to 1:

var :: Ord a => a -> Equation a
var tag = Equation $ \vars -> (Map.findWithDefault 0 tag vars,
                               Map.singleton tag 1)

We can test this out with eval:

λ> eval (var 'x') $ Map.fromList [('x', 3.0)]
(3.0, fromList [('x',1.0)])

As expected, we get back the same value and a partial derivative of 1.

The system is polymorphic in tag types. Above, we used a Char, but we could also use Strings (or anything of typeclass Ord):

λ> eval (var "distance") $ Map.fromList [("distance", 3.0)]
(3.0, fromList [("distance",1.0)])

We'd like to do math on equations, e.g.:

λ> let z = var 'x' + var 'y'

How can we implement this?

The function (+) has type Num a => a -> a -> a, which means we'll need to make our Equation class an instance of Num. We'll also make it an instance of Fractional; between these two typeclasses, we'll be able to do all kinds of arithmetic on our equations.

As discussed above, we'll also be tracking partial derivatives with respect to each variable. Look closely and you'll spot implementations of the sum, product, and quotient rules for differentiation:

instance Ord a => Num (Equation a) where
    a + b = Equation $ \vars ->
        let (ra, das) = eval a vars
            (rb, dbs) = eval b vars
         in (ra + rb, Map.unionWith (+) das dbs)
    a * b = Equation $ \vars ->
        let (ra, das) = eval a vars
            (rb, dbs) = eval b vars
         in (ra * rb, Map.unionWith (+) (Map.map (*rb) das)
                                        (Map.map (*ra) dbs))
    abs a = Equation $ \vars ->
        let (r, ds) = eval a vars
         in (abs r, Map.map (* signum r) ds)
    negate a = Equation $ \vars ->
        let (r, ds) = eval a vars
         in (negate r, Map.map negate ds)
    signum a = Equation $ \vars ->
        let (r, _) = eval a vars
         in (signum r, Map.empty)
    fromInteger a = Equation $ const (fromInteger a, Map.empty)

instance Ord a => Fractional (Equation a) where
    a / b = Equation $ \vars ->
        let (ra, das) = eval a vars
            (rb, dbs) = eval b vars
         in (ra / rb, Map.map (/rb**2) $
                      Map.unionWith (+)
                          (Map.map (*rb) das)
                          (Map.map (negate . (*ra)) dbs))
    fromRational a = Equation $ const (fromRational a, Map.empty)

With these typesclass instances defined, we get arithmetic!
Let's check our math from the example above:

λ> let x = var 'x'
λ> let y = var 'y'
λ> eval ((x + 5)*(x + y)) $ Map.fromList [('x', 2), ('y', 3)]
(35.0,fromList [('x',12.0),('y',7.0)])

Gradient descent

Now, let's focus on the mechanics of gradient descent.

We start at some point $[x_0,y_0,z_0,...]$ in $n$-dimensional space (where $n$ is the number of variables in the system). Our goal is to find $[x,y,z,...]$ such that a particular function $f(x,y,z,...)$ is zero.


We'll move through $n$-dimensional space until a terminating condition is met. The three terminating conditions are as follows:

The cost function is below some threshold

$$ f(x,y,z,...) < \epsilon$$

All partial derivatives are close to zero (indicating a local minima)

$$ \left|\frac{\partial f}{\partial x}\right| < \epsilon \text{ and } \left|\frac{\partial f}{\partial y}\right| < \epsilon \text{ and } \left|\frac{\partial f}{\partial z}\right| < \epsilon \text{ and ...} $$

The solver fails to converge

$$ \left| f(x_n,y_n,z_n,...) - f(x_{n+1},y_{n+1},z_{n+1},...) \right| < \epsilon$$


The direction of the step is given by the systems' gradient, i.e. $$ \nabla f(x,y,z,...) = \left[\frac{\partial f}{\partial x}\vec{x}, \frac{\partial f}{\partial y}\vec{y}, \frac{\partial f}{\partial z}\vec{z},...\right]$$ where $\vec{x}$, $\vec{y}$, $\vec{z}$ are unit vectors along that dimension.

The logic of our step function is as follows:

Here's our implementation:

epsilon :: Double
epsilon = 1e-12

-- Solves a single step of gradient descent,
-- using a backtracking line search.
--
-- Returns Nothing if the descent has converged,
-- otherwise Just nextPoint
step :: Ord a => Equation a -> Vars a -> Maybe (Vars a)
step eqn vars =
    if r < epsilon || all ((< epsilon) . abs) ds || converged
    then Nothing
    else Just next
    where (r, ds) = eval eqn vars
          (next, converged) = backtrack 1
          threshold = 0.5 * (sum $ Map.map (^2) ds)
          backtrack stepSize =
              if r - r' >= stepSize * threshold
              then (vars', abs (r - r') < epsilon)
              else backtrack (stepSize * 0.5)
              where vars' = Map.unionWith (-) vars $
                            Map.map (*stepSize) ds
                    r' = fst (eval eqn vars')

We can repeat this over and over again until we get Nothing back, indicating that the solver has converged. Because step returns a Maybe, we use the reversed bind operator (=<<) to chain function calls:

-- Find a local minima from an Equation and a starting point
minimize :: Ord a => Equation a -> Vars a -> Vars a
minimize eqn vars =
    fromJust $ last $ takeWhile isJust
             $ iterate (step eqn =<<) (return vars)

The grid below contains the same function you tried to solve earlier, but shows the solver's path instead of the local gradient:

Notice that the solver converges to different solutions depending on its starting point. This is actually desirable behavior for a CAD constraint system: given multiple valid solutions, it should pick the one that's closest to the existing state of the drawing.

Solving systems of equations

We've been glossing over how the solver actually solves a system of equations.

We've written a tool that minimizes a single equation then used it to satisfy multiple constraints – how does that work?

The answer: sum-of-squares cost functions.

In our example, we're trying to solve $$ (2x + 3y) \times (x - y) = 2 $$ $$ 3x + y = 5 $$

First, we rephrase them as cost functions: $$ \left((2x + 3y) \times (x - y)\right) - 2 $$ $$ (3x + y) - 5 $$

For ease-of-use, we define an infix operator ===:

infixl 5 ===
(===) :: Ord a => Equation a -> Equation a -> Equation a
(===) = (-)

This allows us to construct equations that look like equality expressions but are in fact cost functions, e.g. var "x" === 5 is actually var "x" - 5

Now that we can express cost functions, let's combine a set of them by summing their squares. The result is a function that's always $\geq 0$; it's only equal to zero when all of the constraints are met: $$ \left(\left((2x + 3y) \times (x - y)\right) - 2\right)^2 + \left((3x + y) - 5\right)^2$$

This is the function that we put into our minimizer.

The simplest solver is thus

\eqns vars -> minimize (sum $ map (^2) eqns) vars

However, this solver is not robust against overconstrainted systems:
If you ask it to solve var 'x' === 3 and var 'x' === 5, it will give you an answer somewhere in between and fail to satisfy both constraints.

We'd like the solver to satisfy as many constraints as possible; in the example above, it should pick either $x=3$ or $x=5$.

Here's one implementation strategy:

-- Returns a list of booleans indicating constraint
-- satisfaction and a  map of resulting variable values.
solveSystem :: Ord a => [Equation a] -> Vars a -> ([Bool], Vars a)
solveSystem eqns vars =
    if and satisfied
    then (satisfied, vars')
    else -- If not all constraints are satisfied, drop
         -- an unsatisfied constraint and recurse
        let index = fromJust $ elemIndex False satisfied
            (front, back) = splitAt index eqns
            (satisfied', out) =
                solveSystem (front ++ (drop 1 back)) vars'
            (a, b) = splitAt index satisfied'
        in (a ++ False:b, out)
    where vars' = minimize (sum $ map (^2) eqns) vars
          scores = map (\eqn -> (fst $ eval eqn vars')^2) eqns
          satisfied = map (< sqrt epsilon) scores

If a constraint is not satisfied, we throw it out and try again. The function returns both a solution and a list of which constraints were satisfied:

λ> let x = var 'x'
λ> solveSystem [x === 5, x === 3] $ Map.fromList [('x', 2)]
([False,True],fromList [('x',3.0)])

Constraint solving

With all of this explained, we can now understand the earlier interactive examples. Each one defines a set of constraints-as-equations then uses gradient descent to minimize the total sum-of-squares cost function.

We apply an extra constraint to the dragged point, setting it equal to the cursor's position. This constraint is the first to be discarded if infeasible.

Here's one more for the road – thanks for following along!


$$ a_x^2 + a_y^2 = 1 $$ $$ (a_x - b_x)^2 + (a_y - b_y) ^2 = 2 $$ $$ b_y = c_y = 0 $$ $$ c_x - b_x = 1 $$


About the visualizations

The interactive visualizations are running Haskell code that was cross-compiled into Javascript with Haste. I looked at both Haste and GHCJS; the latter didn't have a good way to make Haskell functions available from Javascript, so it wasn't a suitable choice.

I've been impressed by Haste: it worked out-of-the-box and exporting functions is a single call. However, there's a significant performance penalty: all of these simulations are instantaneous on the desktop, but you'll see a bit of lag when moving points around the diagrams.

The graphics are made with d3.js and lots of amaturish (actual) Javascript.

Fun fact: the solver is 30% less code than the graphics
(146 lines vs. 212, as reported by cloc).


Thanks to Sam Calisch, Neil Gershenfeld, and Tikhon Jelvis for feedback on a draft of this article; props to Chaoya Li for catching a bug in my implementation of division.