pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

fixing GPTQ #148

Open HDCharles opened 3 months ago

HDCharles commented 3 months ago

Stack from ghstack (oldest at bottom):

Summary:

trying to fix the issue with kv_cache update by changing tracing into a tensor subclass. However it seems we have less success than the fx tracer. The fx tracer breaks due

k_out[:,:, input_pos] = k_val

getting traced as

new_var = torch.ops.aten.indexput(k_out, [None, None, input_pos], k_val)

with new var never being accessed afterward. new_var becomes hte correct multiInput value, but then is lost.

The subclass ont he other hand, tries to use the func "<slot wrapper 'setitem' of 'torch._C.TensorBase' objects>" which seems to not want to mutate k_out and so the attempt to make it a multiTensor fails.

Test Plan: sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags: