gussmith23 / glenside

A pure, low-level tensor program representation enabling tensor program optimization via program rewriting. See the web demo at https://gussmith23.github.io/glenside-web-demo/
67 stars 10 forks source link

Implementing batch norms in hardware #46

Open gussmith23 opened 3 years ago

gussmith23 commented 3 years ago

cc @stdavids

I know I'm a bit behind on this, but we're finally ready to start looking at implementing batch norms from the Glenside side of things. We can talk in hackathon about it. From the Glenside perspective, this involves:

The interesting thing about batch norms in our workloads is that they aren't one "batch norm" operator, but instead, they're a chain of simpler primitive operators implementing the linear transformation that batch norm represents at inference time. This is because I've run the SimplifyInference Relay pass over the workloads before importing them from Relay. This is a habitual thing to do when working with workloads in Relay, and so that's why I did it; however, it might be easier for Glenside to take in opaque, un-simplified batch norm operators. It could then implement its own equivalent of SimplifyInference and create a "batch norm inference" node. Currently, this isn't possible, as the batch norm nodes in Relay have been blown up into smaller operators. I'm not sure this will be necessary, though -- I'm hoping we can just enable Glenside to fold the computation back together into some efficient vectorized format.

gussmith23 commented 3 years ago

Joseph has a kernel for batch norms. Take a look at that. That's a good first step. We may want to collapse batch norm back to its non-inference version. I can also write a glenside rewrite to detect the batch norm for inference pattern.

gussmith23 commented 3 years ago

should just replace batch norms with a "batch norm" call

gussmith23 commented 3 years ago

Here's my plan: