Joshuaalbert / DSA2000-Cal

DSA-2000 Calibration and Forward Modelling
https://www.deepsynoptic.org/overview
MIT License
1 stars 0 forks source link

Reorder vmap(scan) into scan(vmap) #72

Closed Joshuaalbert closed 5 days ago

Joshuaalbert commented 2 months ago

This benchmark shows that scan(vmap) is faster, probably due to the contiguous memory access pattern. Thus for performance, it makes more sense to scan over vmap in accumulation code.

import time

import jax
import numpy as np
import jax.numpy as jnp
from jax import lax
from jax import random, jit

def main(batch_size1, batch_size2, length):
    # Define a simple function f
    def f(x):
        return jnp.sin(x ** 2 + 1)

    # Initialize the random number generator
    key = random.PRNGKey(0)

    # Define batch dimensions

    # Generate random input data
    xs = random.normal(key, (batch_size1, batch_size2, length))

    # Define the functions
    def scan_vmap_f(xs):
        # xs: [batch_size1, batch_size2]
        def body_fn(carry, x):
            return (), jax.vmap(f)(x)  # [batch_size2]

        _, result = lax.scan(body_fn, (), xs)  # [batch_size1, batch_size2]
        return result

    def vmap_scan_f(xs):
        # xs: [batch_size1, batch_size2]

        def scan(x):
            # x: [batch_size2]
            def body_fn(carry, x):
                return (), f(x)

            _, result = lax.scan(body_fn, (), x)  # [batch_size2]
            return result

        return jax.vmap(scan)(xs)  # [batch_size1, batch_size2]

    # JIT compile the functions
    scan_vmap_f_jit = jit(scan_vmap_f).lower(xs).compile()
    vmap_scan_f_jit = jit(vmap_scan_f).lower(xs).compile()

    # Measure execution time for scan(vmap(f))
    start_time = time.time()
    ys_scan_vmap = scan_vmap_f_jit(xs)
    ys_scan_vmap.block_until_ready()
    end_time = time.time()
    dt_scan_vmap = end_time - start_time

    # Measure execution time for vmap(scan(f))
    start_time = time.time()
    ys_vmap_scan = vmap_scan_f_jit(xs)
    ys_vmap_scan.block_until_ready()
    end_time = time.time()
    dt_vmap_scan = end_time - start_time

    # Check if the results are the same
    np.testing.assert_allclose(ys_scan_vmap, ys_vmap_scan)

    print(f"batch_size1: {batch_size1}, batch_size2: {batch_size2}, length: {length}\n"
          f"\t> scan(vmap(f)): {dt_scan_vmap:.2e} s\n"
          f"\t> vmap(scan(f)): {dt_vmap_scan:.2e} s\n"
          f"\t\t> Winner: {'scan(vmap(f))' if dt_scan_vmap < dt_vmap_scan else 'vmap(scan(f))'}\n"
          f"\t\t> Speedup: {dt_vmap_scan / dt_scan_vmap if dt_scan_vmap < dt_vmap_scan else dt_scan_vmap / dt_vmap_scan:.2f}x")

if __name__ == '__main__':
    for batch_size1 in [100, 1000, 5000]:
        for batch_size2 in [100, 1000, 5000]:
            for length in [1, 10, 100]:
                main(batch_size1=batch_size1, batch_size2=batch_size2, length=length)
batch_size1: 100, batch_size2: 100, length: 1
    > scan(vmap(f)): 2.99e-04 s
    > vmap(scan(f)): 2.25e-04 s
        > Winner: vmap(scan(f))
        > Speedup: 1.33x
batch_size1: 100, batch_size2: 100, length: 10
    > scan(vmap(f)): 1.22e-03 s
    > vmap(scan(f)): 1.28e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.05x
batch_size1: 100, batch_size2: 100, length: 100
    > scan(vmap(f)): 7.20e-03 s
    > vmap(scan(f)): 8.25e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.15x
batch_size1: 100, batch_size2: 1000, length: 1
    > scan(vmap(f)): 1.11e-03 s
    > vmap(scan(f)): 1.97e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.78x
batch_size1: 100, batch_size2: 1000, length: 10
    > scan(vmap(f)): 8.16e-03 s
    > vmap(scan(f)): 9.87e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.21x
batch_size1: 100, batch_size2: 1000, length: 100
    > scan(vmap(f)): 6.64e-02 s
    > vmap(scan(f)): 8.43e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.27x
batch_size1: 100, batch_size2: 5000, length: 1
    > scan(vmap(f)): 3.44e-03 s
    > vmap(scan(f)): 4.99e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.45x
batch_size1: 100, batch_size2: 5000, length: 10
    > scan(vmap(f)): 3.59e-02 s
    > vmap(scan(f)): 4.38e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.22x
batch_size1: 100, batch_size2: 5000, length: 100
    > scan(vmap(f)): 3.40e-01 s
    > vmap(scan(f)): 4.23e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.25x
batch_size1: 1000, batch_size2: 100, length: 1
    > scan(vmap(f)): 9.19e-04 s
    > vmap(scan(f)): 1.07e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.17x
batch_size1: 1000, batch_size2: 100, length: 10
    > scan(vmap(f)): 7.77e-03 s
    > vmap(scan(f)): 8.34e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.07x
batch_size1: 1000, batch_size2: 100, length: 100
    > scan(vmap(f)): 6.70e-02 s
    > vmap(scan(f)): 8.30e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.24x
batch_size1: 1000, batch_size2: 1000, length: 1
    > scan(vmap(f)): 6.47e-03 s
    > vmap(scan(f)): 9.61e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.48x
batch_size1: 1000, batch_size2: 1000, length: 10
    > scan(vmap(f)): 7.94e-02 s
    > vmap(scan(f)): 9.42e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.19x
batch_size1: 1000, batch_size2: 1000, length: 100
    > scan(vmap(f)): 7.23e-01 s
    > vmap(scan(f)): 8.66e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.20x
batch_size1: 1000, batch_size2: 5000, length: 1
    > scan(vmap(f)): 3.26e-02 s
    > vmap(scan(f)): 4.47e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.37x
batch_size1: 1000, batch_size2: 5000, length: 10
    > scan(vmap(f)): 3.76e-01 s
    > vmap(scan(f)): 4.84e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.29x
batch_size1: 1000, batch_size2: 5000, length: 100
    > scan(vmap(f)): 3.59e+00 s
    > vmap(scan(f)): 4.37e+00 s
        > Winner: scan(vmap(f))
        > Speedup: 1.22x
batch_size1: 5000, batch_size2: 100, length: 1
    > scan(vmap(f)): 3.36e-03 s
    > vmap(scan(f)): 4.26e-03 s
        > Winner: scan(vmap(f))
        > Speedup: 1.27x
batch_size1: 5000, batch_size2: 100, length: 10
    > scan(vmap(f)): 3.47e-02 s
    > vmap(scan(f)): 4.44e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.28x
batch_size1: 5000, batch_size2: 100, length: 100
    > scan(vmap(f)): 3.44e-01 s
    > vmap(scan(f)): 4.32e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.26x
batch_size1: 5000, batch_size2: 1000, length: 1
    > scan(vmap(f)): 3.08e-02 s
    > vmap(scan(f)): 5.32e-02 s
        > Winner: scan(vmap(f))
        > Speedup: 1.72x
batch_size1: 5000, batch_size2: 1000, length: 10
    > scan(vmap(f)): 3.74e-01 s
    > vmap(scan(f)): 4.61e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.23x
batch_size1: 5000, batch_size2: 1000, length: 100
    > scan(vmap(f)): 3.30e+00 s
    > vmap(scan(f)): 4.37e+00 s
        > Winner: scan(vmap(f))
        > Speedup: 1.33x
batch_size1: 5000, batch_size2: 5000, length: 1
    > scan(vmap(f)): 1.82e-01 s
    > vmap(scan(f)): 2.87e-01 s
        > Winner: scan(vmap(f))
        > Speedup: 1.57x
batch_size1: 5000, batch_size2: 5000, length: 10
    > scan(vmap(f)): 1.82e+00 s
    > vmap(scan(f)): 2.39e+00 s
        > Winner: scan(vmap(f))
        > Speedup: 1.31x
Joshuaalbert commented 2 months ago

Looks like the scan should even maybe move outside the partitioned frequency axis. To be clear, partitioning of N frequency over M devices means that each device is responsible for contiguous chunks of N//M frequencies. Even divisibility is required. These N//M frequencies are then processed using vectorised operations locally on the device, and thus no communication is required. It thus makes sense that scan(vmap(vmap)) could make sense. As long as there is no unintentional communication between devices.

Joshuaalbert commented 5 days ago

Done with multi_vmap which is much more complex and better. Could also automatically allow choosing best ordering of scan/vmap to minimise flops etc.