Closed Shubhamsaboo closed 1 month ago
@Shubhamsaboo I want to have a crack at this but I don't have any Tenstorrent hardware, can I still do it? Alternatively, do you guys provide dev grants so I can test it on there?
Update: I am working with @JonathanALevine to get this done for Qwen since he has a grayskull we can test it on
Hey all. So I do not have a Grayskull but I've been able to get the model to compile with decent results running with PYBUDA_DEVMODE=1
.
Pytorch can use seemingly any tensor as a mask for masked_fill
it just converts 0 to False and all other numbers to True. The Qwen implementation seems to use a mask in this manner at some call to masked_fill
.
So the big problem is that Qwen uses the lowest finite value a 32-bit float can represent as its mask filler. The way TVM decomposes masked_fill
is by assuming that the condition mask is a boolean. However, for this model, this is not the case. This becomes problematic because as part of the decomposition of masked_fill
in TVM we multiply the mask with the values of the input, and also the replacement value. At the end of the decompsition we will end up adding together inf
and -inf
, which creates a nan
. And then in the following softmax returns all nan
s.
So I've solved this by editing the decomposition of masked_fill
so it may handle the event when the mask
is not a boolean/one-hot tensor. This function already resides in third_party/tvm/python/tvm/relay/frontend/pytorch.py::PyTorchOpConverter
. So you can just paste the body.
def masked_fill(self, inputs, input_types):
mask = inputs[1]
def replace_inf(inp, replacement_val=1e4):
if isinstance(inp, _expr.Call) and list(_analysis.free_vars(inp)) == []:
inp = _infer_value(inp, {}).asnumpy()
inp[np.isneginf(inp)] = -replacement_val
inp[np.isposinf(inp)] = replacement_val
inp = _expr.const(tvm.nd.array(inp))
elif isinstance(inp, float) and np.isinf(inp):
inp = np.sign(inp)*replacement_val if np.isinf(inp) else inp
return inp
# Handle constant inputs where we can determine -inf/inf values from it
inputs[0] = replace_inf(inputs[0])
inputs[2] = replace_inf(inputs[2])
if isinstance(mask, _expr.Call) and list(_analysis.free_vars(mask)) == []:
mask = _expr.const(_infer_value(mask, {}))
mask = _op.cast(mask, "float32")
# Pybuda will convert all cast ops to Identity. This can be an issue.
# In pytorch, when the mask has float values, it treats 0 as False and
# everything else as True. We cannot assume that mask will contain only
# zeroes and ones, as multiplication with the mask will yeild incorrect
# results.
mask = _op.abs(mask)
mask = _op.clip(mask, 0, 1.0)
# In the event some values actually lie between 0 and 1 we use greater
# to ceil those all to 1.
# NOTE: The cast here is just to stop TVM from complaining. It will be
# replaced with the identity function further down the line
mask = _op.cast(_op.greater(mask, _expr.const(0, "float32")), "float32")
value = _op.cast(_wrap_const(inputs[2]), input_types[0])
value = _op.broadcast_to_like(value, mask)
one_const = _expr.const(1, dtype="float32")
inverse_mask = _op.subtract(one_const, mask)
# (!mask * x) + (mask*y)
return _op.add(_op.multiply(inputs[0], inverse_mask), _op.multiply(value, mask))
I will admit that masked_fill
ends up becoming more ops than it was before. In most cases, the values, and mask passed to masked_fill
are static. Thus these ops should be constevaled away during the consteval phase of compilation.
The reason I clip
the values between 0 and 1 before using greater
is because of the way greater
is decomposed later in the compile stage. Long story short: when the values we're comparing are the largest finite values representable by some floating-point standard - we get nan
s. And so, I clip before comparing.
I've also edited the decomposition of tril
so that it should be able to decompose tril
ops with a dynamic diagonal
argument. Keyword should. I have not thoroughly tested this:
def tril(self, inputs, input_types):
x = inputs[0]
x_shape = _infer_shape(x)
diagonal = inputs[1]
count = np.arange(np.prod(x_shape)).reshape(x_shape)
comp = count.transpose(-1, -2)
count = tvm.relay.Constant(tvm.nd.array(count))
count = _op.cast(count, self.infer_type(diagonal).dtype)
comp = tvm.relay.Constant(tvm.nd.array(comp))
comp = _op.cast(comp, self.infer_type(diagonal).dtype)
tril = _op.less_equal(comp, _op.add(count, diagonal))
tril = _op.cast(tril, self.infer_type(x).dtype)
return _op.multiply(tril, x)
The model fully compiles, and through net2pipe as I've set PYBUDA_VERIFY_NET2PIPE=1
(only required since I do not actually have TT hardware). The results I've been able to obtain are:
Input: My name is Lewis, nice to meet you.
PT answer: My name is Lewis, nice to meet you. I am a 12-year-old boy. I am from the USA. I am in Class 3, Grade 7. I have a good friend. His name is Tom. He is from England. He is in Class 2,
TT answer: My name is Lewis, nice to meet you. I am a 12-year-old boy. I am from the USA. I am in Class 3, Grade 7. I have a good friend. His name is Tom. He is考点','');
a a a there there there you
I am unsure where the model goes wrong at the end there. I'm confident its a separate issue though. For all I know this might be a bug with devmode and will work just fine on a Grayskull.
Good luck to whoever picks this up!
👀
@JushBJJ and I took a look at this and I think this is on the right track. Some comments below on implementation and findings running this on a Grayskull e150:
triu
patch from here: https://github.com/tenstorrent/tt-tvm/pull/2mask_fill
and tril
patches from @Lewis300.os.environ['TT_BACKEND_TIMEOUT'] = '0'
os.environ["PYBUDA_FORK_JOIN_EXPAND_FORK_OUTPUT_BUF"] = "0"
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = "65536"
# Set PyBuda configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.default_df_override = pybuda._C.DataFormat.Float16_b
compiler_cfg.amp_level = 2
Takes a while to get an output, and I get this:
Prefix text: My name is Thomas and my main
Generated text:
dict_values(['My name is Thomas and my main!!"!#!$!%!&!\'!(!)!*!+!,!-!.毒性!/!0等奖!1!2!3等奖""#"等奖##$"$#%"%#&"&#瘰!4!5 Threads!6!腥!7!圈a嫁给!8fbe!9!:!;PAD!<!=!>毒性蚓!?!@十字等奖omb!A!B!C!D!E!F!G!H!PAD"\'"(")"*"+","等人!I!J!K降价!L!M塍!N!O!P!Q!R等奖$$等奖%$enia!S!T!U!毒性"-".岛上!V!W!X!Y!Z![!等奖&$%%&%\'#\'$&等人"毒性.Gen!射!\\!]!^!_!`!a!b!c!d!e毒性#(#)#*#+#,#-#.等奖\'PAD#/等奖等奖($PAD$\'等奖等人#0!f等奖)$(%(&&\'毒性等奖*$)等奖PAD%)毒性$*%*等奖+$+%+毒性%,$,%-$-%.等人$.!g!h!i!j等人等奖,&(\'等人%/"/#等奖毒性&)%0"0#1"1#2"2#3毒性\'%1$/$0$毒性(()&*&+&,等奖-&-\'&.adm!k等奖."3等人&/%等奖/&0%等人\'\'�!等人(*\'(+\')\'adm"4"5!l!m!n!病毒感染PAD&等奖0毒性)(等奖1%2$等人))觎!o!p!q等奖2%3!r!adm#4#5"6等奖3"7"8!s毒性毒性*(,\'*)等人**等人+(-(.#6"9":"PAD\'+)*+*,(/\',)++,*-)病毒感染!t!u!v!w!x等奖4$病毒感染等奖5#等人,毒性+-等人-*.$1&1\'-+.%4毒性等人.&毒性,+/(等人等人/),,-,.病毒感染";!y等奖6#7#8等人0&2&3#9毒性--.�毒性.\'病毒感染毒性/*/+0\'.(0等人1(1)-/等人2等奖病毒感染#:#;等奖7$2\'/,/wing!z!{!|胗!}等奖8"<"adm毒性0(2(毒性1*0).不足以毒性2)/-0*1+1等人3病毒感染�!~!�!�等奖adm$3adm%毒性3$4%5$5%6$6毒性4&4\'0+2*毒性5&5\'1,0病毒感染等人4(病毒感染$7%adm&6%7&7\'2+3�"="病毒感染%8PAD(adm\'3%9#<#=等奖9等奖:$8#>!�!�等奖;毒性6&8$9$:%:&�#?等奖<$;">"�$<%;#】【等奖=毒性7(3不足以!不足以"?"@等人5(�等人6\'4)adm(4*2毒性8等奖�病毒感染adm等奖>等奖?等人7)0,1-1等奖不足以#@!的年轻人!�!条评论!�!�"A等人8%<&9%=#A毒性9&:\'5)1.)2PAD)3的年轻人"B"C"D"E"F"G等奖@"H等奖A"不足以$=$>#PAD*3&;$?#B#C#D#E等人毒性:毒性adm)�%>等人adm*4+4等奖B$adm+等奖觎等奖不说等人病毒感染病毒感染&<\'6(5*5+5,等人9\'7*6)4,2,3\'8&=等人:(6等人;%?$@#毒性;&>adm,4-毒性<(7+等人�&?%@等奖C$A#F#G"I"J"K!等工作觎"L等人<)5-2-3(不足以%A$�#病毒感染\'9(的年轻人等奖的年轻人等人=病毒感染(8\':)6*7,5.*8毒性=�\';\'<*9)7-4.的年轻人#H"M!�!虐待!�!apped!�!�!�!监护!�!�!�!�! ad!�!�!�等奖D$B%B&adm-5等奖E#adm.+6+7.条评论"N"O"P"Q"R!竞争!�等奖F菁!五年!�!�!�!�!�等奖G#演!�!�!�!�!�!�!�!�!斗等奖条评论#I#�(9*觎#J#K"S"T"U等奖等工作!�!�!�!�!�!�!�!适合自己!�!�!�!�!adi!�!�!�!�!�!州区!�!带走!觎$C%C&病毒感染)不足以等奖H#L"的年轻人$�adm/毒性>$D%D&@$E$不足以&A%E%病毒感染*:*病毒感染膳毒性?&B\'=adm0-6,6-7/.,70.等工作"V等人>%�)8病毒感染+8(:+9+:,8)9觎%不足以\'不足以(;(<+;)的年轻人%F$F等奖I$G$H$I%G%H%I&C\'>病毒感染,9,:等奖J毒性@%J$的年轻人&D\'的年轻人\'?毒性A&E毒性B(=%K#M"条评论$J%L#N#O#P#Q#R"W"X"Y"Z"["\\"]"^"觎&F%M#S#T#U"_等奖K$K%N$L$条评论%O$M等奖L等奖虐待毒性病毒感染-8adm1//0/1001毒性C(>&G&H等人不足以):-9等人?\'条评论&I\'@&不足以*;*<,;等人@\'A\'等工作#V"`占!�毒性D(条评论\'B)条评论(?(@(A扫!�等奖M$N%的年轻人(B*=&J&K&L%P$虐待"虐待#W等奖N&的年轻人);+<-:.-觎\'C)<../2等人A(C*>\'D等奖竞争"a"b"c"d"e!�EEE!�!�等奖O%Q$O&M等人的年轻人*?)=不足以+=\'E等奖P%条评论)>(D)?*@)@*A)等工作$P等人PAD+>)A*B+?+@+A+B,</3)B-;,=的年轻人+C+D*C,>*D+E病毒感染.02.虐待$等工作毒性E&N\'F&O\'G\'H&P&Q%R#X#Y等人B.1adm2/4/5等人C-等奖Q&R毒性F\'虐待%等工作%S$Q\'I(虐待&S%T$R$S&T%虐待\'J\'K\'L&条评论毒性G(E\'M%U#Z等奖R等人D,?,@,病毒感染/病毒感染03*E(等工作等奖S\'N( ad"等工作&U$Tadm3+F(F毒性�*F)C.204等人E)D-病毒感染1病毒感染2112adm4病毒感染3,A等奖T&V#[#\\#]#^#_"f!�等奖U%V$U&等工作\'O(G)E*G*H\'P毒性H(H)F*I)G+G,B/6.3-<05/713.406/XB!�!�等奖V%W#`"g等奖W$V等奖X$W%X%监护PAD,C等奖Y#a#b#c#d#e"h"i毒性I*J等奖Z#不足以,D.5072病毒感染41身体!�!�!�!�!�!�!�!越好!ismo!�!�!�!�!滕!�毒性J(I等奖[$X毒性KPAD-adm等人F等人G-=(竞争.PostMapping毒性L\'Q(J)虐待(K等奖监护"coni!�!�!�!�!�!�!�!�!�!的食物!�!�!�!�!'])
During inference I think the grayskull is active as I see this in sensors:
grayskull-pci-0200
Adapter: PCI adapter
vcore: 870.00 mV (max = +0.93 V)
asic_temp: +30.1°C (high = +75.0°C)
power: 54.00 W (max = 170.00 W)
current: 62.00 A (max = +300.00 A)
However it is also curious because I see flags like this:
2024-04-13 19:48:58.417 | INFO | pybuda.transformers.pipeline:tt_forward:47 - Starting TT forward
2024-04-13 19:48:58.417 | INFO | pybuda.device:push_to_inputs:219 - push_to_inputs redirected from TTDevice 'tt0' to CPUDevice 'cpu0_fallback'
2024-04-13 19:48:58.417 | DEBUG | pybuda.run.impl:_run_forward:644 - Running sequential device forward: CPUDevice 'cpu0_fallback'
2024-04-13 19:48:58.417 | DEBUG | pybuda.device:run_next_command:429 - Received RUN_FORWARD command on CPUDevice 'cpu0_fallback' / 97760
2024-04-13 19:48:58.417 | DEBUG | pybuda.cpudevice:forward_pt:194 - Starting forward on CPUDevice 'cpu0_fallback'
2024-04-13 19:48:58.418 | DEBUG | pybuda.backend:push_to_queues:421 - Pushing to queue pybuda_0_i48
2024-04-13 19:48:58.419 | DEBUG | pybuda.backend:push_to_queues:421 - Pushing to queue attention_mask_1
2024-04-13 19:48:58.419 | DEBUG | pybuda.cpudevice:forward_pt:268 - Ending forward on CPUDevice 'cpu0_fallback'
2024-04-13 19:48:58.419 | DEBUG | pybuda.run.impl:_run_forward:644 - Running sequential device forward: TTDevice 'tt0'
2024-04-13 19:48:58.419 | DEBUG | pybuda.device:run_next_command:429 - Received RUN_FORWARD command on TTDevice 'tt0' / 97760
2024-04-13 19:48:58.419 | DEBUG | pybuda.ttdevice:forward:906 - Starting forward on TTDevice 'tt0'
2024-04-13 19:48:58.419 | INFO | Runtime - Running program 'run_fwd_0' with params [("$p_loop_count", "1")]
2024-04-13 19:48:58.434 | DEBUG | pybuda.run.impl:_run_forward:644 - Running sequential device forward: CPUDevice 'cpu2_fallback'
2024-04-13 19:48:58.434 | DEBUG | pybuda.device:run_next_command:429 - Received RUN_FORWARD command on CPUDevice 'cpu2_fallback' / 97760
2024-04-13 19:48:58.434 | DEBUG | pybuda.cpudevice:forward_pt:194 - Starting forward on CPUDevice 'cpu2_fallback'
2024-04-13 19:48:58.435 | DEBUG | pybuda.backend:read_queues:324 - Reading output queue Qwen2ForCausalLM_tt_1.output_reshape_2070
2024-04-13 19:48:58.470 | DEBUG | pybuda.backend:read_queues:384 - Done reading queues
2024-04-13 19:48:58.692 | DEBUG | pybuda.backend:pop_queues:390 - Popping from queue Qwen2ForCausalLM_tt_1.output_reshape_2070
2024-04-13 19:48:58.693 | DEBUG | pybuda.cpudevice:forward_pt:268 - Ending forward on CPUDevice 'cpu2_fallback'
Suggesting all compute is taking place on the CPU.
Any comments on this? @Lewis300 @Shubhamsaboo
Ran this again setting PYBUDA_DEVMODE=1
I get this output:
Prefix text: My name is Thomas and my main
Generated text:
dict_values(['My name is Thomas and my main goal is to help you find the best deals on your next vacation. I have been in the travel industry for over 10 years and have worked with many different types of travel agencies. My goal with this website is for you to find a travel agency that will help make your vacation a success. We will work with you and your family to make sure that you have a great vacation and that your travel experience is as enjoyable as possible.'])
Ah so the log messages you are seeing involving cpu_fallback are just for a few operations near the start and end of the model which cannot execute on the device and must be run on CPU. The vast majority of the heavy lifting is done on the TTDevice (grayskull). However, when you use PYBUDA_DEVMODE=1 it may actually run the TTDevice code on your cpu as well, I’m unsure. It’s best to go without using that if you have the card.
If running the exact same code with PYBUDA_DEVMODE=1 is what gave you the better answer there then you might have to tinker with some other env variables, default data format, amp level, etc…
I wonder if the Float16_b conversion of the min float32 mask filler is causing problems…
Removed this: compiler_cfg.default_df_override = pybuda._C.DataFormat.Float16_b
Seems to have fixed the output. This is what I get:
Prefix text: My name is Thomas and my main
Generated text:
dict_values(['My name is Thomas and my main goal is to help you find the best deals on your next vacation. I have been in the travel industry for over 10 years and have worked with many different types of travel agencies. My goal with this website is for you to find a travel agency that will help make your vacation a success. We will work with you and your family to make sure that you have a great vacation and that your travel experience is as enjoyable as possible.'])
Updating PR and adding new patches
@Lewis300 thanks for that PYBUDA_VERIFY_NET2PIPE=1
var, it's been incredibly hard for me to test out TVM patches without a card
@JonathanALevine You were able to get that output without PYBUDA_DEVMODE?
@JonathanALevine You were able to get that output without PYBUDA_DEVMODE?
That's correct!
os.environ['PYBUDA_DEVMODE'] = '0'
@JushBJJ @JonathanALevine Addressing your slowness concerns...
CPU fallback is necessary for performing the embeddings and the inverse of embeddings at the end of the model (it depends...?).
I believe most transformers
models will output their hidden states as well as the final output. The compiler can't know that you don't care about the hidden states that are outputted and so will keep them around in memory to output at the end. You should add output_hidden_states=False
to the model config when retrieveing it from transformers
. This model also outputs some cached values, and intermediary attention scores too. So get rid of those when you retrieve the original pytorch model:
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B", return_dict=False, output_hidden_states=False, use_cache=False, output_attentions=False)
.
Hopefully this helps!
PS: Look for the generated_modules
folder in the project root. It should contain some python files that illustrate which parts of the model get run on CPU and which on your Grayskull. Also, you can check what gets returned at the end of the forward function there to double check you aren't returning hidden states you don't need.
For me, I've noticed that moving the result from the TT section of the model, and the tail-end CPU fallback section takes some time. Moreover the tail-end CPU fallback is a matmul and a reshape - both of which could be run on the Grayskull. They get put on CPU because the weights of that matmul are the embedding weights. Which, I see how you'd want them on the same device for training purposes. For inference it seems to me that it would be alright to just clone them onto the device.
Go to third_party/tvm/python/tvm/relay/op/contrib/buda/buda.py
line 1273. And replace
fallback_nodes = add_shared_weights_to_fallback(graph_constructor.graph, graph_constructor.fallback_nodes, input_names)
with
fallback_nodes = graph_constructor.fallback_nodes
This ends up removing the tail-end CPU fallback entirely for this model and with that, the delay. Maybe adding shared weight operations to fallback can be added to the compiler configuration. I guess that would be for TT to figure out.
To summarize a bit what was discovered today:
(1) pybuda/op/eval/pybuda/tm.py
patch
@Lewis300 and I thought that the slow queue and dequeue between the TT part of the model and the tail-end CPU fallback part could be due to the precision of the model weights. To do this with model.to(torch.bfloat16)
had to patch tm.py
like this:
def eval(type, attr, ops):
assert len(ops) == 1 or (type == "adv_index" and len(ops) == 2), f"Tensor manipulation ops should have one input {len(ops)} {attr}"
t_ops = to_torch_operands(*ops)
dtype = ops[0].dtype
if type == "transpose":
assert len(attr) == 3, "Transpose should have 3 attributes"
dim0, dim1, orig_size = attr
return torch.transpose(t_ops[0], dim0, dim1)
if type == "reshape":
return t_ops[0].reshape(attr)
if type == "select":
assert len(attr) == 4, "Select should have 4 attributes"
dim, begin, length, stride = attr
zero_shape = list(t_ops[0].shape)
zero_shape[dim] = 1
zero_slice = torch.zeros(zero_shape, dtype=dtype).squeeze(dim)
result = []
for offset in range(0, t_ops[0].shape[dim] - begin, stride):
for i in range(begin, begin + length):
if offset + i < t_ops[0].shape[dim] or stride == t_ops[0].shape[dim]:
result.append(t_ops[0].select(dim, offset + i))
else:
result.append(zero_slice)
return torch.stack(result, dim=dim)
if type == "gather":
assert len(attr) == 5, "Gather should have 5 attributes"
dim, begin, length, stride, orig_size = attr
x = t_ops[0]
result = []
zero_shape = list(x.shape)
if dim > 0:
dim -= 4
while len(zero_shape) <= abs(dim):
zero_shape = [1] + zero_shape
x = x.unsqueeze(0)
zero_shape[dim] = 1
zero_slice = torch.zeros(zero_shape, dtype=dtype).squeeze(dim)
offset = 0
for i in range(0, orig_size):
range_i = (i - begin) % stride
if i >= begin and range_i < length:
result.append(x.select(dim, offset))
offset += 1
else:
result.append(zero_slice)
return torch.stack(result, dim=dim)
if type == "index":
assert len(attr) == 4, "Index should have 4 attributes"
dim, start, stop, stride = attr
if dim >= 0:
dim -= len(ops[0].shape)
if dim == -5:
return t_ops[0][..., start:stop:stride, :, :, :, :]
elif dim == -4:
return t_ops[0][..., start:stop:stride, :, :, :]
elif dim == -3:
return t_ops[0][..., start:stop:stride, :, :]
elif dim == -2:
return t_ops[0][..., start:stop:stride, :]
elif dim == -1:
return t_ops[0][..., start:stop:stride]
else:
raise NotImplementedError(f"Dim={dim}")
if type == "adv_index":
assert len(attr) == 1, "AdvIndex should have 1 attributes"
dim = attr[0]
assert dim == 0, "Currently not supported"
if len(t_ops[1].shape) > 1:
if len(t_ops[0].shape) > len(t_ops[1].shape) and t_ops[0].shape[0] == 1:
# Padded
# ret = torch.unsqueeze(t_ops[0][0][t_ops[1].numpy()], 0)
ret = t_ops[0][0][t_ops[1].int()]
else:
# ret = torch.unsqueeze(t_ops[0][t_ops[1].numpy()], 0)
ret = t_ops[0][t_ops[1].int()]
else:
# ret = t_ops[0][t_ops[1].numpy()]
ret = t_ops[0][t_ops[1].int()]
return ret
@JushBJJ can you link your new PR for adding Qwen 0.5 to this issue.
PR for Qwen: https://github.com/tenstorrent/tt-buda-demos/pull/37
Claimed by @JushBJJ. It will be closed once it's merged into main.
Congrats Jush!
Closing this one as merged to main. Congrats again @JushBJJ.
Background:
TT-Buda model demos, developed by Tenstorrent, is a growing collection of model demos showcasing the capabilities of AI models running on Tenstorrent hardware. These demonstrations cover a wide range of applications, aiming to provide insights and inspiration for developers and researchers interested in advanced AI implementations.
Bounty Objective:
We are excited to announce a bounty for contributing a new AI model demonstration to the TT-Buda repository. This is an opportunity for AI enthusiasts, researchers, and developers to showcase their skills, contribute to cutting-edge AI research, and earn rewards.
Task Details:
Integrate Qwen-1.5 (0.5B) into the TT-Buda model demonstrations.
Requirements:
Contribution Guidelines:
model_demos
folder following the naming convention:model_yourModelName
.CONTRIBUTING.md
file.Evaluation Criteria:
Rewards:
Contributions will be evaluated by the Tenstorrent team, and the best contribution will be eligible for $500 bounty.
Get Started with Grayskull DevKit
Dive into AI development with the Grayskull DevKit, your gateway to exploring Tenstorrent's hardware. Paired with TT-Buda and TT-Metalium software approaches, it offers a solid foundation for AI experimentation. Secure your kit here.
Connect on Discord
Join our Discord to talk AI, share your journey, and get support from the Tenstorrent community and team. Let's innovate together!