gerkone / painn-jax

PaiNN in jax
MIT License
7 stars 3 forks source link

Performance question and batching #1

Open tisabe opened 6 months ago

tisabe commented 6 months ago

Hi there!

Thanks for this great repository and sharing your implementation of PaiNN! I have a question/issue regarding the advertised performance: Is the inference time mentioned in the readme the timing on the validation split, when you run validate.py? When I ran validate.py myself, I did not get quite the same performance on a newer gpu (A100) on a cluster: Starting 100 epochs with 587137 parameters. Jitting... ... [Epoch 11] train loss 0.019037, epoch 227907.39ms - val loss 0.079965 (best), infer 18.01ms I haven't managed to check GPU usage yet, but this timing is actually similar to what I got with the original. To make sure the gpu was found by Jax, I printed: Jax devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

What has sped up my training in the past was using dynamic batching, i.e. collecting graphs for a batch until a maximum number of edges, nodes or graphs is reached. It would be interesting to see if this speeds up training in this minimum example as well. If understand correctly, the batches from schnetpack.QM9 come naively batched, with the same number of graphs?

gerkone commented 6 months ago

Hey Tim Glad you found it useful. Looking at it now the timings seem a bit sketchy, especially the torch one. I reran the validation, but after jitting I still get similar values for jax: around 8ms on my gpu (RTX 4000), so your result is unexpected. It should not be compiling again after the first epoch as the validation loader is not shuffled. Just to be sure could you check if that's the case?

Starting 100 epochs with 587137 parameters.
Jitting...
[Epoch    1] train loss 0.334538, epoch 357033.30ms - val loss 0.232655 (best), infer 33.78ms
[Epoch    2] train loss 0.053471, epoch 287106.63ms - val loss 0.163645 (best), infer 8.08ms
[Epoch    3] train loss 0.035966, epoch 277099.84ms - val loss 0.106103 (best), infer 8.10ms

This is in line with what I on egnn. Torch is probably off though, and quickly rerunning it got me to ~13ms. Thanks for pointing out, I'll update the readmes once I have time to do try it properly.

As for batching, one could definitely improve it, for example by solving a special knapsack problem. Right now the batches are padded to 1.3 times the worst case (here)

  max_batch_nodes = int(
      1.3 * max(sum(d["_n_atoms"]) for d in dataset.val_dataloader())
  )

which is clearly not very smart. On the other hand the point of these experiments was not about performance, but more about validating PaiNN and confirming it does what it should. On top of this GPU utilization during training it hardly goes above 30% on QM9. If you look around, especially in the QM9 code, you'll see how hacky it.

tisabe commented 6 months ago

I did the advanced profiling technique of putting a print statement in the jitted functions, and they were all compiled just once:

Jax devices:  [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
Target:  mu
Starting 100 epochs with 587137 parameters.
Jitting...
unjitted call to update
unjitted call to train_mse
unjitted call to predict
[Epoch    1] train loss 0.315345, epoch 359965.70msunjitted call to eval_mae
unjitted call to eval_mae
unjitted call to predict
 - val loss 0.196682 (best), infer 16.31ms
[Epoch    2] train loss 0.052425, epoch 281621.66ms
[Epoch    3] train loss 0.038055, epoch 252686.31ms
[Epoch    4] train loss 0.026014, epoch 248861.45ms
[Epoch    5] train loss 0.022775, epoch 232528.28ms
[Epoch    6] train loss 0.019563, epoch 233256.68ms
[Epoch    7] train loss 0.017503, epoch 243905.89ms
[Epoch    8] train loss 0.013741, epoch 236907.86ms
[Epoch    9] train loss 0.015446, epoch 243664.94ms
[Epoch   10] train loss 0.012798, epoch 248185.17ms
[Epoch   11] train loss 0.012427, epoch 233851.51ms - val loss 0.088224 (best), infer 17.92ms
[Epoch   12] train loss 0.009694, epoch 236405.61ms
gerkone commented 6 months ago

Weird. What's not normal about your times is that the inference run at epoch 1 takes about the same as second one at epoch 11 This should not be the case, since jitting alone will move the average runtime upwards (you can see this from the output I printed). Did you modify the inference function? The runtimes I report are for the model forward only, not the graph transform and dataloading.

tisabe commented 6 months ago

I seem to get different timings when I change the val-freq argument. When I evaluate after every epoch, I get much better times:

[Epoch    2] train loss 0.052915, epoch 258117.03ms - val loss 0.124259 (best), infer 2.76ms
[Epoch    3] train loss 0.035268, epoch 240552.22ms - val loss 0.132137, infer 2.77ms
[Epoch    4] train loss 0.028790, epoch 233492.50ms - val loss 0.138793, infer 2.79ms

Infer time is the time per graph, right? I did not change the infer function.

gerkone commented 6 months ago

Time is per batch of 100 graphs (as default), after padding. This has gotten even stranger. The time I put in the readme is from evaluating every graph, but at the moment I honestly don't know why the times are like this. I still think it's somehow jitting again every time if you evaluate at the 10th epoch. Of course this should not happen since the shapes are always the same. Small check would be to discard the first runtime in the evaluate function, for example by dry-running the model over the first batch out of next(iter(loader)). There is likely something wrong in the validation experiments, so don't take them as a good starting point. I might spend some time to investigate this next week, I'll let you know if I find something.