CorentinJ / Real-Time-Voice-Cloning

Clone a voice in 5 seconds to generate arbitrary speech in real-time
Other
52.05k stars 8.71k forks source link

### 🚀 The feature, motivation and pitch #1142

Closed ImanuillKant1 closed 1 year ago

ImanuillKant1 commented 1 year ago

🚀 The feature, motivation and pitch

I'm trying to write a custom torch.autograd.Function that can take in and return a List of Tensors like

class FooFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        indices,
        offsets,
        weights,
    ) -> torch.Tensor:
        with torch.enable_grad():
            return torch.ones(2, requires_grad=True)

    @staticmethod
    def backward(ctx, dout):
        return None, None, [torch.ones(1)]

And then

foo = FooFunction()
weights = [torch.ones(1, requires_grad=True)]
out = foo.apply(torch.ones(1), torch.ones(1), weights)
out_sum = out.sum()
print("out sum", out_sum)
out_sum.backward()
print(weights[0].grad)

>>> torch.ones(1)

Right now i'm getting an error of:

File ~/anaconda3/envs/compilers/lib/python3.10/site-packages/torch/_tensor.py:473, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    463 if has_torch_function_unary(self):
    464     return handle_torch_function(
    465         Tensor.backward,
    466         (self,),
   (...)
    471         inputs=inputs,
    472     )
--> 473 torch.autograd.backward(
    474     self, gradient, retain_graph, create_graph, inputs=inputs
    475 )

File ~/anaconda3/envs/compilers/lib/python3.10/site-packages/torch/autograd/__init__.py:197, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    192     retain_graph = create_graph
    194 # The reason we repeat same the comment below is that
    195 # some Python versions print out the first line of a multi-line function
    196 # calls in the traceback and some print out the last line
--> 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    198     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    199     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I'm guessing because autograd doesn't support this List[torch.Tensor] types

Is there a way to do this? (is there some kind of TensorList class, even something like NestedTensor would be useful),

Thanks!

Alternatives

No response

Additional context

No response

Originally posted by @YLGH in https://github.com/pytorch/pytorch/issues/89801