Open sbodenstein opened 6 days ago
@kaixih
I just created a PR to fix/workaround this long-standing cuDNN limitation with dbias.
cuDNN only supports dbias when the batch size is 1 (here). For batch sizes greater than 1, the cuDNN SDPA API returns an all-zero dbias tensor (which isn't ideal, but that's the current behavior). When using vmap, the API only detects a singleton batch before the vmap is applied, causing it to mistakenly set has_dbias to True
, which leads to this SegFault.
This PR resolves the issue by resetting has_dbias to False, and returns an all-zero dbias tensor as in the non-vmap version.
To summarize, the behavior is:
no vmap:
bias (1, ...) => has_dbias=True => OK
bias (B, ...) => has_dbias=False => OK, but dbias is all-zero
vmap:
bias (1, ...) => has_dbias=True => OK
bias (B, ...) => has_dbias=True => Segfault # which is fixed to be
bias (B, ...) => has_dbias=False => OK, but dbias is all-zero
Also, @Cjkkkk for comments on the dbias support from cudnn.
Having d_bias be zeroes when there is a batch dimension is definitely wrong behaviour: it seems like a silent failure that would be extremely difficult to debug for users when their training curves just don't look right. I think we should fail in this case rather until cuDNN supports this.
We actually had some internal discussions earlier. I think the dilemma is this: it seems that some models don't require d_bias
but still want to benefit from cuDNN flash attention. If we simply throw an error when the bias has a batch size other than 1, they might complain. Ideally, if we could detect whether d_bias
is needed, we could decide whether to error out or proceed. However, it seems that no such mechanism exists in JAX. Instead, we currently have silent all-zero biases when it's not supported (which is very bad, I know....). Do you think issuing a warning would help? Or should we just throw an error whenever cuDNN+bias is used with its batch size larger than 1?
I think that this should either work correctly, or it should throw an error. It should definitely not seg fault or give incorrect gradients (even with a warning). The latter is just too dangerous for users, who expect JAX APIs to do what they think they will do, or waste massive compute on runs with a major bug. Would you agree @hawkinsp?
I do. Wrong outputs aren't ok, because they are the kind of thing that makes people lose trust in a library.
Sure, I will essentially move this logic to the public API to throw an error for the cudnn exec path.
By the way, do you think we should apply this error-throwing behavior to the public API or the cuDNN API? Perhaps it should only be applied to the public API, allowing power users who are certain they don't have d_bias to use the private cuDNN API.
Private APIs (jax._src
, etc.) can do whatever you like. If a private API breaks, you get to keep both pieces.
Description
Run
produces
System info (python version, jaxlib version, accelerator, etc.)