yiyixuxu / denoising-diffusion-flax

Implementing the Denoising Diffusion Probabilistic Model in Flax
Apache License 2.0
142 stars 23 forks source link

dtype assertion bug in train (possible JAX-version issue) #2

Open xmax1 opened 1 year ago

xmax1 commented 1 year ago

Hi, thanks for the great work!

There is an assertion error when checking the dataset, which is confusing because as far as I understand it should fail for anyone.

Possibly a version issue (maybe some version of jax recognises tf types as jnp?).

AssertionError                            Traceback (most recent call last)
/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb Cell 5 in <cell line: 2>()
      [1](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) work_dir = './fashion_mnist'
----> [2](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) state = train.train(my_config, work_dir)

File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:436, in train(config, workdir, wandb_artifact)
    434 rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    435 train_step_rng = jnp.asarray(train_step_rng)
--> 436 state, metrics = p_train_step(train_step_rng, state, batch)
    437 for h in hooks:
    438     h(step)

    [... skipping hidden 17 frame]

File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:252, in p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition, is_pred_x0, pmap_axis)
    248 def p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition=False, is_pred_x0=False, pmap_axis='batch'):
    249     
    250     # run the forward diffusion process to generate noisy image x_t at timestep t
    251     x = batch['image']
--> 252     assert x.dtype in [jnp.float32, jnp.float64]
    254     # create batched timesteps: t with shape (B,)
    255     B, H, W, C = x.shape

AssertionError:

get_dataset shown below with fixing lines commented out

def get_dataset(rng, config):

    if config.data.batch_size % jax.device_count() > 0:
        raise ValueError('Batch size must be divisible by the number of devices')

    batch_size = config.data.batch_size //jax.process_count()

    platform = jax.local_devices()[0].platform
    if config.training.half_precision:
        if platform == 'tpu':
            # input_dtype = tf.bfloat16
            input_dtype = jnp.bfloat16
        else:
            # input_dtype = tf.float16
            input_dtype = jnp.float16
    else: 
        input_dtype = tf.float32
        # input_dtype = jnp.float32

For anyone reading I'm using 0.3.21 CUDA (not TPU).

yiyixuxu commented 1 year ago

Hi @xmax1

Jax doesn't recognise tf types as jnp - this line in get_dataset prefetch your batch to the devices and shard it for you - so you shouldn't need to change the input_dtype if the prefetch_to_device function runs correctly (I just run the notebook with GPU on colab and it seems fine)

it = jax_utils.prefetch_to_device(it, 2)