openai / DALL-E

PyTorch package for the discrete VAE used for DALL·E.
Other
10.77k stars 1.94k forks source link

Upsample recompute_scale_factor related error when running the notebook #71

Open sammo opened 2 years ago

sammo commented 2 years ago

Is anyone facing this issue when running the notebook? If not, what version of pytorch are you using?

I see this method is affected in the most recent release of pytorch (1.11.0) so my guess that has something to do with it.

AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'

digitalShaman commented 2 years ago

The error 'Upsample' object has no attribute 'recompute_scale_factor'

is related to a change in the torch Upscale class from 1.10 to 1.11.

It appears that 'old' Upscale objects are saved within the model after this line of code: model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda') i used the following code immediatly after the load_model call to patch this:

# Patch for torch 1.11 and higher: replace the old Upsample object by the new version
# that exposes recompute_scale_factor
_ = model.blocks.group_1.upsample
model.blocks.group_1.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_2.upsample
model.blocks.group_2.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_3.upsample
model.blocks.group_3.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)

and it's running fine with torch 1.12.1!

kamwoh commented 1 year ago

my solution is to use the state dict from the encoder/decoder online

from dall_e import Encoder, Decoder
# This can be changed to a GPU, e.g. 'cuda:0'.
dev = torch.device('cpu')
# For faster load times, download these files locally and use the local paths instead.
enc_old = load_model("https://cdn.openai.com/dall-e/encoder.pkl", dev)
dec_old = load_model("https://cdn.openai.com/dall-e/decoder.pkl", dev)
enc = Encoder()
enc.load_state_dict(enc_old.state_dict())
enc.eval()
dec = Decoder()
dec.load_state_dict(dec_old.state_dict())
dec.eval()
lzh107u commented 1 year ago

I use an exception handler to check the status of "self.recompute_scale_factor" in upsampling.py . If that attribute doesn't exist during execution, my modified code will directly assign False to it.

It's interesting because this exception doesn't raise every time, and I am still not sure about the mechanism behind it.