pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Cannot access storage of TensorWrapper #629

Open vfdev-5 opened 2 years ago

vfdev-5 commented 2 years ago

Another issue with functional_autograd_benchmark with functorch (code) when running vjp on deepspeech model:

python functional_autograd_benchmark.py --model-filter=deepspeech --task-filter=vjp
File "[/tmp/functorch/functorch/_src/eager_transforms.py]()", line 254, in vjp
    primals_out = func(*diff_primals)
  File "[/pytorch/benchmarks/functional_autograd_benchmark/audio_text_models.py]()", line 71, in forward
    out, out_sizes = model(inputs, inputs_sizes)
  File "[/pytorch/torch/nn/modules/module.py]()", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "[/pytorch/benchmarks/functional_autograd_benchmark/torchaudio_models.py]()", line 257, in forward
    x = rnn(x, output_lengths)
  File "[/pytorch/torch/nn/modules/module.py]()", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "[/pytorch/benchmarks/functional_autograd_benchmark/torchaudio_models.py]()", line 161, in forward
    x, h = self.rnn(x)
  File "[/pytorch/torch/nn/modules/module.py]()", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "[/pytorch/torch/nn/modules/rnn.py]()", line 772, in forward
    result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
NotImplementedError: Cannot access storage of TensorWrapper
zou3519 commented 2 years ago

classic RNNs being tricky to handle

cyyever commented 2 years ago

I also encounter this issue today. Then I have to change from RNN to Transformer...