amzn / convolutional-handwriting-gan

ScrabbleGAN: Semi-Supervised Varying Length Handwritten Text Generation (CVPR20)
https://www.amazon.science/publications/scrabblegan-semi-supervised-varying-length-handwritten-text-generation
MIT License
267 stars 56 forks source link

Error while running train.py #14

Closed vijay1131 closed 3 years ago

vijay1131 commented 3 years ago

I am trying to train on RIMES dataset. I have successfully ran create_text_dataset.py.

Now when i run train.py with following command.

python train.py --name_prefix demo --dataname RIMEScharH32W16 --capitalize --no_html --gpu_ids 0 --batch_size 32

i get below error. self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss) AttributeError: 'ScrabbleGANModel' object has no attribute 'fake'

I am not sure why.

sharonFogel commented 3 years ago

I wasn't able to recreate this error...Are you sure you are using the right environment?

vijay1131 commented 3 years ago

I am using the same conda env file shared with this project. Can pytorch - Gpu compatibility be the issue because i am using rtx 3090 where this issues is coming i have tried with rtx 2060 with same env and its working great.

Also one more thing. If Training from scratch, how many epoch would you recommend.

rlit commented 3 years ago

Do you have the latest version? I cannot find this line in the code

self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)

vijay1131 commented 3 years ago

Its here. https://github.com/amzn/convolutional-handwriting-gan/blob/f7daa5045a281be23c1d20c5b74f12ffbddf69f9/models/ScrabbleGAN_baseModel.py#L340

I think its may be because self. fake not able get declared over here. `

 if self.opt.one_hot:
            self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.opt.n_classes).to(self.device)
            try:
                self.fake = self.netG(self.z, self.one_hot_fake)
            except:
                print("hh", words)
        else:
            self.fake = self.netG(self.z, self.text_encode_fake)  # generate output image given the input data_A

` the full error is below.

one hot True hh [b'over', b'with', b'using', b'Mudlarks', b'a', b'430040', b'the', b'pots'] Traceback (most recent call last): File "train.py", line 78, in model.optimize_G() File "convolutional-handwriting-gan/models/ScrabbleGAN_baseModel.py", line 425, in optimize_G self.backward_G() File "convolutional-handwriting-gan/models/ScrabbleGAN_baseModel.py", line 341, in backward_G self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss) AttributeError: 'ScrabbleGANModel' object has no attribute 'fake'

Something with one hot encoding?

vijay1131 commented 3 years ago

Ok Found the issue, its something to do with GPU architecture. Installing pytorch nightly solves this. ref: https://github.com/NVIDIA/apex/issues/580