Closed lucidrains closed 2 years ago
a good template would be how the logs are accumulated here https://github.com/lucidrains/tf-bind-transformer/blob/main/tf_bind_transformer/training_utils.py additional helper functions can be brought in for "maybe transforms" on certain keys in the log
What's the overall idea about being experimenter tracker agnostic? Do you want to support other trackers or do you mostly want to be able to disable it?
Regarding distributed training, i figure there's 2 things to support:
How would you want to implement this ? What's the main goal ?
@rom1504 both support other trackers and be able to disable. i've done this successfully for some other projects by now - here is an example of what i have for https://github.com/lucidrains/video-diffusion-pytorch
import wandb
wandb.init(project = 'video-diffusion')
wandb.run.name = 'resnet'
wandb.run.save()
trainer = Trainer(
diffusion,
'/home/phil/dl/nuwa-pytorch/gif-moving-mnist/',
results_folder = './results-new-focus-present',
train_batch_size = 4,
train_lr = 2e-5,
save_and_sample_every = 1000,
max_grad_norm = 0.5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 8, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True # turn on mixed precision
)
trainer.load(-1)
def log_fn(log):
if 'sample' in log:
log['sample'] = wandb.Video(log['sample'])
wandb.log(log)
trainer.train(log_fn = log_fn, prob_focus_present = 0.)
the log_fn
can be made more composable for sure, as you may want to exclude certain keys from being logged, wrap other ones, derive other keys from available ones in the set etc
ran into a bunch of issues with wandb and distributed training for https://github.com/lucidrains/dalle-pytorch, so we should refactor any training script to be experiment tracker agnostic this time around