sahajgarg / image_transformer

Pytorch implementation of the image transformer for unconditional image generation
114 stars 32 forks source link

Two problems with your code #4

Closed zhshi0816 closed 3 years ago

zhshi0816 commented 3 years ago

Hi,

Thanks for sharing your codes. But it seems there are two problems with your code. Please ignore it if I am wrong.

It seems you forget to add position embedding to the input representations. Maybe your add_timing_signal finish this job, but I am not sure. By the way, what is add_timing_signal function for?

In your training code, you forget to generate a mask to mask off the pixels that have not been generated. It means the previous pixel can get the information of the next pixel.

sahajgarg commented 3 years ago

The timing_signal is a type of positional embedding used with the image transformer, using the default in the original implementation, see https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/image_transformer.py#L206, https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py#L504, https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/layers/common_attention.py#L408

For the masking in the attention, the masking occurs here: https://github.com/sahajgarg/image_transformer/blob/d33b8d007299b434c62e068e1dad35b8a2688212/image_transformer.py#L303 This generates an upper triangular mask on the logits of the attention, preventing any information from future pixels from reaching the current pixel. However, the training code can evaluate the conditional probability of each pixel given all the previous pixels simultaneously, so long as this masking does occur.