Closed kyakuno closed 2 years ago
MIT
エクスポートのための修正 ○ ldm/modules/diffusionmodules/model.py
class AttnBlock(nn.Module):
...
def forward(self, x):
...
w_ = w_ * (int(c)**(-0.5))
↓
class AttnBlock(nn.Module):
...
def forward(self, x):
...
w_ = w_ * (c**(-0.5))
○ ldm/modules/diffusionmodules/util.py
def checkpoint(func, inputs, params, flag):
...
if flag:
...
↓
def checkpoint(func, inputs, params, flag):
...
flag = False
if flag:
...
○ ldm/modules/encoders/modules.py
class BERTEmbedder(AbstractEncoder):
...
def forward(self, text):
if self.use_tknz_fn:
...
else:
...
z = self.transformer(tokens, return_embeddings=True)
↓
class BERTEmbedder(AbstractEncoder):
...
def forward(self, text):
if self.use_tknz_fn:
...
else:
...
print("------>")
import functools
from torch.autograd import Variable
self.transformer.forward = functools.partial(self.transformer.forward, return_embeddings=True)
self.transformer.cpu()
x = Variable(tokens.cpu())
torch.onnx.export(
self.transformer, x, 'transformer_emb.onnx',
input_names=["x"],
output_names=["out"],
dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
verbose=False, opset_version=12
)
print("<------")
class BERTEmbedder(AbstractEncoder):
...
def forward(self, text):
if self.use_tknz_fn:
...
else:
...
x = self.transformer(tokens, return_embeddings=True)
print("------>")
from torch.autograd import Variable
self.transformer.forward = self.transformer.forward2
self.transformer.cpu()
x = Variable(x.cpu())
torch.onnx.export(
self.transformer, x, 'transformer_attn.onnx',
input_names=["x"],
output_names=["out"],
dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
verbose=False, opset_version=12
)
print("<------")
○ ldm/modules/x_transformer.py
class TransformerWrapper(nn.Module):
def forward(
...
):
...
x = self.project_emb(x)
...
if num_mem > 0:
...
↓
class TransformerWrapper(nn.Module):
def forward(
...
):
...
x = self.project_emb(x)
return x
if num_mem > 0:
...
def forward2(self, x, mask=None, mems=None, **kwargs):
num_mem = 0
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
return_embeddings = True
out = self.to_logits(x) if not return_embeddings else x
return out
○ ldm/models/diffusion/ddpm.py
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
↓
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
print("------>")
from torch.autograd import Variable
xx = (Variable(x), Variable(t), Variable(cc))
torch.onnx.export(
self.diffusion_model, xx, 'diffusion_emb.onnx',
input_names=["x", "timesteps", "context"],
output_names=["h", "emb", "h0", "h1", "h2", "h3", "h4", "h5", "h6", "h7", "h8", "h9", "h10", "h11"],
dynamic_axes={'x' : {0 : 'n', 2:'h',3:'w'}, 'timesteps' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h' : {0 : 'n', 2:'h1',3:'w1'}, 'emb' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}},
verbose=False, opset_version=12
)
print("<------")
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
h = out[0]
emb = out[1]
hs = out[2:]
print("------>")
from torch.autograd import Variable
self.diffusion_model.forward = self.diffusion_model.forward2
xx = (
Variable(h), Variable(emb), Variable(cc),
Variable(hs[6]), Variable(hs[7]), Variable(hs[8]), Variable(hs[9]),
Variable(hs[10]), Variable(hs[11]))
torch.onnx.export(
self.diffusion_model, xx, 'diffusion_mid.onnx',
input_names=[
"h", "emb", "context", "h6", "h7", "h8", "h9", "h10", "h11"],
output_names=["out"],
dynamic_axes={'h' : {0 : 'n', 2:'h4',3:'w4'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}, 'out' : {0 : 'n', 2:'h2',3:'w2'}},
verbose=False, opset_version=12
)
print("<------")
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
h = out[0]
emb = out[1]
hs = out[2:]
h = self.diffusion_model.forward2(
h, emb, cc,
hs[6], hs[7], hs[8], hs[9], hs[10], hs[11])
print("------>")
from torch.autograd import Variable
self.diffusion_model.forward = self.diffusion_model.forward3
xx = (
Variable(h), Variable(emb), Variable(cc),
Variable(hs[0]), Variable(hs[1]), Variable(hs[2]), Variable(hs[3]),
Variable(hs[4]), Variable(hs[5]))
torch.onnx.export(
self.diffusion_model, xx, 'diffusion_out.onnx',
input_names=[
"h", "emb", "context", "h0", "h1", "h2", "h3", "h4", "h5"],
output_names=["out"],
dynamic_axes={'h' : {0 : 'n', 2:'h2',3:'w2'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'out' : {0 : 'n', 2:'h',3:'w'}},
verbose=False, opset_version=12
)
print("<------")
○ ldm/modules/diffusionmodules/openaimodel.py
class UNetModel(nn.Module):
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
...
h = self.middle_block(h, emb, context)
↓
class UNetModel(nn.Module):
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
...
h = self.middle_block(h, emb, context)
return h, emb, hs[0], hs[1], hs[2], hs[3], hs[4], hs[5], hs[6], hs[7], hs[8], hs[9], hs[10], hs[11]
def forward2(self, h, emb, context, h6, h7, h8, h9, h10, h11):
...
hs = [h6, h7, h8, h9, h10, h11]
for i, module in enumerate(self.output_blocks[:6]):
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
return h
def forward3(self, h, emb, context, h0, h1, h2, h3, h4, h5):
hs = [h0, h1, h2, h3, h4, h5]
for i, module in enumerate(self.output_blocks[6:]):
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
○ ldm/models/diffusion/ddpm.py
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if hasattr(self, "split_input_params"):
...
else:
if isinstance(self.first_stage_model, VQModelInterface):
...
else:
return self.first_stage_model.decode(z)
↓
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if hasattr(self, "split_input_params"):
...
else:
if isinstance(self.first_stage_model, VQModelInterface):
...
else:
print("------>")
self.first_stage_model.forward = self.first_stage_model.decode
from torch.autograd import Variable
x = Variable(z)
torch.onnx.export(
self.first_stage_model, x, 'autoencoder.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input' : {0 : 'n', 2:'h',3:'w'}, 'output' : {0 : 'n', 2:'ho',3:'wo'}},
verbose=False, opset_version=11
)
print("<------")
○ ldm/models/diffusion/ddpm.py
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
...
c = model.cond_stage_model.encode(batch["masked_image"])
↓
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
...
print("------>")
model.cond_stage_model.forward = model.cond_stage_model.encode
from torch.autograd import Variable
x = Variable(batch["masked_image"])
torch.onnx.export(
model.cond_stage_model, x, 'cond_stage_model.onnx',
input_names=["masked_image"],
output_names=["out"],
dynamic_axes={'masked_image' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
verbose=False, opset_version=11
)
print("<------")
○ ldm/models/diffusion/ddpm.py
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
...
if hasattr(self, "split_input_params"):
..
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
↓
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
...
if hasattr(self, "split_input_params"):
...
else:
if isinstance(self.first_stage_model, VQModelInterface):
print("------>")
from torch.autograd import Variable
x = Variable(z)
self.first_stage_model.forward = self.first_stage_model.decode
torch.onnx.export(
self.first_stage_model, x, 'autoencoder.onnx',
input_names=["z"],
output_names=["dec"],
dynamic_axes={'z' : {2 : 'h', 3 : 'w'}, 'dec' : {2 : 'oh', 3 : 'ow'}},
verbose=False, opset_version=12
)
print("<------")
○ ldm/models/diffusion/ddpm.py
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
↓
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'concat':
print("------>")
from torch.autograd import Variable
xx = (Variable(xc), Variable(t))
torch.onnx.export(
self.diffusion_model, xx, 'diffusion_model.onnx',
input_names=["xc", "t"],
output_names=["out"],
dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
verbose=False, opset_version=12
)
print("<------")
○ ldm/models/diffusion/ddpm.py
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
...
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
...
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
↓
class LatentDiffusion(DDPM):
...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
...
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
...
if isinstance(self.first_stage_model, VQModelInterface):
print("------>")
from torch.autograd import Variable
x = Variable(z[:, :, :, :, 0])
self.first_stage_model.forward = self.first_stage_model.decode
torch.onnx.export(
self.first_stage_model, x, 'first_stage_decode.onnx',
input_names=["x"],
output_names=["out"],
dynamic_axes={'x' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
verbose=False, opset_version=12
)
print("<------")
○ ldm/modules/diffusionmodules/openaimodel.py
class QKVAttentionLegacy(nn.Module):
...
def forward(self, qkv):
...
scale = 1 / math.sqrt(math.sqrt(ch))
↓
class QKVAttentionLegacy(nn.Module):
...
def forward(self, qkv):
...
scale = 1 / ((ch**0.5)**0.5)
○ ldm/models/diffusion/ddpm.py
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
↓
class DiffusionWrapper(pl.LightningModule):
...
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
...
elif self.conditioning_key == 'concat':
print("------>")
from torch.autograd import Variable
xx = (Variable(xc), Variable(t))
torch.onnx.export(
self.diffusion_model, xx, 'diffusion_model.onnx',
input_names=["xc", "t"],
output_names=["out"],
dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
verbose=False, opset_version=12
)
print("<------")
https://github.com/CompVis/latent-diffusion