xai-org / grok-1

Grok open release
Apache License 2.0
49.42k stars 8.33k forks source link

Nr. of devices needed #38

Open zcobol opened 6 months ago

zcobol commented 6 months ago

Running python run.py on a single Nvidia GPU it fails with ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

Can the nr of devices be adjusted to 1 only?

nickorlabs commented 6 months ago

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 2 devices in mesh Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 282, in initialize runner.initialize( File "/opt/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

yarodevuci commented 6 months ago

i did put 1 instead of 8

yarodevuci commented 6 months ago

I keep getting same error : PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: 'D:\dev\shm\tmpp53ohpcl'

KHARAPSY commented 6 months ago

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 2 devices in mesh Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 282, in initialize runner.initialize( File "/opt/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

I have the same issues, is there a way to resolve this?

zRzRzRzRzRzRzR commented 6 months ago

same issue even all requirements install. I am using 8 GPUs

nickorlabs commented 6 months ago

I have 2 GPUs and everything installed ok as well.

bluevisor commented 6 months ago

in run.py, I changed line 60: local_mesh_config=(1, 8), to local_mesh_config=(1, 1),

(I have 1 3090)

nickorlabs commented 6 months ago

Ok got a little further this time!

Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 294, in initialize params = runner.load_or_init(dummy_data) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 238, in load_or_init state = xai_checkpoint.restore( ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 196, in restore loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 107, in load_tensors return [f.result() for f in fs] ^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 107, in return [f.result() for f in fs] ^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 449, in result return self.get_result() ^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 401, in get_result raise self._exception File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 72, in fast_unpickle with copy_to_shm(path) as tmp_path: File "/opt/anaconda3/envs/groq-1/lib/python3.11/contextlib.py", line 137, in enter return next(self.gen) ^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 52, in copy_to_shm shutil.copyfile(file, tmp_path) File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 269, in copyfile _fastcopy_sendfile(fsrc, fdst) File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 158, in _fastcopy_sendfile raise err from None File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 144, in _fastcopy_sendfile sent = os.sendfile(outfd, infd, offset, blocksize) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'

I have 2 Quadro 5000s, I guess we do not have enough vRAM doh.

bluevisor commented 6 months ago

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

nickorlabs commented 6 months ago

I have 128 GB on this rig, with the two cards its like 32 GB, this is why I assumed vRAM. Maybe I could be wrong.

bluevisor commented 6 months ago

bummer... guess we'll just have to wait for gguf...

nickorlabs commented 6 months ago

Possibly. I might spin up a runpod, or wait for GGUF, I was reading people needing 8 GPUs.

thisIsLoading commented 6 months ago

after changing the mesh to (1, 6) i get this error:

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 6 devices in mesh
2024-03-18 15:58:10.001688: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, vocab_size=131072, widening_factor=8, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, data_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
INFO:rank:State sharding type: <class 'model.TrainingState'>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>
    main()
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize
    params = runner.load_or_init(dummy_data)
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 235, in load_or_init
    state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
ValueError: One of pjit outputs with pytree key path .params['transformer/decoder_layer_0/moe/linear']['w'] was given the sharding of NamedSharding(mesh=Mesh('data': 1, 'model': 6), spec=PartitionSpec(None, 'data', 'model')), which implies that the global size of its dimension 2 should be divisible by 6, but it is equal to 32768 (full shape: (8, 6144, 32768))

looks like it doesnt like 6 either

thisIsLoading commented 6 months ago

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

nickorlabs commented 6 months ago

get it up and running?

yarodevuci commented 6 months ago

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

did it work after?

thisIsLoading commented 6 months ago

@yarodevuci still downloading weights.

i was under the impression that the test wiould download stuff (looks like i'm spoiled by the huggingface api which does it) will report tomorrow. right now it tells me 17 more hours (dont know why so long, am on 750mbit but magnet download is painfully slow)

nickorlabs commented 6 months ago

Im seeding (again), took me most the evening last night to download, and I have 2000mbit download

ad1tyac0des commented 6 months ago

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

My system has 192GB of RAM, I also encountered same. OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpbeofn6hn

yarodevuci commented 6 months ago

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

toughcoding commented 6 months ago

Is anybody here who saw live presentation where X developers run it using exact commands or we all trying to test it for them?

pwxpwxtop commented 6 months ago

坑爹,为了下载它,花费了我一天的心血

KHARAPSY commented 6 months ago

I succeeded increasing space and get rid of this error "OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'"

but in exchange to do that I end up with system crashed instead, so I will give up for now. I don't have enough RAM to run Grok-1 neither enough money to upgrade my hardware"

zRzRzRzRzRzRzR commented 6 months ago

same issue even all requirements install. I am using 8 GPUs

I change it to 8 x A100 GPU and it cost 65G memory in per GPU to run this model, The resources required to run this model are a bit large. and the requirement is instealled successfull.

Finally run with this code

AX_TRACEBACK_FILTERING=off python run.py

and its work

image
ad1tyac0des commented 6 months ago

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

I had about 100GB of storage left, but at the moment when the error occurred, my system's RAM was completely utilized. This seems to be the reason why the program stopped. It looks like the problem was due to the high RAM usage rather than storage space.

thisIsLoading commented 6 months ago

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs.

only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

thisIsLoading commented 6 months ago

ok, got a little further but still no cigar:

(.venv) loading@ai:~/PycharmProjects/grok-1$ python run.py                                                                                                                                           │└───────────────────────────────────────────┴───────────────────────────────────────────┘│      6 netns           [netns]                                                  1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
                                                                                                                                                                                                     │┌─┤net├────────────────────────────────────────────────────────────────────┤‹b eno2 n›├─┐│      7 kworker/0:0-eve [kworker/0:0-events]                                     1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA                                    ││10K                                                      ⣀                             ││      8 kworker/0:0H-ev [kworker/0:0H-events_highpri]                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory                                ││                                                         ⣿    ┌─┤Download├───────────┐ ││     10 mm_percpu_wq    [mm_percpu_wq]                                           1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...                                                                                                    ││                                                       ⣴⣷⣿ ⣷  │▼ Byte:     1.89 KiB/s│ ││     11 rcu_tasks_rude_ [rcu_tasks_rude_]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Detected 6 devices in mesh                                                                                                                                                                 ││                                             ⣦   ⣴ ⣄⣤  ⣿⣿⣿ ⣿  │▼ Bit:      15.4 Kibps│ ││     12 rcu_tasks_trace [rcu_tasks_trace]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
2024-03-19 07:55:00.536833: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver i││                                             ⣿⣶⣦⣶⣿⣤⣿⣿⣶⣤⣿⣿⣿⣾⣿⣾ │▼ Total:       313 GiB│ ││     13 ksoftirqd/0     [ksoftirqd/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
s older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility pa││                                              ⣿⠻⠟⣿⠻⣿⠻⡿⠻⠻⠻⡿⣿⠻⠻ │                      │ ││     14 rcu_sched       [rcu_sched]                                              1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ckages.                                                                                                                                                                                              ││                                              ⣿  ⣿ ⣿      ⠻   │▲ Byte:     6.11 KiB/s│ ││     15 migration/0     [migration/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=6, num_layers=6││                                              ⠻  ⣿ ⣿          │▲ Bit:      48.5 Kibps│ ││     16 idle_inject/0   [idle_inject/0]                                          1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
4, vocab_size=131072, widening_factor=6, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, d││                                                 ⠈ ⣿          │▲ Total:       182 GiB│ ││     18 cpuhp/0         [cpuhp/0]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ata_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output││50K                                                ⣿          └─┤Upload├─────────────┘ ││     19 cpuhp/1         [cpuhp/1]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>                                        │└───────────────────────────────────────────────────────────────────────────────────────┘└─┤↑ select ↓├─┤info ↲├─┤terminate├─┤kill├─┤interrupt├─────────────────────────────────────────────┤5/665├─┘
INFO:rank:(1, 256, 6144)                                                                                                                                                                             ├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │Every 2.0s: nvidia-smi                                                                                                                                                   ai: Tue Mar 19 07:56:37 2024
INFO:rank:State sharding type: <class 'model.TrainingState'>                                                                                                                                         │
INFO:rank:(1, 256, 6144)                                                                                                                                                                             │Tue Mar 19 07:56:37 2024
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │+---------------------------------------------------------------------------------------+
INFO:rank:Loading checkpoint at ./checkpoints/ckpt-0                                                                                                                                                 │| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
Traceback (most recent call last):                                                                                                                                                                   │|-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>                                                                                                                           │| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
    main()                                                                                                                                                                                           │| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main                                                                                                                               │|                                         |                      |               MIG M. |
    inference_runner.initialize()                                                                                                                                                                    │|=========================================+======================+======================|
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize                                                                                                                    │|   0  NVIDIA GeForce RTX 4090        On  | 00000000:16:00.0 Off |                  Off |
    params = runner.load_or_init(dummy_data)                                                                                                                                                         │|  0%   31C    P8              23W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 238, in load_or_init                                                                                                                  │|                                         |                      |                  N/A |
    state = xai_checkpoint.restore(                                                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/checkpoint.py", line 218, in restore                                                                                                                    │|   1  NVIDIA GeForce RTX 4090        On  | 00000000:34:00.0 Off |                  Off |
    state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 342, in host_local_array_to_global_array                                  │|                                         |                      |                  N/A |
    out_flat = [                                                                                                                                                                                     │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 343, in <listcomp>                                                        │|   2  NVIDIA GeForce RTX 4090        On  | 00000000:52:00.0 Off |                  Off |
    host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,                                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind                                                                                 │|                                         |                      |                  N/A |
    return self.bind_with_trace(find_top_trace(args), args, params)                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 423, in bind_with_trace                                                                      │|   3  NVIDIA GeForce RTX 4090        On  | 00000000:70:00.0 Off |                  Off |
    out = trace.process_primitive(self, map(trace.full_raise, args), params)                                                                                                                         │|  0%   30C    P8              20W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive                                                                    │|                                         |                      |                  N/A |
    return primitive.impl(*tracers, **params)                                                                                                                                                        │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 250, in host_local_array_to_global_array_impl                             │|   4  NVIDIA GeForce RTX 4090        On  | 00000000:AC:00.0 Off |                  Off |
    for d, index in local_sharding.devices_indices_map(arr.shape).items()]                                                                                                                           │|  0%   32C    P8              29W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 110, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return common_devices_indices_map(self, global_shape)                                                                                                                                            │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 59, in common_devices_indices_map                                                  │|   5  NVIDIA GeForce RTX 4090        On  | 00000000:CA:00.0 Off |                  Off |
    return gspmd_sharding.devices_indices_map(global_shape)                                                                                                                                          │|  0%   30C    P8              16W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 898, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return gspmd_sharding_devices_indices_map(self, global_shape)                                                                                                                                    │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 826, in gspmd_sharding_devices_indices_map                                         │
    self.shard_shape(global_shape)  # raises a good error message                                                                                                                                    │+---------------------------------------------------------------------------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 122, in shard_shape                                                                │| Processes:                                                                            |
    return _common_shard_shape(self, global_shape)                                                                                                                                                   │|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 77, in _common_shard_shape                                                         │|        ID   ID                                                             Usage      |
    raise ValueError(                                                                                                                                                                                │|=======================================================================================|
ValueError: Sharding GSPMDSharding({devices=[1,1,6]<=[6]}) implies that array axis 2 is partitioned 6 times, but the dimension size is 32768 (full shape: (8, 6144, 32768), per-dimension tiling fact│|  No running processes found                                                           |
ors: [1, 1, 6] should evenly divide the shape)                                                                                                                                                       │+---------------------------------------------------------------------------------------+
(.venv) loading@ai:~/PycharmProjects/grok-1$  
malinichev commented 6 months ago

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd

surak commented 6 months ago

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs. only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

It is probably not. I have 4 A100 and 512gb per node as well and I am not sure I can run it. It's stuck at loading checkpoints for a while now.

Christmas-Wong commented 6 months ago

you should install jaxlib for cuda, so that your 8 GPUs can be detected. or you can set local_mesh_config=(1, 1), and grok will run on cpu.

SamKnightV commented 6 months ago
  • In the file checkpoint.pyI'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd

After did it i got an error zsh: killed python run.py

surak commented 6 months ago
  • In the file checkpoint.pyI'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd After did it i got an error zsh: killed python run.py

We are talking about machines with 512gb of RAM and hundreds of gb of VRAM not being able to run it, not in a laptop. You will have to wait for a WAY smaller version of it to run in a small machine.

SamKnightV commented 6 months ago

I have Imac with processor 3,6 GHz 10-Core Intel Core i9 Graphics AMD Radeon Pro 5300 4 GB and memory 16 GB 2667 MHz DDR4 ))) What i need to change?

surak commented 6 months ago

))))

You need a real data center gpu compute node with at least 8 x A100 with 80gb to run grok at this point. I doubt that any quantized version would fit on a Mac anytime soon, but who knows? )))

Na-Yun1990 commented 6 months ago

raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8) 好吧~.~老马是真的一点不留情面啊。普通人还是老老实实的玩玩grok或者gpt什么的吧。这玩意你设备不达标,他是死活不会让你用的。等老子发达了,再搞他24块4090组个服务器爽爽!

Na-Yun1990 commented 6 months ago

I have Imac with processor 3,6 GHz 10-Core Intel Core i9 Graphics AMD Radeon Pro 5300 4 GB and memory 16 GB 2667 MHz DDR4 ))) What i need to change?

You need to change all.Ordinary civilian equipment cannot run this.Maybe Amazon cloud server can run grok-1. But the price will definitely be high