rmihaylov / mpttune

Tune MPTs
Apache License 2.0
84 stars 16 forks source link

Can this run on multi GPU setup? #9

Closed UmarJawad closed 1 year ago

UmarJawad commented 1 year ago

Hi I am using an ec2 instance 'g5.12xlarge' with 4 A10g GPU (28x4 GB) gpus. I was able to succesfully fine tune and generate on a single GPU but for a multigpu machine. I get this error:

"Traceback (most recent call last): File "/home/ec2-user/venv/bin/mpttune", line 33, in sys.exit(load_entry_point('mpttune==0.1.0', 'console_scripts', 'mpttune')()) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/run.py", line 87, in main args.func(args) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/generate.py", line 71, in generate generated_ids = model.generate( File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/generate.py", line 27, in autocast_generate return self.model.non_autocast_generate(*args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1565, in generate return self.sample( File "/home/ec2-user/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2612, in sample outputs = self( File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/model/mpt/model.py", line 864, in forward outputs = self.transformer( File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/model/mpt/model.py", line 772, in forward layer_outputs = decoder_layer( File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/model/mpt/model.py", line 443, in forward (b, self_attn_weights, present_key_value) = self.attn( File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/model/mpt/model.py", line 373, in forward qkv = self.Wqkv(hidden_states) File "/home/ec2-user/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, *kwargs) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/backend/triton/quantlinear.py", line 17, in forward out = self._forward_no_grad(x) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/backend/triton/quantlinear.py", line 26, in _forward_no_grad return tu.triton_matmul(x, self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/backend/triton/triton_utils.py", line 246, in triton_matmul matmul_248_kernel[grid](input, qweight, output, File "/home/ec2-user/venv/lib/python3.10/site-packages/mpttune-0.1.0-py3.10.egg/mpttune/backend/triton/custom_autotune.py", line 110, in run return self.fn.run(args, num_warps=config.num_warps, num_stages=config.num_stages, kwargs, **config.kwargs) File "", line 23, in matmul_248_kernel RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered"

rmihaylov commented 1 year ago

Yes.