Closed fabianlim closed 1 month ago
so because of the casting we are facing this error in #25 now
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 75, in _forward_q
_fused_op(X)
File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 64, in _fused_op
Q, K, V = fused_operation(attn, X)
File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 620, in apply_lora_qkv
Q, K, V = LoRA_QKV.apply(
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
return fwd(*args, **kwargs)
File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 464, in forward
QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)
File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py", line 137, in dequant248
dequant_kernel_248[grid](
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench
fn()
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
self.fn.run(
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run
self.cache[device][key] = compile(
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 543, in compile
next_module = compile_kernel(module)
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 435, in <lambda>
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
GPTQ-LoRa depends on the AutoGPTQ package, but there are issues that prevent the base GPTQ model from being FSDPed
The issue comes from that
QuantLinear
class stores the parameters (i.e.qweight, qzeros
) intorch.int32
, which results intorch.nn.Parameter
, and these aretorch.Tensor
The fix is to then use
torch.tensor.view
, which does a C++ style reinterpret cast inQuantLinear.forward
before call in the theQuantLinearFunction
autograd function. We create thenn.Parameter
in the same vein, by doing aqweight.view(torch_type)
to force the parameter to be a oftorch_type
(which is going to be a float type)Reproduce
To reproduce the fix, consider the command
Losses and Throughputs AFTER FIX
TODO:
g_idx
andscales
as parameters so they can be sharded. Update: the code is quite flexibile now and easy to add more parameterslow_cpu_mem_usage
properly. the model is currently unncessarily loading the full model into GPU memory beforeprepare
, which should be avoided.QuantLinear
like marlin, etc (may not do).