Open Zantares opened 2 weeks ago
@qihqi not sure if exporting for training is something we support today.
Add more background:
Compare with Torch-XLA, I found that JAX has a convenient API takes jitted function as input. Here is an example from JAX repo:
...
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
...
if __name__ == "__main__":
@jit
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
...
params = update(params, next(batches))
...
Then it can be easily exported as below:
# Export the function to StableHLO
sh_exported = export.export(update)(params, batch)
sh_text = get_stablehlo_asm(sh_exported.mlir_module())
print(sh_text)
I can execute the generated StableHLO and get expected results. So, I'm wondering if Torch-XLA can export training model like this.
❓ Questions and Help
The export API only supports
torch.nn.module
as input, is any method to export a training model with step_fn to StableHlo?Here is a simple training case from example:
The guidance https://pytorch.org/xla/master/features/stablehlo.html#torch-export-to-stablehlo only introduced how to export the original
self.model
, but it didn't tell how to export the model with Optimizer and Loss functions.