google / compare_gan

Compare GAN code.
Apache License 2.0
1.82k stars 317 forks source link

Need trainable version for pretrained SSGAN in tfhub #37

Closed RuiLiFeng closed 5 years ago

RuiLiFeng commented 5 years ago

@Marvin182 Hi, the tfhub team has just upload your SSGAN module. It's wonderful but seems dose not have a trainable version. I set m = hub.Module(spec_name, name="gen_module", tags={"gen", "bsNone"}, trainable=True), but the module offers no gradients when optimizor is applied. Below is part of my code. ` class Generator(object):

def init(self, module_spec, trainable=True): self._module_spec = module_spec self._trainable = trainable self._module = hub.Module(self._module_spec, name="gen_module", tags={"gen", "bsNone"}, trainable=self._trainable) self.input_info = self._module.get_input_info_dict()

def build_graph(self, input_dict): """ Build tensorflow graph for Generator :param inputdict: {'z': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>, 'labels': None or (?,)} :return:{'generated': <hub.ParsedTensorInfo shape=(?, 128, 128, 3) dtype=float32 is_sparse=False>} """ inv_input = {} inv_input['z'] = G_mapping_ND(inputdict['z'], 120, 120)

inv_input['labels'] = input_dict.get('labels', None)

self.samples = self._module(inputs=inv_input, as_dict=True)['generated']
return self.samples

@property def trainable_variables(self): return [var for var in tf.trainable_variables() if 'generator' in var.name] `

I wonder if it is my implementaion not right or the module itself not trainable.

Marvin182 commented 5 years ago

The Hub module is for inference only. If you want to continue training I would advice you to run the code and load the the variable values from the checkpoint in the hub module, a bit tricky but doable.

RuiLiFeng commented 5 years ago

Thank you very much. While there might be minor mistakes in the tfhub document of ssgan.

In https://tfhub.dev/google/compare_gan/ssgan_128x128/1, the example usage suggests to put a label variable into the module, which in my case, will return an error: TypeError: Cannot convert dict_inputs: missing ['images'], extra given ['labels'].

The colab version also suggests a conditional version of SSGAN, but the get_input_info_dict() says the model only accepts 'z' as input.