Oneflow-Inc / oneflow

OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.
http://www.oneflow.org
Apache License 2.0
5.78k stars 658 forks source link

flash_attention_v2_backward #10495

Closed cccddd77 closed 1 month ago

cccddd77 commented 2 months ago

flash attention v2 backward算子

github-actions[bot] commented 2 months ago
Speed stats: ``` ```
github-actions[bot] commented 2 months ago

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

github-actions[bot] commented 2 months ago

View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10495/

github-actions[bot] commented 1 month ago

View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10495/

github-actions[bot] commented 1 month ago
Speed stats: ``` GPU Name: NVIDIA GeForce RTX 3080 Ti ❌ OneFlow resnet50 time: 43.7ms (= 4371.3ms / 100, input_shape=[16, 3, 224, 224]) PyTorch resnet50 time: 58.0ms (= 5797.6ms / 100, input_shape=[16, 3, 224, 224]) ✔️ Relative speed: 1.33 (= 58.0ms / 43.7ms) OneFlow resnet50 time: 26.2ms (= 2616.5ms / 100, input_shape=[8, 3, 224, 224]) PyTorch resnet50 time: 38.1ms (= 3812.5ms / 100, input_shape=[8, 3, 224, 224]) ✔️ Relative speed: 1.46 (= 38.1ms / 26.2ms) OneFlow resnet50 time: 19.7ms (= 3932.1ms / 200, input_shape=[4, 3, 224, 224]) PyTorch resnet50 time: 35.3ms (= 7060.7ms / 200, input_shape=[4, 3, 224, 224]) ✔️ Relative speed: 1.80 (= 35.3ms / 19.7ms) OneFlow resnet50 time: 17.9ms (= 3571.5ms / 200, input_shape=[2, 3, 224, 224]) PyTorch resnet50 time: 31.5ms (= 6297.8ms / 200, input_shape=[2, 3, 224, 224]) ✔️ Relative speed: 1.76 (= 31.5ms / 17.9ms) OneFlow resnet50 time: 16.8ms (= 3353.0ms / 200, input_shape=[1, 3, 224, 224]) PyTorch resnet50 time: 29.5ms (= 5903.2ms / 200, input_shape=[1, 3, 224, 224]) ✔️ Relative speed: 1.76 (= 29.5ms / 16.8ms) OneFlow swin dataloader time: 0.201s (= 40.171s / 200, num_workers=1) PyTorch swin dataloader time: 0.127s (= 25.467s / 200, num_workers=1) Relative speed: 0.634 (= 0.127s / 0.201s) OneFlow swin dataloader time: 0.054s (= 10.830s / 200, num_workers=4) PyTorch swin dataloader time: 0.033s (= 6.583s / 200, num_workers=4) Relative speed: 0.608 (= 0.033s / 0.054s) OneFlow swin dataloader time: 0.031s (= 6.216s / 200, num_workers=8) PyTorch swin dataloader time: 0.017s (= 3.320s / 200, num_workers=8) Relative speed: 0.534 (= 0.017s / 0.031s) ❌ OneFlow resnet50 time: 49.2ms (= 4924.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2) PyTorch resnet50 time: 65.9ms (= 6586.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2) ✔️ Relative speed: 1.34 (= 65.9ms / 49.2ms) OneFlow resnet50 time: 36.4ms (= 3638.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2) PyTorch resnet50 time: 47.1ms (= 4710.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2) ✔️ Relative speed: 1.29 (= 47.1ms / 36.4ms) OneFlow resnet50 time: 27.9ms (= 5587.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2) PyTorch resnet50 time: 42.5ms (= 8501.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2) ✔️ Relative speed: 1.52 (= 42.5ms / 27.9ms) OneFlow resnet50 time: 25.5ms (= 5100.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2) PyTorch resnet50 time: 39.0ms (= 7800.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2) ✔️ Relative speed: 1.53 (= 39.0ms / 25.5ms) OneFlow resnet50 time: 24.5ms (= 4901.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2) PyTorch resnet50 time: 35.7ms (= 7149.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2) ✔️ Relative speed: 1.46 (= 35.7ms / 24.5ms) ```