acids-ircam / nn_tilde

Other
321 stars 36 forks source link

`nn~` crashes if model output dtype is not `float` #52

Closed hugofloresgarcia closed 1 year ago

hugofloresgarcia commented 1 year ago

hi!

If I export a model that outputs a Tensor with a long dtype and open it in Max with nn~, Max will crash.

class IntegerModel(nn_tilde.Module):

    @torch.jit.export
    def return_an_int(self, x):
        return x.long()

The issue occurs here: https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/src/backend/backend.cpp#L77, when the Tensor data is cast back into a float pointer.

I think an easy fix would be to add an additional check here, to ensure all nn~ models have float outputs.
https://github.com/acids-ircam/nn_tilde/blob/e144b5d7b7b769edb19dd3a271eaea59e0b1867b/python_tools/__init__.py#L75

if y.dtype != torch.float:
      raise ValueError(f"Output tensor must be of type float32, got {y.dtype}")

Let me know if you would be interested in me submitting a PR! Have been wanting to contribute to the RAVE/nn~ universe of projects :)

caillonantoine commented 1 year ago

Hi ! Good catch ! Please submit a PR and I'll merge it :)