pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Export training model to StableHlo #8366

Open Zantares opened 2 weeks ago

Zantares commented 2 weeks ago

❓ Questions and Help

The export API only supports torch.nn.module as input, is any method to export a training model with step_fn to StableHlo?

Here is a simple training case from example:

  def __init__(self):
    ...
    self.device = torch_xla.device()
    self.model = torchvision.models.resnet50().to(self.device)
    self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
    self.loss_fn = nn.CrossEntropyLoss()
    ...

  def run_optimizer(self):
    self.optimizer.step()

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    output = self.model(data)
    loss = self.loss_fn(output, target)
    loss.backward()
    self.run_optimizer()
    return loss

The guidance https://pytorch.org/xla/master/features/stablehlo.html#torch-export-to-stablehlo only introduced how to export the original self.model, but it didn't tell how to export the model with Optimizer and Loss functions.

JackCaoG commented 2 weeks ago

@qihqi not sure if exporting for training is something we support today.

Zantares commented 2 weeks ago

Add more background:

Compare with Torch-XLA, I found that JAX has a convenient API takes jitted function as input. Here is an example from JAX repo:

...
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))
...

if __name__ == "__main__":
  @jit
  def update(params, batch):
    grads = grad(loss)(params, batch)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

  ...
  params = update(params, next(batches))
  ...

Then it can be easily exported as below:

  # Export the function to StableHLO
  sh_exported = export.export(update)(params, batch)
  sh_text = get_stablehlo_asm(sh_exported.mlir_module())
  print(sh_text)

I can execute the generated StableHLO and get expected results. So, I'm wondering if Torch-XLA can export training model like this.