pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.42k stars 232 forks source link

Support for torch.fft #1362

Open gorkemalgan opened 6 months ago

gorkemalgan commented 6 months ago

I have a custom model that uses torch.fft.rfftn and torch.fft.irfftn. I can successfully run capture_pre_autograd_graphand export (only with static sizes though). But, when I run to_edge I get the following error:

Operator torch._ops.aten._fft_r2c.default is not Aten Canonical.

cccclai commented 6 months ago

To temporarily unblock, try to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) while we're working on address it.

cc: @larryliu0820

gorkemalgan commented 6 months ago

Thanks for the suggestion but now it throws the following error:

Exception has occurred: SpecViolationError These operators are taking Tensor inputs with mismatched dtypes: defaultdict(<class 'dict'>, {: schema = aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor: {'self': torch.float32, 'ret_0': torch.complex64}, : schema = aten::view_as_real_copy(Tensor self) -> Tensor: {'self': torch.complex64, '__ret_0': torch.float32}, : schema = aten::complex(Tensor real, Tensor imag) -> Tensor: {'real': torch.float32, 'imag': torch.float32, 'ret_0': torch.complex64}, : schema = aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor: {'self': torch.complex64, '__ret_0': torch.float32}})

larryliu0820 commented 6 months ago

From which step did you get this error?

gorkemalgan commented 6 months ago

I get it in the last step

example_args = (torch.randn(1, 4, 1000, 1504),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)) # to_edge(aten_dialect)
larryliu0820 commented 6 months ago

I think this is an valid error. It is saying tensor dtype torch.complex64 is not accepted by edge dialect (hence not accepted by ExecuTorch). We are not supporting this dtype now and is working on to get it supported.

gorkemalgan commented 6 months ago

torch.fft.rfft returns torch.complex64 dtype tensor so as I understand there is no way to run it with ExecuTorch for now? Correct me if I am wrong, but I will need to wait till torch.complex64 type is supported by ExecuTorch

larryliu0820 commented 6 months ago

torch.fft.rfft returns torch.complex64 dtype tensor so as I understand there is no way to run it with ExecuTorch for now? Correct me if I am wrong, but I will need to wait till torch.complex64 type is supported by ExecuTorch

Yep I don’t think we have complex number support in ExecuTorch runtime right now

kimishpatel commented 6 months ago

@guangy10 if we have triaging and followup meeting, I would like to follow up on this issue specifically.

@gorkemalgan note that there are a few issues here. 1) complex dtype not supported 2) implementation of *fft not available

cccclai commented 6 months ago

I think we already have the complex data type issue filed separately here: https://github.com/pytorch/executorch/issues/886

cbilgin commented 6 months ago

@gorkemalgan Closing because it's a duplicate of https://github.com/pytorch/executorch/issues/886. If this is a separate issue, please do open it again

gorkemalgan commented 6 months ago

This issue is not a duplicate of #886. In addition to complex data type support requested in #886 this issue requires Executorch implementation for torch.fft operator as well.

cccclai commented 6 months ago

fft

I believe torch.fft is not part of core aten. If it is not, it'll decomposed to the core aten ops. Maybe @larryliu0820 or @SS-JIA can confirm

SS-JIA commented 6 months ago

With FFT related ops, we decided to defer decision for now. However, seeing that there is a use-case for them we will need to revisit whether those should be core. To my knowledge you can't really decompose fft.

However, even before considering adding fft ops, we need to think about how we're going to support complex data types. Is there a plan in place for this currently?

alealv commented 6 months ago

Hi, I'm also eager to have support of FFT functions. In my case, I'm getting

    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_c2r.default

The above exception was the direct cause of the following exception:

...

    raise UnsupportedOperatorException(func)
RuntimeError: Failed running call_function <built-in method istft of type object at 0x7f3704de8a40>(*(FakeTensor(..., size=(s0, 513, 50), dtype=torch.complex64), 1024, 256, 1024, FakeTensor(..., size=(1024,))), **{'center': True}):
aten._fft_c2r.default

During handling of the above exception, another exception occurred:

....

    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_c2r.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

The above exception was the direct cause of the following exception:

....

    raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e)) from e
torch._dynamo.exc.UserError: speculate_subgraph: while introspecting cond, we were unable to trace function `tf_istft` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: unsupported operator: aten._fft_c2r.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)