pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.29k stars 22.46k forks source link

torch.compile of simple loop takes 34 seconds #111441

Open yaroslavvb opened 1 year ago

yaroslavvb commented 1 year ago

πŸ› Describe the bug

I'm converting some numpy code, basically it's this for loop. It takes 30 seconds to compile.

Is it too much to ask to make it faster? :)

Works great otherwise!

    for step_idx in range(num_steps):
        X = hsqrt * np.random.randn(B, d)
        losses = np.einsum("BD,BD->B", E, X)
        E -= alpha * np.einsum("BD,B->BD", X, losses)
        traj[step_idx] = E

    return np.sum(traj * traj, axis=2)

Error logs

No response

Minified repro

import time
import torch

import numpy as np

class timeit:
    def __init__(self, tag=""):
        self.tag = tag

    def __enter__(self):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self.start = time.perf_counter()
        return self

    def __exit__(self, *args):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self.end = time.perf_counter()
        interval_ms = 1000 * (self.end - self.start)
        print(f"{interval_ms:8.2f}   {self.tag}")

@torch.compile
def getNormsSq0(errors):
    return (errors*errors).sum(axis=1)

@torch.compile
def trajH0(h, B, alpha, num_steps):
    """(normalize for total(h)=1)

    Simulate trajectories of fixed point equation
     e=e-alpha x <e, x>

    where x is a Gaussian RV with diagonal covariance entries h

    p: power-law decay constant of covariance eigenvalues
    d: number of dimensions
    B: batch size (number of trajectories to sample)
    alpha: step size
    num_steps: how many steps to run for

    Returns (num_steps, B) vector of ||e||^2
    """

    # h = h / np.sum(h)
    hsqrt = np.sqrt(h)
    E = np.ones((B, d))
    E = np.random.randn(B, d)
    traj = np.zeros((num_steps, B, d))

    for step_idx in range(num_steps):
        X = hsqrt * np.random.randn(B, d)
        losses = np.einsum("BD,BD->B", E, X)
        E -= alpha * np.einsum("BD,B->BD", X, losses)
        traj[step_idx] = E

    return np.sum(traj * traj, axis=2)

d = 1000
h = np.arange(1, d + 1)
h = h ** -1.1

num_steps=100
d = h.shape[0]

a = 2/(2*h.max() + h.sum())

#with timeit("regular"):
#        means = getNormsSq0(trajH(h, B=1000, alpha=a, num_steps=num_steps))

with timeit("compiled"):
    means = getNormsSq0(trajH0(h, B=1000, alpha=a, num_steps=num_steps))

with timeit("compiled+cuda"):
    with torch.device("cuda"):
        means = getNormsSq0(trajH0(h, B=1000, alpha=a, num_steps=num_steps))

Versions

This is on Google Colab and PyTorch nightly

PyTorch version: 2.2.0.dev20231017+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.6
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.120+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          12
On-line CPU(s) list:             0-11
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              6
Socket(s):                       1
Stepping:                        7
BogoMIPS:                        4400.41
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       192 KiB (6 instances)
L1i cache:                       192 KiB (6 instances)
L2 cache:                        6 MiB (6 instances)
L3 cache:                        38.5 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-11
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Vulnerable; SMT Host state unknown
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Vulnerable
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231017+cu121
[pip3] torchaudio==2.0.2+cu118
[pip3] torchdata==0.6.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.2+cu118
[pip3] triton==2.0.0
[conda] Could not collect

cc @ezyang @anijain2305 @chauhang @penguinwu @oulgen @jamesjwu @aorenste @laithsakka @zou3519 @ydwu4 @bdhirsh @msaroufim @wconstab

lezcano commented 1 year ago

Dynamo unrolls loops, and its not particularly fast at tracing through them. This is not only an issue with NumPy, but also when tracing PyTorch.

For NumPy, what we figured that often works best is to just compile the body of the loop. You can also even do some partial unroling, where you compile a few iterations of the loop at a time.

Of course, this is just a workaround for the actual issue, but yeah.

zou3519 commented 1 year ago

cc @Chillee for compile time. Also, as @Chillee mentioned elsewhere, a host-side control flow operator (e.g. something like jax.lax.fori_loop) could help here.

Fidget-Spinner commented 1 year ago

I'm interested in poking around at this issue. Let me talk with others first though to get a feel of how much effort a possible solution could take.

yaroslavvb commented 1 year ago

@zou3519 afori_loop primitive that was was substituted automatically for Python functions involving for loops, would resolve the issue

Fidget-Spinner commented 12 months ago

πŸ› Describe the bug

I'm converting some numpy code, basically it's this for loop. It takes 30 seconds to compile.

Is it too much to ask to make it faster? :)

Works great otherwise!


    for step_idx in range(num_steps):
        X = hsqrt * np.random.randn(B, d)
        losses = np.einsum("BD,BD->B", E, X)
        E -= alpha * np.einsum("BD,B->BD", X, losses)
        traj[step_idx] = E

Converting the above to a higher-order fori_loop function automatically isn't too hard, ~with the exception of the line traj[step_idx] = E. That line alone has side-effects and would probably require recursion. However, the IR passed to Inductor is already recursive. So there would be no benefit there.~ (I forgot that the non-functional stuff is pulled out of Inductor, so this could work too).

Fidget-Spinner commented 12 months ago

One other possible way: after one or two iterations and a fixpoint is reached (which should be quite fast since this is a trace), express the loop as a recursive definition of state. Then compile that recursive definition. This would be a more efficient way of evaluating the loop than the current method which is evaluating the loop fully then evaluating the recursive expressions fully.

(see: cyclic term graph (Ariola & Klop 1996))

ezyang commented 11 months ago

An approach that requires less static analysis juice is to require some user annotation. For example, you could annotate a function with a higher order op that instructs Dynamo to never inline it; instead, it must be possible to compile it once and reuse at all call sites. If the loop here can be rewritten to call such a function, it would also resolve your problem. The simplest implementation of this op would not allow side effects on Python.

Fidget-Spinner commented 11 months ago

One other possible way: after one or two iterations and a fixpoint is reached (which should be quite fast since this is a trace) ...

Nevermind, after reasoning this out, it's not possible. That would require projecting traces or some form of symbolic analysis, which is completely against the dynamo architecture.

bhack commented 2 months ago

What is the current state of this?

tengyifei commented 2 months ago

Does the torch._higher_order_ops.while_loop in https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66 help with this case?

anijain2305 commented 2 months ago

TorchDynamo still unrolls the loop aggressively. There is no easy workaround here. If it makes sense in your codebase, instead of applying torch.compile at the very top, you could lift the loop body into a separate function and then apply torch.compile manually on that lifted function. This might be not very user-friendly.

laithsakka commented 1 month ago

TorchDynamo still unrolls the loop aggressively. There is no easy workaround here. If it makes sense in your codebase, instead of applying torch.compile at the very top, you could lift the loop body into a separate function and then apply torch.compile manually on that lifted function. This might be not very user-friendly.

I wonder if long term we can do that lifting implicitly during compilation