Almost all of the libraries for creating neural networks (Tensorflow, Theano, Torch, etc.) are using automatic differentiation (AD) in one way or another. It has applications in the other parts of the mathematical world as well since it is a clever and effective way to calculate the gradients, effortlessly. It works by first creating a computational graph of the operations and then traversing it in either forward mode or reverse mode. Let us see how to implement them using operator overloading to calculate the first order partial derivative. I highly recommend reading Colah's blog here first. It has an excellent explanation about computational graphs and this post is related to the implementation side of it. It may not be the best performing piece of code for AD but I think it's the simplest one for getting your head around the concept. The example function we are considering here is:
$$f(a, b) = (a + b) * (b + 1)$$Using basic differential calculus rules, we can calculate the derivative of the above function with respect to a and b by applying the sum rule and the product rule:
$$\begin{eqnarray} \frac{\partial e}{\partial a} &=&(c*\frac{\partial d}{\partial a}) + (\frac{\partial c}{\partial a}*d) \\ \nonumber &=&((a + b)* (\frac{\partial b}{\partial a} + \frac{\partial 1}{\partial a}) )+ ((\frac{\partial a}{\partial a} + \frac{\partial b}{\partial a})* (b + 1)) \\ \nonumber &=&(1 + 0)* (b + 1)) \\ \nonumber &=&b + 1 \\ \nonumber \end{eqnarray} $$and similarly,
$$\begin{eqnarray} \frac{\partial e}{\partial b} &=&(c*\frac{\partial d}{\partial b}) + (\frac{\partial c}{\partial b}*d) \\ \nonumber &=&((a + b)* (\frac{\partial b}{\partial b} + \frac{\partial 1}{\partial b})) + ((\frac{\partial a}{\partial b} + \frac{\partial b}{\partial b})* (b + 1)) \\ \nonumber &=&((a + b)*(1 + 0)) * ((0 + 1)*(b + 1)) \\ \nonumber &=&(a + b)*(b + 1) \\ \nonumber \end{eqnarray} $$In order to get the derivative of a function programmatically there are two approaches we can follow and operate on the computational graph, forward mode and reverse mode. Both of these approaches make use of the chain rule.
# for overloading
import Base.+, Base.*
Forward Mode
Forward mode is very similar to the calculation we did above. We pick an independent variable with respect to which we would like to calculate the partial derivative of the function, set its derivative with respect to itself as 1 and then, we recursively moving forward calculate the derivative of the sub-graph till we reach the output node.
In a pen-and-paper calculation, one can do so by repeatedly substituting the derivative of the inner functions in the chain rule:
$${\displaystyle {\frac {\partial y}{\partial x}}={\frac {\partial y}{\partial w_{1}}}{\frac {\partial w_{1}}{\partial x}}={\frac {\partial y}{\partial w_{1}}}\left({\frac {\partial w_{1}}{\partial w_{2}}}{\frac {\partial w_{2}}{\partial x}}\right)={\frac {\partial y}{\partial w_{1}}}\left({\frac {\partial w_{1}}{\partial w_{2}}}\left({\frac {\partial w_{2}}{\partial w_{3}}}{\frac {\partial w_{3}}{\partial x}}\right)\right)=\cdots }$$
The graph here can be thought to be constructed by way a programming language may perform the operations, using the BODMAS rule. In terms of simple operations, the above function can be broken down to:
$$c = a + b$$$$d = b + 1$$$$e = c * d$$hence the operations to calculate the partial derivative of the above function with respect to a may look like:
\begin{array}{cc|lcr|lcr} \mathrm{value} && \mathrm{derivative} && node\\ \hline \\ a=a && \frac{\partial a}{\partial a} = 1 && node 1\\ b=b && \frac{\partial b}{\partial a} = 0 && node 2\\ c=a+b && \frac{\partial c}{\partial a} = \frac{\partial a}{\partial a} + \frac{\partial b}{\partial a} && node3 \Leftarrow node1 + node2 \\ d=b+1 && \frac{\partial d}{\partial a} = \frac{\partial b}{\partial a} + \frac{\partial 1}{\partial a} && node5 \Leftarrow node2 + node4 \\ e=c*d && \frac{\partial e}{\partial a} = c*\frac{\partial d}{\partial a} + \frac{\partial c}{\partial a}*d && node6 \Leftarrow node3*node5 \\ \end{array}To simulate the above steps, we have a type ADFwd which also represents a node in the calculation graph.
# type to store the float value for a variable and
# the derivative with repect to the variable at that value.
type ADFwd
value::Float64 # say, to store c
derivative::Float64 # say, to store dc/da
ADFwd(val::Float64) = new(val, 0)
ADFwd(val::Float64, der::Float64) = new(val, der)
end
We define the operation on this type, and also the derivation rule to follow. Operator overloading helps here in the operations over the type ADFwd.
# sum rule
function adf_add(x::ADFwd, y::ADFwd)
return ADFwd(x.value + y.value, x.derivative + y.derivative)
end
+(x::ADFwd, y::ADFwd) = adf_add(x, y)
# product rule
function adf_mul(x::ADFwd, y::ADFwd)
return ADFwd(x.value * y.value, y.value * x.derivative + x.value * y.derivative)
end
*(x::ADFwd, y::ADFwd) = adf_mul(x, y)
# define test function
function f(x::ADFwd,y::ADFwd)
(x+y)*(y + ADFwd(1.0))
end
Now let us get the partial derivative of the above function with respect to 'a'.
# define variables
aFwd = ADFwd(2.0, 1.0)
bFwd = ADFwd(1.0)
# forward mode AD
eaFwd = f(aFwd, bFwd)
eaFwd.value
# calculated derivative: de/da
eaFwd.derivative
Similarly, for 'b'.
# define variables
aFwd = ADFwd(2.0)
bFwd = ADFwd(1.0, 1.0)
# forward mode AD
ebFwd = f(aFwd, bFwd)
ebFwd.value
# calculated derivative: de/db
ebFwd.derivative
The partial derivative result will be present in the output ADFwd type variable. It represents the change in the output dependent variable with respect to the change in the input independent variable. The forward mode is simple to implement and does not take much memory. But if we have to calculate the derivative with respect to multiple variables then we need to do the forward pass for each variable. In such cases, reverse mode AD proves useful.
Reverse Mode
Reverse mode helps in understanding the change in the inputs with respect to the change in the output. The first half of the reverse made is similar to the calculations as in the forward mode, we just don't calculate the derivatives. We move forward in the graph calculating the actual value of the sub-expression and then on reaching the output node, we set the output dependent variable's derivative component as 1. We use this derivative component along with the actual values calculated in the forward pass to apply the chain rule and calculate the derivative components for the parent dependent variable(s) and so on until the independent variables are reached.
In a pen-and-paper calculation, one can perform the equivalent by repeatedly substituting the derivative of the outer functions in the chain rule:
$${\displaystyle {\frac {\partial y}{\partial x}}={\frac {\partial y}{\partial w_{1}}}{\frac {\partial w_{1}}{\partial x}}=\left({\frac {\partial y}{\partial w_{2}}}{\frac {\partial w_{2}}{\partial w_{1}}}\right){\frac {\partial w_{1}}{\partial x}}=\left(\left({\frac {\partial y}{\partial w_{3}}}{\frac {\partial w_{3}}{\partial w_{2}}}\right){\frac {\partial w_{2}}{\partial w_{1}}}\right){\frac {\partial w_{1}}{\partial x}}=\cdots }$$
We can see the equations during the reverse pass as:
\begin{array}{cc} \mathrm{derivative} && child node \Leftarrow parent node\\ \hline \\ \frac{\partial e}{\partial e} = 1 && node6\\ \frac{\partial e}{\partial c} = \frac{\partial e}{\partial e}*\frac{\partial e}{\partial c} = 1*d && node 3 \Leftarrow node 6\\ \frac{\partial e}{\partial d} = \frac{\partial e}{\partial e}*\frac{\partial e}{\partial d} = 1*c && node5 \Leftarrow node6 \\ \frac{\partial e}{\partial a} = \frac{\partial e}{\partial c}*\frac{\partial c}{\partial a} = d*1 && node1 \Leftarrow node3 \\ \frac{\partial e}{\partial b} = \frac{\partial e}{\partial c}*\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}*\frac{\partial d}{\partial b} = d*1 + c*1 && node2 \Leftarrow node3,node5 \\ \end{array}In the implementation, we have a type ADRev which stores the value and the derivative for a particular node. We also store the parents during the forward pass to propagate the derivative backwards during the reverse pass.
# type to store the float value for a variable during the forward pass
# and the derivative during the reverse pass.
type ADRev
value::Float64
derivative::Float64
derivativeOp::Function
parents::Array{ADRev}
ADRev(val::Float64) = new(val, 0, ad_constD, Array(ADRev,0))
ADRev(val::Float64, der::Float64) = new(val, der, ad_constD, Array(ADRev,0))
end
function ad_constD(prevDerivative::Float64, adNodes::Array{ADRev})
return 0
end
# define the actual addition operation and the derivative rule to use
# during the reverse pass.
function adr_add(x::ADRev, y::ADRev)
result = ADRev(x.value + y.value)
result.derivativeOp = adr_addD
push!(result.parents, x)
push!(result.parents, y)
return result
end
function adr_addD(prevDerivative::Float64, adNodes::Array{ADRev})
adNodes[1].derivative = adNodes[1].derivative + prevDerivative * 1
adNodes[2].derivative = adNodes[2].derivative + prevDerivative * 1
return
end
+(x::ADRev, y::ADRev) = adr_add(x, y)
# define the actual multiplication operation and the derivative rule to use
# during the reverse pass.
function adr_mul(x::ADRev, y::ADRev)
result = ADRev(x.value * y.value)
result.derivativeOp = adr_mulD
push!(result.parents, x)
push!(result.parents, y)
return result
end
function adr_mulD(prevDerivative::Float64, adNodes::Array{ADRev})
adNodes[1].derivative = adNodes[1].derivative + prevDerivative * adNodes[2].value
adNodes[2].derivative = adNodes[2].derivative + prevDerivative * adNodes[1].value
return
end
*(x::ADRev, y::ADRev) = adr_mul(x, y)
We are doing a breadth-first graph traversal to propagate the derivatives backward during the reverse pass. Since the objects are passed using reference, updating the parent having multiple children becomes trivial. For example, node 2 needs to accumulate the derivate from node 3 and node 5 in our case, both of which may get evaluated separately during the traversal. And this is why we are adding the calculated derivative instead of directly assigning it to the node's derivative.
adNodes[1].derivative = adNodes[1].derivative + ...
# this is the reverse pass where we apply the chain rule
function chainRule(graph::ADRev)
current = graph
# set the derivative to 1
current.derivative = 1
bfs = [current]
while length(bfs) != 0
current = pop!(bfs)
currDerivative = current.derivative
current.derivativeOp(currDerivative, current.parents)
numParents = length(current.parents)
for i=1:numParents
push!(bfs, current.parents[i])
end
end
return graph
end
# define the function
function f(x::ADRev,y::ADRev)
(x+y)*(y + ADRev(1.0))
end
# create the variables
aRev = ADRev(2.0)
bRev = ADRev(1.0)
# forward pass
eRev_forward = f(aRev, bRev)
# reverse pass
eRev_reverse = chainRule(eRev_forward)
Since we are storing the graph during the forward pass to help us in the reverse pass, the output variable can explain the parent-child relationship as well as the operations performed on each of the nodes.
# derivative with respect to all the independent variables
aRev.derivative
bRev.derivative
As mentioned before, the benefit of using reverse mode AD is that we can calculate the derivative of the output with respect to each of the input variables in a single iteration only. We'll use this property to implement a neural network in the coming post.