Open JushBJJ opened 1 week ago
Could be related to https://github.com/tenstorrent/tt-forge-fe/issues/309 since you guys are getting very similar errors
This last error is filed bug on metal, and it is P0: https://github.com/tenstorrent/tt-metal/issues/13120.
For the first one, can you check compile_tvm_to_python
:
Fist: Does the weight parameter have dtype bfloat16 when you pass framework_mod
to the method load_tvm_graph
.
Second: That load_tvm_graph
returns json_graphs
where that node still has framework_dtype
defined and equal to bfloat16. Also check that weights
(also returned by load_tvm_graph
) also matches with dtype...
Probably attribute 'framework_dtype' iz not properly populated inside load_tvm_graph or something like that...
Thanks for reporting this issue @JushBJJ! 🙌 @dgolubovicTT already provided a few valid points! :))
In sum, regarding the latter issue, we expect PR to be merged soon on the Metal project:
Regarding DF issues you're facing on FFE, here are a few points:
dev_data_format
isn't set during ForgeModule codegen of add_parameter
. As a consequence, we're defaulting to float32. Here is a part of generated ForgeModule (output of TVM):
class Foo(ForgeModule):
def __init__(self, name):
super().__init__(name)
self.add_parameter("weights", forge.Parameter(*(1024, 2048), requires_grad=True, dev_data_format=forge.DataFormat.Float32))
You can see an invalid data type.
P.S. For easier tracking, I'm adding issues for removing data type attributes from Transpose op:
@JushBJJ let me know if you would like to take a look into this issue with our help :))
If this isn't on your priority list, I'll see to find someone from our side to take a look (probably next week).
let me know if you would like to take a look into this issue with our help :)) If this isn't on your priority list, I'll see to find someone from our side to take a look (probably next week).
Don't worry it isn't priority at the moment, this is just me playing around forge early, I was originally trying to do a simple feedforward in Llama 3.2 1B first before getting into other parts of the model like attention :smile:
For the first one, can you check
compile_tvm_to_python
: Fist: Does the weight parameter have dtype bfloat16 when you passframework_mod
to the methodload_tvm_graph
. Second: Thatload_tvm_graph
returnsjson_graphs
where that node still hasframework_dtype
defined and equal to bfloat16. Also check thatweights
(also returned byload_tvm_graph
) also matches with dtype...Probably attribute 'framework_dtype' iz not properly populated inside load_tvm_graph or something like that...
Yes, the parameter does have bfloat16 dtype when passed to load_tvm_graph
. This is the json_graph i get when returned: https://pastebin.com/raw/1Vcer2Wi, the generated modules is pretty much what @nvukobratTT said above.
I noticed that in the json_graph that framework_dtype
is "N/A", is that normal? Because in _determine_node_dtype(node)
it would always return float32
in that case instead of bfloat16, looks like framework_dtype is correctly saved but then get's lost along the way in TVM
I noticed that in the json_graph that framework_dtype is "N/A", is that normal? Because in _determine_node_dtype(node) it would always return float32 in that case instead of bfloat16, looks like framework_dtype is correctly saved but then get's lost along the way in TVM
Yeah, this is probably a bug that is worth sorting out.
Don't worry it isn't priority at the moment, this is just me playing around forge early, I was originally trying to do a simple feedforward in Llama 3.2 1B first before getting into other parts of the model like attention 😄
Thanks for the details! We're currently pushing few other priorities, but I'll see to find someone to take a look in next few weeks :))
System Information OS: Ubuntu 22.04 (via docker) Arch: grayskull e75
Commit: https://github.com/tenstorrent/tt-forge-fe/commit/685c8954bc0bd93e00bf84fe68a2cd65063a67c0 (latest)
Problem When you are running a simple pytorch module like this for example:
The model compiles, but when TT-Forge runs the compiled model binary, this error occurs:
Full log here: https://pastebin.com/raw/CNzjP0jU
I also get this debug log, which might hint something:
which makes sense because its only the weights that are defaulted to float32 and not bfloat16.
If you change
populate_transpose_args()
function intt-forge/forge/forge/tvm_to_python.py
:then the error goes away.
But then this error occurs:
Any idea why this is happening?