google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Unexpected AllReduce in backward pass with shard_map, custom_vjp, and pallas #21855

Open southfreebird opened 2 months ago

southfreebird commented 2 months ago

Description

Hello team, I have found unexpected behavior when using shard_map, custom_vjp, and pallas (triton works the same) simultaneously during auto grad. Please take a look at the code snippet:

import jax
import jax.numpy as jnp

import flax.linen as nn

from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils, pallas as pl
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

def _copy_kernel_pallas(operand_a_ref, operand_b_ref):
    operand_b_ref[...] = operand_a_ref[...]

@jax.custom_vjp
def simple_func(inputs):
    return pl.pallas_call(
        _copy_kernel_pallas,
        out_shape=jax.ShapeDtypeStruct(inputs.shape, inputs.dtype),
    )(inputs)

simple_func.defvjp(
    lambda x: (simple_func(x), (x)),
    lambda _, out_grad: (simple_func(out_grad),)
)

if __name__ == "__main__":
    rng = jax.random.PRNGKey(0)

    batch_size = 4
    embedding = 128
    mp = 2

    inputs = jax.random.normal(
        rng, (batch_size, embedding), dtype=jnp.bfloat16
    )

    num_devices = len(jax.devices())
    device_mesh = mesh_utils.create_device_mesh(mesh_shape=(num_devices))

    mesh = Mesh(
        device_mesh.reshape(num_devices // mp, mp),
        axis_names=("data", "model"),
    )
    rules = (("batch", "data"), ("hidden", "model"))

    inputs = jax.device_put(
        inputs,
        NamedSharding(
            mesh,
            nn.logical_to_mesh_axes(("batch", None), rules),
        ),
    )

    func = shard_map(
        simple_func,
        in_specs=P("data", None),
        out_specs=P("data", None),
        mesh=mesh,
        check_rep=False,
    )

    def grad_func(*args, **kvargs):
        return func(*args, **kvargs).sum()

    grad_func = jax.jit(jax.grad(grad_func))

    out = grad_func(inputs)
    out.block_until_ready()

    compiled_model = grad_func.lower(inputs).compile()
    print(compiled_model.as_text())

Output:

HloModule jit_grad_func, is_scheduled=true, entry_computation_layout={()->bf16[1,128]{1,0}}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=8, frontend_attributes={fingerprint_before_lhs="3856bd19cf20f377dd2d1fa37620e4f3"}

%fused_broadcast () -> bf16[1,128] {
  %constant_4 = bf16[] constant(0.5)
  ROOT %broadcast.3.1 = bf16[1,128]{1,0} broadcast(bf16[] %constant_4), dimensions={}, metadata={op_name="jit(grad_func)/jit(main)/div" source_file="/test_shard.py" source_line=66}
}

%region_0.8 (Arg_0.9: bf16[], Arg_1.10: bf16[]) -> bf16[] {
  %Arg_1.10 = bf16[] parameter(1)
  %Arg_0.9 = bf16[] parameter(0)
  ROOT %add.1 = bf16[] add(bf16[] %Arg_0.9, bf16[] %Arg_1.10), metadata={op_name="jit(grad_func)/jit(main)/transpose(jvp(jit(shmap_body)))/add" source_file="/test_shard.py" source_line=66}
}

ENTRY %main.19_spmd () -> bf16[1,128] {
  %loop_broadcast_fusion = bf16[1,128]{1,0} fusion(), kind=kLoop, calls=%fused_broadcast, metadata={op_name="jit(grad_func)/jit(main)/div" source_file="/test_shard.py" source_line=66}
  %custom-call.2.0 = bf16[1,128]{1,0} custom-call(bf16[1,128]{1,0} %loop_broadcast_fusion), custom_call_target="__gpu$xla.gpu.triton", operand_layout_constraints={bf16[1,128]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(grad_func)/jit(main)/transpose(jvp(jit(shmap_body)))/jit(wrapped)/pallas_call[name=_copy_kernel_pallas which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(1, 128), dtype=bfloat16),) out_shapes=(ShapeDtypeStruct(shape=(1, 128), dtype=bfloat16),) debug=False interpret=False grid_mapping=GridMapping(grid=(), block_mappings=(None, None), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]" source_file="/test_shard.py" source_line=66}, backend_config={debug = false, grid_x = 1 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIR19.0.0git\00\01/\07\01\05\09\19\01\03\0F\03\13\13\17\1B\1F#'+/3\05\057;\03uK\17\01I\0F\0F\0F\0F\07\0B\13\13\13\0B\0F\0F\13\0B\0F\0B\0B\0B\0F\13\0F\0F\0B\13\0F\0F\0B\13\0F\0F\0B\13\0F\0B\13\0B\05\03Y\01\15\0F\1B\1B\13\17\17\1B\17\07\0B\03\039\02\DE\02\1DG\15\1D#\15\11\01\05\11\01\01\1F\05\1F\11\01\02\04#\01\01\01\03\03\1B\1D\05!\15%)\11\13\00\01\05\11\11\05#\11\01\81\0D\0F\05%\05'\1D\13'\17\0B\19\01\15+1\1D-/\05)\17\0B#\01\1539\1D57\05+\17\0B1\01\15;A\1D=?\05-\17\0B\85\01\1DCE\05/\17\0B\8D\01\051#arith.overflow<none>\00\01\02\02\1B\05\05\02\04\01\1B\05\05\02\04\15\1B\03\05\01\1B\05\05\05\01\1B\03\02\04\01\1B\05\05\02\04\11\05\05\15\15\01\07\01\09!tt.ptr<bf16>\00\04\06\04\05\01P\09\01\07\04\E2\03\03\01\05\0FP\09\03\07\04\B6\03\03A\83\05))\00\05B\03\05\03\07\07F\03\07\03\09\03\05\0B\06\03\03\03\03\07\15B\03\09\03\01\03\06\03\03\03\03\0B\17F\03\0B\03\03\05\09\0D\05B\03\0D\03\0B\07F\03\0F\03\03\03\11\15B\03\07\03\01\03\06\03\03\03\03\15\17F\03\0B\03\03\05\13\17\03\06\03\03\05\03\01\09\06\03\03\05\05\1B\0F\09\06\03\03\05\05\1D\19\0DF\03\11\03\0D\03\1F\05B\01\05\03\07\07F\01\07\03\09\03#\0B\06\01\03\03\03%\15B\01\09\03\01\03\06\01\03\03\03)\17F\01\0B\03\03\05'+\05B\01\0D\03\0B\07F\01\0F\03\03\03/\15B\01\07\03\01\03\06\01\03\03\033\17F\01\0B\03\03\0515\03\06\01\03\05\03\03\09\06\01\03\05\059-\09\06\01\03\05\05;7\0DF\01\11\03\0D\03=\11D\01\13\05=!\13\00\09\06\03\01\05\01\00\1A\093\AE\02\13\15\13\19\AA\02\0F!)/\0B\13\0F\0D\0B\0B\15\0F\19\17\0D\0F\0D\07\11builtin\00tt\00arith\00module\00splat\00make_range\00expand_dims\00addptr\00broadcast\00load\00func\00store\00return\00constant\00muli\00/test_shard.py\00_copy_kernel_pallas\00tt.divisibility\00public\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 1, 1)], [None, None]), CustomNode(Slice[(0, 128, 1)], [None, None]))), (1, 128), ())], []),))]\00simple_func\00<lambda>\00grad_func\00<module>\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 1, 1)], [None, None]), CustomNode(Slice[(0, 128, 1)], [None, None]))), (1, 128), ())], []),))]\00\08K\15\05\01\01\0B3\1F\01\13C\05\05\07\03\05\03\0D\03\93\05\0D\07\03\07\11\1F\0B\0B/\01\07\01\03\07\1F\0B\0B", name = "_copy_kernel_pallas", num_stages = 3 : i32, num_warps = 4 : i32}
  %all-reduce-start = bf16[1,128]{1,0} all-reduce-start(bf16[1,128]{1,0} %custom-call.2.0), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=%region_0.8, metadata={op_name="jit(grad_func)/jit(main)/transpose(jvp(jit(shmap_body)))/psum[axes=(\'model\',) axis_index_groups=None]" source_file="/test_shard.py" source_line=66}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false},"force_earliest_schedule":false}
  ROOT %all-reduce-done = bf16[1,128]{1,0} all-reduce-done(bf16[1,128]{1,0} %all-reduce-start), metadata={op_name="jit(grad_func)/jit(main)/transpose(jvp(jit(shmap_body)))/psum[axes=(\'model\',) axis_index_groups=None]" source_file="/test_shard.py" source_line=66}
}

As you can see, in the backward pass, when my output is replicated by 2 GPUs (output sharding does not contain the hidden axis, so it's not parallelized by the model), I encounter an unexpected AllReduce. As I understand it, this is needed to sync data by the model Mesh axis between 2 GPUs. However, in my particular case, I don't need this synchronization and want to get rid of it, but I can't find a good solution for this. It doesn't reproduce if we don't have a custom_call. Could you help me with that?

Thank you!

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.28 jaxlib: 0.4.28 numpy: 1.24.3 python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)] process_count: 1 platform: uname_result(system='Linux', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')

southfreebird commented 2 months ago

I've found the documentation about unmapped inputs and outputs, and it seems like it addresses the problem I'm fighting against. However, I can't figure out if there is a way to make it work with Pallas.

mattjj commented 1 month ago

Sorry for the slow response!

One workaround is not to use unmapped inputs or outputs in shard_map: instead, do any broadcasts or sums you need outside the shard_map, and rely on XLA to generate the right code automatically (i.e. turning explicit broadcasting or jnp.tile calls on inputs into no-ops, and turning jnp.sums to reduction collectives).

Another workaround is not to use autodiff at all, and to use a custom_vjp around the shard_map.

I find both of those workarounds pretty unsatisfactory!

A better solution would be to, as you describe, make the pallas_call work with shard_map's check_rep=True. Logically that means telling shard_map's internals how your particular pallas_call affects cross-device replication; then, once shard_map can track cross-device replication through it, we can set check_rep=True and thus avoid defensive allreduces (as described in the doc you linked). (We'll also avoid the error "No replication rule for pallas_call", which I mention here so that this issue is more discoverable.) Unfortunately we don't yet have public APIs for registering that information.

Until we have a public API for this, here's a sketch of how to do it using JAX-internal APIs, which may change without warning:

MeshAxisName = str

from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental.shard_map import register_norewrite, register_check

@register_check(pallas_call_p)
def pallas_check(
    mesh: Mesh, *in_rep: set[MeshAxisName], name: str, **params
) -> list[set[MeshAxisName]]:
  del params  # ignore other pallas_call_p params
  if name == '_copy_kernel_pallas':
    inputs_rep, = in_rep  # single input in this case
    return inputs_rep  # no change in replication, single output
  else:
    raise NotImplementedError(name)
register_norewrite(pallas_call_p)

I inlined it into your file and also set check_rep=True on the pallas_call, then got this output:

HloModule jit_grad_func, is_scheduled=true, entry_computation_layout={()->bf16[1,128]{1,0}}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=8

ENTRY %main.36_spmd () -> bf16[1,128] {
  %constant.9 = bf16[] constant(1)
  ROOT %broadcast.4 = bf16[1,128]{1,0} broadcast(bf16[] %constant.9), dimensions={}, metadata={op_name="jit(grad_func)/jit(main)/convert_element_type[new_dtype=bfloat16 weak_type=False sharding=None]" source_file="/usr/local/google/home/mattjj/packages/jax/28155.py" source_line=82}
}

What do you think?

mattjj commented 1 month ago

Actually, a better rule for pallas calls that don't communicate might look something like this:

from typing import Sequence

MeshAxisName = str

from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental.shard_map import (register_rewrite, register_check,
                                        _standard_rewrite_rule)

@register_check(pallas_call_p)
def pallas_check(
    mesh: Mesh, *in_rep: set[MeshAxisName],
    name: str, out_shapes: Sequence[jax.ShapeDtypeStruct], **_,
) -> list[set[MeshAxisName]]:
  num_outputs = len(out_shapes)
  if name == '_copy_kernel_pallas':
    in_rep_ = [r for r in in_rep if r is not None]
    if in_rep_ and not in_rep_[:-1] == in_rep_[1:]:
      raise Exception("rewrite rule failed")
    return [in_rep_[0]] * num_outputs
  else:
    raise NotImplementedError(name)

@register_rewrite(pallas_call_p)
def pallas_rewrite(mesh, in_rep, *args, **params):
  name = params['name']
  if name == '_copy_kernel_pallas':
    return _standard_rewrite_rule(pallas_call_p, mesh, in_rep, *args, **params)
  else:
    raise NotImplementedError(name)
mattjj commented 1 month ago

Oops, this issue isn't closed; I closed the wrong issue with that PR.