tohinz / multiple-objects-gan

Implementation for "Generating Multiple Objects at Spatially Distinct Locations" (ICLR 2019)
MIT License
113 stars 15 forks source link

AttnGAN evaluation error #5

Open rracinskij opened 5 years ago

rracinskij commented 5 years ago

Hi,

I get an error while trying to generate images using the pretrained AttnGAN model: sh sample.sh coco-attngan ...

  File "main.py", line 86, in gen_example
    algo.gen_example(data_dic)
  File ".../multiple-objects-gan-master/code/coco/attngan/trainer.py", line 604, in gen_example
    netG.load_state_dict(state_dict)
RuntimeError: Error(s) in loading state_dict for G_NET:
Unexpected key(s) in state_dict: "netG".
tohinz commented 5 years ago

Can you give me a print out of the state_dict? For a quick fix try netG.load_state_dict(state_dict["netG"]) to see if that works.

rracinskij commented 5 years ago

Thank you, the quick fix solved the problem, as the state_dict contains a netG key only. So the line 604 of the code/coco/attngan/trainer.py should probably contain netG.load_state_dict(state_dict["netG"])

tohinz commented 5 years ago

Thanks for spotting this, I have just fixed this in the code.

rracinskij commented 5 years ago

Sorry for reopening - one more issue with sh sample.sh coco-attngan:

File "main.py", line 160, in <module>
    gen_example(dataset.wordtoix, algo)  # generate images for customized captions
  File "main.py", line 86, in gen_example
    algo.gen_example(data_dic)
  File "/home/.../multiple-objects-gan-master/code/coco/attngan/trainer.py", line 636, in gen_example
    fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
  File "/.../lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() takes exactly 7 arguments (5 given)
tohinz commented 5 years ago

Hi, the method you're trying to call is from the original AttnGAN implementation and is used to generate images from "novel" sentences (i.e. not from the validation set). To avoid this make sure you have B_VALIDATION: True in the cfg_eval.yml file.

If you want to use novel sentences to generate images you'll need to provide the two values for transf_matrices_inv and label_one_hot for each caption and call fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) Essentially you need to decide where in the image (with transf_matrices_inv) and what kind of objects (with label_one_hot) you want to generate. See datasets.py lines 331-339 to see how to obtain the transformation matrices from a scaled bounding box (bbox values are [x-coordinate, y-coordinate, width, height] each value scaled so it is between 0 and 1, e.g. [0.0, 0.0, 0.5, 0.5] for an object that covers the top left quarter of the image).

The labels are: 0 person 1 bicycle 2 car 3 motorcycle 4 airplane 5 bus 6 train 7 truck 8 boat 9 traffic light 10 fire hydrant 11 stop sign 12 parking meter 13 bench 14 bird 15 cat 16 dog 17 horse 18 sheep 19 cow 20 elephant 21 bear 22 zebra 23 giraffe 24 backpack 25 umbrella 26 handbag 27 tie 28 suitcase 29 frisbee 30 skis 31 snowboard 32 sports ball 33 kite 34 baseball bat 35 baseball glove 36 skateboard 37 surfboard 38 tennis racket 39 bottle 40 wine glass 41 cup 42 fork 43 knife 44 spoon 45 bowl 46 banana 47 apple 48 sandwich 49 orange 50 broccoli 51 carrot 52 hot dog 53 pizza 54 donut 55 cake 56 chair 57 couch 58 potted plant 59 bed 60 dining table 61 toilet 62 tv 63 laptop 64 mouse 65 remote 66 keyboard 67 cell phone 68 microwave 69 oven 70 toaster 71 sink 72 refrigerator 73 book 74 clock 75 vase 76 scissors 77 teddy bear 78 hair drier 79 toothbrush

rracinskij commented 5 years ago

Thank you for a detailed explanation. Could you please specify the correct dimensions of bbox, transf_matrices_inv and label_one_hot?

tohinz commented 5 years ago

The dimensions should be the following (num objects = how many objects you want to specify for the object pathway, in our case usually 3): bbox: [batch size, num objects, 4] trans_matrices_inv and transf_matrices: [batch size, num objects, 2, 3] label_one_hot: [batch size, num objects, 81]