nshepperd / flash_attn_jax

JAX bindings for Flash Attention v2
BSD 3-Clause "New" or "Revised" License
62 stars 0 forks source link

Encounter edge_padding_low error while testing #7

Closed sh0416 closed 1 month ago

sh0416 commented 1 month ago
tests/test_flash.py:200: in func
    o, bwd = jax.vjp(fwd,q,k,v)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/api.py:2169: in vjp
    return _vjp(
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/api.py:2178: in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:143: in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:132: in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/profiler.py:335: in wrapper
    return func(*args, **kwargs)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py:774: in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/linear_util.py:192: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
tests/test_flash.py:199: in fwd
    return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
src/flash_attn_jax/flash.py:237: in flash_mha
    o = _flash_mha_vjp(q,k,v,dict(softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size))
src/flash_attn_jax/flash.py:213: in fwd
    out, lse = _flash_mha_fwd(q,k,v, **config)
src/flash_attn_jax/flash.py:53: in _flash_mha_fwd
    return tuple(_flash_mha_fwd_p.bind(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size))
src/flash_attn_jax/flash.py:135: in mha_fwd_batch
    out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/experimental/custom_partitioning.py:500: in _custom_partitioning_lowering_rule
    return mlir.lower_fun(
src/flash_attn_jax/flash_hlo.py:111: in _flash_mha_fwd_hlo_lowering
    q_padded = mlir.hlo.PadOp(q,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <jaxlib.mlir.dialects._stablehlo_ops_gen.PadOp object at 0x7fbb843c3a40>, operand = <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb40581e30>
padding_value = <jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbb483102b0>, edge_padding_low = [0, 0, 0, 0], edge_padding_high = [0, 0, 0, 5]
interior_padding = [0, 0, 0, 0]

    def __init__(self, operand, padding_value, edge_padding_low, edge_padding_high, interior_padding, *, loc=None, ip=None):
      operands = []
      results = []
      attributes = {}
      regions = None
      operands.append(_get_op_result_or_value(operand))
      operands.append(_get_op_result_or_value(padding_value))
      _ods_context = _ods_get_default_loc_context(loc)
      attributes["edge_padding_low"] = (edge_padding_low if (
      isinstance(edge_padding_low, _ods_ir.Attribute) or
      not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
        _ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(edge_padding_low, context=_ods_context))
      attributes["edge_padding_high"] = (edge_padding_high if (
      isinstance(edge_padding_high, _ods_ir.Attribute) or
      not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
        _ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(edge_padding_high, context=_ods_context))
      attributes["interior_padding"] = (interior_padding if (
      isinstance(interior_padding, _ods_ir.Attribute) or
      not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
        _ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(interior_padding, context=_ods_context))
      _ods_successors = None
>     super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
E     RuntimeError: Invalid attribute value for the key "edge_padding_low" when attempting to create the operation "stablehlo.pad" (Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details))

../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jaxlib/mlir/dialects/_stablehlo_ops_gen.py:3466: RuntimeError

I got this error while testing with pytest. Is there any simple solution to resolve this error?

sh0416 commented 1 month ago

I installed from source, FYI.

nshepperd commented 1 month ago

Huhh. What jax and jaxlib versions?

sh0416 commented 1 month ago
jax:    0.4.28                                                                                                                                                                                                                                                                                                    
jaxlib: 0.4.28                                                                                                                                                                                                                                                                                                    
numpy:  1.26.4                                                                                                                                                                                                                                                                                                    
python: 3.9.19 (main, May  6 2024, 19:43:03)  [GCC 11.2.0]                                                                                                                                                                                                                                                        
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]                                                                                                                                                                                                                                                           
process_count: 1                                                                                                                                                                                                                                                                                                  
platform: uname_result(system='Linux', node='gpusystem', release='5.15.0-100-generic', version='#110-Ubuntu SMP Wed Feb 7 13:27:48 UTC 2024', machine='x86_64') 

This one.. seems that the problem is in jax, not your code.

nshepperd commented 1 month ago

Okie, I'll do some testing, should be easy to fix.

nshepperd commented 1 month ago

Tested this on 0.4.28, this should work now.