lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
10.97k stars 1.07k forks source link

experiment tracker agnostic #73

Closed lucidrains closed 2 years ago

lucidrains commented 2 years ago

ran into a bunch of issues with wandb and distributed training for https://github.com/lucidrains/dalle-pytorch, so we should refactor any training script to be experiment tracker agnostic this time around

lucidrains commented 2 years ago

a good template would be how the logs are accumulated here https://github.com/lucidrains/tf-bind-transformer/blob/main/tf_bind_transformer/training_utils.py additional helper functions can be brought in for "maybe transforms" on certain keys in the log

rom1504 commented 2 years ago

What's the overall idea about being experimenter tracker agnostic? Do you want to support other trackers or do you mostly want to be able to disable it?

Regarding distributed training, i figure there's 2 things to support:

  1. Logging only on node=0
  2. not even logging on node 0, but let nodes report through some custom way (eg the disk), so some other node (eg a login node) can retrieve that information and log to the tracker (this is for example needed on juwels where compute nodes don't have access to the internet)

How would you want to implement this ? What's the main goal ?

lucidrains commented 2 years ago

@rom1504 both support other trackers and be able to disable. i've done this successfully for some other projects by now - here is an example of what i have for https://github.com/lucidrains/video-diffusion-pytorch

import wandb
wandb.init(project = 'video-diffusion')
wandb.run.name = 'resnet'
wandb.run.save()

trainer = Trainer(
    diffusion,
    '/home/phil/dl/nuwa-pytorch/gif-moving-mnist/',
    results_folder = './results-new-focus-present',
    train_batch_size = 4,
    train_lr = 2e-5,
    save_and_sample_every = 1000,
    max_grad_norm = 0.5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 8,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)

trainer.load(-1)

def log_fn(log):
    if 'sample' in log:
        log['sample'] = wandb.Video(log['sample'])
    wandb.log(log)

trainer.train(log_fn = log_fn, prob_focus_present = 0.)
lucidrains commented 2 years ago

the log_fn can be made more composable for sure, as you may want to exclude certain keys from being logged, wrap other ones, derive other keys from available ones in the set etc

lucidrains commented 2 years ago

started https://github.com/lucidrains/DALLE2-pytorch/commit/89de5af63ec76b9aad402d3925529b06612fa1f3