Closed hugofloresgarcia closed 1 year ago
hi!
If I export a model that outputs a Tensor with a long dtype and open it in Max with nn~, Max will crash.
long
nn~
class IntegerModel(nn_tilde.Module): @torch.jit.export def return_an_int(self, x): return x.long()
The issue occurs here: https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/src/backend/backend.cpp#L77, when the Tensor data is cast back into a float pointer.
I think an easy fix would be to add an additional check here, to ensure all nn~ models have float outputs. https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/python_tools/__init__.py#L75
if y.dtype != torch.float: raise ValueError(f"Output tensor must be of type float32, got {y.dtype}")
Let me know if you would be interested in me submitting a PR! Have been wanting to contribute to the RAVE/nn~ universe of projects :)
Hi ! Good catch ! Please submit a PR and I'll merge it :)
hi!
If I export a model that outputs a Tensor with a
long
dtype and open it in Max withnn~
, Max will crash.The issue occurs here: https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/src/backend/backend.cpp#L77, when the Tensor data is cast back into a float pointer.
I think an easy fix would be to add an additional check here, to ensure all nn~ models have float outputs.
https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/python_tools/__init__.py#L75
Let me know if you would be interested in me submitting a PR! Have been wanting to contribute to the RAVE/nn~ universe of projects :)