icon-lab / SynDiff

Official PyTorch implementation of SynDiff described in the paper (https://arxiv.org/abs/2207.08208).
Other
244 stars 37 forks source link

Wrong signatature in forward function #38

Closed sRassmann closed 9 months ago

sRassmann commented 9 months ago

Hi, I am trying to train SynDiff on my on data using the cmd args as pointed out in the readme. I am encountering the following error in the forward function:

Traceback (most recent call last):
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rassmanns/diffusion/SynDiff/backbones/ncsnpp_generator_adagn.py", line 312, in forward
    h = modules[m_idx](hs[-1], temb, zemb)
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rassmanns/miniconda3/envs/icon/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

The issue seems to be that the corresponding AttnBlockpp takes only two args (self, x) rather than the additional (..., zemb, temb) arguments.

Could you please point out what I am doing wrong?

Thanks

sRassmann commented 9 months ago

Never mind, I just found that the implementation uses the image_size from the config (set to 32 by default), while I am using my own dataloaders with image_size set to 256. Hence, passing in the correct size to the net's constructors solved the issue.