pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.21k stars 22.11k forks source link

Functionalization doesn't work with torch.nn.functional.ctc_loss #86384

Closed tugsbayasgalan closed 1 year ago

tugsbayasgalan commented 1 year ago

🐛 Describe the bug

To repro:

from functorch.experimental import functionalize
import torch

def f(x):
    log_probs = x.log_softmax(2)
    targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
    input_lengths = torch.full((16,), 50, dtype=torch.long)
    target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
    loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
    return loss

f_functional = functionalize(f, remove="mutations_and_views")
f_functional(torch.randn(50, 16, 20))

which gives following error:

The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-c62f15727642> in <module>
     11 
     12 f_functional = functionalize(f, remove="mutations_and_views")
---> 13 f_functional(torch.randn(50, 16, 20))
/mnt/xarfuse/uid-26336/1f88dbba-seed-nspid4026533609_cgpid30323950-ns-4026533606/functorch/_src/vmap.py in fn(*args, **kwargs)
     33     def fn(*args, **kwargs):
     34         with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 35             return f(*args, **kwargs)
     36     return fn
     37 
/mnt/xarfuse/uid-26336/1f88dbba-seed-nspid4026533609_cgpid30323950-ns-4026533606/functorch/_src/eager_transforms.py in wrapped(*args, **kwargs)
   1458             flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
   1459 
-> 1460             func_outputs = func(*func_args, **func_kwargs)
   1461             outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
   1462             flat_outputs, func_out_spec = tree_flatten(outputs)
<ipython-input-7-c62f15727642> in f(x)
      7     input_lengths = torch.full((16,), 50, dtype=torch.long)
      8     target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
----> 9     loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
     10     return loss
     11 
/mnt/xarfuse/uid-26336/1f88dbba-seed-nspid4026533609_cgpid30323950-ns-4026533606/torch/nn/functional.py in ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
   2622             blank=blank, reduction=reduction, zero_infinity=zero_infinity
   2623         )
-> 2624     return torch.ctc_loss(
   2625         log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity
   2626     )
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

I think this is because ctc_loss is still not composite_complaint even after this diff (https://github.com/pytorch/pytorch/pull/84752/files).

Versions

latest master

cc @bdhirsh @ezyang @soumith @SherlockNoMad @ngimel

tugsbayasgalan commented 1 year ago

cc: @bdhirsh @Chillee @gmagogsfm @larryliu0820

bdhirsh commented 1 year ago

This is a composite compliance issue. It looks like we actually have a composite-compliant decomposition for ctc_loss, but it currently only runs for tensor subclasses, and doesn't run when functionalization is active (which I think we should fix). More discussion here: https://github.com/pytorch/pytorch/pull/84752/files#r989417906