tenstorrent / tt-forge-fe

The TT-Forge FE is a graph compiler designed to optimize and transform computational graphs for deep learning models, enhancing their performance and efficiency.
https://docs.tenstorrent.com/tt-forge-fe/
Apache License 2.0
20 stars 3 forks source link

Training MNIST based Linear model e2e on device #176

Open nvukobratTT opened 4 months ago

nvukobratTT commented 4 months ago

Summary

In order to train this model, the following key details are required:

Having the components mentioned above, pushing initial e2e training should be straightforward.

Details

With all the mentioned components, most of the details should be in place. For easier understanding, here is the overview of how compile flow should work and look like:

  1. We define our training model and training loop. Check pybuda/test/mlir/mnist/test_training.py for a sample.
  2. Based on the training loop, runtime push inputs to appropriate locations (e.g. host, dram) and executes model compile
  3. Each model is compiled separately, but runtime is aware of inputs/outputs locations (e.g. l1, dram, etc). Utilize runtime stitching as a concept to drive this
  4. Depending on how the user defines the training loop, runtime executes graphs and updates specific memory locations as defined by flatbuffer metadata
    • MLIR shouldn't be aware of graph purpose, they should be independent (more details in follow-up sections)

Model compile

It's important to note that each model is compiled separately. E.g. each pybuda.compile executes a separate compile workflow that generates flatbuffers and returns them to the PyBuda runtime together with appropriate metadata that runtime needs (e.g. where to find inputs)

During model compile, each component should be compiled separately and be unaware of each other. That way, we can strive to more streamlined support of other frameworks, and keep the flexibility of compilation. This ensures that users have an ability to use the PyTorch training loop style. That way, our customers can integrate our compiler as plug-and-play, where only a few lines of their existing workflows should be updated.

Likewise, we should utilize something like runtime stiching in order to support independent module/components compile and runtime workloads. Only runtime for each frontend should be aware of inference/training specifics. MLIR should be agnostic to it (same applies for flatbuffer schemas).

Details around model training

Main point of training is for model to "learn" how to do the desired task. You can view the model as a simple "mapper". I.e. "mapping" one type of input to desired output. For example, "mapping" an image of a cat, into the actual prediction string "cat". In this section, I'll focus only on training points that are related to the compiler.

In sum, once the model weights are initialized (automatically done when the model is created) we're running the training loop and doing the following (code sample here: pybuda/test/mlir/mnist/test_training.py):

  1. Forward pass and get model predictions
  2. Calculate loss by referencing forward pass outputs (predictions) and targets (ground truth labels)
  3. Backward pass that uses loss output to calculate gradients for each model parameter
  4. Optimization pass that uses calculated (or accumulated) gradients to update original model parameters

For a more precise example, here are a few graphs coming from how the training worked on the older PyBuda/BBE stack.

Forward graph with attached loss

image

Backward graph:

image

Optimizer graph:

image

Training loop

To easily map above-mentioned concepts, here is the proposed sample of how a training loop can look like.

    # Define model and instruct it to compile and run on TT device
    framework_model = MNISTLinear()
    tt_model = pybuda.compile(framework_model)
    tt_model.to("tt")

    # Create a torch loss and leave on CPU
    tt_loss_fn = torch.nn.L1Loss()

    # Define optimizer and instruct it to compile and run on TT device
    framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)
    tt_optimizer = pybuda.compile(framework_optimizer)
    tt_optimizer.to("tt")

    for epoch_idx in range(num_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            # Put inputs on device
            data = data.to("tt")

            # Create target tensor and leave on CPU
            target = nn.functional.one_hot(target, num_classes=10).float()

            # Reset gradients (every batch)
            tt_optimizer.zero_grad()

            # Forward pass (prediction) on device
            pred = tt_model(data)

            # Pull output back to CPU
            pred = pred.to("cpu")

            # Compute loss on CPU
            loss = tt_loss_fn(pred, target)

            # Run backward pass on device
            loss.backward()

            # Adjust weights (on device)
            tt_optimizer.step()

Training components

In sum, this is still in exploration work, and we should strive to do as many R&D experiments as possible in order to choose the best approach for the long run. Therefore, some of these concepts might change over time.

Firstly, regarding the forward graph. In sum, it represents an inference graph. One caveat is that we previously once caveat is that we should determine how we can simply how loss is connected to it. In the previous stack, it was attached to the forward graph (like an extension). The details that need to be checked here are whether we have to be aware of the loss part of the graph, or if we can simply integrate it as a part of the forward graph, and compute backward on top of it.

As already mentioned, the loss part is currently examined as an addition to the forward graph. If we prove that we can follow a similar mechanism and get valid training, it can simplify graph compilation a bit. However, this can become tricky for more advanced loss functions. As mentioned, we should explore this path in more detail. If it's not possible to merge those two, runtime will have to be aware of a loss component in the long run. To bridge that complexity, we can also run loss on the host/CPU for the initial training PoC.

Following is the backward part. It's generated based on the forward graph using the PyBuda autograd mechanism. It's used to compute gradients for each model parameter, which is latter used by the optimizer to properly update those gradients based on loss output. Overall, MLIR should view the backward graph as a plain graph, without any notion of which type is it. PyBuda runtime is the one that will orchestrate what is what and how to properly manage those components/data.

And lastly, there is the optimizer. It's main point is to update model parameters using calculated gradients. Also, the optimizer can have additional inputs that define how parameters are updated (e.g. learning rate). Like each graph mentioned above, the optimizer should also be a standalone graph. The trick here is for runtime to know how to map gradients to the valid parameters. To handle this properly, we should ensure that the model is compiled first. That way, we can gather the list of model parameters, and match those to the calculated/accumulated gradients. Afterward, once we compile the optimizer and determine where the inputs/outputs are stored, we can just run a graph and update parameters based on how the training loop is defined by the user.

Limitations

As expected, this initial overview has some limitations. The main one is that we're not exploring as part of the initial run how to scale out training. During this push, we should be aware that training will also run across multiple chips. However, for initial PoC it won't be forced. That mentioned, here is a list of the current limitations that we should think about:

Passing Criteria

nsmithtt commented 4 months ago

I think this section should be edited:

    tt_model = pybuda.compile(framework_model.to("tt"))
    tt_optimizer = pybuda.compile(framework_optimizer.to("tt"))
    tt_loss = pybuda.compile(framework_loss.to("tt"))

To just:

    tt_model = pybuda.compile(framework_model)
    tt_optimizer = pybuda.compile(framework_optimizer)
    tt_loss = pybuda.compile(framework_loss)

And then when the respective module is ready to be run on the device you have to do:

tt_model.to("tt")

What .to("tt") does on a module is prepare/copy the weights to device DRAM. It could be legal to run:

tt_model.to("tt")
tt_loss.to("tt")

But what this means is that all of the parameters for tt_model and tt_loss will be copied to device DRAM and live simultaneously there.

We should try and stick to these semantics as closely as we can: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to

nsmithtt commented 4 months ago

Nikola, this looks great! Once we have a POC it'd be great to turn this into a documentation spec for training.

nsmithtt commented 4 months ago

I think an R&D effort needs to try and see how we can hook:

loss.backward() # executes bwd graph
tt_optimizer.step() # executes optimizer graph

An additional challenge is that if loss is separately compiled regardless of it being scheduled on CPU or TT, is how we can propagate gradients between graphs. At first glance it seems like it'd just be a matter of wiring up loss backward outputs -> model backward inputs, but there are likely some finicky details to work out. For example, when loss.backward() is called it implicitly runs tt_model backward, so I think the runtime will need to keep track of this and schedule this automatically.

Optimizer step is essentially the same issue, but it feels slightly more explicit since the outputs are given up front and it somewhat decoupled from the loss and model graphs.

A potentially reasonable intermediate step could be to force users to embed loss into their module directly:

class ModuleAndLoss(nn.Module):
    def __init__(self, module, loss):
        self.module = module
        self.loss = loss

    def forward(self, inputs, targets):
        return self.loss(self.module(inputs), targets)

This has the benefit of not having to propagate gradients through flatbuffer boundaries, however, in the long term that would provide most flexibility.

nvukobratTT commented 4 months ago

details regarding .to("tt") Similar as on this discussion, I agree!

Nikola, this looks great! Once we have a POC it'd be great to turn this into a documentation spec for training. For sure!

An additional challenge is that if loss is separately compiled regardless of it being scheduled on CPU or TT, is how we can propagate gradients between graphs. At first glance it seems like it'd just be a matter of wiring up loss backward outputs -> model backward inputs, but there are likely some finicky details to work out. For example, when loss.backward() is called it implicitly runs tt_model backward, so I think the runtime will need to keep track of this and schedule this automatically.

I see your reasoning. The caveat of backward() is not outputs propagation, but rather runtime awareness how that should be mapped. More precisely, when we call loss.backward(), runtime needs to:

  1. Invoke loss bwd and store gradients on desired location
  2. Invoke tt_model bwd but fetching loss bwd output, and running rest of the bwd graph
  3. Store full bwd output to the desired location

Regarding optimizer, it's a bit more simpler, as it doesn't need to run different modules. However, it has to be aware of where accumulated gradients are stored, and where are parameters that needs to be updated.

Let's see during R&D what are the tricky points.