UPDATE: This code is now available in both Java and Python!
I’ve been on an automatic differentiation kick ever since reading about dual numbers on Wikipedia.
I implemented a simple forward-mode autodiff system in Rust, thinking it would allow me to do ML faster. I failed to realize/read that forward differentiation, while simpler, requires one forward pass to get the derivative of ALL outputs with respect to ONE input variable. Reverse-mode, in contrast, gives you the derivative of all inputs with respect to one output.
That is to say, if I had f(x, y, z) = [a, b, c], forward mode would give me da/dx, db/dx, dc/dx in a single pass. Reverse mdoe would give me da/dx, da/dy, da/dz in a single pass.
Forward mode is really easy. I have a repo with code changes here: https://github.com/JosephCatrambone/RustML
Reverse mode took me a while to figure out, mostly because I was confused about how adjoints worked. I’m still confused, but I’m now so accustomed to the strangeness that I’m not noticing it. Here’s some simple, single-variable reverse-mode autodiff. It’s about 100 lines of Python:
#!/usr/bin/env python | |
# JAD: Joseph's Automatic Differentiation | |
from collections import deque | |
class Graph(object): | |
def __init__(self): | |
self.names = list() | |
self.operations = list() | |
self.derivatives = list() # A list of LISTS, where each item is the gradient with respect to that argument. | |
self.node_inputs = list() # A list of the indices of the input nodes. | |
self.shapes = list() | |
self.graph_inputs = list() | |
self.forward = list() # Cleared on forward pass. | |
self.adjoint = list() # Cleared on reverse pass. | |
def get_output(self, input_set, node=-1): | |
self.forward = list() | |
for i, op in enumerate(self.operations): | |
self.forward.append(op(input_set)) | |
return self.forward[node] | |
def get_gradient(self, input_set, node, forward_data=None): | |
if forward_data is not None: | |
self.forward = forward_data | |
else: | |
self.forward = list() | |
for i, op in enumerate(self.operations): | |
self.forward.append(op(input_set)) | |
# Initialize adjoints to 0 except our target, which is 1. | |
self.adjoint = [0.0]*len(self.forward) | |
self.adjoint[node] = 1.0 | |
gradient_stack = deque() | |
for input_node in self.node_inputs[node]: | |
gradient_stack.append((input_node, node)) # Keep pairs of target/parent. | |
while gradient_stack: # While not empty. | |
current_node, parent_node = gradient_stack.popleft() | |
for dop in self.derivatives[current_node]: | |
self.adjoint[current_node] += self.adjoint[parent_node]*dop(input_set) | |
for input_arg in self.node_inputs[current_node]: | |
gradient_stack.append((input_arg, current_node)) | |
return self.adjoint | |
def get_shape(self, node): | |
return self.shapes[node] | |
def add_input(self, name, shape): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: inputs[name]) | |
self.derivatives.append([lambda inputs: 1]) | |
self.node_inputs.append([]) | |
self.graph_inputs.append(index) | |
self.shapes.append(shape) | |
return index | |
def add_add(self, name, left, right): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: self.forward[left] + self.forward[right]) | |
self.derivatives.append([lambda inputs: 1, lambda inputs: 1]) # d/dx a + b = 1 + 0 or 0 + 1 | |
self.node_inputs.append([left, right]) | |
self.shapes.append(self.get_shape(left)) | |
return index | |
def add_multiply(self, name, left, right): | |
index = len(self.names) | |
self.names.append(name) | |
self.operations.append(lambda inputs: self.forward[left] * self.forward[right]) | |
self.derivatives.append([lambda inputs: self.forward[right], lambda inputs: self.forward[left]]) | |
self.node_inputs.append([left, right]) | |
self.shapes.append(self.get_shape(left)) | |
return index | |
if __name__=="__main__": | |
g = Graph() | |
x = g.add_input("x", (1, 1)) | |
y = g.add_input("y", (1, 1)) | |
a = g.add_add("a", x, y) | |
b = g.add_multiply("b", a, x) | |
input_map = {'x': 2, 'y': 3} | |
print(g.get_output(input_map)) # 10 | |
print(g.get_gradient(input_map, b)) # 3, 2, 2, 1. |