Closed Jordan-Dennis closed 1 year ago
We have GPUs at SMP, or you can try Colab (but it will only work in 32 bit).
——————————————— Dr Benjamin Pope (he/him) Lecturer in Astrophysics University of Queensland benjaminpope.github.io
From: Jordan Dennis @.> Sent: Thursday, January 5, 2023 10:50:28 AM To: LouisDesdoigts/dLux @.> Cc: Subscribed @.***> Subject: [LouisDesdoigts/dLux] Profiling on GPU. (Issue #179)
Hi all, As I strive to make my forwards models faster I am finding results that are counter-intuitive. For example, a vectorised version running slower than a non-vectorised version. I am starting to wonder if this is a CPU thing so I would like to do some testing on GPU. Hit me up if you have a free GPU to run some code. Regards Jordan.
— Reply to this email directly, view it on GitHubhttps://github.com/LouisDesdoigts/dLux/issues/179, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABN6YFIZ6VLEE4DA45UGX3LWQYEEJANCNFSM6AAAAAATRK4OUI. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Colab is probably the easiest, but comparing performance across different hardware is kind of apples to oranges...
I'm a little suspicious that a vectorised version is running slower though, where is this in the code?
Here are two implementations,
@ft.partial(jax.jit, inline=True)
def soft_square_aperture(width: float, ccoords: float) -> float:
pixel_scale: float = get_pixel_scale(ccoords)
acoords: float = jax.lax.abs(ccoords)
x: float = jax.lax.index_in_dim(acoords, 0)
y: float = jax.lax.index_in_dim(acoords, 1)
square: float = ((x < width) & (y < width)).astype(float)
edges: float = ((x < (width + pixel_scale)) & (y < (width + pixel_scale))).astype(float)
return ((square + edges) / 2.).squeeze()
@ft.partial(jax.jit, inline=True)
def _soft_square_aperture(width: float, ccoords: float) -> float:
pixel_scale: float = get_pixel_scale(ccoords)
acoords: float = jax.lax.abs(ccoords)
square: float = (acoords < width).prod(axis = 0).astype(float)
edges: float = (acoords < (width + pixel_scale)).prod(axis = 0).astype(float)
return (square + edges) / 2.
Running this on my machine I find that,
Which is nearly a factor of ten.
Similar results on colab
The vectorisation thing seems to be really prevalent. Here is another example with similar results,
import jax
@jax.jit
def hypotenuse_v0(coords: float) -> float:
return jax.lax.sqrt(jax.lax.integer_pow(coords, 2).sum(axis = 0))
@jax.jit
def hypotenuse_v1(ccoords: float) -> float:
x: float = ccoords[0]
y: float = ccoords[1]
return np.hypot(x, y)
@jax.jit
def hypotenuse_v2(ccoords: float) -> float:
x: float = ccoords[0]
y: float = ccoords[1]
x_sq: float = jax.lax.integer_pow(x, 2)
y_sq: float = jax.lax.integer_pow(y, 2)
return jax.lax.sqrt(x_sq + y_sq)
@jax.jit
def hypotenuse_v3(ccoords: float) -> float:
x: float = jax.lax.index_in_dim(ccoords, 0)
y: float = jax.lax.index_in_dim(ccoords, 1)
x_sq: float = jax.lax.integer_pow(x, 2)
y_sq: float = jax.lax.integer_pow(y, 2)
return jax.lax.sqrt(x_sq + y_sq)
%%timeit
hypotenuse_v0(ccoords)
%%timeit
hypotenuse_v1(ccoords)
%%timeit
hypotenuse_v2(ccoords)
%%timeit
hypotenuse_v3(ccoords)
The builtin version as a benchmark showed similar performance to v2
and v3
. The builtin version actually checks for zeros to avoid calling integer_pow
on them. This is all really confusing me and I am thinking that I might make a jax
issue since it is very counter intuitive.
Have you been using .block_until_ready()
when profiling?
Yes mostly. I often forget it the first time though.
Here is my times,
You can see it is a significant difference. What's more is the
jaxpr
looks really nice for v0
at just three lines. Many less than the others.
So @jakevdp ran my code on a GPU and they all had very similar performance. He implied that XLA
compiler has seen a lot more work on GPU than on CPU. @benjaminpope is there any chance I can use some of the SMP GPUs (or just send the code to you)?
Honestly it colab would be the quickest and easiest
Last time I tried colab
I got the same results, but that might have been instance specific.
Similar results on colab
Colab CPU:
Colab GPU:
Relative times appear the same
I musn't have set up COLAB
correctly. I'll look into it a bit more since I don't normally use it.
OK I'm hooked COLAB
has vim
keybindings.
Hi all, As I strive to make my forwards models faster I am finding results that are counter-intuitive. For example, a vectorised version running slower than a non-vectorised version. I am starting to wonder if this is a CPU thing so I would like to do some testing on GPU. Hit me up if you have a free GPU to run some code. Regards Jordan.