microsoft / VQ-Diffusion

Official implementation of VQ-Diffusion
MIT License
894 stars 62 forks source link

device-side assert triggered #29

Open zideliu opened 1 year ago

zideliu commented 1 year ago

I trained taming-transformers on my own data set and got the ckpt file and the corresponding yaml file. When I apply it to vq-diffusion, an error will be reported. I followed configs/imagenet.yaml. , only the ckpt file path and the corresponding yaml file path are replaced.

I feel that some parameters need to be adjusted accordingly, but due to personal ability problems, I have not debugged it. My personal suspicion is that help_folder/statistics/taming_vqvae_974.pt may be different from the parameters I used to train taming-transformers. If you can provide training ifhq dataset details, I would be greatly appreciated.

configs/mydataset.yaml

# change from o4
model:
  target: image_synthesis.modeling.models.conditional_dalle.C_DALLE
  params:
    content_info: {key: image}
    condition_info: {key: label}
    content_codec_config: 
      target: image_synthesis.modeling.codecs.image_codec.taming_gumbel_vqvae.TamingVQVAE
      params:
        trainable: False
        token_shape: [16, 16]
        config_path: 'OUTPUT/pretrained_model/mydataset/mydataset.yaml'
        ckpt_path: 'OUTPUT/pretrained_model/mydataset/last.ckpt'
        num_tokens: 1024
        quantize_number: 974
        mapping_path: './help_folder/statistics/taming_vqvae_974.pt'
        # return_logits: True
    diffusion_config:      
      target: image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer
      params:
        diffusion_step: 100
        alpha_init_type: 'alpha1'        
        auxiliary_loss_weight: 1.0e-3
        adaptive_auxiliary_loss: True
        mask_weight: [1, 1]    # the loss weight on mask region and non-mask region

        transformer_config:
          target: image_synthesis.modeling.transformers.transformer_utils.Condition2ImageTransformer
          params:
            attn_type: 'selfcondition'
            n_layer: 24
            class_type: 'adalayernorm'
            class_number: 15
            content_seq_len: 256  # 16 x 16
            content_spatial_size: [16, 16]
            n_embd: 512 # the dim of embedding dims   # both this and content_emb_config
            n_head: 16 
            attn_pdrop: 0.0
            resid_pdrop: 0.0
            block_activate: GELU2
            timestep_type: 'adalayernorm'    # adainsnorm or adalayernorm and abs
            mlp_hidden_times: 4
            mlp_type: 'conv_mlp'
        condition_emb_config:
          target: image_synthesis.modeling.embeddings.class_embedding.ClassEmbedding
          params:
            num_embed: 15 # 
            embed_dim: 512
            identity: True
        content_emb_config:
          target: image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding
          params:
            num_embed: 974
            spatial_size: !!python/tuple [32, 32]
            embed_dim: 512
            trainable: True
            pos_emb_type: embedding

solver:
  base_lr: 3.0e-6
  adjust_lr: none # not adjust lr according to total batch_size
  max_epochs: 100
  save_epochs: 2
  validation_epochs: 100
  sample_iterations: epoch  # epoch #30000      # how many iterations to perform sampling once ?
  print_specific_things: True

  # config for ema
  ema:
    decay: 0.99
    update_interval: 25
    device: cpu

  clip_grad_norm:
    target: image_synthesis.engine.clip_grad_norm.ClipGradNorm
    params:
      start_iteration: 0
      end_iteration: 5000
      max_norm: 0.5
  optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers
  - name: none # default is None
    optimizer:
      target: torch.optim.AdamW
      params: 
        betas: !!python/tuple [0.9, 0.96]
        weight_decay: 4.5e-2
    scheduler:
      step_iteration: 1
      target: image_synthesis.engine.lr_scheduler.ReduceLROnPlateauWithWarmup
      params:
        factor: 0.5
        patience: 100000
        min_lr: 1.0e-6
        threshold: 1.0e-1
        threshold_mode: rel
        warmup_lr: 4.5e-4 # the lr to be touched after warmup
        warmup: 5000 

dataloader:
........

OUTPUT/pretrained_model/mydataset/mydataset.yaml

model:
  base_learning_rate: 4.5e-06
  target: image_synthesis.taming.models.vqgan.VQModel
  params:
    embed_dim: 256
    n_embed: 1024
    monitor: val/rec_loss
    ddconfig:
      double_z: false
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 1
      - 2
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions:
      - 16
      dropout: 0.0
    lossconfig:
      target: image_synthesis.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        disc_conditional: false
        disc_in_channels: 3
        disc_start: 0
        disc_weight: 0.8
        codebook_weight: 1.0
        # ssim_loss: true
zideliu commented 1 year ago

num_class of my dataset is 15

aj-113 commented 1 year ago

Did you find any solution to this? My guess since the mapping file (of shape [embed_dim] which is 1D ) consists of indices [0,2887] arranged in an order to map the latent variable to codebook entry. And author here is not using GumbleQuantization, but is using mapping to quantize vectors. So probably commenting out the quantization part that uses the mapping file and use the learnt Gumble quantization model using quant.emb.weights of the trained VQGAN model

mounchiliu commented 4 months ago

Any ideas to solve this problem?