google-deepmind / sonnet

TensorFlow-based neural network library
https://sonnet.dev/
Apache License 2.0
9.75k stars 1.29k forks source link

VQ-VAE training example(v2) returned NAN loss #198

Open EBGU opened 3 years ago

EBGU commented 3 years ago

Dear Team Deepmind,

I am really grateful that you shared a vqvae_example with sonnet2. However, when running it, I currently encounter a problem of NAN vqvae loss from the beginning. The outcome is: 100 train loss: nan recon_error: 1.010 perplexity: 1.031 vqvae loss: nan and so on. The plot of the training set is fine, but the reconstruction is pure grey. I tried vq_use_ema = False of True and got the same results. I have slightly modified your code by replacing downloading and data loading with the previous version(https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb) using a local directory. Also, I'm using TensorFlow version 2.2.0 Sonnet version 2.0.0. My code didn't return any error, just NAN loss. I wonder if you could kindly help me with this problem. Thanks a lot!

Sincerely, Harold

My code: import os import subprocess import tempfile

import matplotlib.pyplot as plt import numpy as np import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds import tree

try: import sonnet.v2 as snt tf.enable_v2_behavior() except ImportError: import sonnet as snt

from six.moves import cPickle from six.moves import urllib from six.moves import xrange

for plt dispaly

os.system('export DISPLAY=:0')

print("TensorFlow version {}".format(tf.version)) print("Sonnet version {}".format(snt.version))

local_data_dir='/home/harold/Documents/VQ-VAE' '''

Downloading cifar10

cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1)) cifar10.pop("id", None) cifar10.pop("label") tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10) '''

Data loading

''' train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10) valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10) test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)

def cast_and_normalise_images(data_dict): """Convert images to floating point with the range [-0.5, 0.5]""" images = data_dict['image'] data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5 return data_dict

train_data_variance = np.var(train_data_dict['image'] / 255.0) print('train data variance: %s' % train_data_variance) '''

def unpickle(filename): with open(filename, 'rb') as fo: return cPickle.load(fo, encoding='latin1')

def reshape_flattened_image_batch(flat_image_batch): return flat_image_batch.reshape(-1, 3, 32, 32).transpose([0, 2, 3, 1]) # convert from NCHW to NHWC

def combine_batches(batch_list): images = np.vstack([reshape_flattened_image_batch(batch['data']) for batch in batch_list]) labels = np.vstack([np.array(batch['labels']) for batch in batch_list]).reshape(-1, 1) return {'images': images, 'labels': labels}

train_data_dict = combine_batches([ unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/databatch%d' % i)) for i in range(1,5) ])

valid_data_dict = combine_batches([ unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/data_batch_5'))])

test_data_dict = combine_batches([ unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/test_batch'))])

def cast_and_normalise_images(data_dict): """Convert images to floating point with the range [-0.5, 0.5]""" images = data_dict['images'] data_dict['images'] = (tf.cast(images, tf.float32) / 255.0) - 0.5 return data_dict

train_data_variance = np.var(train_data_dict['images'] / 255.0) print('train data variance: %s' % train_data_variance)

Encoder & Decoder Architecture

class ResidualStack(snt.Module): def init(self, num_hiddens, num_residual_layers, num_residual_hiddens, name=None): super(ResidualStack, self).init(name=name) self._num_hiddens = num_hiddens self._num_residual_layers = num_residual_layers self._num_residual_hiddens = num_residual_hiddens

self._layers = []
for i in range(num_residual_layers):
  conv3 = snt.Conv2D(
      output_channels=num_residual_hiddens,
      kernel_shape=(3, 3),
      stride=(1, 1),
      name="res3x3_%d" % i)
  conv1 = snt.Conv2D(
      output_channels=num_hiddens,
      kernel_shape=(1, 1),
      stride=(1, 1),
      name="res1x1_%d" % i)
  self._layers.append((conv3, conv1))

def call(self, inputs): h = inputs for conv3, conv1 in self._layers: conv3_out = conv3(tf.nn.relu(h)) conv1_out = conv1(tf.nn.relu(conv3_out)) h += conv1_out return tf.nn.relu(h) # Resnet V1 style

class Encoder(snt.Module): def init(self, num_hiddens, num_residual_layers, num_residual_hiddens, name=None): super(Encoder, self).init(name=name) self._num_hiddens = num_hiddens self._num_residual_layers = num_residual_layers self._num_residual_hiddens = num_residual_hiddens

self._enc_1 = snt.Conv2D(
    output_channels=self._num_hiddens // 2,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="enc_1")
self._enc_2 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="enc_2")
self._enc_3 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(3, 3),
    stride=(1, 1),
    name="enc_3")
self._residual_stack = ResidualStack(
    self._num_hiddens,
    self._num_residual_layers,
    self._num_residual_hiddens)

def call(self, x): h = tf.nn.relu(self._enc_1(x)) h = tf.nn.relu(self._enc_2(h)) h = tf.nn.relu(self._enc_3(h)) return self._residual_stack(h)

class Decoder(snt.Module): def init(self, num_hiddens, num_residual_layers, num_residual_hiddens, name=None): super(Decoder, self).init(name=name) self._num_hiddens = num_hiddens self._num_residual_layers = num_residual_layers self._num_residual_hiddens = num_residual_hiddens

self._dec_1 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(3, 3),
    stride=(1, 1),
    name="dec_1")
self._residual_stack = ResidualStack(
    self._num_hiddens,
    self._num_residual_layers,
    self._num_residual_hiddens)
self._dec_2 = snt.Conv2DTranspose(
    output_channels=self._num_hiddens // 2,
    output_shape=None,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="dec_2")
self._dec_3 = snt.Conv2DTranspose(
    output_channels=3,
    output_shape=None,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="dec_3")

def call(self, x): h = self._dec_1(x) h = self._residual_stack(h) h = tf.nn.relu(self._dec_2(h)) x_recon = self._dec_3(h) return x_recon

class VQVAEModel(snt.Module): def init(self, encoder, decoder, vqvae, pre_vq_conv1, data_variance, name=None): super(VQVAEModel, self).init(name=name) self._encoder = encoder self._decoder = decoder self._vqvae = vqvae self._pre_vq_conv1 = pre_vq_conv1 self._data_variance = data_variance

def call(self, inputs, is_training): z = self._pre_vq_conv1(self._encoder(inputs)) vq_output = self._vqvae(z, is_training=is_training) x_recon = self._decoder(vq_output['quantize']) recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance loss = recon_error + vq_output['loss'] return { 'z': z, 'x_recon': x_recon, 'loss': loss, 'recon_error': recon_error, 'vq_output': vq_output, }

Build Model and train

%%time

Set hyper-parameters.

batch_size = 32 image_size = 32

100k steps should take < 30 minutes on a modern (>= 2017) GPU.

10k steps gives reasonable accuracy with VQVAE on Cifar10.

num_training_updates = 10000

num_hiddens = 128 num_residual_hiddens = 32 num_residual_layers = 2

These hyper-parameters define the size of the model (number of parameters and layers).

The hyper-parameters in the paper were (For ImageNet):

batch_size = 128

image_size = 128

num_hiddens = 128

num_residual_hiddens = 32

num_residual_layers = 2

This value is not that important, usually 64 works.

This will not change the capacity in the information-bottleneck.

embedding_dim = 64

The higher this value, the higher the capacity in the information bottleneck.

num_embeddings = 512

commitment_cost should be set appropriately. It's often useful to try a couple

of values. It mostly depends on the scale of the reconstruction cost

(log p(x|z)). So if the reconstruction cost is 100x higher, the

commitment_cost should also be multiplied with the same amount.

commitment_cost = 0.25

Use EMA updates for the codebook (instead of the Adam optimizer).

This typically converges faster, and makes the model less dependent on choice

of the optimizer. In the VQ-VAE paper EMA updates were not used (but was

developed afterwards). See Appendix of the paper for more details.

vq_use_ema = False

This is only used for EMA updates.

decay = 0.99

learning_rate = 3e-4

Data Loading.

train_dataset = ( tf.data.Dataset.from_tensor_slices(train_data_dict) .map(cast_and_normalise_images) .shuffle(10000) .repeat(-1) # repeat indefinitely .batch(batch_size, drop_remainder=True) .prefetch(-1))

valid_dataset = ( tf.data.Dataset.from_tensor_slices(valid_data_dict) .map(cast_and_normalise_images) .repeat(1) # 1 epoch .batch(batch_size) .prefetch(-1))

'''

train_batch = next(iter(train_dataset))

def convert_batch_to_image_grid(image_batch): reshaped = (image_batch.reshape(4, 8, 32, 32, 3) .transpose(0, 2, 1, 3, 4) .reshape(4 32, 8 32, 3)) return reshaped + 0.5

f = plt.figure(figsize=(16,8)) ax = f.add_subplot(2,2,1) ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()), interpolation='nearest') ax.set_title('training data originals') plt.axis('off') plt.show()

'''

Build modules.

encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens) decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens) pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim, kernel_shape=(1, 1), stride=(1, 1), name="to_vq")

if vq_use_ema: vq_vae = snt.nets.VectorQuantizerEMA( embedding_dim=embedding_dim, num_embeddings=num_embeddings, commitment_cost=commitment_cost, decay=decay) else: vq_vae = snt.nets.VectorQuantizer( embedding_dim=embedding_dim, num_embeddings=num_embeddings, commitment_cost=commitment_cost)

model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1, data_variance=train_data_variance)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

@tf.function def train_step(data): with tf.GradientTape() as tape: model_output = model(data['images'], is_training=True) trainable_variables = model.trainable_variables grads = tape.gradient(model_output['loss'], trainable_variables) optimizer.apply(grads, trainable_variables)

return model_output

train_losses = [] train_recon_errors = [] train_perplexities = [] train_vqvae_loss = []

for step_index, data in enumerate(train_dataset): train_results = train_step(data) train_losses.append(train_results['loss']) train_recon_errors.append(train_results['recon_error']) train_perplexities.append(train_results['vq_output']['perplexity']) train_vqvae_loss.append(train_results['vq_output']['loss'])

if (step_index + 1) % 100 == 0: print('%d train loss: %f ' % (step_index + 1, np.mean(train_losses[-100:])) + ('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) + ('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) + ('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:]))) if step_index == num_training_updates: break

Plot loss

f = plt.figure(figsize=(16,8)) ax = f.add_subplot(1,2,1) ax.plot(train_recon_errors) ax.set_yscale('log') ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2) ax.plot(train_perplexities) ax.set_title('Average codebook usage (perplexity).') plt.show()

Visualization

Reconstructions

train_batch = next(iter(train_dataset)) valid_batch = next(iter(valid_dataset))

Put data through the model with is_training=False, so that in the case of

using EMA the codebook is not updated.

train_reconstructions = model(train_batch['images'], is_training=False)['x_recon'].numpy() valid_reconstructions = model(valid_batch['images'], is_training=False)['x_recon'].numpy()

def convert_batch_to_image_grid(image_batch): reshaped = (image_batch.reshape(4, 8, 32, 32, 3) .transpose(0, 2, 1, 3, 4) .reshape(4 32, 8 32, 3)) return reshaped + 0.5

f = plt.figure(figsize=(16,8)) ax = f.add_subplot(2,2,1) ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()), interpolation='nearest') ax.set_title('training data originals') plt.axis('off')

ax = f.add_subplot(2,2,2) ax.imshow(convert_batch_to_image_grid(train_reconstructions), interpolation='nearest') ax.set_title('training data reconstructions') plt.axis('off')

ax = f.add_subplot(2,2,3) ax.imshow(convert_batch_to_image_grid(valid_batch['images'].numpy()), interpolation='nearest') ax.set_title('validation data originals') plt.axis('off')

ax = f.add_subplot(2,2,4) ax.imshow(convert_batch_to_image_grid(valid_reconstructions), interpolation='nearest') ax.set_title('validation data reconstructions') plt.axis('off') plt.show()

tomhennigan commented 3 years ago

Hi @EBGU , there's quite a lot of code there! I recognize at least some of this from our vqvae example notebook? Rather than printing the whole file it might be more useful if you could highlight what you have changed?

I've just ran our vqvae notebook using a free GPU instance on Google Colab, with TF 2.4.1, you can see the results in the gist below:

https://colab.research.google.com/gist/tomhennigan/62edee62a4638e0d0ab9738a757043ed/tf2_vq_vae_training_example.ipynb

As far as I can tell things are working correctly?

EBGU commented 3 years ago

Hi @tomhennigan! I also tried your original code without any changes. The result was still NaN. I thought it could be an environmental problem, but there was no error coming up.

EBGU commented 3 years ago

I upgrade my tf to 2.4.1 and it worked! I guess tf 2.2.0 is somehow incompatible with the code. Thanks a lot!

abhilash1910 commented 3 years ago

Hi , I tried to run the Notebook with TF 2.2.0 and it works. Please find the notebook: https://colab.research.google.com/drive/18GT4HVkjDwHB4e2AEU2G8XYU__A2t-F6?usp=sharing Hope this helps