MICS-Lab / scyan

Biology-driven deep generative model for cell-type annotation in cytometry. Scyan is an interpretable model that also corrects batch-effect and can be used for debarcoding or population discovery.
https://mics-lab.github.io/scyan/
BSD 3-Clause "New" or "Revised" License
33 stars 1 forks source link

GPU training is slow #31

Closed grst closed 6 months ago

grst commented 6 months ago

Description

I'm honestly quite impressed by the speed achieved by Scyan even when working with millions of cells -- It's the first time I can work with flow data without excessive subsampling. Yet, I was interested if it can be sped up even further by using a GPU for training.

I tried setting model.fit(trainer=pl.Trainer(accelerator='gpu', devices=1)) and to run the model.fit() on a Nvidia A6000 48GB. The result was dissapointing in that the training ended up being slower than on CPU. While the actual training of a single minibatch seems faster, a lot of time is spent between the batches. I suspect that a lot of time is spent copying data between RAM and the GPU, so I was wondering if there's anything that can be optimized in that regard?

CPU Training ![2024-03-18_13-21-33](https://github.com/MICS-Lab/scyan/assets/7051479/f0073385-6db0-458f-8e9e-a9b213d98a4f)
GPU Training ![2024-03-18_13-24-47](https://github.com/MICS-Lab/scyan/assets/7051479/b3bed5f0-9125-43ff-a17a-f3a7eaa1efce)

System

quentinblampey commented 6 months ago

Hello @grst, I'm surprised to see how slow it is using GPUs. I don't use often GPUs to train Scyan, but it should be much faster.

Can you maybe try profiling the training? The following lines should print two tables with some basic profiling:

model.fit(accelerator="cpu", profiler="simple")
model.fit(accelerator="gpu", profiler="simple")

(Btw, the kwargs of the model.fit method are given to the Trainer, so you can directly pass accelerator="gpu" instead of re-creating a Trainer)

grst commented 6 months ago

ok, new day, new node on the HPC and the problem seems gone. While GPU is still not (significantly) faster, it's also not slower than CPU. The training is super fast anyway (1:30min for 15M cells), feel free to close.

I am having more GPU issues with batch effect correction, but I'll open a separate issue for that.

Here are anyway the profiling results:

CPU

[INFO] (scyan.model) Training scyan with the following hyperparameters:
"batch_key":       sample_id
"batch_size":      8192
"hidden_size":     16
"lr":              0.0001
"max_samples":     200000
"modulo_temp":     3
"n_hidden_layers": 6
"n_layers":        7
"prior_std":       0.25
"temperature":     0.5
"warm_up":         (0.35, 4)

/cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.9 /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/pyth ...
  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:176: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.
  rank_zero_warn(

  | Name   | Type        | Params
---------------------------------------
0 | module | ScyanModule | 139 K 
---------------------------------------
139 K     Trainable params
0         Non-trainable params
139 K     Total params
0.558     Total estimated model params size (MB)

[INFO] (scyan.model) Ended warm up epochs
FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                   |  Mean duration (s)    |  Num calls        |  Total time (s)   |  Percentage %     |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                    |  -                |  24578            |  98.738           |  100 %            |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                       |  4.2859           |  23               |  98.576           |  99.836           |
|  [TrainingEpochLoop].train_dataloader_next                                                |  0.1071           |  552              |  59.119           |  59.875           |
|  run_training_batch                                                                       |  0.068117         |  552              |  37.601           |  38.081           |
|  [LightningModule]Scyan.optimizer_step                                                    |  0.067583         |  552              |  37.306           |  37.783           |
|  [Strategy]SingleDeviceStrategy.backward                                                  |  0.030875         |  552              |  17.043           |  17.261           |
|  [Strategy]SingleDeviceStrategy.training_step                                             |  0.026719         |  552              |  14.749           |  14.937           |
|  [Callback]TQDMProgressBar.on_train_batch_end                                             |  0.0011405        |  552              |  0.62954          |  0.63759          |
|  [LightningModule]Scyan.optimizer_zero_grad                                               |  0.00052472       |  552              |  0.28965          |  0.29335          |
|  [Strategy]SingleDeviceStrategy.batch_to_device                                           |  6.914e-05        |  552              |  0.038165         |  0.038653         |
|  [LightningModule]Scyan.transfer_batch_to_device                                          |  4.4503e-05       |  552              |  0.024565         |  0.024879         |
|  [LightningModule]Scyan.configure_gradient_clipping                                       |  2.6763e-05       |  552              |  0.014773         |  0.014962         |
|  [Callback]TQDMProgressBar.on_train_start                                                 |  0.014255         |  1                |  0.014255         |  0.014437         |
|  [Callback]TQDMProgressBar.on_train_epoch_start                                           |  0.00055201       |  23               |  0.012696         |  0.012858         |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_epoch_end       |  0.00035052       |  23               |  0.008062         |  0.008165         |
|  [Callback]ModelSummary.on_fit_start                                                      |  0.0046358        |  1                |  0.0046358        |  0.0046951        |
|  [Callback]TQDMProgressBar.on_train_epoch_end                                             |  0.00018335       |  23               |  0.0042171        |  0.004271         |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_after_backward        |  5.9648e-06       |  552              |  0.0032926        |  0.0033346        |
|  [Callback]ModelSummary.on_train_batch_end                                                |  3.8899e-06       |  552              |  0.0021472        |  0.0021747        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_batch_end       |  2.7839e-06       |  552              |  0.0015367        |  0.0015563        |
|  [LightningModule]Scyan.on_before_batch_transfer                                          |  2.7143e-06       |  552              |  0.0014983        |  0.0015175        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_batch_start     |  2.2071e-06       |  552              |  0.0012183        |  0.0012339        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_zero_grad      |  2.1704e-06       |  552              |  0.001198         |  0.0012134        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_optimizer_step |  2.1691e-06       |  552              |  0.0011973        |  0.0012126        |
|  [LightningModule]Scyan.training_step_end                                                 |  1.8963e-06       |  552              |  0.0010468        |  0.0010601        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_backward       |  1.8639e-06       |  552              |  0.0010289        |  0.001042         |
|  [LightningModule]Scyan.on_train_epoch_end                                                |  4.4397e-05       |  23               |  0.0010211        |  0.0010342        |
|  [LightningModule]Scyan.on_after_backward                                                 |  1.5122e-06       |  552              |  0.00083472       |  0.00084539       |
|  [Callback]TQDMProgressBar.on_after_backward                                              |  1.4897e-06       |  552              |  0.00082233       |  0.00083284       |
|  [Strategy]SingleDeviceStrategy.training_step_end                                         |  1.301e-06        |  552              |  0.00071817       |  0.00072735       |
|  [LightningModule]Scyan.on_train_batch_end                                                |  1.2629e-06       |  552              |  0.00069712       |  0.00070603       |
|  [LightningModule]Scyan.on_before_zero_grad                                               |  1.2586e-06       |  552              |  0.00069475       |  0.00070363       |
|  [Callback]TQDMProgressBar.on_before_zero_grad                                            |  1.1967e-06       |  552              |  0.00066055       |  0.000669         |
|  [Callback]TQDMProgressBar.on_before_optimizer_step                                       |  1.1492e-06       |  552              |  0.00063436       |  0.00064246       |
|  [Callback]GradientAccumulationScheduler.on_train_batch_end                               |  1.1389e-06       |  552              |  0.00062869       |  0.00063673       |
|  [LightningModule]Scyan.on_before_backward                                                |  1.124e-06        |  552              |  0.00062043       |  0.00062836       |
|  [LightningModule]Scyan.on_after_batch_transfer                                           |  1.0841e-06       |  552              |  0.0005984        |  0.00060605       |
|  [Callback]TQDMProgressBar.on_train_batch_start                                           |  1.0625e-06       |  552              |  0.00058649       |  0.00059398       |
|  [LightningModule]Scyan.configure_optimizers                                              |  0.00058481       |  1                |  0.00058481       |  0.00059228       |
|  [Callback]GradientAccumulationScheduler.on_after_backward                                |  1.0052e-06       |  552              |  0.0005549        |  0.00056199       |
|  [Callback]TQDMProgressBar.on_before_backward                                             |  1.0043e-06       |  552              |  0.00055435       |  0.00056144       |
|  [Callback]ModelSummary.on_before_optimizer_step                                          |  1.0019e-06       |  552              |  0.00055303       |  0.00056009       |
|  [LightningModule]Scyan.on_before_optimizer_step                                          |  9.9129e-07       |  552              |  0.00054719       |  0.00055419       |
|  [LightningModule]Scyan.on_train_batch_start                                              |  9.7585e-07       |  552              |  0.00053867       |  0.00054555       |
|  [Callback]ModelSummary.on_after_backward                                                 |  9.6932e-07       |  552              |  0.00053506       |  0.0005419        |
|  [Callback]GradientAccumulationScheduler.on_before_optimizer_step                         |  9.6308e-07       |  552              |  0.00053162       |  0.00053842       |
|  [Callback]ModelSummary.on_before_zero_grad                                               |  9.2344e-07       |  552              |  0.00050974       |  0.00051625       |
|  [Callback]ModelSummary.on_train_batch_start                                              |  8.7778e-07       |  552              |  0.00048453       |  0.00049073       |
|  [Callback]GradientAccumulationScheduler.on_before_zero_grad                              |  8.7498e-07       |  552              |  0.00048299       |  0.00048916       |
|  [Strategy]SingleDeviceStrategy.on_train_batch_start                                      |  8.6541e-07       |  552              |  0.00047771       |  0.00048381       |
|  [Callback]ModelSummary.on_before_backward                                                |  8.3722e-07       |  552              |  0.00046214       |  0.00046805       |
|  [Callback]GradientAccumulationScheduler.on_train_batch_start                             |  8.2894e-07       |  552              |  0.00045758       |  0.00046343       |
|  [Callback]GradientAccumulationScheduler.on_before_backward                               |  7.9955e-07       |  552              |  0.00044135       |  0.00044699       |
|  [LightningModule]Scyan.train_dataloader                                                  |  0.00032903       |  1                |  0.00032903       |  0.00033323       |
|  [Callback]TQDMProgressBar.on_train_end                                                   |  0.00023542       |  1                |  0.00023542       |  0.00023842       |
|  [Callback]GradientAccumulationScheduler.on_train_epoch_start                             |  3.8853e-06       |  23               |  8.9362e-05       |  9.0504e-05       |
|  [Callback]ModelSummary.on_train_epoch_end                                                |  2.2251e-06       |  23               |  5.1178e-05       |  5.1832e-05       |
|  [Callback]ModelSummary.on_train_epoch_start                                              |  2.1535e-06       |  23               |  4.9531e-05       |  5.0164e-05       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_epoch_start     |  1.9644e-06       |  23               |  4.518e-05        |  4.5758e-05       |
|  [LightningModule]Scyan.on_train_epoch_start                                              |  1.1131e-06       |  23               |  2.56e-05         |  2.5927e-05       |
|  [Callback]GradientAccumulationScheduler.on_train_epoch_end                               |  9.7927e-07       |  23               |  2.2523e-05       |  2.2811e-05       |
|  [Callback]TQDMProgressBar.setup                                                          |  4.4554e-06       |  1                |  4.4554e-06       |  4.5124e-06       |
|  [Callback]ModelSummary.on_train_start                                                    |  4.0233e-06       |  1                |  4.0233e-06       |  4.0747e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_fit_end               |  3.6135e-06       |  1                |  3.6135e-06       |  3.6597e-06       |
|  [LightningModule]Scyan.configure_callbacks                                               |  3.3602e-06       |  1                |  3.3602e-06       |  3.4032e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_start           |  3.092e-06        |  1                |  3.092e-06        |  3.1315e-06       |
|  [LightningModule]Scyan.on_fit_start                                                      |  2.9802e-06       |  1                |  2.9802e-06       |  3.0183e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.setup                    |  2.6599e-06       |  1                |  2.6599e-06       |  2.6938e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_fit_start             |  2.5779e-06       |  1                |  2.5779e-06       |  2.6108e-06       |
|  [Callback]GradientAccumulationScheduler.on_fit_start                                     |  2.4885e-06       |  1                |  2.4885e-06       |  2.5203e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_end             |  2.2724e-06       |  1                |  2.2724e-06       |  2.3015e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.teardown                 |  1.9372e-06       |  1                |  1.9372e-06       |  1.9619e-06       |
|  [Callback]ModelSummary.on_train_end                                                      |  1.8626e-06       |  1                |  1.8626e-06       |  1.8864e-06       |
|  [LightningModule]Scyan.prepare_data                                                      |  1.6689e-06       |  1                |  1.6689e-06       |  1.6903e-06       |
|  [Callback]ModelSummary.setup                                                             |  1.6019e-06       |  1                |  1.6019e-06       |  1.6223e-06       |
|  [LightningModule]Scyan.on_train_start                                                    |  1.4305e-06       |  1                |  1.4305e-06       |  1.4488e-06       |
|  [LightningModule]Scyan.configure_sharded_model                                           |  1.2442e-06       |  1                |  1.2442e-06       |  1.2601e-06       |
|  [Callback]TQDMProgressBar.on_fit_start                                                   |  1.2219e-06       |  1                |  1.2219e-06       |  1.2375e-06       |
|  [Strategy]SingleDeviceStrategy.on_train_start                                            |  1.1995e-06       |  1                |  1.1995e-06       |  1.2149e-06       |
|  [LightningModule]Scyan.teardown                                                          |  1.1995e-06       |  1                |  1.1995e-06       |  1.2149e-06       |
|  [Callback]GradientAccumulationScheduler.setup                                            |  1.125e-06        |  1                |  1.125e-06        |  1.1394e-06       |
|  [LightningModule]Scyan.setup                                                             |  1.1176e-06       |  1                |  1.1176e-06       |  1.1319e-06       |
|  [Callback]TQDMProgressBar.on_fit_end                                                     |  1.1101e-06       |  1                |  1.1101e-06       |  1.1243e-06       |
|  [Callback]GradientAccumulationScheduler.on_train_start                                   |  1.1027e-06       |  1                |  1.1027e-06       |  1.1168e-06       |
|  [Callback]TQDMProgressBar.teardown                                                       |  1.0878e-06       |  1                |  1.0878e-06       |  1.1017e-06       |
|  [LightningModule]Scyan.on_fit_end                                                        |  1.0654e-06       |  1                |  1.0654e-06       |  1.079e-06        |
|  [LightningModule]Scyan.on_train_end                                                      |  1.0058e-06       |  1                |  1.0058e-06       |  1.0187e-06       |
|  [Callback]GradientAccumulationScheduler.on_train_end                                     |  9.9838e-07       |  1                |  9.9838e-07       |  1.0111e-06       |
|  [Strategy]SingleDeviceStrategy.on_train_end                                              |  9.9093e-07       |  1                |  9.9093e-07       |  1.0036e-06       |
|  [Callback]GradientAccumulationScheduler.teardown                                         |  8.7917e-07       |  1                |  8.7917e-07       |  8.904e-07        |
|  [Callback]ModelSummary.on_fit_end                                                        |  8.6427e-07       |  1                |  8.6427e-07       |  8.7531e-07       |
|  [Callback]GradientAccumulationScheduler.on_fit_end                                       |  8.4937e-07       |  1                |  8.4937e-07       |  8.6022e-07       |
|  [Callback]ModelSummary.teardown                                                          |  8.4192e-07       |  1                |  8.4192e-07       |  8.5268e-07       |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

[INFO] (scyan.model) Successfully ended training.

Scyan model with N=14730000 cells, P=12 populations and M=14 markers.
   ├── Covariates: sample_id
   ├── No continuum-marker provided
   └── Batch correction mode: True

GPU

[INFO] (scyan.model) Training scyan with the following hyperparameters:
"batch_key":       sample_id
"batch_size":      8192
"hidden_size":     16
"lr":              0.0001
"max_samples":     200000
"modulo_temp":     3
"n_hidden_layers": 6
"n_layers":        7
"prior_std":       0.25
"temperature":     0.5
"warm_up":         (0.35, 4)

/cfs/sturmgre/conda/envs/1403-0001_scyan/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.9 /cfs/sturmgre/conda/envs/1403-0001_scyan/lib/pyth ...
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type        | Params
---------------------------------------
0 | module | ScyanModule | 139 K 
---------------------------------------
139 K     Trainable params
0         Non-trainable params
139 K     Total params
0.558     Total estimated model params size (MB)

[INFO] (scyan.model) Ended warm up epochs
FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                   |  Mean duration (s)    |  Num calls        |  Total time (s)   |  Percentage %     |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                    |  -                |  24578            |  89.533           |  100 %            |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                       |  2.9712           |  23               |  68.337           |  76.326           |
|  [TrainingEpochLoop].train_dataloader_next                                                |  0.085578         |  552              |  47.239           |  52.762           |
|  run_training_batch                                                                       |  0.034779         |  552              |  19.198           |  21.442           |
|  [LightningModule]Scyan.optimizer_step                                                    |  0.034386         |  552              |  18.981           |  21.2             |
|  [Strategy]SingleDeviceStrategy.training_step                                             |  0.012117         |  552              |  6.6885           |  7.4704           |
|  [Strategy]SingleDeviceStrategy.backward                                                  |  0.0095293        |  552              |  5.2602           |  5.8751           |
|  [LightningModule]Scyan.optimizer_zero_grad                                               |  0.00099661       |  552              |  0.55013          |  0.61445          |
|  [Callback]TQDMProgressBar.on_train_batch_end                                             |  0.00094529       |  552              |  0.5218           |  0.5828           |
|  [Strategy]SingleDeviceStrategy.batch_to_device                                           |  0.00040811       |  552              |  0.22528          |  0.25161          |
|  [LightningModule]Scyan.transfer_batch_to_device                                          |  0.00038269       |  552              |  0.21125          |  0.23594          |
|  [Callback]TQDMProgressBar.on_train_epoch_start                                           |  0.00052991       |  23               |  0.012188         |  0.013613         |
|  [LightningModule]Scyan.configure_gradient_clipping                                       |  1.5236e-05       |  552              |  0.0084101        |  0.0093932        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_epoch_end       |  0.00030582       |  23               |  0.0070338        |  0.0078561        |
|  [Callback]TQDMProgressBar.on_train_start                                                 |  0.0068856        |  1                |  0.0068856        |  0.0076905        |
|  [Callback]TQDMProgressBar.on_train_epoch_end                                             |  0.00020069       |  23               |  0.0046159        |  0.0051556        |
|  [Callback]ModelSummary.on_fit_start                                                      |  0.0045977        |  1                |  0.0045977        |  0.0051352        |
|  [Callback]ModelSummary.on_train_batch_end                                                |  3.1021e-06       |  552              |  0.0017124        |  0.0019126        |
|  [LightningModule]Scyan.on_train_epoch_end                                                |  5.5071e-05       |  23               |  0.0012666        |  0.0014147        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_batch_end       |  2.1692e-06       |  552              |  0.0011974        |  0.0013374        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_batch_start     |  2.0465e-06       |  552              |  0.0011296        |  0.0012617        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_zero_grad      |  2.0154e-06       |  552              |  0.0011125        |  0.0012426        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_after_backward        |  1.9603e-06       |  552              |  0.0010821        |  0.0012086        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_backward       |  1.7567e-06       |  552              |  0.00096972       |  0.0010831        |
|  [LightningModule]Scyan.on_before_batch_transfer                                          |  1.7427e-06       |  552              |  0.00096194       |  0.0010744        |
|  [LightningModule]Scyan.configure_optimizers                                              |  0.00093456       |  1                |  0.00093456       |  0.0010438        |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_before_optimizer_step |  1.4337e-06       |  552              |  0.00079141       |  0.00088393       |
|  [LightningModule]Scyan.training_step_end                                                 |  1.3929e-06       |  552              |  0.00076887       |  0.00085876       |
|  [Callback]TQDMProgressBar.on_train_batch_start                                           |  1.1784e-06       |  552              |  0.00065047       |  0.00072652       |
|  [Callback]GradientAccumulationScheduler.on_train_batch_end                               |  1.1435e-06       |  552              |  0.00063121       |  0.000705         |
|  [Callback]TQDMProgressBar.on_after_backward                                              |  1.0541e-06       |  552              |  0.00058185       |  0.00064987       |
|  [Callback]TQDMProgressBar.on_before_zero_grad                                            |  1.0167e-06       |  552              |  0.00056122       |  0.00062683       |
|  [Callback]TQDMProgressBar.on_before_backward                                             |  1.0155e-06       |  552              |  0.00056058       |  0.00062612       |
|  [LightningModule]Scyan.on_train_batch_end                                                |  1.0081e-06       |  552              |  0.00055645       |  0.0006215        |
|  [LightningModule]Scyan.on_after_batch_transfer                                           |  9.7866e-07       |  552              |  0.00054022       |  0.00060337       |
|  [LightningModule]Scyan.on_before_zero_grad                                               |  9.6682e-07       |  552              |  0.00053369       |  0.00059608       |
|  [Callback]TQDMProgressBar.on_before_optimizer_step                                       |  9.5189e-07       |  552              |  0.00052544       |  0.00058687       |
|  [Strategy]SingleDeviceStrategy.training_step_end                                         |  9.5097e-07       |  552              |  0.00052494       |  0.00058631       |
|  [Callback]ModelSummary.on_train_batch_start                                              |  9.2016e-07       |  552              |  0.00050793       |  0.00056731       |
|  [LightningModule]Scyan.on_train_batch_start                                              |  9.0994e-07       |  552              |  0.00050229       |  0.00056101       |
|  [Callback]ModelSummary.on_before_optimizer_step                                          |  8.7463e-07       |  552              |  0.0004828        |  0.00053924       |
|  [LightningModule]Scyan.on_before_backward                                                |  8.5711e-07       |  552              |  0.00047313       |  0.00052844       |
|  [LightningModule]Scyan.on_after_backward                                                 |  8.4502e-07       |  552              |  0.00046645       |  0.00052098       |
|  [Callback]GradientAccumulationScheduler.on_train_batch_start                             |  8.3891e-07       |  552              |  0.00046308       |  0.00051721       |
|  [Callback]ModelSummary.on_before_zero_grad                                               |  8.1844e-07       |  552              |  0.00045178       |  0.0005046        |
|  [Callback]GradientAccumulationScheduler.on_before_backward                               |  8.1089e-07       |  552              |  0.00044761       |  0.00049994       |
|  [Callback]ModelSummary.on_before_backward                                                |  8.0993e-07       |  552              |  0.00044708       |  0.00049935       |
|  [Callback]ModelSummary.on_after_backward                                                 |  8.045e-07        |  552              |  0.00044408       |  0.000496         |
|  [Callback]GradientAccumulationScheduler.on_before_optimizer_step                         |  7.9901e-07       |  552              |  0.00044105       |  0.00049261       |
|  [Strategy]SingleDeviceStrategy.on_train_batch_start                                      |  7.9103e-07       |  552              |  0.00043665       |  0.0004877        |
|  [Callback]GradientAccumulationScheduler.on_before_zero_grad                              |  7.8168e-07       |  552              |  0.00043149       |  0.00048193       |
|  [Callback]GradientAccumulationScheduler.on_after_backward                                |  7.7882e-07       |  552              |  0.00042991       |  0.00048016       |
|  [LightningModule]Scyan.train_dataloader                                                  |  0.00041352       |  1                |  0.00041352       |  0.00046187       |
|  [LightningModule]Scyan.on_before_optimizer_step                                          |  6.9523e-07       |  552              |  0.00038376       |  0.00042863       |
|  [Callback]TQDMProgressBar.on_train_end                                                   |  0.00021302       |  1                |  0.00021302       |  0.00023792       |
|  [Callback]GradientAccumulationScheduler.on_train_epoch_start                             |  3.6845e-06       |  23               |  8.4743e-05       |  9.465e-05        |
|  [Callback]ModelSummary.on_train_epoch_end                                                |  2.2689e-06       |  23               |  5.2184e-05       |  5.8285e-05       |
|  [Callback]ModelSummary.on_train_epoch_start                                              |  2.0855e-06       |  23               |  4.7967e-05       |  5.3574e-05       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_epoch_start     |  2.0457e-06       |  23               |  4.705e-05        |  5.2551e-05       |
|  [LightningModule]Scyan.on_train_epoch_start                                              |  1.1953e-06       |  23               |  2.7493e-05       |  3.0707e-05       |
|  [Callback]GradientAccumulationScheduler.on_train_epoch_end                               |  9.9158e-07       |  23               |  2.2806e-05       |  2.5472e-05       |
|  [LightningModule]Scyan.configure_sharded_model                                           |  6.482e-06        |  1                |  6.482e-06        |  7.2398e-06       |
|  [Callback]TQDMProgressBar.setup                                                          |  5.5283e-06       |  1                |  5.5283e-06       |  6.1746e-06       |
|  [Callback]ModelSummary.setup                                                             |  4.6641e-06       |  1                |  4.6641e-06       |  5.2093e-06       |
|  [LightningModule]Scyan.configure_callbacks                                               |  4.4927e-06       |  1                |  4.4927e-06       |  5.0179e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_fit_end               |  4.3139e-06       |  1                |  4.3139e-06       |  4.8182e-06       |
|  [Callback]ModelSummary.on_train_start                                                    |  3.6359e-06       |  1                |  3.6359e-06       |  4.0609e-06       |
|  [LightningModule]Scyan.prepare_data                                                      |  3.5986e-06       |  1                |  3.5986e-06       |  4.0193e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.setup                    |  3.2559e-06       |  1                |  3.2559e-06       |  3.6365e-06       |
|  [LightningModule]Scyan.on_fit_start                                                      |  3.159e-06        |  1                |  3.159e-06        |  3.5284e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_fit_start             |  2.9802e-06       |  1                |  2.9802e-06       |  3.3286e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_start           |  2.943e-06        |  1                |  2.943e-06        |  3.287e-06        |
|  [Callback]GradientAccumulationScheduler.on_fit_start                                     |  2.712e-06        |  1                |  2.712e-06        |  3.0291e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.on_train_end             |  2.1234e-06       |  1                |  2.1234e-06       |  2.3717e-06       |
|  [Callback]EarlyStopping{'monitor': 'loss_epoch', 'mode': 'min'}.teardown                 |  2.0042e-06       |  1                |  2.0042e-06       |  2.2385e-06       |
|  [Callback]ModelSummary.on_train_end                                                      |  1.8328e-06       |  1                |  1.8328e-06       |  2.0471e-06       |
|  [LightningModule]Scyan.setup                                                             |  1.3337e-06       |  1                |  1.3337e-06       |  1.4896e-06       |
|  [Strategy]SingleDeviceStrategy.on_train_end                                              |  1.3262e-06       |  1                |  1.3262e-06       |  1.4812e-06       |
|  [LightningModule]Scyan.teardown                                                          |  1.3262e-06       |  1                |  1.3262e-06       |  1.4812e-06       |
|  [LightningModule]Scyan.on_fit_end                                                        |  1.3113e-06       |  1                |  1.3113e-06       |  1.4646e-06       |
|  [LightningModule]Scyan.on_train_start                                                    |  1.274e-06        |  1                |  1.274e-06        |  1.423e-06        |
|  [LightningModule]Scyan.on_train_end                                                      |  1.1921e-06       |  1                |  1.1921e-06       |  1.3315e-06       |
|  [Callback]TQDMProgressBar.on_fit_start                                                   |  1.1846e-06       |  1                |  1.1846e-06       |  1.3231e-06       |
|  [Callback]TQDMProgressBar.on_fit_end                                                     |  1.1846e-06       |  1                |  1.1846e-06       |  1.3231e-06       |
|  [Callback]GradientAccumulationScheduler.setup                                            |  1.1772e-06       |  1                |  1.1772e-06       |  1.3148e-06       |
|  [Callback]GradientAccumulationScheduler.on_train_start                                   |  1.1101e-06       |  1                |  1.1101e-06       |  1.2399e-06       |
|  [Strategy]SingleDeviceStrategy.on_train_start                                            |  9.5367e-07       |  1                |  9.5367e-07       |  1.0652e-06       |
|  [Callback]TQDMProgressBar.teardown                                                       |  9.5367e-07       |  1                |  9.5367e-07       |  1.0652e-06       |
|  [Callback]GradientAccumulationScheduler.on_train_end                                     |  9.3877e-07       |  1                |  9.3877e-07       |  1.0485e-06       |
|  [Callback]ModelSummary.on_fit_end                                                        |  9.3132e-07       |  1                |  9.3132e-07       |  1.0402e-06       |
|  [Callback]GradientAccumulationScheduler.teardown                                         |  9.1642e-07       |  1                |  9.1642e-07       |  1.0236e-06       |
|  [Callback]GradientAccumulationScheduler.on_fit_end                                       |  8.7172e-07       |  1                |  8.7172e-07       |  9.7363e-07       |
|  [Callback]ModelSummary.teardown                                                          |  8.6427e-07       |  1                |  8.6427e-07       |  9.6531e-07       |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Scyan model with N=14730000 cells, P=12 populations and M=14 markers.
   ├── Covariates: sample_id
   ├── No continuum-marker provided
   └── Batch correction mode: True
quentinblampey commented 6 months ago

Good to hear that the GPU is faster now! The minor speed benefit may be because the model is small. Or maybe increasing the mini-batch size could help.

Closing for now, don't hesitate to re-open if it happens again