Closed shenzhuo closed 11 months ago
Hi @shenzhuo,
Thank you for reporting this issue and providing details for reproducing it!
I've created a PR https://github.com/microsoft/DeepSpeed/pull/3580 in the DeepSpeed repo updating the BLOOM container to inherit the HybridEngineContainer
feature and added a corresponding set_lora_params()
function.
I've been able to test on my end and see the BLOOM container working now.
Could you please test on your end as well?
Thanks, Lev
Hi @shenzhuo,
Thank you for reporting this issue and providing details for reproducing it!
I've created a PR #3580 in the DeepSpeed repo updating the BLOOM container to inherit the
HybridEngineContainer
feature and added a correspondingset_lora_params()
function.I've been able to test on my end and see the BLOOM container working now.
Could you please test on your end as well?
Thanks, Lev
Hi @lekurile ,
Thanks for the fix on this. I tried this.
First, the error is:
Traceback (most recent call last):
File "DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 562, in <module>
main()
File "DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 471, in main
out = trainer.generate_experience(prompts)
File "DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 97, in generate_experience
seq = self._generate_sequence(prompts)
File "DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 73, in _generate_sequence
seq = self.actor_model.module.generate(prompts,
File "/dcv/lib/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 245, in generate
generate_ret_vals = self._generate(*inputs, **kwargs)
File "/dcv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/generation/utils.py", line 1437, in generate
return self.greedy_search(
File "/dcv/lib/python3.9/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
outputs = self(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1208, in _call_impl
result = forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/models/bloom/modeling_bloom.py", line 913, in forward
transformer_outputs = self.transformer(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1208, in _call_impl
result = forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/models/bloom/modeling_bloom.py", line 786, in forward
outputs = block(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1208, in _call_impl
result = forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/deepspeed/model_implementations/transformers/ds_transformer.py", line 147, in forward
self.attention(input,
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/deepspeed/ops/transformer/inference/ds_attention.py", line 160, in forward
context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
File "/dcv/lib/python3.9/site-packages/deepspeed/ops/transformer/inference/ds_attention.py", line 253, in compute_attention
attn_mask=((1 - input_mask).half() * minus_inf),
File "/dcv/lib/python3.9/site-packages/torch/_tensor.py", line 39, in wrapped
return f(*args, **kwargs)
File "/dcv/lib/python3.9/site-packages/torch/_tensor.py", line 833, in __rsub__
return _C._VariableFunctions.rsub(self, other)
RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
More details can be seen here: issue
So, I changed the source code of DS, actually changed DeepSpeed/deepspeed/ops/transformer/inference/ds_attention.py. This change is:
As a result, the above bug was solved.
Second, there is another error:
Traceback (most recent call last):
File "DeepSpeedRLHF/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 678, in <module>
main()
File "DeepSpeedRLHF/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 502, in main
out = trainer.generate_experience(batch_prompt['prompt'],
File "DeepSpeedRLHF/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 107, in generate_experience
seq = self._generate_sequence(prompts, mask)
File "DeepSpeedRLHF/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 81, in _generate_sequence
seq = self.actor_model.module.generate(input_ids=prompts,
File "/dcv/lib/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 266, in generate
generate_ret_vals = self._generate(*inputs, **kwargs)
File "/dcv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/generation/utils.py", line 1607, in generate
return self.beam_search(
File "/dcv/lib/python3.9/site-packages/transformers/generation/utils.py", line 2905, in beam_search
outputs = self(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/models/bloom/modeling_bloom.py", line 913, in forward
transformer_outputs = self.transformer(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/transformers/models/bloom/modeling_bloom.py", line 786, in forward
outputs = block(
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/deepspeed/model_implementations/transformers/ds_transformer.py", line 157, in forward
self.attention(input,
File "/dcv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/dcv/lib/python3.9/site-packages/deepspeed/ops/transformer/inference/ds_attention.py", line 158, in forward
context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
File "/dcv/lib/python3.9/site-packages/deepspeed/ops/transformer/inference/ds_attention.py", line 247, in compute_attention
matmul_result = torch.matmul(query_layer, key_layer)
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
I can't understand this error, this is extremely weird.
same error
Hi @shenzhuo, @SabrinaZhuangxx,
After adding the changes in the following PRs:
I was able to get bigscience/bloomz-1b7
to train in DeepSpeed-Chat step 3, however, the critic model must be trained through step 2 of training first.
The command I used to run this looks as follow:
DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning$ bash training_scripts/bloom/single_node/run_bloom.sh bigscience/bloomz-1b7 ../step2_reward_model_finetuning/bloom_7b_output/ 2 2 output_bloom7b_actor_hf_critic_step2
Can you please try running again with the latest changes and instead of using bigscience/bloomz-1b7
for the critic model in step 3, please use a critic model trained through step 2 of DeepSpeed-Chat training.
Thanks, Lev
Hi @shenzhuo,
Closing the issue for now since solution was provided. If any issues are still encountered, feel free to open another issue.
Describe the bug When use hybrid_engine + bloomz, zero2. An error was reported, it seems to tell me that bloomz does not support hybrid_engine
Log output
To Reproduce the
run.sh
is:the
run_bloom_1b7.sh
is:Expected behavior DS_BloomContainer has attribute 'set_params_wo_copy' and can use hybrid engine to train
ds_report output
Screenshots no. The error is in the
Log output
System info (please complete the following information):
Docker context no
Additional context no
@cmikeh2 @jeffra @lekurile @awan-10