DAGWorks-Inc / hamilton

Hamilton helps data scientists and engineers define testable, modular, self-documenting dataflows, that encode lineage/tracing and metadata. Runs and scales everywhere python does.
https://hamilton.dagworks.io/en/latest/
BSD 3-Clause Clear License
1.87k stars 125 forks source link

provide examples of back-propagation #1137

Open vograno opened 2 months ago

vograno commented 2 months ago

A dag defines a forward flow of information. At the same time it implies a back-flow of information when we reverse each edge on the DAG and associate a back-flow node function with each node. Additionally, we should also provide a merging function that merges back-flow inputs that connect to the common forward output, but this is a technicality.

The gradient descent of a feed-forward neural network is an example of such back-flow. Namely, the forward pass computes node outputs and gradients given the model parameters, while the back pass updates model parameters according to the computed gradients. I think, the merging function is sum in this case.

The question is then whether Hamilton is an appropriate framework for inferring the back flow DAG from the forward one. Here inferring means compute the back-flow driver given the forward-flow one.

Use gradient descent as a study case.

skrawcz commented 2 months ago

@vograno haven't tried. The internal FunctionGraph does have the bi-directional linking, so the building blocks are there. To me it sounds like you'd want to change a bit of how the graph is walked and what state is stored where for this, e.g. a new Driver.

elijahbenizzy commented 2 months ago

Adding to what @skrawcz said:

Heh, fascinating idea. I've had a similar idea but never really considered it. Some thoughts (feel free to disagree on these points!)

As @skrawcz said -- it would also require rewalking the graph, at least backwards.

What I'd do if you're interested is first build a simple 2-node graph out of pytorch objects, which have the relationship. Then you can POC it out. You can compute gradients individually, and expand to more ways of specifying nodes/gradients.

I'd also consider other optimization routines if the goal is flexibility!

vograno commented 2 months ago

... The internal FunctionGraph does have the bi-directional linking, so the building blocks are there. To me it sounds like you'd want to change a bit of how the graph is walked and what state is stored where for this, e.g. a new Driver.

I can walk the graph backward all right, but I also need to create new nodes in the back-flow graph along the way, and this is where I'm not sure. I can see two options, but first note:

Option 1 - using temp module.

Option 2. Start the back-flow graph empty and add nodes to it as I traverse the forward graph. Here I need to create nodes outside the Builder and I'm not sure what API to use.

vograno commented 2 months ago

Let me propose a less fascinating, but simpler example to work with.

  1. There are two kinds of nodes, Red and Blue.
  2. Let y denote the node ouput, and (x1, ..., xn) denote the inputs.
  3. The forward function of each node is just the sum of its inputs, regardless of the color of the node, i.e. y = sum(x_i).
  4. The backward function of a Red node sends its input to the first input node and zero to all other nodes, i.e. x1 = y, x2=0, ..., xn=0.
  5. The backward function of a Blue node splits its input equally between all input nodes, i.e. x1 = y/n, x2=y/n, ..., xn=y/n.
  6. The merging function that combines back-flow values of x-s attached to upstream_node(x) is the sum.

The goal is to compute the back-flow driver given the forward-flow one.

skrawcz commented 2 months ago

@vograno sounds cool. Unfortunately we're a little bandwidth constrained at the current time to be as helpful as we'd like. So just wanted to mention that. Do continue to add context / questions here - we'll try to help when we can.