replicate / cog-flux

Cog inference for flux models
https://replicate.com/black-forest-labs/flux-dev
Apache License 2.0
272 stars 28 forks source link

use nn.Sequential to remove python control flow from autoencoder up/downsampling #33

Open technillogue opened 2 weeks ago

technillogue commented 2 weeks ago

the current autoencoder implementation causes graph breaks, likely due to python control flow. the hs list constructed in the encoder is also egregious. performance improvement numbers TBD

yorickvP commented 2 weeks ago
daanelson commented 2 weeks ago

@yorickvP interesting that there's no graph breaks; wonder if the compiler is smart enough to realize that the if statements will always evaluate to true/false depending on loop iteration.

agreed that we should get perf numbers & test that outputs are unchanged

technillogue commented 1 week ago

unfortunately this changes the state_dict keys:

Got 180 missing keys: encoder.down.0.norm1.weight encoder.down.0.norm1.bias encoder.down.0.conv1.weight encoder.down.0.conv1.bias

encoder.down.0.norm2.weight encoder.down.0.norm2.bias encoder.down.0.conv2.weight encoder.down.0.conv2.bias encoder.down.1.norm1.weight encoder.down.1.norm1.bias encoder.down.1.conv1.weight encoder.down.1.conv1.bias encoder.down.1.norm2.weight encoder.down.1.norm2.bias encoder.down.1.conv2.weight encoder.down.1.conv2.bias encoder.down.2.conv.weight encoder.down.2.conv.bias encoder.down.3.norm1.weight encoder.down.3.norm1.bias encoder.down.3.conv1.weight encoder.down.3.conv1.bias encoder.down.3.norm2.weight encoder.down.3.norm2.bias encoder.down.3.conv2.weight encoder.down.3.conv2.bias encoder.down.3.nin_shortcut.weight encoder.down.3.nin_shortcut.bias encoder.down.4.norm1.weight encoder.down.4.norm1.bias encoder.down.4.conv1.weight encoder.down.4.conv1.bias encoder.down.4.norm2.weight encoder.down.4.norm2.bias encoder.down.4.conv2.weight encoder.down.4.conv2.bias encoder.down.5.conv.weight encoder.down.5.conv.bias encoder.down.6.norm1.weight encoder.down.6.norm1.bias encoder.down.6.conv1.weight encoder.down.6.conv1.bias encoder.down.6.norm2.weight encoder.down.6.norm2.bias encoder.down.6.conv2.weight encoder.down.6.conv2.bias encoder.down.6.nin_shortcut.weight encoder.down.6.nin_shortcut.bias encoder.down.7.norm1.weight encoder.down.7.norm1.bias encoder.down.7.conv1.weight encoder.down.7.conv1.bias encoder.down.7.norm2.weight encoder.down.7.norm2.bias encoder.down.7.conv2.weight encoder.down.7.conv2.bias encoder.down.8.conv.weight encoder.down.8.conv.bias encoder.down.9.norm1.weight encoder.down.9.norm1.bias encoder.down.9.conv1.weight encoder.down.9.conv1.bias encoder.down.9.norm2.weight encoder.down.9.norm2.bias encoder.down.9.conv2.weight encoder.down.9.conv2.bias encoder.down.10.norm1.weight encoder.down.10.norm1.bias encoder.down.10.conv1.weight encoder.down.10.conv1.bias encoder.down.10.norm2.weight encoder.down.10.norm2.bias encoder.down.10.conv2.weight encoder.down.10.conv2.bias decoder.up.0.norm1.weight decoder.up.0.norm1.bias decoder.up.0.conv1.weight decoder.up.0.conv1.bias decoder.up.0.norm2.weight decoder.up.0.norm2.bias decoder.up.0.conv2.weight decoder.up.0.conv2.bias decoder.up.0.nin_shortcut.weight decoder.up.0.nin_shortcut.bias decoder.up.1.norm1.weight decoder.up.1.norm1.bias decoder.up.1.conv1.weight decoder.up.1.conv1.bias decoder.up.1.norm2.weight decoder.up.1.norm2.bias decoder.up.1.conv2.weight decoder.up.1.conv2.bias decoder.up.2.norm1.weight decoder.up.2.norm1.bias decoder.up.2.conv1.weight decoder.up.2.conv1.bias decoder.up.2.norm2.weight decoder.up.2.norm2.bias decoder.up.2.conv2.weight decoder.up.2.conv2.bias decoder.up.3.norm1.weight decoder.up.3.norm1.bias decoder.up.3.conv1.weight decoder.up.3.conv1.bias decoder.up.3.norm2.weight decoder.up.3.norm2.bias decoder.up.3.conv2.weight decoder.up.3.conv2.bias decoder.up.3.nin_shortcut.weight decoder.up.3.nin_shortcut.bias decoder.up.4.norm1.weight decoder.up.4.norm1.bias decoder.up.4.conv1.weight decoder.up.4.conv1.bias decoder.up.4.norm2.weight decoder.up.4.norm2.bias decoder.up.4.conv2.weight decoder.up.4.conv2.bias decoder.up.5.norm1.weight decoder.up.5.norm1.bias decoder.up.5.conv1.weight decoder.up.5.conv1.bias decoder.up.5.norm2.weight decoder.up.5.norm2.bias decoder.up.5.conv2.weight decoder.up.5.conv2.bias decoder.up.6.conv.weight decoder.up.6.conv.bias decoder.up.7.norm1.weight decoder.up.7.norm1.bias decoder.up.7.conv1.weight decoder.up.7.conv1.bias decoder.up.7.norm2.weight decoder.up.7.norm2.bias decoder.up.7.conv2.weight decoder.up.7.conv2.bias decoder.up.8.norm1.weight decoder.up.8.norm1.bias decoder.up.8.conv1.weight decoder.up.8.conv1.bias decoder.up.8.norm2.weight decoder.up.8.norm2.bias decoder.up.8.conv2.weight decoder.up.8.conv2.bias decoder.up.9.norm1.weight decoder.up.9.norm1.bias decoder.up.9.conv1.weight decoder.up.9.conv1.bias decoder.up.9.norm2.weight decoder.up.9.norm2.bias decoder.up.9.conv2.weight decoder.up.9.conv2.bias decoder.up.10.conv.weight decoder.up.10.conv.bias decoder.up.11.norm1.weight decoder.up.11.norm1.bias decoder.up.11.conv1.weight decoder.up.11.conv1.bias decoder.up.11.norm2.weight decoder.up.11.norm2.bias decoder.up.11.conv2.weight decoder.up.11.conv2.bias decoder.up.12.norm1.weight decoder.up.12.norm1.bias decoder.up.12.conv1.weight decoder.up.12.conv1.bias decoder.up.12.norm2.weight decoder.up.12.norm2.bias decoder.up.12.conv2.weight decoder.up.12.conv2.bias decoder.up.13.norm1.weight decoder.up.13.norm1.bias decoder.up.13.conv1.weight decoder.up.13.conv1.bias decoder.up.13.norm2.weight decoder.up.13.norm2.bias decoder.up.13.conv2.weight decoder.up.13.conv2.bias decoder.up.14.conv.weight decoder.up.14.conv.bias

Got 180 unexpected keys: encoder.down.0.block.0.conv1.bias encoder.down.0.block.0.conv1.weight encoder.down.0.block.0.conv2.bias encoder.down.0.block.0.conv2.weight

encoder.down.0.block.0.norm1.bias encoder.down.0.block.0.norm1.weight encoder.down.0.block.0.norm2.bias encoder.down.0.block.0.norm2.weight encoder.down.0.block.1.conv1.bias encoder.down.0.block.1.conv1.weight encoder.down.0.block.1.conv2.bias encoder.down.0.block.1.conv2.weight encoder.down.0.block.1.norm1.bias encoder.down.0.block.1.norm1.weight encoder.down.0.block.1.norm2.bias encoder.down.0.block.1.norm2.weight encoder.down.0.downsample.conv.bias encoder.down.0.downsample.conv.weight encoder.down.1.block.0.conv1.bias encoder.down.1.block.0.conv1.weight encoder.down.1.block.0.conv2.bias encoder.down.1.block.0.conv2.weight encoder.down.1.block.0.nin_shortcut.bias encoder.down.1.block.0.nin_shortcut.weight encoder.down.1.block.0.norm1.bias encoder.down.1.block.0.norm1.weight encoder.down.1.block.0.norm2.bias encoder.down.1.block.0.norm2.weight encoder.down.1.block.1.conv1.bias encoder.down.1.block.1.conv1.weight encoder.down.1.block.1.conv2.bias encoder.down.1.block.1.conv2.weight encoder.down.1.block.1.norm1.bias encoder.down.1.block.1.norm1.weight encoder.down.1.block.1.norm2.bias encoder.down.1.block.1.norm2.weight encoder.down.1.downsample.conv.bias encoder.down.1.downsample.conv.weight encoder.down.2.block.0.conv1.bias encoder.down.2.block.0.conv1.weight encoder.down.2.block.0.conv2.bias encoder.down.2.block.0.conv2.weight encoder.down.2.block.0.nin_shortcut.bias encoder.down.2.block.0.nin_shortcut.weight encoder.down.2.block.0.norm1.bias encoder.down.2.block.0.norm1.weight encoder.down.2.block.0.norm2.bias encoder.down.2.block.0.norm2.weight encoder.down.2.block.1.conv1.bias encoder.down.2.block.1.conv1.weight encoder.down.2.block.1.conv2.bias encoder.down.2.block.1.conv2.weight encoder.down.2.block.1.norm1.bias encoder.down.2.block.1.norm1.weight encoder.down.2.block.1.norm2.bias encoder.down.2.block.1.norm2.weight encoder.down.2.downsample.conv.bias encoder.down.2.downsample.conv.weight encoder.down.3.block.0.conv1.bias encoder.down.3.block.0.conv1.weight encoder.down.3.block.0.conv2.bias encoder.down.3.block.0.conv2.weight encoder.down.3.block.0.norm1.bias encoder.down.3.block.0.norm1.weight encoder.down.3.block.0.norm2.bias encoder.down.3.block.0.norm2.weight encoder.down.3.block.1.conv1.bias encoder.down.3.block.1.conv1.weight encoder.down.3.block.1.conv2.bias encoder.down.3.block.1.conv2.weight encoder.down.3.block.1.norm1.bias encoder.down.3.block.1.norm1.weight encoder.down.3.block.1.norm2.bias encoder.down.3.block.1.norm2.weight decoder.up.0.block.0.conv1.bias decoder.up.0.block.0.conv1.weight decoder.up.0.block.0.conv2.bias decoder.up.0.block.0.conv2.weight decoder.up.0.block.0.nin_shortcut.bias decoder.up.0.block.0.nin_shortcut.weight decoder.up.0.block.0.norm1.bias decoder.up.0.block.0.norm1.weight decoder.up.0.block.0.norm2.bias decoder.up.0.block.0.norm2.weight decoder.up.0.block.1.conv1.bias decoder.up.0.block.1.conv1.weight decoder.up.0.block.1.conv2.bias decoder.up.0.block.1.conv2.weight decoder.up.0.block.1.norm1.bias decoder.up.0.block.1.norm1.weight decoder.up.0.block.1.norm2.bias decoder.up.0.block.1.norm2.weight decoder.up.0.block.2.conv1.bias decoder.up.0.block.2.conv1.weight decoder.up.0.block.2.conv2.bias decoder.up.0.block.2.conv2.weight decoder.up.0.block.2.norm1.bias decoder.up.0.block.2.norm1.weight decoder.up.0.block.2.norm2.bias decoder.up.0.block.2.norm2.weight decoder.up.1.block.0.conv1.bias decoder.up.1.block.0.conv1.weight decoder.up.1.block.0.conv2.bias decoder.up.1.block.0.conv2.weight decoder.up.1.block.0.nin_shortcut.bias decoder.up.1.block.0.nin_shortcut.weight decoder.up.1.block.0.norm1.bias decoder.up.1.block.0.norm1.weight decoder.up.1.block.0.norm2.bias decoder.up.1.block.0.norm2.weight decoder.up.1.block.1.conv1.bias decoder.up.1.block.1.conv1.weight decoder.up.1.block.1.conv2.bias decoder.up.1.block.1.conv2.weight decoder.up.1.block.1.norm1.bias decoder.up.1.block.1.norm1.weight decoder.up.1.block.1.norm2.bias decoder.up.1.block.1.norm2.weight decoder.up.1.block.2.conv1.bias decoder.up.1.block.2.conv1.weight decoder.up.1.block.2.conv2.bias decoder.up.1.block.2.conv2.weight decoder.up.1.block.2.norm1.bias decoder.up.1.block.2.norm1.weight decoder.up.1.block.2.norm2.bias decoder.up.1.block.2.norm2.weight decoder.up.1.upsample.conv.bias decoder.up.1.upsample.conv.weight decoder.up.2.block.0.conv1.bias decoder.up.2.block.0.conv1.weight decoder.up.2.block.0.conv2.bias decoder.up.2.block.0.conv2.weight decoder.up.2.block.0.norm1.bias decoder.up.2.block.0.norm1.weight decoder.up.2.block.0.norm2.bias decoder.up.2.block.0.norm2.weight decoder.up.2.block.1.conv1.bias decoder.up.2.block.1.conv1.weight decoder.up.2.block.1.conv2.bias decoder.up.2.block.1.conv2.weight decoder.up.2.block.1.norm1.bias decoder.up.2.block.1.norm1.weight decoder.up.2.block.1.norm2.bias decoder.up.2.block.1.norm2.weight decoder.up.2.block.2.conv1.bias decoder.up.2.block.2.conv1.weight decoder.up.2.block.2.conv2.bias decoder.up.2.block.2.conv2.weight decoder.up.2.block.2.norm1.bias decoder.up.2.block.2.norm1.weight decoder.up.2.block.2.norm2.bias decoder.up.2.block.2.norm2.weight decoder.up.2.upsample.conv.bias decoder.up.2.upsample.conv.weight decoder.up.3.block.0.conv1.bias decoder.up.3.block.0.conv1.weight decoder.up.3.block.0.conv2.bias decoder.up.3.block.0.conv2.weight decoder.up.3.block.0.norm1.bias decoder.up.3.block.0.norm1.weight decoder.up.3.block.0.norm2.bias decoder.up.3.block.0.norm2.weight decoder.up.3.block.1.conv1.bias decoder.up.3.block.1.conv1.weight decoder.up.3.block.1.conv2.bias decoder.up.3.block.1.conv2.weight decoder.up.3.block.1.norm1.bias decoder.up.3.block.1.norm1.weight decoder.up.3.block.1.norm2.bias decoder.up.3.block.1.norm2.weight decoder.up.3.block.2.conv1.bias decoder.up.3.block.2.conv1.weight decoder.up.3.block.2.conv2.bias decoder.up.3.block.2.conv2.weight decoder.up.3.block.2.norm1.bias decoder.up.3.block.2.norm1.weight decoder.up.3.block.2.norm2.bias decoder.up.3.block.2.norm2.weight decoder.up.3.upsample.conv.bias decoder.up.3.upsample.conv.weight The unexpected keys are the one in the ae.sft file, the expected keys are the ones from using flat `Sequential[ResnetBlock]` instead of `ModuleList[ModuleList[ResnetBlock]]`. there's probably a way to remap the state_dict keys but it would be easier if we could modify the file