google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Cannot interpret 'key<fry>' as a data type #21351

Closed chenyue-max closed 1 month ago

chenyue-max commented 3 months ago

Description

2024-05-21 07:41:32,181 ERROR worker.py:405-- Unhandled error (suppress with 'RAY IGNORE UNHANDLED ERRORS=1'): ec[36mray::MeshHostWorker.run executableFile "/root/cy/temp/geesibling/python/geesibling/adapters/jax/pipeline/devicecontext.py", line 365, in run executableself.do recv(instruction.micro batch id.File "/root/cy/temp/geesibling/python/geesibling/adapters/jax/pipeline/devicecontext.py", line 392, in do recvrecv buffercupy.zeros(var.aval.shape,dtype=var.aval.dtype)File "/root/miniconda3/envs/framework-cy/lib/python3.9/site-packages/cupy/ creation/basic.py", line 248, in zerosa=cupy.ndarray(shape,dtype,order=order)File "cupy/ core/core.pyx",line 132,in cupy. core.core.ndarray. newFile "cupy/ core/core.pyx",line 204,in cupy. core.core. _ndarray base. init.get dtype with itemsizeFile "cupy/ core/ dtvpe,pyx",line 61.in cupy, core. dtypeTypeError:Cannot interpret 'key' as a data type

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.7 jaxlib: 0.4.7 numpy: 1.23.0 python: 3.9.13 (main, Oct 13 2022, 21:15:33) [GCC 11.2.0] jax.devices (8 total, 8 local): [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0) StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0) ... StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0) StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0)] process_count: 1

superbobry commented 3 months ago

Can you provide a way for us to reproduce the error you're seeing?

The stack trace suggests the error is coming from cupy, so my guess would be that cupy doesn't accept a custom JAX dtype, but have a reproducer would still help to diagnose this.

chenyue-max commented 3 months ago
def do_recv(self, micro_batch_id, input_vars, src_rank, group_name='default'):
    src_gpu_idx = 0
    for var in input_vars:
        with cupy.cuda.Device(0):
            if var.aval.dtype==np.bool_:
                recv_buffer = cupy.zeros(var.aval.shape,dtype=np.int32)
            else:
                recv_buffer = cupy.zeros(var.aval.shape,dtype=var.aval.dtype)
        col.recv_multigpu(recv_buffer, src_rank,src_gpu_idx, group_name)
        cupy.cuda.Device(0).synchronize()
        recv_buffer = recv_buffer.get()
        if var.aval.dtype==np.bool_:
            recv_buffer = recv_buffer.astype(np.bool_)
        val = jax.device_put(recv_buffer)
        if var in self.buffers[-1]:
            self.buffers[-1][var] = val
        else:
            self.buffers[micro_batch_id][var] = val

When receiving data with dtype of fry type, how to receive fry type data here and how to process it?

jakevdp commented 3 months ago

It looks like you need something like this at the beginning of your function

if jax.dtypes.issubdtype(var.dtype, jax.dtypes.prng_key):
    var = jax.random.key_data(var)
    impl = jax.random.key_impl(var)

And if you need to convert back to a typed key, use var = jax.random.wrap_key_data(var, impl=impl).