Today I will try to explain how the forward and backward mode in automatic differentiation work. I will only cover the principle, not actual algorithms and the optimizations they apply. While the so called forward mode is quite intuitive, it is not so easy to wrap your head around the backward mode. I will try to go through all steps and not leave out anything seemingly trivial.
We consider the computation of a function with independent variables and dependent variables . The ultimate goal is to compute the Jacobian
We view the function as a composite of elementary operations
for where we set for (i.e. we reserve these indices for the start values of the computation) and for (i.e. these are the final results of the computation). The notation should suggest that depends on prior results with in some index set . Note that if this refers to a direct dependency of on , i.e. if depends on , but does not enter the calculation of directly then .
As an example consider the function
for which we would have , , , . The direct dependencies are , and , but not , because does not enter the expression for directly.
We can view the computation chain as a directed graph with vertices and edges if . There are no circles allowed in this graph (it is a acyclic graph) and it consists of vertices.
We write for the length of the longest path from to and call that number the distance from to . Note that this is not the usual definition of distance normally being the length of the shortest path.
If is not reachable from we set . If is reachable from the distance is finite, since the graph is acyclic.
We can compute a partial derivative using the chain rule
This suggest a forward propagation scheme: We start at the initial nodes . For all nodes with maximum distance from all of these nodes we compute
where we can choose for freely at this stage. This assigns the dot product of the gradient of w.r.t. and to the node .
If we choose for one specific and zero otherwise, we get the partial derivative of by , but we can compute any other directional derivatives using other vectors . (Remember that the directional derivative is the gradient times the direction w.r.t. which the derivative shall be computed.)
Next we consider nodes with maximum distance from all nodes . For such a node
where we can assume that the were computed in the previous step, because their maximum distance to all initial nodes muss be less than , hence .
Also note that if , which may be the case, if and zero otherwise, so trivially. Or seemingly trivial.
The same argument can be iterated for nodes with maximum distance until we reach the final nodes . This way we can work forward through the computational graph and compute the directional derivative we seek.
In the backward mode we do very similar things, but in a dual way: We start at the final nodes and compute for all nodes with maximum distance from all of these nodes
Note that we compute a weighted sum in the dependent variables now. By setting a specific to and the rest to zero again we can compute the partial derivatives of a single final variable. Again using the chain rule we can compute
for all nodes with maximum distance of from all the final nodes.
Note that the chain rule formally requires to include all variables on which depends. Howvever if does not depend on the whole term will effectively be zero, so we can drop these summands from the beginning. Also we may include indices on which does not depend in the first place, which is not harmful for the same reason.
As above we can assume all to be computed in the previous step, so that we can iterate backwards to the inital nodes to get all partial derivatives of the weighted sum of the final nodes w.r.t. the initial nodes.