Open technillogue opened 2 weeks ago
torch._dynamo.explain
on the (FP8) decoder and it said no graph breaks.@yorickvP interesting that there's no graph breaks; wonder if the compiler is smart enough to realize that the if statements will always evaluate to true/false depending on loop iteration.
agreed that we should get perf numbers & test that outputs are unchanged
unfortunately this changes the state_dict keys:
Got 180 missing keys: encoder.down.0.norm1.weight encoder.down.0.norm1.bias encoder.down.0.conv1.weight encoder.down.0.conv1.bias
Got 180 unexpected keys: encoder.down.0.block.0.conv1.bias encoder.down.0.block.0.conv1.weight encoder.down.0.block.0.conv2.bias encoder.down.0.block.0.conv2.weight
the current autoencoder implementation causes graph breaks, likely due to python control flow. the
hs
list constructed in the encoder is also egregious. performance improvement numbers TBD