Automatic Differentiation using Operator Overloading

Automatic differentiation, derivatives

Posted on November 6, 2016  

Almost all of the libraries for creating neural networks (Tensorflow, Theano, Torch, etc.) are using automatic differentiation (AD) in one way or another. It has applications in the other parts of the mathematical world as well since it is a clever and effective way to calculate the gradients, effortlessly. It works by first creating a computational graph of the operations and then traversing it in either forward mode or reverse mode. Let us see how to implement them using operator overloading to calculate the first order partial derivative. I highly recommend reading Colah's blog here first. It has an excellent explanation about computational graphs and this post is related to the implementation side of it. It may not be the best performing piece of code for AD but I think it's the simplest one for getting your head around the concept. The example function we are considering here is:

$$f(a, b) = (a + b) * (b + 1)$$

Using basic differential calculus rules, we can calculate the derivative of the above function with respect to a and b by applying the sum rule and the product rule:

$$\begin{eqnarray} \frac{\partial e}{\partial a} &=&(c*\frac{\partial d}{\partial a}) + (\frac{\partial c}{\partial a}*d) \\ \nonumber &=&((a + b)* (\frac{\partial b}{\partial a} + \frac{\partial 1}{\partial a}) )+ ((\frac{\partial a}{\partial a} + \frac{\partial b}{\partial a})* (b + 1)) \\ \nonumber &=&(1 + 0)* (b + 1)) \\ \nonumber &=&b + 1 \\ \nonumber \end{eqnarray} $$

and similarly,

$$\begin{eqnarray} \frac{\partial e}{\partial b} &=&(c*\frac{\partial d}{\partial b}) + (\frac{\partial c}{\partial b}*d) \\ \nonumber &=&((a + b)* (\frac{\partial b}{\partial b} + \frac{\partial 1}{\partial b})) + ((\frac{\partial a}{\partial b} + \frac{\partial b}{\partial b})* (b + 1)) \\ \nonumber &=&((a + b)*(1 + 0)) * ((0 + 1)*(b + 1)) \\ \nonumber &=&(a + b)*(b + 1) \\ \nonumber \end{eqnarray} $$

In order to get the derivative of a function programmatically there are two approaches we can follow and operate on the computational graph, forward mode and reverse mode. Both of these approaches make use of the chain rule.

In [1]:
# for overloading
import Base.+, Base.*

Forward Mode

Forward mode is very similar to the calculation we did above. We pick an independent variable with respect to which we would like to calculate the partial derivative of the function, set its derivative with respect to itself as 1 and then, we recursively moving forward calculate the derivative of the sub-graph till we reach the output node.

In a pen-and-paper calculation, one can do so by repeatedly substituting the derivative of the inner functions in the chain rule:

$${\displaystyle {\frac {\partial y}{\partial x}}={\frac {\partial y}{\partial w_{1}}}{\frac {\partial w_{1}}{\partial x}}={\frac {\partial y}{\partial w_{1}}}\left({\frac {\partial w_{1}}{\partial w_{2}}}{\frac {\partial w_{2}}{\partial x}}\right)={\frac {\partial y}{\partial w_{1}}}\left({\frac {\partial w_{1}}{\partial w_{2}}}\left({\frac {\partial w_{2}}{\partial w_{3}}}{\frac {\partial w_{3}}{\partial x}}\right)\right)=\cdots }$$

- wikipedia

The graph here can be thought to be constructed by way a programming language may perform the operations, using the BODMAS rule. In terms of simple operations, the above function can be broken down to:

$$c = a + b$$$$d = b + 1$$$$e = c * d$$

hence the operations to calculate the partial derivative of the above function with respect to a may look like:

\begin{array}{cc|lcr|lcr} \mathrm{value} && \mathrm{derivative} && node\\ \hline \\ a=a && \frac{\partial a}{\partial a} = 1 && node 1\\ b=b && \frac{\partial b}{\partial a} = 0 && node 2\\ c=a+b && \frac{\partial c}{\partial a} = \frac{\partial a}{\partial a} + \frac{\partial b}{\partial a} && node3 \Leftarrow node1 + node2 \\ d=b+1 && \frac{\partial d}{\partial a} = \frac{\partial b}{\partial a} + \frac{\partial 1}{\partial a} && node5 \Leftarrow node2 + node4 \\ e=c*d && \frac{\partial e}{\partial a} = c*\frac{\partial d}{\partial a} + \frac{\partial c}{\partial a}*d && node6 \Leftarrow node3*node5 \\ \end{array}

To simulate the above steps, we have a type ADFwd which also represents a node in the calculation graph.

In [2]:
# type to store the float value for a variable and
# the derivative with repect to the variable at that value.
type ADFwd
    value::Float64 # say, to store c
    derivative::Float64 # say, to store dc/da
    
    ADFwd(val::Float64) = new(val, 0)
    ADFwd(val::Float64, der::Float64) = new(val, der)
end

We define the operation on this type, and also the derivation rule to follow. Operator overloading helps here in the operations over the type ADFwd.

In [3]:
# sum rule
function adf_add(x::ADFwd, y::ADFwd)
    return ADFwd(x.value + y.value, x.derivative + y.derivative)
end
+(x::ADFwd, y::ADFwd) = adf_add(x, y)

# product rule
function adf_mul(x::ADFwd, y::ADFwd)
    return ADFwd(x.value * y.value, y.value * x.derivative + x.value * y.derivative)
end
*(x::ADFwd, y::ADFwd) = adf_mul(x, y)
Out[3]:
* (generic function with 150 methods)
In [4]:
# define test function
function f(x::ADFwd,y::ADFwd)
    (x+y)*(y + ADFwd(1.0))
end
Out[4]:
f (generic function with 1 method)

Now let us get the partial derivative of the above function with respect to 'a'.

In [5]:
# define variables
aFwd = ADFwd(2.0, 1.0)
bFwd = ADFwd(1.0)
Out[5]:
ADFwd(1.0,0.0)
In [6]:
# forward mode AD
eaFwd = f(aFwd, bFwd)
eaFwd.value
Out[6]:
6.0
In [7]:
# calculated derivative: de/da
eaFwd.derivative
Out[7]:
2.0

Similarly, for 'b'.

In [8]:
# define variables
aFwd = ADFwd(2.0)
bFwd = ADFwd(1.0, 1.0)
Out[8]:
ADFwd(1.0,1.0)
In [9]:
# forward mode AD
ebFwd = f(aFwd, bFwd)
ebFwd.value
Out[9]:
6.0
In [10]:
# calculated derivative: de/db
ebFwd.derivative
Out[10]:
5.0

The partial derivative result will be present in the output ADFwd type variable. It represents the change in the output dependent variable with respect to the change in the input independent variable. The forward mode is simple to implement and does not take much memory. But if we have to calculate the derivative with respect to multiple variables then we need to do the forward pass for each variable. In such cases, reverse mode AD proves useful.

Reverse Mode

Reverse mode helps in understanding the change in the inputs with respect to the change in the output. The first half of the reverse made is similar to the calculations as in the forward mode, we just don't calculate the derivatives. We move forward in the graph calculating the actual value of the sub-expression and then on reaching the output node, we set the output dependent variable's derivative component as 1. We use this derivative component along with the actual values calculated in the forward pass to apply the chain rule and calculate the derivative components for the parent dependent variable(s) and so on until the independent variables are reached.

In a pen-and-paper calculation, one can perform the equivalent by repeatedly substituting the derivative of the outer functions in the chain rule:

$${\displaystyle {\frac {\partial y}{\partial x}}={\frac {\partial y}{\partial w_{1}}}{\frac {\partial w_{1}}{\partial x}}=\left({\frac {\partial y}{\partial w_{2}}}{\frac {\partial w_{2}}{\partial w_{1}}}\right){\frac {\partial w_{1}}{\partial x}}=\left(\left({\frac {\partial y}{\partial w_{3}}}{\frac {\partial w_{3}}{\partial w_{2}}}\right){\frac {\partial w_{2}}{\partial w_{1}}}\right){\frac {\partial w_{1}}{\partial x}}=\cdots }$$

- wikipedia

We can see the equations during the reverse pass as:

\begin{array}{cc} \mathrm{derivative} && child node \Leftarrow parent node\\ \hline \\ \frac{\partial e}{\partial e} = 1 && node6\\ \frac{\partial e}{\partial c} = \frac{\partial e}{\partial e}*\frac{\partial e}{\partial c} = 1*d && node 3 \Leftarrow node 6\\ \frac{\partial e}{\partial d} = \frac{\partial e}{\partial e}*\frac{\partial e}{\partial d} = 1*c && node5 \Leftarrow node6 \\ \frac{\partial e}{\partial a} = \frac{\partial e}{\partial c}*\frac{\partial c}{\partial a} = d*1 && node1 \Leftarrow node3 \\ \frac{\partial e}{\partial b} = \frac{\partial e}{\partial c}*\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}*\frac{\partial d}{\partial b} = d*1 + c*1 && node2 \Leftarrow node3,node5 \\ \end{array}

In the implementation, we have a type ADRev which stores the value and the derivative for a particular node. We also store the parents during the forward pass to propagate the derivative backwards during the reverse pass.

In [11]:
# type to store the float value for a variable during the forward pass
# and the derivative during the reverse pass. 
type ADRev
    value::Float64
    derivative::Float64
    derivativeOp::Function
    parents::Array{ADRev}
    
    ADRev(val::Float64) = new(val, 0, ad_constD, Array(ADRev,0))
    ADRev(val::Float64, der::Float64) = new(val, der, ad_constD, Array(ADRev,0))
end

function ad_constD(prevDerivative::Float64, adNodes::Array{ADRev})
    return 0
end
Out[11]:
ad_constD (generic function with 1 method)
In [12]:
# define the actual addition operation and the derivative rule to use
# during the reverse pass.
function adr_add(x::ADRev, y::ADRev)
    result = ADRev(x.value + y.value)
    result.derivativeOp = adr_addD
    push!(result.parents, x)
    push!(result.parents, y)
    return result
end
function adr_addD(prevDerivative::Float64, adNodes::Array{ADRev})
    adNodes[1].derivative = adNodes[1].derivative + prevDerivative * 1
    adNodes[2].derivative = adNodes[2].derivative + prevDerivative * 1
    return
end
+(x::ADRev, y::ADRev) = adr_add(x, y)
Out[12]:
+ (generic function with 165 methods)
In [13]:
# define the actual multiplication operation and the derivative rule to use
# during the reverse pass.
function adr_mul(x::ADRev, y::ADRev)
    result = ADRev(x.value * y.value)
    result.derivativeOp = adr_mulD
    push!(result.parents, x)
    push!(result.parents, y)
    return result
end
function adr_mulD(prevDerivative::Float64, adNodes::Array{ADRev})
    adNodes[1].derivative = adNodes[1].derivative + prevDerivative * adNodes[2].value
    adNodes[2].derivative = adNodes[2].derivative + prevDerivative * adNodes[1].value
    return
end
*(x::ADRev, y::ADRev) = adr_mul(x, y)
Out[13]:
* (generic function with 151 methods)

We are doing a breadth-first graph traversal to propagate the derivatives backward during the reverse pass. Since the objects are passed using reference, updating the parent having multiple children becomes trivial. For example, node 2 needs to accumulate the derivate from node 3 and node 5 in our case, both of which may get evaluated separately during the traversal. And this is why we are adding the calculated derivative instead of directly assigning it to the node's derivative.

adNodes[1].derivative = adNodes[1].derivative + ...
In [14]:
# this is the reverse pass where we apply the chain rule
function chainRule(graph::ADRev)
    current = graph
    # set the derivative to 1
    current.derivative = 1
    bfs = [current]
    while length(bfs) != 0
        current = pop!(bfs)
        currDerivative = current.derivative
        current.derivativeOp(currDerivative, current.parents)
        numParents = length(current.parents)
        for i=1:numParents 
            push!(bfs, current.parents[i])
        end
    end
    return graph
end
Out[14]:
chainRule (generic function with 1 method)
In [15]:
# define the function
function f(x::ADRev,y::ADRev)
    (x+y)*(y + ADRev(1.0))
end
Out[15]:
f (generic function with 2 methods)
In [16]:
# create the variables
aRev = ADRev(2.0)
bRev = ADRev(1.0)
Out[16]:
ADRev(1.0,0.0,ad_constD,ADRev[])
In [17]:
# forward pass
eRev_forward = f(aRev, bRev)
Out[17]:
ADRev(6.0,0.0,adr_mulD,ADRev[ADRev(3.0,0.0,adr_addD,ADRev[ADRev(2.0,0.0,ad_constD,ADRev[]),ADRev(1.0,0.0,ad_constD,ADRev[])]),ADRev(2.0,0.0,adr_addD,ADRev[ADRev(1.0,0.0,ad_constD,ADRev[]),ADRev(1.0,0.0,ad_constD,ADRev[])])])
In [18]:
# reverse pass
eRev_reverse = chainRule(eRev_forward)
Out[18]:
ADRev(6.0,1.0,adr_mulD,ADRev[ADRev(3.0,2.0,adr_addD,ADRev[ADRev(2.0,2.0,ad_constD,ADRev[]),ADRev(1.0,5.0,ad_constD,ADRev[])]),ADRev(2.0,3.0,adr_addD,ADRev[ADRev(1.0,5.0,ad_constD,ADRev[]),ADRev(1.0,3.0,ad_constD,ADRev[])])])

Since we are storing the graph during the forward pass to help us in the reverse pass, the output variable can explain the parent-child relationship as well as the operations performed on each of the nodes.

In [19]:
# derivative with respect to all the independent variables
aRev.derivative
Out[19]:
2.0
In [20]:
bRev.derivative
Out[20]:
5.0

As mentioned before, the benefit of using reverse mode AD is that we can calculate the derivative of the output with respect to each of the input variables in a single iteration only. We'll use this property to implement a neural network in the coming post.