Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.87k stars 3.34k forks source link

Export to ONNX #18568

Closed MarcoPrassel closed 11 months ago

MarcoPrassel commented 11 months ago

Bug description

Hi, I'm trying to convert my model Detr to onnx , I followed the guide but I get this error:

f"torch>=2.0 requires onnx to be installed to use {type(self).__name__}.to_onnx()"

I am working on Google Colab.

Can you help me?

What version are you seeing the problem on?

master

How to reproduce the bug

checkpoint_path = '/content/lightning_logs/version_0/checkpoints/max.ckpt'
checkpoint = torch.load(checkpoint_path)

# Crea una nuova istanza del modello e assegna i pesi del checkpoint
model = Detr(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)
model.load_state_dict(checkpoint['state_dict'])
model.to_onnx("test.onnx",  export_params=True)

Error messages and logs

# Error messages and logs here please

ModuleNotFoundError: torch>=2.0 requires onnx to be installed to use Detr.to_onnx()

1362 ) 1361 f"torch>=2.0 requires onnx to be installed to use {type(self).__name__}.to_onnx()" -> 1360 raise ModuleNotFoundError( 1359 if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE: 1358 """ /usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py in to_onnx(self, file_path, input_sample, **kwargs) 1 frames

----> 1 model.to_onnx("test.onnx", export_params=True) in <cell line: 1>() ModuleNotFoundError Traceback (most recent call last)

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

awaelchli commented 11 months ago

@MarcoPrassel You need to install the onnx package as the error tells you:

pip install onnx

Or in Google Colab, add this in the cell:

! pip install onnx
MarcoPrassel commented 11 months ago

Thank you @awaelchli , I had installed onnx but the error was always present, now everything seems to be OK, although I get this error: Could not export to ONNX since neitherinput_samplenormodel.example_input_arrayattribute is set. I am not an expert, this is the first time I use pytroch lightning, my goal is to export the model trained with Detr to the onnx format, can you explain how I can do this?

awaelchli commented 11 months ago

@MarcoPrassel Just pass in the input sample argument as suggested by the error message. Here are the docs with examples: https://lightning.ai/docs/pytorch/stable/deploy/production_advanced.html#compile-your-model-to-onnx You can find these by just typing "onnx" into the search bar.

awaelchli commented 11 months ago

Let me know if you have any more issues :)

Lookforworld commented 7 months ago

@awaelchli
I have an erro when using to_onnx() method.

The model's forward like this:

def forward(self, windows_batch):
        insample_y = windows_batch["insample_y"]
        insample_mask = windows_batch["insample_mask"]
        futr_exog = windows_batch["futr_exog"]
        hist_exog = windows_batch["hist_exog"]
        stat_exog = windows_batch["stat_exog"]
                        ...
                        ...

I using the to_onnx() method like this:

trained_md.eval()
data_=trained_md.dataset
in_=iter(data_)
in_=next(in_)
input_names=[]
input_dict = {}
for itm in in_:
    if in_[itm] is not None:
        input_dict[itm] = in_[itm][-1:]
        input_names.append(itm)
in_=input_dict
trained_md.to_onnx(file_path='Nbx.onnx',
                   input_sample=in_,  
                   input_names=input_names,

Then give me an erro:

TypeError: NBEATSx.forward() missing 1 required positional argument: 'windows_batch'

How can I fix this issue?