Open coreylowman opened 1 year ago
Does anyone know how other frameworks do it? My naive first thought is to add some Vec<Fn>
to Tensor that indicate elementwise operations to lazily apply to the tensor. Then kernels that can't be appended to this lazy list would have to apply all the functions that are currently in the list while they run.
I have some ideas for this, but they would require a pretty major overhaul to tensor_ops, and some changes to how tensors are represented.
I think that we should target folding for reshapes (stack, concat, slice, permute, select, gather), unary operations, binary operations, and reductions. Effectively planning folds for these operations requires access to the full graph of tensor operations, and for tensors to only sometimes contain their data. I do not think we can reasonably fold convolutions or matrix multiplications, as each require calling external kernels. In this scheme, kernels would consist of the following four phases:
On the cpu, folded operations can be represented with rust structs/enums that contain information on all reshape operations and all mathematical operations, with the mathematical operations represented represented something like:
struct BinaryOperation {
input_register1: usize,
input_register2: usize,
output_register: usize,
op: Box<dyn Fn(f32, f32) -> f32>, // or &'static dyn
cuda_representation: String,
}
or with a specialized trait. This format should be designed to be directly translated into a cuda kernel and executed or jit compiled on the cpu.
Blocking questions:
let tensor = ...;
let ln_tensor = tensor.clone().ln();
let exp_tensor = tensor.exp();
with a single kernel, which requires exp_tensor to modify the data of ln_tensor.
folding for reshapes (stack, concat, slice, permute, select, gather), unary operations, binary operations, and reductions I was imagining something simpler where we just fuse unary operations. Anything that doesn't fuse (binary ops/data movements/matmul/conv) would have to be passed list of functions to call on input data before using it.
Do you know what other frameworks fuse? Each op we could fuse would save on allocation/data movement, but I think a whole rewrite of tensor ops seems like a big price to pay.
I also want to add to the discussion: when should we not fuse? If a tensor is only used once, it makes sense to fuse and this does save computation time. However if a tensor is re-used, does fusing cost twice as much computation?
How do we represent the graph of operations within tensors? I would like to be able to execute the following operations:
In my head we would have:
struct Tensor {
...
lazy_fns: Arc<Vec<D::Func>>,
}
after tensor.ln()
, the lazy_fns would look like [..., "ln"]
.
then after tensor.ln().exp()
the lazy_fns would look like [..., "ln", "exp"]
.
Interestingly enough ln/exp are opposites so we could actually just remove these two 😁 but for sake of example whatever would be fusing would compile this to exp(ln(x))
If we had a case of tensor being used multiple times (e.g. the tensor.ln() also was used for another operation like tensor.ln().sqrt()
. This would be handled the same way we handle this now where the resulting tensor's lazy_fns would look like [..., "ln", "sqrt"]
.
What about using something similar to Rust's Iterator
? Each operation could create a new type and the evaluation only happens once it is needed. Main difference from @coreylowman's idea is that the operations are stored in the types instead of a Vec
at runtime.
We could even go as far as to do various reductions such as the one you mentioned (ln
followed by exp
). The ln
function on the ExpOperation
type could reduce to an Identity
or something like that.
Another topic related to this that is worth discussing is CUDA Graphs #360. If we go with this lazy evaluation pattern then we could also generate CUDA graphs automatically (reduce overhead of calling individual kernels).
The type-based system would need to be at compile time, right? That's how rust's iterators work.
Would the vec-based system be implementable as a special tape? Since all operations are recorded on the tape, it could build up this vec of operations during some initial pass, then combine the eligible ops, generate the kernels, and then produce a new forward function?
Another idea I'll add to the mix is adding some wrapper type around device like Fusing<Cpu>
, where we could change the device storage:
struct FusedVec<E> {
data: Vec<E>,
ops: Vec<Box<dyn Fn()>>,
}
impl DeviceStorage<E> for Fusing<Cpu> {
type Storage = FusedVec<E>
}
...
let dev: Fusing<Cpu> = Default::default();
this would require moving away from GATs like was introduced in #633
For future reference: https://live.juliacon.org/talk/9RFTHY, https://github.com/PumasAI/SimpleChains.jl
What about using something similar to Rust's
Iterator
? Each operation could create a new type and the evaluation only happens once it is needed. Main difference from @coreylowman's idea is that the operations are stored in the types instead of aVec
at runtime.We could even go as far as to do various reductions such as the one you mentioned (
ln
followed byexp
). Theln
function on theExpOperation
type could reduce to anIdentity
or something like that.Another topic related to this that is worth discussing is CUDA Graphs #360. If we go with this lazy evaluation pattern then we could also generate CUDA graphs automatically (reduce overhead of calling individual kernels).
i gener(ic)ally like this idea, but reducing something like exp().ln()
is not as easy as the dynamic approach. Rust has some limitations on multiple implementations for one trait, so implemenating a special logic for Exp<Ln>>
won't work, unless we check that on runtime with type_id()
s. Using generics makes it difficult to find "more complex" reductions, like (exp() - 1).ln()
...
I think how it would be done is some trait like Apply
with an associated type Applied
, where for non-simplifying operations is just a wrapper but something like <Ln as Apply>::Applied<Exp<T>> = T
. I don't know how selective you can get with trait impls, since you would need to implement something for T != Exp
, for example
I'll throw in an idea I've been thinking about when reading through tinygrad code.
There's essentially a spectrum of ways to run DL computation.
On one end, there's Pytorch (or at least Pytorch 1.x, and currently dfdx), where everything is eager. When z = x + y
is interpreted, the Add
kernel is fired off, and the execution halts until that result comes back. Then it moves on to the next line. Super straightforward, works exactly how any programmer would expect, and is really easy to debug. Print statements can be thrown in anywhere.
On the opposite end is Tensorflow 1.x, where everything is fully static. The entire network gets built as a huge DAG of operations, and nothing gets ran until the model is compiled and executed. This means when you write the operations in the model, they can be reordered, changed, or entirely deleted so long as the end behavior is the same. This allows the TF compiler to work with the network at a global level, and understand every single thing going on all at once. Of course, this means the limit to optimization is the power of the compiler and the creativity of the people programming it. It results in the fastest models with the most aggressive optimization. Downside is that it's really hard to debug this, as no prints can be put in the middle, operations aren't straightforward, and the network is difficult to program, with things like tf.session
. This is sort of what led to the downfall of TF (requiring a bunch of other APIs because no one understood static graphs, hence fragmentation).
In tinygrad, the goal is to compute everything lazily, basically only run the computations when the data is actually needed. Which means that when z = x + y
is interpreted, it doesn't actually fire that kernel off, but rather just tracks the operation and moves on. Only when z is actually used (or more likely something else that derives from z) then the graph is computed, optimized, and ran. That way fancy fusions can be done at runtime. However, if needed, z can just be printed immediately after, in which case the computation will be ran, same as pytorch. In this approach, you get the best of both worlds: Fast execution with aggressive optimizations when no debug / dynamic control flow is needed and only primitive operations are used, but eager dynamism when needed.
However, in typical python fashion, this is all super implicit and handled behind the scenes. If one dynamic line is added somewhere deep down in the module tree, a potentially very large graph that could have been well optimized gets split, perf goes down, and it's hard to understand why without going through every line of code.
I think this can be done in an explicit (albeit less developer friendly) way. We can still keep the current Tensor with the eager operations, which means all modules can be directly ported over. But for modules that have a high performance cost or get ran often, we can instead define a local graph:
/// Feedforward layer for a transformer
struct FF<I, M, O> {
lin1: Linear<I, M>,
lin2: Linear<M, O>,
}
/// Typical eager forward
fn forward(&self, in: Tensor<I>) -> Tensor<O> {
let mut mid = self.lin1.forward(in);
mid = mid.relu();
self.lin2.forward(mid)
}
/// Graph-ified forward
fn forward(&self, in: Tensor<I>) -> Tensor<O> {
let graph = Graph::new(in).apply(self.lin1.forward).relu().apply(self.lin2.forward);
graph.compute()
}
In this second forward, we create a graph object, which then wraps the tensor and goes through the same ops, only now we don't actually execute them, only track them. This API still preserves the type-safety / tensor shape safety, but the computation only happens when compute is called. At that point, the graph is optimized and ran (and the optimized graph can be cached).
We can go much further and allow the forward function to take in something that turns into a Graph, which would be either a graph or a tensor, and output a graph so that this module can be part of a larger graph, rather than only the small graph we demonstrated above:
/// Maybe part of a larger graph!
fn forward<G: Into<Graph<Tensor<O>>>>(&self, in: G) -> Graph<Tensor<O>> {
let graph = in.into(); // If "in" is already a graph, this is a no-op
let graph = graph.apply(self.lin1.forward).relu().apply(self.lin2.forward);
graph // Notice no compute here, we're not doing any computation, just passing it back to the caller!
}
Apologies for the wall of text, just a view of what I think the future of DL libs will look like. I think dfdx is uniquely positioned to take the lead in perf if these graphs can be optimized well enough. Safe and fast!
@coreylowman I know you're busy but when you get a chance would love to hear some thoughts
Another idea (that fails to address the "automatic" fusing problem, but is probably simpler) would be to implement "manual" fusing. E.g. a closure comprised of a mathematical expression that either evaluates to a value or source string useable in jit compilation.
let f = |x: Resolve<f32>, y: Resolve<f32>| x.add(y).mul(3.6).sub(y);
let a = f(4f32.to_val(), 3f32.to_val());
assert_eq!(a.eval(), 22.2);
let src = f("x".to_marker(), "y".to_marker()).to_cl_source();
assert_eq!(src, "(((x + y) * 3.6) - y)");
The user would then need to specify the operations applied to the tensor in a similar closure. I implemented this approach in custos: https://github.com/elftausend/custos/blob/main/src/two_way_ops/mod.rs
I've been working on a DL library that does fully static computation, which allows it to do aggressive fusion / compilation before running: https://github.com/jafioti/luminal
Llama now runs on it!
The approach I took is pretty incompatible with dfdx, which relies on eager execution, but it might be another useful approach to look at.
Llama now runs on it!
Awesome work! Any benchmarks to report? I'm super curious what the performance benefits of this are.
I'm leaning towards not approaching this because it's very complex, and as noted in other places, writing custom fused kernels is pretty standard these days, and let's you take advantage of things that automatic fusion maybe won't be able to (See flash attention & flash attention 2)
@coreylowman Perf right now is absolute dogwater because not much in the way of fusion optimizers have been written yet. I've been trying to speedrun achieving llama with the smallest set of primitive operators possible (11!)
Next step will be to write optimizers to start fusing the primitive graph down to be reasonably fast. Good news is that this should be super straightforward, since optimizers take a global view of the graph, and so optimizations can pretty much go as far as your imagination takes you. Also, since optimizers just take in a global graph and mutate it, both manual kernels and automatic fusion are possible. In a week or so I expect to have some decent benchmarks.
Keep us posted 👍
Let's discuss how operator fusion might work in dfdx. I suspect it will require a lot of work. On cuda side of things it will at least require jit compiling kernels.
Originally posted by @jafioti in https://github.com/coreylowman/dfdx/issues/590#issuecomment-1482930062