tenstorrent / tt-forge-fe

The TT-Forge FE is a graph compiler designed to optimize and transform computational graphs for deep learning models, enhancing their performance and efficiency.
https://docs.tenstorrent.com/tt-forge-fe/
Apache License 2.0
15 stars 1 forks source link

[Bug] RuntimeError: Tensor 1 - data type mismatch: expected Float32, got BFloat16 #369

Open JushBJJ opened 1 week ago

JushBJJ commented 1 week ago

System Information OS: Ubuntu 22.04 (via docker) Arch: grayskull e75

Commit: https://github.com/tenstorrent/tt-forge-fe/commit/685c8954bc0bd93e00bf84fe68a2cd65063a67c0 (latest)

Problem When you are running a simple pytorch module like this for example:

import forge
import torch

from torch import nn

class foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(1024, 2048, dtype=torch.bfloat16))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.linear(x, self.weights)

def test_bug():
    model = foo()
    inputs = torch.rand((1, 12, 2048), dtype=torch.bfloat16)
    golden_output = model(inputs) # passes
    forge_model = forge.compile(model, sample_inputs=inputs)

    out = forge_model(inputs)
    print(out)

test_bug()

The model compiles, but when TT-Forge runs the compiled model binary, this error occurs:

2024-10-06 15:23:32.996 | FATAL    | TTDevice        - Tensor 1 - data type mismatch: expected Float32, got BFloat16

Full log here: https://pastebin.com/raw/CNzjP0jU

I also get this debug log, which might hint something:

2024-10-06 15:57:48.637 | DEBUG    | forge.tvm_to_python:_determine_node_dtype:1718 - Node 'weights' does not have a framework dtype specified. Using TVM generated dtype.

which makes sense because its only the weights that are defaulted to float32 and not bfloat16.

If you change populate_transpose_args() function in tt-forge/forge/forge/tvm_to_python.py:

def populate_transpose_args(graph, nid, compiler_cfg):
    node = graph["nodes"][nid]
    axes = [int(axis) for axis in node["attrs"]["axes"][0]]
    transpose_shape = list(graph["nodes"][nid]["forge_shape"])

    assert int(node["attrs"]["num_inputs"]) == 1

    for i, axis in enumerate(axes):
        if axis < 0:
            axes[i] += len(transpose_shape)

    node["attrs"]["axes"] = axes

    transpose_axes = []
    for idx, axis in enumerate(axes):
        if axis != idx:
            transpose_axes.insert(0, axis)

    # Tmp. Needs to be removed after full Jax Bert support
    if len(transpose_axes) == 0:
        transpose_axes = axes

    assert (
        len(transpose_axes) == 2
    ), "only single axis transpose supported at this time, decompose in tvm"

    transpose_axes = [axis - len(transpose_shape) for axis in transpose_axes]

    args = []
    args.append(("dim0", f"{transpose_axes[0]}"))
    args.append(("dim1", f"{transpose_axes[1]}"))

    # If transpose unpadded Z dim, record the original shape
    if (transpose_axes[0] == -3 and transpose_axes[1] != -4) or (transpose_axes[0] == -4 and transpose_axes[1] != -3):
        args.append(("z_dim_slice", f"{transpose_shape[transpose_axes[0]]}"))
    elif (transpose_axes[1] == -3 and transpose_axes[0] != -4) or (transpose_axes[1] == -4 and transpose_axes[0] != -3):
        args.append(("z_dim_slice", f"{transpose_shape[transpose_axes[1]]}"))

-   args.append(("out_dtype", "torch." + node['attrs']['dtype'][0][0]))
+   args.append(("out_dtype", "torch." + "bfloat16"))

    return args

then the error goes away.

But then this error occurs:

Always | FATAL    | Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

Any idea why this is happening?

JushBJJ commented 1 week ago

Could be related to https://github.com/tenstorrent/tt-forge-fe/issues/309 since you guys are getting very similar errors

dgolubovicTT commented 1 week ago

This last error is filed bug on metal, and it is P0: https://github.com/tenstorrent/tt-metal/issues/13120.

dgolubovicTT commented 1 week ago

For the first one, can you check compile_tvm_to_python: Fist: Does the weight parameter have dtype bfloat16 when you pass framework_mod to the method load_tvm_graph. Second: That load_tvm_graph returns json_graphs where that node still has framework_dtype defined and equal to bfloat16. Also check that weights (also returned by load_tvm_graph) also matches with dtype...

Probably attribute 'framework_dtype' iz not properly populated inside load_tvm_graph or something like that...

nvukobratTT commented 1 week ago

Thanks for reporting this issue @JushBJJ! 🙌 @dgolubovicTT already provided a few valid points! :))

In sum, regarding the latter issue, we expect PR to be merged soon on the Metal project:

Regarding DF issues you're facing on FFE, here are a few points:

  1. Transpose op shouldn't even have an output DF attribute. I think we should remove it :))
  2. As @dgolubovicTT mentioned, seems like weights are defaulted to float32 instead of proper data type. In sum, my assumption is that dev_data_format isn't set during ForgeModule codegen of add_parameter. As a consequence, we're defaulting to float32. Here is a part of generated ForgeModule (output of TVM):
    class Foo(ForgeModule):
    def __init__(self, name):
        super().__init__(name)
        self.add_parameter("weights", forge.Parameter(*(1024, 2048), requires_grad=True, dev_data_format=forge.DataFormat.Float32))

You can see an invalid data type.

P.S. For easier tracking, I'm adding issues for removing data type attributes from Transpose op:

nvukobratTT commented 1 week ago

@JushBJJ let me know if you would like to take a look into this issue with our help :))

If this isn't on your priority list, I'll see to find someone from our side to take a look (probably next week).

JushBJJ commented 1 week ago

let me know if you would like to take a look into this issue with our help :)) If this isn't on your priority list, I'll see to find someone from our side to take a look (probably next week).

Don't worry it isn't priority at the moment, this is just me playing around forge early, I was originally trying to do a simple feedforward in Llama 3.2 1B first before getting into other parts of the model like attention :smile:

For the first one, can you check compile_tvm_to_python: Fist: Does the weight parameter have dtype bfloat16 when you pass framework_mod to the method load_tvm_graph. Second: That load_tvm_graph returns json_graphs where that node still has framework_dtype defined and equal to bfloat16. Also check that weights (also returned by load_tvm_graph) also matches with dtype...

Probably attribute 'framework_dtype' iz not properly populated inside load_tvm_graph or something like that...

Yes, the parameter does have bfloat16 dtype when passed to load_tvm_graph. This is the json_graph i get when returned: https://pastebin.com/raw/1Vcer2Wi, the generated modules is pretty much what @nvukobratTT said above.

I noticed that in the json_graph that framework_dtype is "N/A", is that normal? Because in _determine_node_dtype(node) it would always return float32 in that case instead of bfloat16, looks like framework_dtype is correctly saved but then get's lost along the way in TVM 2024-10-08-015933_1129x1113_scrot

nvukobratTT commented 1 week ago

I noticed that in the json_graph that framework_dtype is "N/A", is that normal? Because in _determine_node_dtype(node) it would always return float32 in that case instead of bfloat16, looks like framework_dtype is correctly saved but then get's lost along the way in TVM

Yeah, this is probably a bug that is worth sorting out.

Don't worry it isn't priority at the moment, this is just me playing around forge early, I was originally trying to do a simple feedforward in Llama 3.2 1B first before getting into other parts of the model like attention 😄

Thanks for the details! We're currently pushing few other priorities, but I'll see to find someone to take a look in next few weeks :))