LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
43 stars 6 forks source link

Profiling on GPU. #179

Closed Jordan-Dennis closed 1 year ago

Jordan-Dennis commented 1 year ago

Hi all, As I strive to make my forwards models faster I am finding results that are counter-intuitive. For example, a vectorised version running slower than a non-vectorised version. I am starting to wonder if this is a CPU thing so I would like to do some testing on GPU. Hit me up if you have a free GPU to run some code. Regards Jordan.

benjaminpope commented 1 year ago

We have GPUs at SMP, or you can try Colab (but it will only work in 32 bit).

——————————————— Dr Benjamin Pope (he/him) Lecturer in Astrophysics University of Queensland benjaminpope.github.io


From: Jordan Dennis @.> Sent: Thursday, January 5, 2023 10:50:28 AM To: LouisDesdoigts/dLux @.> Cc: Subscribed @.***> Subject: [LouisDesdoigts/dLux] Profiling on GPU. (Issue #179)

Hi all, As I strive to make my forwards models faster I am finding results that are counter-intuitive. For example, a vectorised version running slower than a non-vectorised version. I am starting to wonder if this is a CPU thing so I would like to do some testing on GPU. Hit me up if you have a free GPU to run some code. Regards Jordan.

— Reply to this email directly, view it on GitHubhttps://github.com/LouisDesdoigts/dLux/issues/179, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABN6YFIZ6VLEE4DA45UGX3LWQYEEJANCNFSM6AAAAAATRK4OUI. You are receiving this because you are subscribed to this thread.Message ID: @.***>

LouisDesdoigts commented 1 year ago

Colab is probably the easiest, but comparing performance across different hardware is kind of apples to oranges...

I'm a little suspicious that a vectorised version is running slower though, where is this in the code?

Jordan-Dennis commented 1 year ago

Here are two implementations,

@ft.partial(jax.jit, inline=True)
def soft_square_aperture(width: float, ccoords: float) -> float:
    pixel_scale: float = get_pixel_scale(ccoords)
    acoords: float = jax.lax.abs(ccoords)
    x: float = jax.lax.index_in_dim(acoords, 0)
    y: float = jax.lax.index_in_dim(acoords, 1)
    square: float = ((x < width) & (y < width)).astype(float)
    edges: float = ((x < (width + pixel_scale)) & (y < (width + pixel_scale))).astype(float)
    return ((square + edges) / 2.).squeeze()

@ft.partial(jax.jit, inline=True)
def _soft_square_aperture(width: float, ccoords: float) -> float:
    pixel_scale: float = get_pixel_scale(ccoords)
    acoords: float = jax.lax.abs(ccoords)
    square: float = (acoords < width).prod(axis = 0).astype(float)
    edges: float = (acoords < (width + pixel_scale)).prod(axis = 0).astype(float)
    return (square + edges) / 2.

Running this on my machine I find that, image Which is nearly a factor of ten.

Jordan-Dennis commented 1 year ago

Similar results on colab image

Jordan-Dennis commented 1 year ago

The vectorisation thing seems to be really prevalent. Here is another example with similar results,

import jax 

@jax.jit
def hypotenuse_v0(coords: float) -> float:
    return jax.lax.sqrt(jax.lax.integer_pow(coords, 2).sum(axis = 0))

@jax.jit
def hypotenuse_v1(ccoords: float) -> float:
    x: float = ccoords[0]
    y: float = ccoords[1]
    return np.hypot(x, y)

@jax.jit
def hypotenuse_v2(ccoords: float) -> float:
    x: float = ccoords[0]
    y: float = ccoords[1]
    x_sq: float = jax.lax.integer_pow(x, 2)
    y_sq: float = jax.lax.integer_pow(y, 2)
    return jax.lax.sqrt(x_sq + y_sq)

@jax.jit
def hypotenuse_v3(ccoords: float) -> float:
    x: float = jax.lax.index_in_dim(ccoords, 0)
    y: float = jax.lax.index_in_dim(ccoords, 1)
    x_sq: float = jax.lax.integer_pow(x, 2)
    y_sq: float = jax.lax.integer_pow(y, 2)
    return jax.lax.sqrt(x_sq + y_sq)

%%timeit
hypotenuse_v0(ccoords)

%%timeit
hypotenuse_v1(ccoords)

%%timeit
hypotenuse_v2(ccoords)

%%timeit
hypotenuse_v3(ccoords)

The builtin version as a benchmark showed similar performance to v2 and v3. The builtin version actually checks for zeros to avoid calling integer_pow on them. This is all really confusing me and I am thinking that I might make a jax issue since it is very counter intuitive.

LouisDesdoigts commented 1 year ago

Have you been using .block_until_ready() when profiling?

Jordan-Dennis commented 1 year ago

Yes mostly. I often forget it the first time though.

Jordan-Dennis commented 1 year ago

Here is my times, image You can see it is a significant difference. What's more is the jaxpr looks really nice for v0 at just three lines. Many less than the others.

Jordan-Dennis commented 1 year ago

So @jakevdp ran my code on a GPU and they all had very similar performance. He implied that XLA compiler has seen a lot more work on GPU than on CPU. @benjaminpope is there any chance I can use some of the SMP GPUs (or just send the code to you)?

LouisDesdoigts commented 1 year ago

Honestly it colab would be the quickest and easiest

Jordan-Dennis commented 1 year ago

Last time I tried colab I got the same results, but that might have been instance specific.

Similar results on colab image

LouisDesdoigts commented 1 year ago

Colab CPU:

Screen Shot 2023-01-06 at 11 31 28 am

Colab GPU:

Screen Shot 2023-01-06 at 11 32 50 am
LouisDesdoigts commented 1 year ago

Relative times appear the same

Jordan-Dennis commented 1 year ago

I musn't have set up COLAB correctly. I'll look into it a bit more since I don't normally use it.

Jordan-Dennis commented 1 year ago

OK I'm hooked COLAB has vim keybindings.