jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

io_callback does not work with custom_vjp #23614

Closed GeophyAI closed 2 weeks ago

GeophyAI commented 2 weeks ago

I want to save some data to disk in the forward pass, and reload them in the backward pass, and I found that only jvp example is provided in the doc. The question is can we use and how to use the io_callback with custom_vjp? The following is my implementation, but it does not work, could anyone help me with this?

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, x, x, cstep)
    return g, None

f = jax.custom_vjp(f)
f.defvjp(fwd, bwd)

def save2file(data, cstep):
    jnp.save(f'tmp/x{cstep}.npy', data)
    return data

def loadfromfile(cstep):
    return jnp.load(f'tmp/x{cstep}.npy')

def step(x, cstep):
    # _f.defvjp(fwd, bwd)
    y = f(x)
    return y, None

x = jnp.array([1., 2.])
nt = 5

# loop over using scan
initial_carry = x
csteps = jnp.arange(10)
final_carry, _ = lax.scan(step, initial_carry, jnp.arange(nt))

def loss(x):
    return jnp.sum(x)
# compute grad
@jax.jit
def cal_grad(x):
    return jax.value_and_grad(loss)(x)

loss, grad = cal_grad(final_carry)
jakevdp commented 2 weeks ago

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

mattjj commented 2 weeks ago

In the pasted code, f isn't being run under autodiff. The function we're differentiating, loss, just calls jnp.sum. So there's no way we would call the autodiff rules associated with f.

mattjj commented 2 weeks ago

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

GeophyAI commented 2 weeks ago

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

Hi, Thank you for your reply. I have read your mentioned doc before, I mean I can not find information about how to combine the usage of custom_vjp and io_callback. The doc decribes that jax.experimental.io_callback(): appropriate for impure functions: e.g. functions which read or write data to disk., but I cannot find some example.

GeophyAI commented 2 weeks ago

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

Thank you for you reply. According to your comments, I modified my codes, it still cannot write to disk as expected. The following is my new implementation, is there any misunderstading from my side?

import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import io_callback
import os

def f(x, cstep):
    return x * 2 + 1

# Save data to file
def save2file(data, cstep):
    jnp.save(f'tmp/x{cstep}.npy', data)
    return data

# Load data from file
def loadfromfile(cstep):
    return jnp.load(f'tmp/x{cstep}.npy')

@jax.custom_vjp
def fwd(x, cstep):
    y = f(x)
    # io_callback(save2file, x, x, cstep)
    return y

def fwd_fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, (x, cstep)

def fwd_bwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return (g * 2,), None

fwd.defvjp(fwd_fwd, fwd_bwd)

@jax.custom_vjp
def bwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return (g * 2,)

def bwd_fwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return g, (x, cstep)

def bwd_bwd(res, g):
    return None

bwd.defvjp(bwd_fwd, bwd_bwd)

f_vjp = jax.custom_vjp(f)
f_vjp.defvjp(fwd, bwd)

def step(x, cstep):
    y = f_vjp(x, cstep)
    return y, None

x = jnp.array([1., 2.])
nt = 5

initial_carry = x
csteps = jnp.arange(10)
final_carry, _ = lax.scan(step, initial_carry, jnp.arange(nt))

def loss(x):
    return jnp.sum(x)

@jax.jit
def cal_grad(x):
    return jax.value_and_grad(loss)(x)

loss_value, grad = cal_grad(final_carry)
print(f"Loss: {loss_value}, Gradient: {grad}")

assert os.path.exists('tmp/x0.npy')
jakevdp commented 2 weeks ago

Hi - I think I may have been unclear in my suggestion. The problem previsously was that you were calling io_callback in fwd and bwd, which do not have custom JVP/VJP rules, and so the autodiff machinery doesn't know how to trace through your callback.

In your update, you created custom autodiff rules for fwd and bwd, but now you are calling io_callback within fwd_fwd, which does not have a custom autodiff rule, and so the autodiff machinery still doesn't know how to trace through your callback.

This may be an XY problem: perhaps we should step back, and you can describe what larger problem you're trying to solve here?

Also, since you mentioned docs: I should mention that although we don't have documentation of io_callback with custom_jvp, we do have an example in the case of pure_callback: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp The io_callback conceptually will be more-or-less identical.

GeophyAI commented 2 weeks ago

Hi, I misunstood your suggestions at that time. So I need to warp my callback functions with custom_vjp, i.e. both save2file and loadfromfile need their own fwd and bwd, right?

Actually, my functions f needs several inputs to calculate the final results z=f(x,y,z) iteratively, so I use a lax.scan here. Instead of return (x,y,z) in fwd at each step, I want to save (x,y,z) to disk in the foward pass to save GPU memory, then reload them from disk at every step and use jax.vjp to calculate the vjp_func and gradients.

yashk2810 commented 2 weeks ago

OOC, why are you saving the residuals to disk? What problem are you trying to solve? Maybe we can suggest something else?

GeophyAI commented 2 weeks ago

I am working on full waveform inversion, an inverse problem based on wave equation. I need to perform a time step forward scheme for calculating the final results and compare it with the target. I don’t want to return the inputs of fwd for bwd, I want to save and reload them. I think it’s something like checkpoint, but with a save_to_disk version.

---- Replied Message ---- | From | Yash @.> | | Date | 09/14/2024 06:53 | | To | google/jax @.> | | Cc | Shaowen @.>, Author @.> | | Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |

OOC, why are you saving the residuals to disk? What problem are you trying to solve? Maybe we can suggest something else?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

yashk2810 commented 2 weeks ago

Maybe offloading it to host is what you want here and you can achieve that via different API instead of using io_callback which are more efficient.

Take a look at this: https://github.com/google/jax/blob/0daca4646428abbdb728c48f73429460c5456f87/tests/memories_test.py#L1447-L1488 to do this via remat.

You can also run manual device_put in your custom_vjp function on residuals to offload them to pinned_host memory space if you want and do the opposite on your bwd function to reload them.

GeophyAI commented 2 weeks ago

Thank you, @yashk2810. Yes, transfer to host is another choice. Does the following codes make sense, how to know the device of a DynamicJaxprTracer?


def f(x, y):
    z = x**2 + y**2
    return jnp.sum(z)

def fwd(x, y):
    return f(x, y), (jax.device_get(x), jax.device_get(y))

def bwd(res, g):
    _, vjp_fun = jax.vjp(f, *res)
    grads = vjp_fun(g)
    return grads

f_vjp = jax.custom_vjp(f)
f_vjp.defvjp(fwd, bwd)

f_vjp = jax.jit(f_vjp)

x = jnp.array([1.,2.])
y = jnp.array([2.,3.])

def loss1(x, y):
    return (f_vjp(x, y)**2).sum()
def loss2(x, y):
    return (f(x, y)**2).sum()

print('f=', f(x, y))
print('f_vjp=', f_vjp(x, y))
print(jax.grad(loss1)(x, y))
print(jax.grad(loss2)(x, y))
yashk2810 commented 2 weeks ago

No, like this:

def test_offload(self):
    def f(x, y):
      z = x**2 + y**2
      return jnp.sum(z)

    def fwd(x, y):
      return f(x, y), jax.device_put(
          (x, y),
          SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host'))

    def bwd(res, g):
      reloaded_res = jax.device_put(
          res, SingleDeviceSharding(jax.devices()[0], memory_kind='device'))
      _, vjp_fun = jax.vjp(f, *reloaded_res)
      grads = vjp_fun(g)
      return grads

    f_vjp = jax.custom_vjp(f)
    f_vjp.defvjp(fwd, bwd)

    f_vjp = jax.jit(f_vjp)

    x = jnp.array([1.,2.])
    y = jnp.array([2.,3.])

    def loss1(x, y):
      return (f_vjp(x, y)**2).sum()
    def loss2(x, y):
      return (f(x, y)**2).sum()

    print('f=', f(x, y))
    print('f_vjp=', f_vjp(x, y))
    print(jax.grad(loss1)(x, y))
    print(jax.grad(loss2)(x, y))

this is the output I got:

f= 18.0
f_vjp= 18.0
[ 72. 144.]
[ 72. 144.]

Note that I ran this on TPU. Which hardware are you planning on run?

GeophyAI commented 2 weeks ago

@yashk2810, I'm coding on a GPU. Really thank you, I have implement this in my original codes, it does works for reducing the memory with slowing down the efficiency (I have to do that due to the OOM problem). I'm still new to JAX and I will try to find some strategy more efficient that can solve my problem.

yashk2810 commented 2 weeks ago

Yeah, this should work on a GPU too. The efficiency might be affected because you incur transfer costs.

Can you tell me how much of a slowdown you see? If you have enough compute, we should be able to hide the transfers behind compute.

GeophyAI commented 2 weeks ago

It's 3.7s/it v.s 0.7s/it on a RTX 3090, around 5x slower. I think it's acceptable for me, since without this the code can even not run.

Sincerely, Shaowen ---- Replied Message ---- | From | Yash @.> | | Date | 9/14/2024 10:11 | | To | @.> | | Cc | Shaowen @.>, @.> | | Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |

Yeah, this should work on a GPU too. The efficiency might be affected because you incur transfer costs.

Can you tell me how much of a slowdown you see? If you have enough compute, we should be able to hide the transfers behind compute.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

yashk2810 commented 2 weeks ago

cc @jaro-sevcik @nouiz can you improve the speed of transfers to pinned_host and back here? I wonder how fast this is on the latest GPUs.

yashk2810 commented 2 weeks ago

I ran your test internally on H100 GPU

but got these timings:

f= 18.0
f_vjp= 18.0
[ 72. 144.]
0.17064547538757324
[ 72. 144.]
0.17866754531860352
GeophyAI commented 2 weeks ago

Oh sorry, I thought you mentioned my original codes, the values I provided before are based on my original implementation in inverse problem.

Sincerely, Shaowen ---- Replied Message ---- | From | Yash @.> | | Date | 9/14/2024 10:33 | | To | @.> | | Cc | Shaowen @.>, @.> | | Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |

I ran your test internally on H100 GPU

but got these timings:

f= 18.0 f_vjp= 18.0 [ 72. 144.] 0.17064547538757324 [ 72. 144.] 0.17866754531860352

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

yashk2810 commented 2 weeks ago

the values I provided before are based on my original implementation in inverse problem.

Yeah I got that. Can you try running the test here to see the timing difference? So we know if this is a GPU generation problem or not. In other words, if you run your original code on H100, what speed difference do you see?

GeophyAI commented 2 weeks ago

For jax.grad(loss1)(x,y), it's 3.02 ms ± 7.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each), For jax.grad(loss2)(x, y), it's 4.39 ms ± 34 μs per loop (mean ± std. dev. of 7 runs, 100 loops each).

GeophyAI commented 2 weeks ago

@yashk2810, sorry, I donot have a H100, that's why I am doing so here.

yashk2810 commented 2 weeks ago

Yeah, I was just wondering how the GPU generation affected the speed of transfers :)

But I am glad that your problem is resolved. Feel free to close this issue if so :)

GeophyAI commented 2 weeks ago

@yashk2810 Thank you very much. And thanks for everyone here.

qGentry commented 1 week ago

Hi guys, @yashk2810 @GeophyAI, sorry for tagging but the issue is already closed. I'm trying to achieve similar results and was wondering how to easily combine proposed solution with the scan. I have a transformer block and currently only saving residual stream between each transformer block by specifying remat_policy for transformer block and then scanning this block over inputs. Instead of saving residuals on GPU memory, I want to offload them to CPU. I was trying to replicate your results by implementing "offload" function but looks like it doesn't work for me:

def offload(x):
    return x

def offload_fwd(x):
    return x, jax.device_put(x, sharding.with_memory_kind("pinned_host"))

def offload_bwd(res, g):
    reloaded_res = jax.device_put(res, sharding.with_memory_kind("device"))
    _, vjp_fun = jax.vjp(offload, reloaded_res)
    grads = vjp_fun(g)
    return grads

f_vjp = jax.custom_vjp(offload)
f_vjp.defvjp(offload_fwd, offload_bwd)

def transformer_block(x):
    # transformer calculation: 1 attention + 1 MLP
    # x - block's output
    x = f_vjp(x)
    return x

# embeddings - inputs
rematted_transformer_block = nn.remat(transformer_block, policy=jax.checkpoint_policies.nothing_saveable())
scanned_transformer_block = nn.scan(transformer_block)

outputs = scanned_transformer_block(embeddings)

but it doesn't seems to work - memory consumption hasn't changed, speed is the same, no D2H and H2D copies in trace. Am I doing something wrong?

yashk2810 commented 1 week ago

Just try using this? https://cs.opensource.google/jax/jax/+/main:tests/memories_test.py;l=1513?q=memories_test.py&ss=jax%2Fjax

A remat policy should be enough I think.

qGentry commented 1 week ago

@yashk2810 Yeah, remat policy was my first idea but for some reason it also doesn't work either, same results, no decrease in memory usage, nothing on traces. I've tried something like this:

def transformer_block(x):
    # transformer calculation: 1 attention + 1 MLP
    # x - block's output
    x = jax.ad_checkpoint.checkpoint_name(x, "embeddings")
    return x

rematted_transformer_block = nn.remat(
    transformer_block,
    policy=jax.checkpoint_policies.save_and_offload_only_these_names(
        names_which_can_be_saved=[],
        names_which_can_be_offloaded=["embeddings"],
        offload_src="device",
        offload_dst="pinned_host",
))
scanned_transformer_block = nn.scan(transformer_block)
# embeddings - inputs
outputs = scanned_transformer_block(embeddings)

My guess is that it doesn't work because this remat policy only affects single transformer block instead of entire scanned function. I've also tried this remat policy to entire 'scanned_transformer_block', but it just leads to x5-x10 memory consumption and just OOM so I was looking for alternative solutions to my problem.

yashk2810 commented 1 week ago

You need to call grad for remat to activate. Can you try with that? Also there should be residuals generated that are you offloading which you can check by looking at the jaxpr.

qGentry commented 1 week ago

Yeah, provided snippet lacks loss and grad but it is indeed called in my code. I'll try to implement small repro script from scratch tomorrow to see if I'm able to reproduce it without providing my entire codebase. Thanks!

qGentry commented 1 week ago

@yashk2810 Hi, I've implemented and tested simple repro and got same results with jax's remat and thought it would be better to open another issue https://github.com/jax-ml/jax/issues/23869