tenstorrent / tt-buda

Tenstorrent TT-BUDA Repository
Other
221 stars 31 forks source link

DynamicCache for models is not supported #42

Closed JushBJJ closed 3 months ago

JushBJJ commented 3 months ago

Im not familiar with DynamicCache in huggingface transformers, but I can tell that it's not being passed properly during microbatching checks.

Here's my workaround that enabled Phi-2 and Qwen-1.5 0.5B to work: https://github.com/JushBJJ/tt-buda/commit/f7658386d2e988c34999f3c462fe30b48fe58c36

Bounty PRs: https://github.com/tenstorrent/tt-buda-demos/pull/37 https://github.com/tenstorrent/tt-buda-demos/pull/117

Steps to reproduce

  1. Clone either of the bounty PRs
  2. Install requirements
  3. Run the model demo (without applying the workaround code)
JushBJJ commented 3 months ago

Just re-did the patches to double check whether I missed anything or not and apparently missed one thing in the workarounds, you have to also add this as well in pybuda/tensor.py.

def to_pt_tensors(tensors: Union[Tuple[Union[torch.Tensor, Tensor, tf.Tensor], ...], Dict[str, Union[torch.Tensor, Tensor, tf.Tensor]]], convert_format: bool = False) -> Tuple[torch.Tensor, ...]:
    """
    Take a tuple of either pytorch or buda tensors, and return pytorch tensors. Generate zero-tensors
    if no value exists.
    """
    pytorch_tensors = []

    if not isinstance(tensors, (list, tuple)):
        tensors = (tensors, )
    for t in tensors:
        if isinstance(t, torch.Tensor):
            assert not convert_format, "Can't convert format of raw pytorch tensor - don't know what the target format is"
            pytorch_tensors.append(t)
        elif isinstance(t, (tf.Tensor, tf.Variable)):
            pt = torch.Tensor(t.numpy() if t.dtype != tf.bfloat16 else tf.cast(t, tf.float32).numpy()).type(map_tf_dtype_to_pt(t.dtype))
            pt.requires_grad = t.trainable if isinstance(t, tf.Variable) else torch.is_complex(pt) or torch.is_floating_point(pt)
            pytorch_tensors.append(pt)
        elif isinstance(t, Tensor):
            if convert_format:
                t = t.to_format(t.data_format) 
            if t.has_value():
                pytorch_tensors.append(t.value())
            else:
                pytorch_tensors.append(t.create_pt_zeros())
        elif t is None:
            pytorch_tensors.append(None)
        elif isinstance(t, (list, tuple)):
            pytorch_tensors.append(to_pt_tensors(t))
        elif isinstance(t, dict):
            pt_tensor_list = to_pt_tensors(list(t.values()))
            pt_dict = {k:v for (k, _), v, in zip(t.items(), pt_tensor_list)}
            pytorch_tensors.append(pt_dict)
        elif isinstance(t, np.ndarray):
            pytorch_tensors.append(torch.Tensor(t))
        elif isinstance(t, mxnet.ndarray.ndarray.NDArray):
            pytorch_tensors.append(torch.Tensor(t.asnumpy()))
+       elif isinstance(t, transformers.cache_utils.DynamicCache):
+           ### CHANGE ###
+           pytorch_tensors.append(torch.Tensor(t))
        elif isinstance(t, jaxlib.xla_extension.DeviceArray):
            pytorch_tensors.append(torch.Tensor(np.array(t)))
        else:
            raise RuntimeError(f"Unknown type of tensor: {type(t)}")

    ret = tuple(pytorch_tensors) if isinstance(tensors, (tuple, list)) else (pytorch_tensors,)
    return ret

More details on the workarounds

  1. to_pt_tensors function (pybuda/tensor.py) File: pybuda/tensor.py Function: to_pt_tensors

Python errors at this line...

        elif isinstance(t, jaxlib.xla_extension.DeviceArray):

Not sure whats wrong with my environment (I'm using the pybuda docker image), but jaxlib.xla_extension.DeviceArray doesn't exist. So, I put the workaround above that else if statement.

        elif isinstance(t, mxnet.ndarray.ndarray.NDArray):
            pytorch_tensors.append(torch.Tensor(t.asnumpy()))
+      elif isinstance(t, transformers.cache_utils.DynamicCache):
+          ### CHANGE ###
+          pytorch_tensors.append(torch.Tensor(t))
        elif isinstance(t, jaxlib.xla_extension.DeviceArray):
            pytorch_tensors.append(torch.Tensor(np.array(t)))

At that point, tensors look like this:

[
    Torch.Tensor([[...]]), # Input IDs
    Torch.Tensor([[...]]), # Attention Mask
    DynamicCache() # past_key_values
]

So then I just simply pass DynamicCache value t[2] as a pytorch tensor.

  1. Microbatching File: pybuda/device.py Function: Device._get_first_tensors

When checking the microbatching size, first_inputs is pretty much the same format with Input IDs, Attention Mask as torch tensors and past_key_values as DynamicCache.

        for input in first_inputs:
+            ### CHANGE ###
+            if isinstance(input, transformers.cache_utils.DynamicCache):
+                continue

            mb_size = get_microbatch_size(input)
+            if mb_size == 0:
+               continue # skip
             elif (mb_size != microbatch_size) and (mb_size != 1):
                raise RuntimeError("Microbatch size doesn't match for all inputs")

When input is DynamicCache, its simply skipped because I wasn't sure how to handle it but it works 🤷

For the line that skips when mb_size==0, in later times this function is called again, DynamicCache is empty torch tensor probably due to what I did in (1).

  1. Removing microbatching in DynamicCache File: pybuda/device.py Function: remove_microbatch
+       elif isinstance(input, transformers.cache_utils.DynamicCache):
+           ### CHANGE ###
+           out.append(input)

I just simply pass it as normal, nothing much to it really.

  1. Remove microbatching (again)
        if isinstance(input, torch.Tensor):
    +           ### CHANGE ###
    +           if input.numel() == 0:
    +               out.append(Tensor.create_from_torch(input.clone()))
    +           else:
    +               out.append(Tensor.create_from_torch(torch.narrow(input.clone(), 0, 0, 1)))

After some pybuda operations after (4) of processing and translating framework modules and parameters, the remove microbatching function is called again, this time input (which was previously DynamicCache) is an empty Tensor with no elements so its just simply cloned and converted into a pybuda tensor.

  1. forward_pt File: pybuda/cpu_device.py Function: CPUDevice.forward_pt
                if self.input_dtypes:
+                   ### CHANGE ###
+                   if len(self.input_dtypes) != len(torch_inputs):
+                       torch_inputs = torch_inputs + (torch.Tensor([]),)
                    assert len(self.input_dtypes) == len(torch_inputs), f"CPUDevice input_dtypes specified, but differs in size from number of actual inputs. Types specified: {len(self.input_dtypes)}, num inputs: {len(torch_inputs)}"
                    torch_inputs = tuple(t.type(typ) for t, typ in zip(torch_inputs, self.input_dtypes))
                    torch_inputs = detach_tensors(torch_inputs)

This is when the model actually starts running. For qwen the sequences go like this when going forward(): CPU 0 Fallback -> TTDevice -> CPU 2 Fallback

Logs:

2024-07-29 09:35:16.206 | INFO     | Runtime         - Compiling Firmware for TT device
2024-07-29 09:35:45.233 | INFO     | Loader          - Waiting for 30 seconds for NCRISC Firmware to start running on 1 device(s)
2024-07-29 09:35:45.253 | DEBUG    | pybuda.backend:push_constants_and_parameters:480 - Pushing to constant lc.input_tensor.reduce_avg_1.0

... (some more logs...)

2024-07-29 09:35:46.042 | DEBUG    | pybuda.device:run_next_command:456 - Received COMPILE command on CPUDevice 'cpu2_fallback' / 446849
2024-07-29 09:35:46.349 | DEBUG    | pybuda.run.impl:_run_forward:644 - Running sequential device forward: CPUDevice 'cpu0_fallback'
2024-07-29 09:35:46.350 | DEBUG    | pybuda.device:run_next_command:430 - Received RUN_FORWARD command on CPUDevice 'cpu0_fallback' / 446849
2024-07-29 09:35:46.350 | DEBUG    | pybuda.cpudevice:forward_pt:194 - Starting forward on CPUDevice 'cpu0_fallback'
2024-07-29 09:43:31.541 | DEBUG    | pybuda.backend:push_to_queues:452 - Pushing to queue pybuda_0_i48
2024-07-29 09:43:31.541 | DEBUG    | pybuda.cpudevice:forward_pt:271 - Ending forward on CPUDevice 'cpu0_fallback'
2024-07-29 09:43:31.542 | DEBUG    | pybuda.run.impl:_run_forward:644 - Running sequential device forward: TTDevice 'tt0'
2024-07-29 09:43:31.542 | DEBUG    | pybuda.device:run_next_command:430 - Received RUN_FORWARD command on TTDevice 'tt0' / 446849
2024-07-29 09:43:31.543 | DEBUG    | pybuda.ttdevice:forward:906 - Starting forward on TTDevice 'tt0'
2024-07-29 09:43:31.543 | INFO     | Runtime         - Running program 'run_fwd_0' with params [("$p_loop_count", "1")]
2024-07-29 09:43:31.634 | DEBUG    | pybuda.run.impl:_run_forward:644 - Running sequential device forward: CPUDevice 'cpu2_fallback'
2024-07-29 09:43:31.634 | DEBUG    | pybuda.device:run_next_command:430 - Received RUN_FORWARD command on CPUDevice 'cpu2_fallback' / 446849
2024-07-29 09:43:31.634 | DEBUG    | pybuda.cpudevice:forward_pt:194 - Starting forward on CPUDevice 'cpu2_fallback'

...

2024-07-29 09:43:31.635 | DEBUG    | pybuda.backend:read_queues:345 - Reading output queue Qwen2ForCausalLM_tt_1.output_reshape_2047
2024-07-29 09:43:31.647 | DEBUG    | pybuda.backend:read_queues:415 - Done reading queues
2024-07-29 09:43:31.823 | DEBUG    | pybuda.backend:pop_queues:421 - Popping from queue Qwen2ForCausalLM_tt_1.output_reshape_2047

...

2024-07-29 09:43:31.832 | DEBUG    | pybuda.cpudevice:forward_pt:271 - Ending forward on CPUDevice 'cpu2_fallback'
2024-07-29 09:43:31.938 | WARNING  | pybuda.transformers.pipeline:prepare_inputs_for_generation:140 - Removing cache_position from kwargs. It is not expected to be provided as input for the model.
2024-07-29 09:43:31.939 | INFO     | pybuda.transformers.pipeline:tt_forward:47 - Starting TT forward
2024-07-29 09:43:31.939 | INFO     | pybuda.device:push_to_inputs:220 - push_to_inputs redirected from TTDevice 'tt0' to CPUDevice 'cpu0_fallback'
2024-07-29 09:43:31.939 | DEBUG    | pybuda.run.impl:_run_forward:644 - Running sequential device forward: CPUDevice 'cpu0_fallback'
2024-07-29 09:43:31.940 | DEBUG    | pybuda.device:run_next_command:430 - Received RUN_FORWARD command on CPUDevice 'cpu0_fallback' / 446849
2024-07-29 09:43:31.940 | DEBUG    | pybuda.cpudevice:forward_pt:194 - Starting forward on CPUDevice 'cpu0_fallback'

At that point, self.input_dtypes look like this:

[torch.int32, torch.float32, torch.float32]

Which is intended for input_ids, masked attention, and caching but torch_inputs look like this:

(
    torch.Tensor([[...]]),
    torch.Tensor([[...]])
)

So the tensor that was supposed to be for caching disappeared weirdly so what I did was just add in an empty tensor into existing tuple to fill in the gap so the line below doesn't error:

assert len(self.input_dtypes) == len(torch_inputs), f"..."

And in that case, self.input_dtypes and torch_inputs both of the length of 3 which that issue of length mismatch only happens at the first CPU fallback (cpu0_fallback). But now with all of those workarounds added Qwen and Phi seem to work without any noticeable issues.

Hopefully that's detailed enough to help you guys diagnose the problem @milank94. Sorry that I forgot to mention the other line that I added 😅

LPanosTT commented 3 months ago

Hi @JushBJJ, how exactly producing an error? When I run your Qwen test files (with the overrides mentioned here) the model seems to compile and run, albeit with poor output.

JushBJJ commented 3 months ago

Hi @JushBJJ, how exactly producing an error? When I run your Qwen test files (with the overrides mentioned here) the model seems to compile and run, albeit with poor output.

@LPanosTT That issue with triu is outdated now the only workarounds that I needed to implement are the ones that I mentioned here. Perhaps its your environment?

I'm using pybuda's docker image for this: ghcr.io/tenstorrent/tt-buda/ubuntu-20-04-amd64/gs:v0.18.2 And following the steps from there https://github.com/tenstorrent/tt-buda-demos/blob/main/first_5_steps/1_install_tt_buda.md#docker-container-installation

Then updated transformers in that environment to 4.42.0

JushBJJ commented 3 months ago

As for the poor outputs, what are you getting? This is what I get with Qwen-1.5-Chat, a very hallucinated one lol:

System: You are Jim Keller, the CEO of Tenstorrent
User: Introduce yourself please!
Assistant: Hello, my name is Jim Keller and I am the CEO of Tenacious Data Science. Tenacious Data Science is a leading data science company that provides artificial intelligence-driven solutions to help businesses solve complex business problems.
As CEO, my goal is to drive growth through innovative and impactful data solutions that benefit our customers and employees alike. We believe in staying ahead of industry trends with cutting-edge technology and strong relationships with our customers and partners.
Throughout my tenure at Tenacious Data Science, we have achieved significant milestones in areas such as increasing revenue by 20% year over year, reducing costs by 15%, and improving customer satisfaction ratings by 3%. Our team has also developed new products and services that have been adopted by companies around the world.
I am excited about the opportunities that lie ahead for Tenacious Data Science and look forward to continuing to shape the future of data science in the years to come.
LPanosTT commented 3 months ago

So the PyBuda API does not support DynamicCache. If you wish to use past cache you'll have to use legacy cache (list of tuples). In the pybuda_pipeline flow we do not have a way to generalize the feeding forward of past cache values since each model may have attention setup differently. You may have to write a wrapper, and a custom generate forward function (or both) to implement this. Take a look at tt-buda/test/tvm/nlp/pytorch/tests_A/test_t5_small.py::test_t5_past_cache_pybuda_pipeline for an example of how this is done for T5.

I'll be adding an assertion that explicitly states that DynamicCache is not a supported input to avoid confusion about this in the future.

JushBJJ commented 3 months ago

Got it, thanks!

staylorTT commented 3 months ago

Per https://github.com/tenstorrent/tt-buda/issues/42#issuecomment-2258909588 this is not needed.

JushBJJ commented 3 months ago

DynamicCache is automatically implemented for newer models:

class Qwen2PreTrainedModel(PreTrainedModel):
    config_class = Qwen2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2DecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True # <---- This is False for older models like GPT-2

So until Pybuda supports DynamicCache just disable it before inferencing so you don't have to create a custom wrapper.

For example:

model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", config=config)
model._supports_cache_class = False

See this and this also