mli / transformers-benchmarks

real Transformer TeraFLOPS on various GPUs
Apache License 2.0
861 stars 105 forks source link

About the theoretical value of the GPU #3

Open fingertap opened 1 year ago

fingertap commented 1 year ago

请问沐神:

  1. 在notebook中指向的wiki里,3090ti的理论值40是从表中的core boosted value(39.997)得到的吗?
  2. 我在自己的很多块3090上,用CUDA11.7和nvidia-driver 525跑出来的TFLOPS都只有24,距离base(29.3)和boost(35.6)的理论值都有一定的差距。请问notebook中用3090ti跑的TFLOPS是经过超频的吗?要想达到接近理论值的FLOPS需要做怎样的设置呢?
H-Jamieu commented 3 months ago

直接上结论:这个问题是pytorch官方默认的distro设定是torch.backends.cuda.matmul.allow_tf32 = True导致的。可以将这个变量设置为true解决问题。用torch.set_float32_matmul_precision('high')也可做到,不过这个是用bf16加速的。

(opinion)也就是说4090在这个microbench里面80T的FP32算力是用tensorcore加速实现的,如果用cuda硬算大概是54T。

参考1:https://pytorch.org/docs/stable/notes/cuda.html 参考2:https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html

----------------过时内容--------------------- 我也遇到了这个问题,目前的发现是:

  1. 如果直接用pip install torch的官方命令,无论Windows还是linux下FP32都约为理论值的62.5%,炼丹实测数据与理论跑分一至。手上的3090/4090/4080/4060都一样,cuda版本从11.2-12.1都试过,pytorch从1.13-2.1都试过。
  2. FP16与理论值接近。
  3. 用nvidia pytorch可以跑出理论成绩。
  4. 自己偶然编译过一版pytorch,cuda用的是10.2,3090在Windows原生下Fp32达到了标称算力。 猜想:该问题可能和pytorch官方编译的轮子有关.
H-Jamieu commented 2 months ago

experiments:

  ENV: windows 11, python3.9
  Pytorch version   : 2.3.1+cu118
  CUDA version  : 11.8
  GPU       : NVIDIA GeForce RTX 4090

  default

  n=128 n=512   n=2048  n=8192  n=16384
  torch.float32 0.224   14.299  55.590  54.695  54.470
  torch.float16 0.118   12.691  168.418 163.888 178.617

  torch.backends.cuda.matmul.allow_tf32 =True

  n=128 n=512   n=2048  n=8192  n=16384
  torch.float32 0.216   13.843  83.765  88.001  87.673
  torch.float16 0.215   13.581  168.568 164.090 178.862

  torch.set_float32_matmul_precision('highest')

  n=128 n=512   n=2048  n=8192  n=16384
  torch.float32 0.226   14.882  55.620  54.704  54.511
  torch.float16 0.217   13.613  168.589 163.722 178.860

  torch.set_float32_matmul_precision('high')
  n=128 n=512   n=2048  n=8192  n=16384
  torch.float32 0.213   13.797  86.896  91.506  91.438
  torch.float16 0.215   13.434  175.768 167.794 184.831