deep-floyd / IF

Other
7.63k stars 495 forks source link

torch.compile for improved performance/AssertionError #98

Open phalexo opened 1 year ago

phalexo commented 1 year ago

In the IF documentation there is a suggestion that "torch.compile" can improve performance. I have tried

if_I = torch.compile(IFStageI('IF-I-XL-v1.0', device='cuda:1')) if_II = torch.compile(IFStageII('IF-II-L-v1.0', device='cuda:2')) if_III = torch.compile(StableStageIII('stable-diffusion-x4-upscaler', device='cuda:3'))

It is not clear from the documentation which type of objects can be compiled.

AssertionError Traceback (most recent call last) Cell In[2], line 1 ----> 1 if_I = torch.compile(IFStageI('IF-I-XL-v1.0', device='cuda:1')) 4 if_II = IFStageII('IF-II-L-v1.0', device='cuda:2') 7 if_III = StableStageIII('stable-diffusion-x4-upscaler', device='cuda:3')

File ~/.local/lib/python3.8/site-packages/torch/init.py:1441, in compile(model, fullgraph, dynamic, backend, mode, options, disable) 1439 if backend == "inductor": 1440 backend = _TorchCompileInductorWrapper(mode, options, dynamic) -> 1441 return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)

File ~/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:182, in _TorchDynamoContext.call(self, fn) 179 new_mod._torchdynamo_orig_callable = mod.forward 180 return new_mod --> 182 assert callable(fn) 184 callback = self.callback 185 on_enter = self.on_enter

AssertionError:

MohsenSadeghi commented 1 year ago

I got the same error when I tried to compile something other than a callable. Make sure that the model you use in torch.compile(model) can actually be called on some data as y = model(x).

phalexo commented 1 year ago

If stages are not callable then it is not clear to me which parts are callable. Did you manage to compile anything?

On Mon, May 15, 2023, 6:46 PM Mohsen Sadeghi @.***> wrote:

I got the same error when I tried to compile something other than a callable. Make sure that the model you use in torch.compile(model) can actually be called on some data as y = model(x).

— Reply to this email directly, view it on GitHub https://github.com/deep-floyd/IF/issues/98#issuecomment-1548714530, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDD3ZO4VSYEETDH3JAZ5JDXGKW3HANCNFSM6AAAAAAXYJOCHQ . You are receiving this because you authored the thread.Message ID: @.***>

MohsenSadeghi commented 1 year ago

yeah, I did manage to compile the whole model, but did not get much of a performance boost! :') The bottleneck apparently was in the dataloader pipeline.

What are you trying to compile?

phalexo commented 1 year ago

I just want the inference to be faster.

Can you paste the snippet of code that worked for you?

On Mon, May 15, 2023, 7:45 PM Mohsen Sadeghi @.***> wrote:

yeah, I did manage to compile the whole model, but did not get much of a performance boost! :') The bottleneck apparently was in the dataloader pipeline.

What are you trying to compile?

— Reply to this email directly, view it on GitHub https://github.com/deep-floyd/IF/issues/98#issuecomment-1548747316, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDD3ZIADNCIC2GQJ4B2PW3XGK5YFANCNFSM6AAAAAAXYJOCHQ . You are receiving this because you authored the thread.Message ID: @.***>

MohsenSadeghi commented 1 year ago

I'm afraid its just the vanilla opt_net = torch.compile(net), where net is a nn.Module with a forward() method. Alternatively, I could add the decorator @torch.compile to the forward() method itself.

phalexo commented 1 year ago

Regardless where I put it, either it complains about parallelism or simply hangs. What was the specific location where you put it "torch.compile" or "@torch.compile" ?