Open nvukobratTT opened 4 months ago
I think this section should be edited:
tt_model = pybuda.compile(framework_model.to("tt"))
tt_optimizer = pybuda.compile(framework_optimizer.to("tt"))
tt_loss = pybuda.compile(framework_loss.to("tt"))
To just:
tt_model = pybuda.compile(framework_model)
tt_optimizer = pybuda.compile(framework_optimizer)
tt_loss = pybuda.compile(framework_loss)
And then when the respective module is ready to be run on the device you have to do:
tt_model.to("tt")
What .to("tt")
does on a module is prepare/copy the weights to device DRAM. It could be legal to run:
tt_model.to("tt")
tt_loss.to("tt")
But what this means is that all of the parameters for tt_model
and tt_loss
will be copied to device DRAM and live simultaneously there.
We should try and stick to these semantics as closely as we can: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to
Nikola, this looks great! Once we have a POC it'd be great to turn this into a documentation spec for training.
I think an R&D effort needs to try and see how we can hook:
loss.backward() # executes bwd graph
tt_optimizer.step() # executes optimizer graph
An additional challenge is that if loss is separately compiled regardless of it being scheduled on CPU or TT, is how we can propagate gradients between graphs. At first glance it seems like it'd just be a matter of wiring up loss backward outputs -> model backward inputs, but there are likely some finicky details to work out. For example, when loss.backward()
is called it implicitly runs tt_model
backward, so I think the runtime will need to keep track of this and schedule this automatically.
Optimizer step is essentially the same issue, but it feels slightly more explicit since the outputs are given up front and it somewhat decoupled from the loss and model graphs.
A potentially reasonable intermediate step could be to force users to embed loss into their module directly:
class ModuleAndLoss(nn.Module):
def __init__(self, module, loss):
self.module = module
self.loss = loss
def forward(self, inputs, targets):
return self.loss(self.module(inputs), targets)
This has the benefit of not having to propagate gradients through flatbuffer boundaries, however, in the long term that would provide most flexibility.
details regarding
.to("tt")
Similar as on this discussion, I agree!Nikola, this looks great! Once we have a POC it'd be great to turn this into a documentation spec for training. For sure!
An additional challenge is that if loss is separately compiled regardless of it being scheduled on CPU or TT, is how we can propagate gradients between graphs. At first glance it seems like it'd just be a matter of wiring up loss backward outputs -> model backward inputs, but there are likely some finicky details to work out. For example, when loss.backward() is called it implicitly runs tt_model backward, so I think the runtime will need to keep track of this and schedule this automatically.
I see your reasoning. The caveat of backward()
is not outputs propagation, but rather runtime awareness how that should be mapped. More precisely, when we call loss.backward()
, runtime needs to:
tt_model
bwd but fetching loss bwd output, and running rest of the bwd graphRegarding optimizer, it's a bit more simpler, as it doesn't need to run different modules. However, it has to be aware of where accumulated gradients are stored, and where are parameters that needs to be updated.
Let's see during R&D what are the tricky points.
Summary
In order to train this model, the following key details are required:
fwd
,bwd
,loss
andopt
ops are supported e2eHaving the components mentioned above, pushing initial e2e training should be straightforward.
Details
With all the mentioned components, most of the details should be in place. For easier understanding, here is the overview of how compile flow should work and look like:
pybuda/test/mlir/mnist/test_training.py
for a sample.Model compile
It's important to note that each model is compiled separately. E.g. each
pybuda.compile
executes a separate compile workflow that generates flatbuffers and returns them to the PyBuda runtime together with appropriate metadata that runtime needs (e.g. where to find inputs)During model compile, each component should be compiled separately and be unaware of each other. That way, we can strive to more streamlined support of other frameworks, and keep the flexibility of compilation. This ensures that users have an ability to use the PyTorch training loop style. That way, our customers can integrate our compiler as plug-and-play, where only a few lines of their existing workflows should be updated.
Likewise, we should utilize something like runtime stiching in order to support independent module/components compile and runtime workloads. Only runtime for each frontend should be aware of inference/training specifics. MLIR should be agnostic to it (same applies for flatbuffer schemas).
Details around model training
Main point of training is for model to "learn" how to do the desired task. You can view the model as a simple "mapper". I.e. "mapping" one type of input to desired output. For example, "mapping" an image of a cat, into the actual prediction string "cat". In this section, I'll focus only on training points that are related to the compiler.
In sum, once the model weights are initialized (automatically done when the model is created) we're running the training loop and doing the following (code sample here:
pybuda/test/mlir/mnist/test_training.py
):For a more precise example, here are a few graphs coming from how the training worked on the older PyBuda/BBE stack.
Forward graph with attached loss
Backward graph:
Optimizer graph:
Training loop
To easily map above-mentioned concepts, here is the proposed sample of how a training loop can look like.
Training components
In sum, this is still in exploration work, and we should strive to do as many R&D experiments as possible in order to choose the best approach for the long run. Therefore, some of these concepts might change over time.
Firstly, regarding the forward graph. In sum, it represents an inference graph. One caveat is that we previously once caveat is that we should determine how we can simply how loss is connected to it. In the previous stack, it was attached to the forward graph (like an extension). The details that need to be checked here are whether we have to be aware of the loss part of the graph, or if we can simply integrate it as a part of the forward graph, and compute backward on top of it.
As already mentioned, the loss part is currently examined as an addition to the forward graph. If we prove that we can follow a similar mechanism and get valid training, it can simplify graph compilation a bit. However, this can become tricky for more advanced loss functions. As mentioned, we should explore this path in more detail. If it's not possible to merge those two, runtime will have to be aware of a loss component in the long run. To bridge that complexity, we can also run loss on the host/CPU for the initial training PoC.
Following is the backward part. It's generated based on the forward graph using the PyBuda autograd mechanism. It's used to compute gradients for each model parameter, which is latter used by the optimizer to properly update those gradients based on loss output. Overall, MLIR should view the backward graph as a plain graph, without any notion of which type is it. PyBuda runtime is the one that will orchestrate what is what and how to properly manage those components/data.
And lastly, there is the optimizer. It's main point is to update model parameters using calculated gradients. Also, the optimizer can have additional inputs that define how parameters are updated (e.g. learning rate). Like each graph mentioned above, the optimizer should also be a standalone graph. The trick here is for runtime to know how to map gradients to the valid parameters. To handle this properly, we should ensure that the model is compiled first. That way, we can gather the list of model parameters, and match those to the calculated/accumulated gradients. Afterward, once we compile the optimizer and determine where the inputs/outputs are stored, we can just run a graph and update parameters based on how the training loop is defined by the user.
Limitations
As expected, this initial overview has some limitations. The main one is that we're not exploring as part of the initial run how to scale out training. During this push, we should be aware that training will also run across multiple chips. However, for initial PoC it won't be forced. That mentioned, here is a list of the current limitations that we should think about:
Passing Criteria