alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.08k stars 357 forks source link

OOM while serving language models #759

Closed zhanyuanucb closed 1 year ago

zhanyuanucb commented 2 years ago

Please describe the bug Deploy Bloom-7b1 with Alpa and KubeRay. Got OOM while Bloom-7b1 is inferencing.

1:actor_name:DeviceMeshGroupManager
22022-10-26 11:54:12 | INFO | stdout | Load model alpa/bloom-7b1 ... (This can take several minutes for very large models)
32022-10-26 11:54:12 | INFO | stdout |  - Compile executables for encoder_chunk_sizes=[1, 64].
42022-10-26 11:54:19,113    WARNING worker.py:1805 -- Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
52022-10-26 11:54:47 | INFO | stdout | elapsed: 34.99 second.
62022-10-26 11:54:47 | INFO | stdout |  - Load parameters.
72022-10-26 11:55:02,213    ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::MeshHostWorker.load_bloom_params_worker_func()[39m (pid=501, ip=10.1.181.198, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f76f5f29b20>)
8  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 821, in load_bloom_params_worker_func
9    load_param(param_prefix + "self_attention.query_key_value.bias",
10  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 797, in load_param
11    self.put_buffers(uuid, datas)
12  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 178, in put_buffers
13    arys[batch_id][device_id] = (self.backend.buffer_from_pyval(
14jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 24576 bytes.
152022-10-26 11:55:12,900   ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::MeshHostWorker.put_buffers()[39m (pid=501, ip=10.1.181.198, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f76f5f29b20>)
16  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 178, in put_buffers
17    arys[batch_id][device_id] = (self.backend.buffer_from_pyval(
18jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 33554432 bytes.
192022-10-26 11:55:22,936   ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::MeshHostWorker.put_buffers()[39m (pid=501, ip=10.1.181.198, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f76f5f29b20>)
20  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 178, in put_buffers
21    arys[batch_id][device_id] = (self.backend.buffer_from_pyval(
22jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 33554432 bytes.
23

Please describe the expected behavior

System information and environment

To Reproduce Steps to reproduce the behavior:

  1. Follow this link to install kuberay operator on k8s cluster
  2. Build Docker image to capture runtime environment. I put the details of the Dockerfile and build context below.
  3. Prepare the YAML file for RayJob by substituting the <My Docker Image> to the image tag
  4. May also need to set up the imagePullSecrets is necessary
  5. Create RayJob through kubectl apply -f <RayJob yaml>. I put the content of the YAML below.
  6. (Optional) Do port-forward component-alpa-service-raycluster-xxx-head-svc port 8265 to monitor the RayJob status through Ray dashboard, where xxx is some auto-generated ID.
  7. Do port-forward component-alpa-service-raycluster-xxx-head-svc on port 8899
  8. Send some query to the endpoint by curl -d '{"prompt":"Hello world, ","max_tokens":"128","temperature":"0.7","top_p":"0.5","model":"default"}' localhost:8899/completions

Screenshots If applicable, add screenshots to help explain your problem.

Code snippet to reproduce the problem

Additional information

if name == "main": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-125m") parser.add_argument("--path", type=str, default="~/opt_weights/") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=str, default=8899) parser.add_argument("--torch-device", type=str, default="cpu") parser.add_argument("--tokenizer", type=str) parser.add_argument("--no-recaptcha", action="store_true") parser.add_argument("--register-name", type=str, default="default") parser.add_argument("--ssl-keyfile", type=str) parser.add_argument("--ssl-certfile", type=str) args = parser.parse_args()

ray.init()

try:
    controller = ray.get_actor(CONTROLLER_NAME)
except ValueError:
    controller = run_controller(args.host, args.port, "/",
                                args.ssl_keyfile, args.ssl_certfile)

group_id = 0
controller.launch_mesh_group_manager.remote(group_id)
t = controller.register_model.remote(
    args.register_name, LangaugeModelWorker,
    (args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ,
     False if args.no_recaptcha else USE_RECAPTCHA),
    override=True)
ray.get(t)
t = controller.create_replica.remote(args.register_name, group_id)
ray.get(t)

while True:
    pass
merrymercy commented 2 years ago

What's the output of ray status? Could you try to use a smaller batch size? https://github.com/alpa-projects/alpa/blob/a38bfde29e2c1ece5faf5bc59cc4189dde852091/examples/llm_serving/generator.py#L25-L26

zhanyuanucb commented 2 years ago

@merrymercy The RayJob status is RUNNING Sure, I can try tunning the batch size. I also encountered OOM during parameter loading. Are there some related parameters I can tune for that?

merrymercy commented 2 years ago

I mean using ray status in the command line to make sure all nodes and GPUs are connected.

zhanyuanucb commented 2 years ago

@merrymercy I set max_batch_size=1, but still saw the error. I was serving bloom-7b1, which requires ~14GB GRAM. And I was using 2 NVIDIA GeForce RTX 2080 Ti, 11GB GRAM for each. Here is the ray status output right after parameter loading is finished and its Ray dashboard screenshot:

(base) ray@component-alpa-service-raycluster-b5rz6-head-tklrq:~$ ray status
======== Autoscaler status: 2022-10-30 20:00:49.839511 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_8ce06bf9278b5972f8977b333d4c0b3406e6f7b4af8d1c1774804d14
 1 node_83894c262e9ed5351adb200250534feae1304a6348c35c152fcdd342
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/8.0 CPU (0.0 used of 2.0 reserved in placement groups)
 2.0/2.0 GPU (2.0 used of 2.0 reserved in placement groups)
 0.0/2.0 accelerator_type:G
 0.00/32.000 GiB memory
 0.00/4.865 GiB object_store_memory

Demands:
 (no resource demands)

Screen Shot 2022-10-30 at 8 03 45 PM

After OOM error

(base) ray@component-alpa-service-raycluster-b5rz6-head-tklrq:~$ ray status
======== Autoscaler status: 2022-10-30 20:05:10.316602 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_8ce06bf9278b5972f8977b333d4c0b3406e6f7b4af8d1c1774804d14
 1 node_83894c262e9ed5351adb200250534feae1304a6348c35c152fcdd342
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 1.0/8.0 CPU (0.0 used of 2.0 reserved in placement groups)
 1.0/2.0 GPU (1.0 used of 2.0 reserved in placement groups)
 0.0/2.0 accelerator_type:G
 0.00/32.000 GiB memory
 0.00/4.865 GiB object_store_memory

Demands:
 (no resource demands)

Screen Shot 2022-10-30 at 8 04 59 PM

Here is the error log

12022-10-30 19:59:20,794    INFO worker.py:1223 -- Using address 10.1.15.177:6379 set in the environment variable RAY_ADDRESS
22022-10-30 19:59:20,794    INFO worker.py:1333 -- Connecting to existing Ray cluster at address: 10.1.15.177:6379...
32022-10-30 19:59:20,802    INFO worker.py:1509 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.1.15.177:8265 [39m[22m
4[2m[36m(Controller pid=244)[0m INFO:     Started server process [244]
5[2m[36m(Controller pid=244)[0m INFO:uvicorn.error:Started server process [244]
6[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 19:59:26 | INFO | stdout | Load model alpa/bloom-7b1 ... (This can take several minutes for very large models)
7[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 19:59:26 | INFO | stdout |  - Compile executables for encoder_chunk_sizes=[1, 64].
8[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 19:59:33,168 WARNING worker.py:2249 -- Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
9[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 19:59:51 | INFO | stdout | elapsed: 25.41 second.
10[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 19:59:51 | INFO | stdout |  - Load parameters.
11[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:15 | INFO | stdout | elapsed: 24.24 second.
12[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:16 | ERROR | stderr | 
Downloading:   0%|          | 0.00/222 [00:00<?, ?B/s]
13[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:16 | ERROR | stderr | 
Downloading: 100%|██████████| 222/222 [00:00<00:00, 83.4kB/s]
14[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:16 | ERROR | stderr |
15[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:16 | ERROR | stderr | 
Downloading:   0%|          | 0.00/14.5M [00:00<?, ?B/s]
16[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:16 | ERROR | stderr | 
Downloading:  15%|█▍        | 2.15M/14.5M [00:00<00:00, 21.5MB/s]
17[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading:  32%|███▏      | 4.61M/14.5M [00:00<00:00, 23.3MB/s]
18[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading:  52%|█████▏    | 7.61M/14.5M [00:00<00:00, 26.4MB/s]
19[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading:  71%|███████   | 10.3M/14.5M [00:00<00:00, 15.9MB/s]
20[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading:  84%|████████▍ | 12.2M/14.5M [00:00<00:00, 16.2MB/s]
21[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading:  97%|█████████▋| 14.1M/14.5M [00:00<00:00, 15.5MB/s]
22[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr | 
Downloading: 100%|██████████| 14.5M/14.5M [00:00<00:00, 17.1MB/s]
23[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:17 | ERROR | stderr |
24[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:18 | ERROR | stderr | 
Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]
25[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:18 | ERROR | stderr | 
Downloading: 100%|██████████| 85.0/85.0 [00:00<00:00, 31.5kB/s]
26[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:18 | ERROR | stderr |
27[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:00:19 | INFO | alpa.llm_serving | Loading model time: 49.65
28[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:04:33 | INFO | alpa.llm_serving | Received new generate request: prompt length [7], max_len: 128, temperature: 0.7, top_p: 0.5, api_key: None, ip: 127.0.0.1, tstamp: 1667185473.083195
29[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:04:34 | INFO | alpa.llm_serving | Generate begin. batch id: 0, batch size: 1
30[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:04:34 | INFO | alpa.llm_serving | Call generate. batch id: 0, padded bs: 1, original bs: 1, generator_args: {'min_length': 7, 'max_length': 135, 'temperature': 0.7, 'do_sample': True, 'top_p': 0.5, 'num_beams': 1, 'num_return_sequences': 1, 'early_stopping': True, 'repetition_penalty': 1.0, 'no_repeat_ngram_size': 8}.
31[2m[36m(DeviceMeshGroupManager pid=277)[0m 2022-10-30 20:04:44 | ERROR | asyncio | Task exception was never retrieved
32[2m[36m(DeviceMeshGroupManager pid=277)[0m future: <Task finished name='Task-5' coro=<LangaugeModelWorker.batch_loop() done, defined at /home/ray/src/llm-serving/examples/llm_serving/launch_model_worker.py:82> exception=RayActorError()>
33[2m[36m(DeviceMeshGroupManager pid=277)[0m Traceback (most recent call last):
34[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/src/llm-serving/examples/llm_serving/launch_model_worker.py", line 128, in batch_loop
35[2m[36m(DeviceMeshGroupManager pid=277)[0m     results = self.generator.generate(**args)
36[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 175, in generate
37[2m[36m(DeviceMeshGroupManager pid=277)[0m     output_ids = self.model_wrapper.generate(input_ids=input_ids, **generator_args)
38[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
39[2m[36m(DeviceMeshGroupManager pid=277)[0m     return func(*args, **kwargs)
40[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/transformers/generation_utils.py", line 1422, in generate
41[2m[36m(DeviceMeshGroupManager pid=277)[0m     return self.sample(
42[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/transformers/generation_utils.py", line 2035, in sample
43[2m[36m(DeviceMeshGroupManager pid=277)[0m     outputs = self(
44[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 110, in __call__
45[2m[36m(DeviceMeshGroupManager pid=277)[0m     ret = self.inference_func(input_ids,
46[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 585, in inference_func
47[2m[36m(DeviceMeshGroupManager pid=277)[0m     logits_step = torch.from_numpy(np.array(output.logits)).to(torch_device).float()
48[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 1610, in __array__
49[2m[36m(DeviceMeshGroupManager pid=277)[0m     return np.asarray(self._value, dtype=dtype)
50[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 1596, in _value
51[2m[36m(DeviceMeshGroupManager pid=277)[0m     fetched_np_buffers = self.device_mesh.get_remote_buffers(
52[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 1170, in get_remote_buffers
53[2m[36m(DeviceMeshGroupManager pid=277)[0m     ret = [ray.get(refs) for refs in obj_refs]
54[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/device_mesh.py", line 1170, in <listcomp>
55[2m[36m(DeviceMeshGroupManager pid=277)[0m     ret = [ray.get(refs) for refs in obj_refs]
56[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
57[2m[36m(DeviceMeshGroupManager pid=277)[0m     return func(*args, **kwargs)
58[2m[36m(DeviceMeshGroupManager pid=277)[0m   File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/worker.py", line 2277, in get
59[2m[36m(DeviceMeshGroupManager pid=277)[0m     raise value
60[2m[36m(DeviceMeshGroupManager pid=277)[0m ray.exceptions.RayActorError: The actor died unexpectedly before finishing this task.
61[2m[36m(DeviceMeshGroupManager pid=277)[0m    class_name: MeshHostWorker
62[2m[36m(DeviceMeshGroupManager pid=277)[0m    actor_id: cc3b0057d518879736db5b9202000000
63[2m[36m(DeviceMeshGroupManager pid=277)[0m    pid: 489
64[2m[36m(DeviceMeshGroupManager pid=277)[0m    namespace: 94d8d73e-37c7-485b-a4fe-073ac09749f9
65[2m[36m(DeviceMeshGroupManager pid=277)[0m    ip: 10.1.15.177
66[2m[36m(DeviceMeshGroupManager pid=277)[0m The actor is dead because its worker process has died. Worker exit type: INTENDED_USER_EXIT Worker exit detail: Worker exits by an user request. exit_actor() is called.
67[2m[36m(MeshHostWorker pid=489)[0m 2022-10-30 20:04:44.195439: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.83GiB (rounded to 4111470848)requested by op
68[2m[36m(MeshHostWorker pid=489)[0m 2022-10-30 20:04:44.196033: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] *************************************************************************************_______________
69[2m[36m(MeshHostWorker pid=489)[0m 2022-10-30 20:04:44.196949: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2134] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4111470712 bytes.
70[2m[36m(MeshHostWorker pid=489)[0m BufferAssignment OOM Debugging.
71[2m[36m(MeshHostWorker pid=489)[0m BufferAssignment stats:
72[2m[36m(MeshHostWorker pid=489)[0m              parameter allocation:    7.78GiB
73[2m[36m(MeshHostWorker pid=489)[0m               constant allocation:         0B
74[2m[36m(MeshHostWorker pid=489)[0m         maybe_live_out allocation:  301.25MiB
75[2m[36m(MeshHostWorker pid=489)[0m      preallocated temp allocation:    3.83GiB
76[2m[36m(MeshHostWorker pid=489)[0m   preallocated temp fragmentation:         0B (0.00%)
77[2m[36m(MeshHostWorker pid=489)[0m                  total allocation:   11.90GiB
78[2m[36m(MeshHostWorker pid=489)[0m               total fragmentation:     5.5KiB (0.00%)
79[2m[36m(MeshHostWorker pid=489)[0m Peak buffers:
80[2m[36m(MeshHostWorker pid=489)[0m    Buffer 1:
81[2m[36m(MeshHostWorker pid=489)[0m        Size: 3.83GiB
82[2m[36m(MeshHostWorker pid=489)[0m        Operator: op_name="parallelize(inference_step_with_cache_pipeshard_parallel_mesh_1)/jit(main)/FlaxBloomForCausalLMModule/lm_head/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/dtypes.py" source_line=97
83[2m[36m(MeshHostWorker pid=489)[0m        XLA Label: convert
84[2m[36m(MeshHostWorker pid=489)[0m        Shape: f32[250880,4096]
85[2m[36m(MeshHostWorker pid=489)[0m        ==========================
86[2m[36m(MeshHostWorker pid=489)[0m
87[2m[36m(MeshHostWorker pid=489)[0m    Buffer 2:
88[2m[36m(MeshHostWorker pid=489)[0m        Size: 1.91GiB
89[2m[36m(MeshHostWorker pid=489)[0m        Operator: op_name="layer_1$start"
90[2m[36m(MeshHostWorker pid=489)[0m        Entry Parameter Subshape: f16[250880,4096]
91[2m[36m(MeshHostWorker pid=489)[0m        ==========================
92[2m[36m(MeshHostWorker pid=489)[0m
93[2m[36m(MeshHostWorker pid=489)[0m    Buffer 3:
94[2m[36m(MeshHostWorker pid=489)[0m        Size: 128.00MiB
95[2m[36m(MeshHostWorker pid=489)[0m        Operator: op_name="layer_1$start"
96[2m[36m(MeshHostWorker pid=489)[0m        Entry Parameter Subshape: f16[16384,4096]
97[2m[36m(MeshHostWorker pid=489)[0m        ==========================
98[2m[36m(MeshHostWorker pid=489)[0m
99[2m[36m(MeshHostWorker pid=489)[0m    Buffer 4:
100[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
101[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
102[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
103[2m[36m(MeshHostWorker pid=489)[0m       ==========================
104[2m[36m(MeshHostWorker pid=489)[0m
105[2m[36m(MeshHostWorker pid=489)[0m   Buffer 5:
106[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
107[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
108[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
109[2m[36m(MeshHostWorker pid=489)[0m       ==========================
110[2m[36m(MeshHostWorker pid=489)[0m
111[2m[36m(MeshHostWorker pid=489)[0m   Buffer 6:
112[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
113[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
114[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
115[2m[36m(MeshHostWorker pid=489)[0m       ==========================
116[2m[36m(MeshHostWorker pid=489)[0m
117[2m[36m(MeshHostWorker pid=489)[0m   Buffer 7:
118[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
119[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
120[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
121[2m[36m(MeshHostWorker pid=489)[0m       ==========================
122[2m[36m(MeshHostWorker pid=489)[0m
123[2m[36m(MeshHostWorker pid=489)[0m   Buffer 8:
124[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
125[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
126[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
127[2m[36m(MeshHostWorker pid=489)[0m       ==========================
128[2m[36m(MeshHostWorker pid=489)[0m
129[2m[36m(MeshHostWorker pid=489)[0m   Buffer 9:
130[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
131[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
132[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
133[2m[36m(MeshHostWorker pid=489)[0m       ==========================
134[2m[36m(MeshHostWorker pid=489)[0m
135[2m[36m(MeshHostWorker pid=489)[0m   Buffer 10:
136[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
137[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
138[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
139[2m[36m(MeshHostWorker pid=489)[0m       ==========================
140[2m[36m(MeshHostWorker pid=489)[0m
141[2m[36m(MeshHostWorker pid=489)[0m   Buffer 11:
142[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
143[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
144[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
145[2m[36m(MeshHostWorker pid=489)[0m       ==========================
146[2m[36m(MeshHostWorker pid=489)[0m
147[2m[36m(MeshHostWorker pid=489)[0m   Buffer 12:
148[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
149[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
150[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
151[2m[36m(MeshHostWorker pid=489)[0m       ==========================
152[2m[36m(MeshHostWorker pid=489)[0m
153[2m[36m(MeshHostWorker pid=489)[0m   Buffer 13:
154[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
155[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
156[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
157[2m[36m(MeshHostWorker pid=489)[0m       ==========================
158[2m[36m(MeshHostWorker pid=489)[0m
159[2m[36m(MeshHostWorker pid=489)[0m   Buffer 14:
160[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
161[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
162[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[4096,16384]
163[2m[36m(MeshHostWorker pid=489)[0m       ==========================
164[2m[36m(MeshHostWorker pid=489)[0m
165[2m[36m(MeshHostWorker pid=489)[0m   Buffer 15:
166[2m[36m(MeshHostWorker pid=489)[0m       Size: 128.00MiB
167[2m[36m(MeshHostWorker pid=489)[0m       Operator: op_name="layer_1$start"
168[2m[36m(MeshHostWorker pid=489)[0m       Entry Parameter Subshape: f16[16384,4096]
169[2m[36m(MeshHostWorker pid=489)[0m       ==========================
170[2m[36m(MeshHostWorker pid=489)[0m
171[2m[36m(MeshHostWorker pid=489)[0m
172[2m[36m(MeshHostWorker pid=489)[0m E1030 20:04:44.198079222     540 chttp2_transport.cc:1103]   Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings"
173
merrymercy commented 2 years ago

The error message tells there is no bug. 2 x 2080Ti (11GB) is just not enough for running bloom-7.1b with Alpa. Probably you need more GPUs. Another thing you can try is to replace this line with dtype=jnp.float16 https://github.com/alpa-projects/alpa/blob/93dbb447e624479f12bbe97590d4748cef53145e/examples/llm_serving/model/bloom_model.py#L499 This can reduce the memory usage of the largest peak buffer (Buffer 1, Size: 3.83GiB) in your error message. Could you try this and report back the error message?

zhanyuanucb commented 2 years ago

@merrymercy I increased the number of GPUs to 4 x 2080Ti (11GB) and it works! Changing the data type to jnp.float16 also helped.

Out of curiosity, I tried 6 GPUs and 8 GPUs, then I got different errors.

For 6 GPUs:

12022-10-31 15:45:36,040    INFO worker.py:1223 -- Using address 10.1.181.255:6379 set in the environment variable RAY_ADDRESS
22022-10-31 15:45:36,040    INFO worker.py:1333 -- Connecting to existing Ray cluster at address: 10.1.181.255:6379...
32022-10-31 15:45:36,047    INFO worker.py:1509 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.1.181.255:8265 [39m[22m
4[2m[36m(Controller pid=244)[0m INFO:     Started server process [244]
5[2m[36m(Controller pid=244)[0m INFO:uvicorn.error:Started server process [244]
6[2m[36m(DeviceMeshGroupManager pid=276)[0m 2022-10-31 15:45:40 | INFO | stdout | Load model alpa/bloom-7b1 ... (This can take several minutes for very large models)
7[2m[36m(DeviceMeshGroupManager pid=276)[0m 2022-10-31 15:45:40 | INFO | stdout |  - Compile executables for encoder_chunk_sizes=[1, 64].
8[2m[36m(DeviceMeshGroupManager pid=276)[0m 2022-10-31 15:45:47,265 WARNING worker.py:2249 -- Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
9Traceback (most recent call last):
10  File "start.py", line 49, in <module>
11    ray.get(t)
12  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
13    return func(*args, **kwargs)
14  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/worker.py", line 2275, in get
15    raise value.as_instanceof_cause()
16ray.exceptions.RayTaskError(NotImplementedError): [36mray::Controller.create_replica()[39m (pid=244, ip=10.1.181.255, repr=<alpa.serve.controller.Controller object at 0x7fce5638b8e0>)
17  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 432, in result
18    return self.__get_result()
19  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 388, in __get_result
20    raise self._exception
21  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/serve/controller.py", line 164, in create_replica
22    await manager.create_replica.remote(name, create_info)
23ray.exceptions.RayTaskError(NotImplementedError): [36mray::DeviceMeshGroupManager.create_replica()[39m (pid=276, ip=10.1.181.255, repr=<alpa.serve.controller.DeviceMeshGroupManager object at 0x7f7140914280>)
24  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 432, in result
25    return self.__get_result()
26  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 388, in __get_result
27    raise self._exception
28  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/serve/controller.py", line 78, in create_replica
29    self.replicas[name] = model_def(*args, **kwargs)
30  File "/home/ray/src/llm-serving/examples/llm_serving/launch_model_worker.py", line 65, in __init__
31    self.generator = Generator(model_name,
32  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 55, in __init__
33    self.load_model()
34  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 62, in load_model
35    self.model_wrapper = get_model(self.model_name, self.path,
36  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 643, in get_model
37    return get_alpa_model(
38  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 464, in get_alpa_model
39    executables, params_aval = m.get_pipeshard_executable(
40  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 728, in get_pipeshard_executable
41    executable = alpa.parallelize(
42  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/api.py", line 127, in get_executable
43    executable, _, _, _ = self._decode_args_and_get_executable(*args)
44  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
45    executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
46  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 295, in memoized_fun
47    ans = call(fun, *args)
48  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/api.py", line 218, in _compile_parallel_executable
49    return method.compile_executable(fun, in_tree, out_tree_thunk,
50  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/parallel_method.py", line 233, in compile_executable
51    return compile_pipeshard_executable(
52  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/pipeline_parallel/compile_executable.py", line 92, in compile_pipeshard_executable
53    pipeshard_config = compile_pipeshard_executable_internal(
54  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/pipeline_parallel/compile_executable.py", line 253, in compile_pipeshard_executable_internal
55    pipeshard_config = PipelineInstEmitter(
56  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/pipeline_parallel/runtime_emitter.py", line 421, in compile
57    self._compile_exec_one_tick(sched, donation_mapping,
58  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/pipeline_parallel/runtime_emitter.py", line 535, in _compile_exec_one_tick
59    raise NotImplementedError(
60NotImplementedError: Not support resharding replicated
61

And with 8 GPUs, I got

12022-10-31 15:32:01,993    INFO worker.py:1223 -- Using address 10.1.181.253:6379 set in the environment variable RAY_ADDRESS
22022-10-31 15:32:01,993    INFO worker.py:1333 -- Connecting to existing Ray cluster at address: 10.1.181.253:6379...
32022-10-31 15:32:01,998    INFO worker.py:1509 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.1.181.253:8265 [39m[22m
4[2m[36m(Controller pid=83, ip=10.1.15.137)[0m INFO:     Started server process [83]
5[2m[36m(Controller pid=83, ip=10.1.15.137)[0m INFO:uvicorn.error:Started server process [83]
6[2m[36m(DeviceMeshGroupManager pid=112, ip=10.1.15.137)[0m 2022-10-31 15:32:06 | INFO | stdout | Load model alpa/bloom-7b1 ... (This can take several minutes for very large models)
7[2m[36m(DeviceMeshGroupManager pid=112, ip=10.1.15.137)[0m 2022-10-31 15:32:06 | INFO | stdout |  - Compile executables for encoder_chunk_sizes=[1, 64].
8Traceback (most recent call last):
9  File "start.py", line 49, in <module>
10    ray.get(t)
11  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
12    return func(*args, **kwargs)
13  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/_private/worker.py", line 2275, in get
14    raise value.as_instanceof_cause()
15ray.exceptions.RayTaskError(AssertionError): [36mray::Controller.create_replica()[39m (pid=83, ip=10.1.15.137, repr=<alpa.serve.controller.Controller object at 0x7fa760842af0>)
16  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 432, in result
17    return self.__get_result()
18  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 388, in __get_result
19    raise self._exception
20  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/serve/controller.py", line 164, in create_replica
21    await manager.create_replica.remote(name, create_info)
22ray.exceptions.RayTaskError(AssertionError): [36mray::DeviceMeshGroupManager.create_replica()[39m (pid=112, ip=10.1.15.137, repr=<alpa.serve.controller.DeviceMeshGroupManager object at 0x7f79cab0f430>)
23  File "/home/ray/anaconda3/lib/python3.8/threading.py", line 932, in _bootstrap_inner
24    self.run()
25  File "/home/ray/anaconda3/lib/python3.8/threading.py", line 870, in run
26    self._target(*self._args, **self._kwargs)
27  File "/home/ray/anaconda3/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
28    self._run_once()
29  File "/home/ray/anaconda3/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
30    handle._run()
31  File "/home/ray/anaconda3/lib/python3.8/asyncio/events.py", line 81, in _run
32    self._context.run(self._callback, *self._args)
33  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/serve/controller.py", line 78, in create_replica
34    self.replicas[name] = model_def(*args, **kwargs)
35  File "/home/ray/src/llm-serving/examples/llm_serving/launch_model_worker.py", line 65, in __init__
36    self.generator = Generator(model_name,
37  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 55, in __init__
38    self.load_model()
39  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 62, in load_model
40    self.model_wrapper = get_model(self.model_name, self.path,
41  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 643, in get_model
42    return get_alpa_model(
43  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 464, in get_alpa_model
44    executables, params_aval = m.get_pipeshard_executable(
45  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 695, in get_pipeshard_executable
46    model, params = init_model_aval(config)
47  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 584, in init_model_aval
48    params = jax.eval_shape(model.init, rngkey, input_ids, attention_mask=attention_mask)
49  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
50    return fun(*args, **kwargs)
51  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 1227, in init
52    _, v_out = self.init_with_output(
53  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
54    return fun(*args, **kwargs)
55  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 1194, in init_with_output
56    return self.apply(
57  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
58    return fun(*args, **kwargs)
59  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 1159, in apply
60    return apply(
61  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
62    y = fn(root, *args, **kwargs)
63  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 1535, in scope_fn
64    return fn(module.clone(parent=scope), *args, **kwargs)
65  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/transforms.py", line 1235, in wrapped_fn
66    return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
67  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
68    return func(*args, **kwds)
69  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
70    return self._call_wrapped_method(fun, args, kwargs)
71  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
72    y = fun(self, *args, **kwargs)
73  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 513, in __call__
74    outputs = self.transformer(
75  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/transforms.py", line 1235, in wrapped_fn
76    return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
77  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
78    return func(*args, **kwds)
79  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
80    return self._call_wrapped_method(fun, args, kwargs)
81  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
82    y = fun(self, *args, **kwargs)
83  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 463, in __call__
84    outputs = self.h(
85  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/transforms.py", line 1235, in wrapped_fn
86    return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
87  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
88    return func(*args, **kwds)
89  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
90    return self._call_wrapped_method(fun, args, kwargs)
91  File "/home/ray/anaconda3/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
92    y = fun(self, *args, **kwargs)
93  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 378, in __call__
94    assert self.config.num_hidden_layers % self.config.num_pp_stages == 0
95jax._src.traceback_util.UnfilteredStackTrace: AssertionError
96
97The stack trace below excludes JAX-internal frames.
98The preceding is the original exception that occurred, unmodified.
99
100--------------------
101
102The above exception was the direct cause of the following exception:
103
104[36mray::DeviceMeshGroupManager.create_replica()[39m (pid=112, ip=10.1.15.137, repr=<alpa.serve.controller.DeviceMeshGroupManager object at 0x7f79cab0f430>)
105  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 432, in result
106    return self.__get_result()
107  File "/home/ray/anaconda3/lib/python3.8/concurrent/futures/_base.py", line 388, in __get_result
108    raise self._exception
109  File "/home/ray/anaconda3/lib/python3.8/site-packages/alpa/serve/controller.py", line 78, in create_replica
110    self.replicas[name] = model_def(*args, **kwargs)
111  File "/home/ray/src/llm-serving/examples/llm_serving/launch_model_worker.py", line 65, in __init__
112    self.generator = Generator(model_name,
113  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 55, in __init__
114    self.load_model()
115  File "/home/ray/src/llm-serving/examples/llm_serving/generator.py", line 62, in load_model
116    self.model_wrapper = get_model(self.model_name, self.path,
117  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 643, in get_model
118    return get_alpa_model(
119  File "/home/ray/src/llm-serving/examples/llm_serving/model/wrapper.py", line 464, in get_alpa_model
120    executables, params_aval = m.get_pipeshard_executable(
121  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 695, in get_pipeshard_executable
122    model, params = init_model_aval(config)
123  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 584, in init_model_aval
124    params = jax.eval_shape(model.init, rngkey, input_ids, attention_mask=attention_mask)
125  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/_src/api.py", line 3105, in eval_shape
126    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
127  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 693, in abstract_eval_fun
128    _, avals_out, _ = trace_to_jaxpr_dynamic(
129  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
130    return func(*args, **kwargs)
131  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2074, in trace_to_jaxpr_dynamic
132    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
133  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2089, in trace_to_subjaxpr_dynamic
134    ans = fun.call_wrapped(*in_tracers_)
135  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
136    ans = self.f(*args, **dict(self.params, **kwargs))
137  File "/home/ray/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
138    ans = self.f(*args, **dict(self.params, **kwargs))
139  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
140    return func(*args, **kwds)
141  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 513, in __call__
142    outputs = self.transformer(
143  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
144    return func(*args, **kwds)
145  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 463, in __call__
146    outputs = self.h(
147  File "/home/ray/anaconda3/lib/python3.8/contextlib.py", line 75, in inner
148    return func(*args, **kwds)
149  File "/home/ray/src/llm-serving/examples/llm_serving/model/bloom_model.py", line 378, in __call__
150    assert self.config.num_hidden_layers % self.config.num_pp_stages == 0
151AssertionError
152
merrymercy commented 2 years ago

Could you tell me the organization of the GPUs and the output of ray status? How many nodes? How many GPUs per node?

zhisbug commented 2 years ago

is it caused by assert self.config.num_hidden_layers % self.config.num_pp_stages == 0? it seems your #layers cannot be divided by #stages.

zhanyuanucb commented 2 years ago

@zhisbug Thanks, I now understand the assertion error.

@merrymercy As for the NotImplementedError error, the cluster has 1 header and 2 workers. And 2 GPU per node. And here is the output of ray status

(base) ray@component-alpa-service-raycluster-5w6td-head-54d6b:~$ ray status
======== Autoscaler status: 2022-10-31 21:06:48.450777 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_5169ad129c6f1e6288afa04cb5e6029e792c9d3fa5761d6c7cc6b308
 1 node_aae9ddf2a31875e0962f6acc733ef4bfadeff1fdb2e634869b7f7442
 1 node_e9c371cf9e6e3598ccc271f38e3ff1ee6342680673e110a4d2c2ad75
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/12.0 CPU
 0.0/6.0 GPU
 0.0/3.0 accelerator_type:G
 0.00/48.000 GiB memory
 0.00/9.638 GiB object_store_memory

Demands:
 (no resource demands)
zhisbug commented 2 years ago

I believe if you fix the stage indivisible error, all other errors will then disappear.

zhanyuanucb commented 2 years ago

@zhisbug Is the NotImplementedError I encountered with 6 GPUs also related to the indivisible issue? I checked that bloom-7b1 has 30 hidden layers https://github.com/alpa-projects/alpa/blob/98df634fdf97c82f016195f74a4d4965420a7d17/examples/llm_serving/model/bloom_model.py#L557-L561 And the number of stages is inferred by: https://github.com/alpa-projects/alpa/blob/98df634fdf97c82f016195f74a4d4965420a7d17/examples/llm_serving/model/wrapper.py#L446-L449

So if I have 3 nodes (2 GPUs on each node), the number of stages is 3, which should be divisible by the number of hidden layers

ddxxdd-code commented 2 years ago

Hi, I migrated Bloom model to serve on Alpa. From the backtraces from the failed assertion in 6GPU case, I feel it might be related to the model's configurations(layers, attention heads, shape of weights). Does it happen for other bloom models? Say, bloom-1b7(which has 24 layers) or bloom-3b(which has 30 layers)? I'll try to reproduce this error and look into this case this or next week.

zhanyuanucb commented 2 years ago

@ddxxdd-code So far I've only seen this NotImplementedError on bloom-7b1

zhisbug commented 1 year ago

@ddxxdd-code any insight on this?

ddxxdd-code commented 1 year ago

Currently not much. I'm still working on figuring out installation and running on the cluster.

zhanyuanucb commented 1 year ago

@ddxxdd-code I included the YAML for k8s cluster in the Additional Information section. Let me know if you need any help.

merrymercy commented 1 year ago

closed due to inactivity

pascalwhoop commented 1 year ago

I have the same error on a cluster with 4x4xA10G + a 8xV100 head node

ddxxdd-code commented 1 year ago

@pascalwhoop May I confirm the error you found is OOM or NotImplementedError?

pascalwhoop commented 1 year ago

@ddxxdd-code actually it's the AssertionError

assert self.config.num_hidden_layers % self.config.num_pp_stages == 0

and

jax._src.traceback_util. UnfilteredStackTrace: AssertionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

Note I was able to run a bloom 7b and an opt 6.7b but no opt 30b or full bloom.

I believe the conversion to the alpa data type completely successfully though, RAM peaked at around 300GB and the progress bar completed

ddxxdd-code commented 1 year ago

@pascalwhoop Thank you for sharing the errors! For full Bloom, there are 70 hidden layers. So in your case, it's possible that your setup yields a num_pp_stages that doesn't divide 70 and hence fail the assertion. For bloom 7b1, there are 30 hidden layers, so it's possible that num_pp_stages in your setup divides 30 but doesn't divide 70. Similarly, the fail for opt-30b is likely result from the indivisibility (based on the fact that num_pp_stages divides 30 but not 48 nor 70, I guess num_pp_stages is 15?)Please print the value of num_pp_stages to check if the indivisibility is indeed the case.

pascalwhoop commented 1 year ago

@ddxxdd-code I turned off the cluster for the weekend to be mindful of resources so I'll check once I boot up monday but can you help me understand (docs ref is fine) what the num_pp_stages is and why it is dependent on the infra?

For your understanding, I'm just getting into this world and coming from a spark world, I am used to treat cluster resources like cattle, I understood from some presentations that data parallelsim is "easy" relatively speaking but would still like to understand what I am missing

ddxxdd-code commented 1 year ago

Hi @pascalwhoop, no worries!

num_pp_stages is the number of pipeline stages (pp stands for pipeline).

On execution, all hidden layers are divided by stages in the pipeline (like layers 1-n to stage 1, layers n-m to stage 2, etc) and parallel the model (inference/training) using the pipeline. To make the pipeline effectively, we need to divide resources given to Alpa to match each stage in the pipeline. All resources are managed by Ray (basically all the GPUs) and Ray runs on a node manner (Ray consists of worker nodes and head node where each node has certain GPU).

In language model serving (OPT/Bloom), the number of pipeline stages is determined by the code below: https://github.com/alpa-projects/alpa/blob/98df634fdf97c82f016195f74a4d4965420a7d17/examples/llm_serving/model/wrapper.py#L446-L449

I'm not an expert on Ray that I'm not familiar with how Ray computes the value (devices in a mesh, devices per node, etc). For the part in Alpa, please refer to this file in Alpa for the definition of get_global_cluster() and how the attributes are calculated (this file is long, but num_hosts and num_devices are what you might focus on and please notice the assumption made there like all nodes are identical). For the concepts of nodes, worker, please check these using Ray documentation.

If you would like to see some examples on how to use Alpa to parallel, please check this section for parallel training in Alpa documentation.

Please feel free to ask me regarding any issue with serving Bloom model on Alpa while for other issue with Alpa, I think Hao (zhisbug), Lianmin (merrymercy), and zhuohan (zhuohan123) might have more insights as creators of the project.

Happy weekend~

pascalwhoop commented 1 year ago

hey @ddxxdd-code I can share more on trying to make the bloom(z) model run.

I have:

And I get the same NotImplementedError

2023-03-08 00:32:52,815 INFO worker.py:1230 -- Using address 127.0.0.1:6379 set in the environment variable RAY_ADDRESS
2023-03-08 00:32:52,816 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 172.29.4.126:6379...
2023-03-08 00:32:52,827 INFO worker.py:1519 -- Connected to Ray cluster. View the dashboard at 172.29.4.126:8265
Load model alpa/bloomz ... (This can take several minutes for very large models)
 - Compile executables for encoder_chunk_sizes=[1, 64]. Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ray/anaconda3/lib/python3.9/site-packages/llm_serving/model/wrapper.py", line 544, in get_model
    return get_alpa_model(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/llm_serving/model/wrapper.py", line 359, in get_alpa_model
    executables, params_aval = m.get_pipeshard_executable(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/llm_serving/model/bloom_model.py", line 728, in get_pipeshard_executable
    executable = alpa.parallelize(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/api.py", line 127, in get_executable
    executable, _, _, _ = self._decode_args_and_get_executable(*args)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
    executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
  File "/home/ray/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 309, in memoized_fun
    ans = call(fun, *args)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/api.py", line 218, in _compile_parallel_executable
    return method.compile_executable(fun, in_tree, out_tree_thunk,
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/parallel_method.py", line 233, in compile_executable
    return compile_pipeshard_executable(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/pipeline_parallel/compile_executable.py", line 93, in compile_pipeshard_executable
    pipeshard_config = compile_pipeshard_executable_internal(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/pipeline_parallel/compile_executable.py", line 254, in compile_pipeshard_executable_internal
    pipeshard_config = PipelineInstEmitter(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/pipeline_parallel/runtime_emitter.py", line 421, in compile
    self._compile_exec_one_tick(sched, donation_mapping,
  File "/home/ray/anaconda3/lib/python3.9/site-packages/alpa/pipeline_parallel/runtime_emitter.py", line 535, in _compile_exec_one_tick
    raise NotImplementedError(
NotImplementedError: Not support resharding replicated
(base) ray@test-cluster-head-rk8pj:~$ ray status
======== Autoscaler status: 2023-03-08 00:38:44.316272 ========
Node status
---------------------------------------------------------------
Healthy:
 1 head-group
 4 size-l-group
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/360.0 CPU (0.0 used of 5.0 reserved in placement groups)
 35.0/35.0 GPU (35.0 used of 35.0 reserved in placement groups)
 0.0/5.0 accelerator_type:A10G
 0.00/1440.000 GiB memory
 0.00/396.161 GiB object_store_memory

Demands:
 (no resource demands)

I'd be happy to work together on making this work and helping commoditise bloom models and potentially other future models that can allow groups to run these very desirable models on their "commodity" hardware. I know we're not there yet but this all feels very much like the early days of spark which helped tremendously in democratising big data.

ddxxdd-code commented 1 year ago

Hi @pascalwhoop, thanks for the information! As shown in the backtrace, the error happens in the part of alpa/pipeline_parallel/compile_executable.py. And, this also happened when zhanyuanucb first reported this error: when running with 6 GPUs for a 30 layer model (bloom-7b1). I'm thinking of a possibility that the process of odd numbers (integer that equals to 1 mod 2) might cause problem: 30/6=5 and 35 itself an odd number. But I'm not sure about if that is the problem. While I don't have the computational resources sufficient for running the full model, I feel it might be prompting to raise this problem in a new issue? This issue's name "OOM" is sometimes misleading. Also, this issue has been closed last year, so I feel a new one is more appropriate. Please mention me in the new issue so we can discuss this further to find the source of this problem. Thx

ddxxdd-code commented 1 year ago

Some possible next step I saw that might help in this case is to look into how alpa plans the parallelization in this setup. As mentioned in issue #891 , I feel Inspect the parallelization strategy will be useful to debug (To look into how alpa planned the parallelization and how the assertion failure comes from).

ddxxdd-code commented 1 year ago

Hi @pascalwhoop I tried to look for this assertion failure yesterday. I found this assertion is commented out in the latest release of Alpa with a "TODO" on it as a walkaround. I suppose it might work even in the case of resharding? Please try updating to release 0.2.3 (released this week) and run to see if the error persists.