berenslab / t-simcne

Unsupervised visualization of image datasets using contrastive learning
https://t-simcne.readthedocs.io/en/latest/
115 stars 13 forks source link

Unable to use multiple devices #9

Open erikhagendorn opened 9 months ago

erikhagendorn commented 9 months ago

It appears the number of devices is hardcoded:

https://github.com/berenslab/t-simcne/blob/e2988f5cb2c86c321e5bad3bc1c968200097def0/tsimcne/tsimcne.py#L481

https://github.com/berenslab/t-simcne/blob/e2988f5cb2c86c321e5bad3bc1c968200097def0/tsimcne/tsimcne.py#L517

Given the contrastive learning task, it would be preferable to utilize more resources for training and allow the devices argument to be passed to your model classes.

I'm unsure to what extent other aspects of the model would require changes, but I believe the learning rate calculation in lr_from_batchsize would need to be updated as well (batch_size * devices).

erikhagendorn commented 9 months ago

After further inspection, I think this issue can be expanded to simply request that you allow the flexibility of passing any arguments to the the Trainer For example, I'd like to include model checkpointing and logging.

jnboehm commented 8 months ago

Thanks for bringing this up!

Yeah, one issue I see is that then the parallelization strategy also becomes a factor to consider. So I would agree that the best course of action would be to have the option to pass in kwargs for the trainer. Not super sure if this could cause any problems between inference and training kwargs though.

I actually wasn't aware that the number of GPUs also plays a role in the lr calculation. Out of interest, do you have a reference for that? This also requires the number of devices to be set explicitly instead of using the default value of "auto" for the Trainer. Maybe it's best to have a devices and a trainer_kwargs option. Would that work for you?

(By the way, in the transform() function it needs to be only one device, otherwise it gives an error because it cannot return the caluclated output. So I'll leave it there as is.)

erikhagendorn commented 8 months ago

Hi, thanks for the quick response!

Here is the reference for the learning rate calculation based on device count.

(By the way, in the transform() function it needs to be only one device, otherwise it gives an error because it cannot return the caluclated output. So I'll leave it there as is.)

Makes sense, thank you.

I think trainer_kwargs would be a good addition, I'd very much appreciate being able to pass some callbacks and a logger as I mentioned previously.

Regarding the devices argument, I think that makes sense as well. Although, after submitting this I was curious if there are any other nuances which would need to be addressed here due to multi-gpu training. It might be best that I test it out and run the CIFAR10 example. I can probably tell by just looking at the plot if it works, but if there are any empirical metrics you'd like let me know.

I'll try to test this out soon, but don't feel like you have to leave this open. If I find success with the multi-gpu training, I can submit a PR with the changes if that works for you.

jnboehm commented 8 months ago

Cool, yeah if you want to draft a pull request, I'd be happy to merge it. To be fair, I am not sure if there are any subtle issues with multi-GPU training. The most important question is if the batch size will vary because that has quite an important influence. But yeah, it should become apparent if the visualization of CIFAR-10 looks off after the default training. Other than that I am not aware of any other issues, but I also do not have much experience with multiple GPUs.

jnboehm commented 8 months ago

I implemented the functionality for passing in custom kwargs for the trainer. So far I haven't changed the lr calculation as I am not really using more than 1 GPU at the moment. Currently it issues a warning when multiple devices are used but the lr is left as the default. I would be happy to merge in some code that will account for how the lr changes when using multiple devices though.

erikhagendorn commented 8 months ago

Great, thanks!

I'm finding that the implementation is more complex than just configuring the devices and adjusting the learning rate, as I initially proposed. Here are a couple of my findings so far...

Firstly, it's crucial to gather the tensors across all devices before calculating the loss. This is important to maximize the sampling of the negative space. Ensuring gradient synchronization across the devices is also essential in this context.

Secondly, since batch normalization is utilized in the model, these need to be synchronized as well. It's important that activation scaling occurs across all batches on all devices, not just within each device's own mini-batch.

Thirdly, when scaling the batch size significantly (sounds like sizes greater than 1024), a shift from a standard SGD optimizer to LARS becomes necessary.

I have implemented the first two points, but my testing was slowed down due to a break over the holiday.

Currently, I'm running a training session with a batch size of 128 across 4 GPUs (effectively 512) to replicate the results from the paper. The initial signs are promising – I did observe some clustering in reduced epochs just the other day, so I'm cautiously optimistic.

Assuming this training run is successful, the next step will be to implement and test with increased batch sizes and LARS. So, there's still a fair bit of work ahead.