justinlovelace / latent-diffusion-for-language

MIT License
90 stars 10 forks source link

How to train on multi-gpu #3

Open prettyprettyboy opened 1 year ago

prettyprettyboy commented 1 year ago

Hi! I have noticed that you utilize the "accelerate" repo to train the model only on a single GPU. I change it to multi-gpus and use "model.module" to replace the corresponding code. However, there is an error when computing perplexity:

Traceback (most recent call last):
  File "train_text_diffusion.py", line 203, in <module>
    main(args)
  File "train_text_diffusion.py", line 91, in main
    trainer.train()
  File "/data/latent-diffusion-for-language/diffusion/denoising_diffusion.py", line 760, in train
    self.sample(num_samples=num_samples, class_id=class_id)
File "/data/latent-diffusion-for-language/evaluation/evaluation.py", line 14, in compute_perplexity
    results = perplexity.compute(predictions=all_texts_list, model_id=model_id, device='cuda')
  File "/data//anaconda3/envs/latent-diffusion/lib/python3.8/site-packages/evaluate/module.py", line 444, in compute
    output = self._compute(**inputs, **compute_kwargs)
  File "/home/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--perplexity/8ab643ad86f568b7d1d5f7822373fa7401ff5ff0297ccf
114b0ca6a33be96bc0/perplexity.py", line 117, in _compute
    tokenizer = AutoTokenizer.from_pretrained(model_id)
  File "/data/anaconda3/envs/latent-diffusion/lib/python3.8/site-packages/transformers/models/auto/tokenization_auto.py", line 549, in from_pretrained
    tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) 
 File "/data/anaconda3/envs/latent-diffusion/lib/python3.8/site-packages/transformers/models/auto/tokenization_auto.py", line 549, in from_pretrained
    tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
  File "/data/anaconda3/envs/latent-diffusion/lib/python3.8/site-packages/transformers/models/auto/tokenization_auto.py", line 418, in get_tokenizer_config
    commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
  File "/data/anaconda3/envs/latent-diffusion/lib/python3.8/site-packages/transformers/utils/hub.py", line 225, in extract_commit_hash
    search = re.search(r"snapshots/([^/]+)/", resolved_file)
  File "/data/anaconda3/envs/latent-diffusion/lib/python3.8/re.py", line 201, in search   
    return _compile(pattern, flags).search(string)
TypeError: expected string or bytes-like object

How can I fix it or could you release a multi-GPU version? Thanks!

justinlovelace commented 1 year ago

Hello, thank you for your interest! As you noticed, we trained all of our models with a single GPU and our codebase doesn't currently support distributed training. I can look into adding support for multi-GPU training, but I don't have the bandwidth to work on it immediately.

It's hard for me to comment on the error if you've made local changes within the codebase. If there is a problem with computing perplexity, then it may be useful to manually inspect the text generations before the perplexity.compute() function call. It seems like the input may be malformed in some way?

prettyprettyboy commented 1 year ago

Thanks for replying. I thought LD4LG is a post-processing or "plug-and-play" method like Diffusion-LM etc. at first. Post-processing methods usually use a classifier to control the model, however, you use a class embedding to control instead. So LD4LG is more like a prefix-tuning method to learn a smaller vector? Another question is that the number of samples in DDIM is usually less than T, but there are still T times of sampling in the pseudocode.

justinlovelace commented 1 year ago

For our class-conditional language generation models, we train them similarly to class-conditional vision models (e.g. [1] [2]) and explicitly condition the network on a learnable embedding that specifies the class. Our approach should also be compatible with classifier guidance, although we did not explore that in this work.

As you mentioned, DDIM generally improves the generation quality when downsampling the timesteps. For simplictly, we omitted the optional downsampling in the pseudocode.