Closed dlangerm closed 2 years ago
Additional investigation: I can load my checkpoint created just before this error and run to_onnx
or summarize
and it works fine. This only happens during training.
Turns out to fix it I had to modify the forward call of my module to accept a dictionary or a Tensor.
def forward(self, x) -> torch.Tensor:
if isinstance(x, torch.Tensor):
x = self.backbone(x) ## <---- initially only had this line here
elif isinstance(x, dict):
x = self.backbone(x[DefaultDataKeys.INPUT])
Not exactly sure how to classify this as a bug or just inconsistent behavior but I think a fix is in order. In my mind, there's no reason for to_onnx
to fail on data when I pass it in using the lightning method of example_input_array.
Hi @dlangerm thanks for your work on this! Glad you got something working :smiley: Internally in Flash we represent everything as a dict. Usually in our tasks we unpack the dict in the *_step method and then implement the forward to just expect a tensor. So your common_step
or similar method would have:
input, target = x[DefaultDataKeys.INPUT], x[DefaultDataKeys.TARGET]
pred = self(input)
Or something like that. That way things like onnx or torchscript which assume your forward to work with just an input tensor will still work without the need for your modification. Really interested to hear your thoughts or suggestions about how we can make this cleaner / smoother to use. Thanks :smiley:
@ethanwharris So the thing is I have all my *_step methods for that, but it seems like to_onnx
and summarize
doesn't call them, but when the example_input_array is passed through the preprocessor, it's wrapped in the dict which causes this error.
Ahhhh, that's interesting, so the ONNX includes the preprocess but skips the step method? That's definitely something we should address. Let me know if you have any ideas to solve it, I'll try to look into it and see what options there are.
It seems like that is what's happening. Since my to_onnx
call is in a callback itself, that may be where it's getting tripped up? It still doesn't make sense why summarize is failing though...
@ethanwharris I was able to confirm that every other stage of training is passing a dict to _apply_batch_transfer_handler
except for the to_onnx
and summarize
calls which are passing raw Tensors.
@ethanwharris To fix the to_onnx
call I had to change this line
To this:
input_sample = self._apply_batch_transfer_handler({'input': input_sample})['input']
I think it could be fixed by overriding to_onnx
in Task
since the behavior is specifically because everything in a Task is expected to be a dictionary. Lightning seems not to have this behavior, so I think inserting this code directly in there would have unintended consequences.
The hunt for this bug has led to unrelated other issues, I'm going to go ahead and open a new issue about documentation on the Preprocess module.
BTW This is still an issue. For anyone else who runs into this, a hotfix is to override _apply_batch_transfer_handler
in your Task.
def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None) -> Any:
if isinstance(batch, torch.Tensor):
return super()._apply_batch_transfer_handler(batch={DefaultDataKeys.INPUT: batch}, device=device, dataloader_idx=dataloader_idx)[DefaultDataKeys.INPUT]
else:
return super()._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
But I'm not comfortable making this a pull request as _
prefixed functions are generally not meant to be overridden per convention and I'm not exactly sure what other side effects this would have.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🐛 Bug
It seems that there is some sort of state issue with custom preprocessors and flash modules. My custom preprocessor basic code is below, but it's based on the examples given in the documentation and is very simple. I can't get the error to appear consistently but it has happened at least once in my past few training iterations.
to_onnx
seems to trigger it as well, but not always, which makes me think it has something to do with the preprocessor state (current_transform perhaps).I am happy to help debug this issue but it's really annoying and I do need a custom preprocessor for my data.
To Reproduce
Steps to reproduce the behavior:
Stack trace:
Code sample
Expected behavior
Expect
to_onnx
ormodel.summarize
to use the correct datatypes for inferenceEnvironment
conda
,pip
, source): condato_onnx
on a model with a custom preprocessor. It could be something else, but my gut says it's probably the preprocess code.Additional context
FWIW the Task that I'm using is very basic but initializes the example_input_array in the constructor like this: