axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2.02k stars 320 forks source link

ADD LatentDiffusionModel #830

Closed kyakuno closed 2 years ago

kyakuno commented 2 years ago

https://github.com/CompVis/latent-diffusion

kyakuno commented 2 years ago

MIT

ooe1123 commented 2 years ago

エクスポートのための修正 ○ ldm/modules/diffusionmodules/model.py

class AttnBlock(nn.Module):
    ...
    def forward(self, x):
        ...
        w_ = w_ * (int(c)**(-0.5))

class AttnBlock(nn.Module):
    ...
    def forward(self, x):
        ...
        w_ = w_ * (c**(-0.5))

○ ldm/modules/diffusionmodules/util.py

def checkpoint(func, inputs, params, flag):
    ...
    if flag:
        ...

def checkpoint(func, inputs, params, flag):
    ...
    flag = False
    if flag:
        ...
ooe1123 commented 2 years ago

○ ldm/modules/encoders/modules.py

class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        z = self.transformer(tokens, return_embeddings=True)

class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        print("------>")
        import functools
        from torch.autograd import Variable
        self.transformer.forward = functools.partial(self.transformer.forward, return_embeddings=True)
        self.transformer.cpu()
        x = Variable(tokens.cpu())
        torch.onnx.export(
            self.transformer, x, 'transformer_emb.onnx',
            input_names=["x"],
            output_names=["out"],
            dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
            verbose=False, opset_version=12
        )
        print("<------")
class BERTEmbedder(AbstractEncoder):
    ...
    def forward(self, text):
        if self.use_tknz_fn:
            ...
        else:
            ...
        x = self.transformer(tokens, return_embeddings=True)

        print("------>")
        from torch.autograd import Variable
        self.transformer.forward = self.transformer.forward2
        self.transformer.cpu()
        x = Variable(x.cpu())
        torch.onnx.export(
            self.transformer, x, 'transformer_attn.onnx',
            input_names=["x"],
            output_names=["out"],
            dynamic_axes={'x' : {0 : 'n'}, 'out' : {0 : 'n'}},
            verbose=False, opset_version=12
        )
        print("<------")

○ ldm/modules/x_transformer.py

class TransformerWrapper(nn.Module):
    def forward(
      ...
    ):
        ...
        x = self.project_emb(x)
        ...
        if num_mem > 0:
            ...

class TransformerWrapper(nn.Module):
    def forward(
      ...
    ):
        ...
        x = self.project_emb(x)

        return x

        if num_mem > 0:
            ...

    def forward2(self, x, mask=None, mems=None, **kwargs):
        num_mem = 0
        x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
        x = self.norm(x)

        mem, x = x[:, :num_mem], x[:, num_mem:]

        return_embeddings = True
        out = self.to_logits(x) if not return_embeddings else x

        return out
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(x), Variable(t), Variable(cc))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_emb.onnx',
                input_names=["x", "timesteps", "context"],
                output_names=["h", "emb", "h0", "h1", "h2", "h3", "h4", "h5", "h6", "h7", "h8", "h9", "h10", "h11"],
                dynamic_axes={'x' : {0 : 'n', 2:'h',3:'w'}, 'timesteps' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h' : {0 : 'n', 2:'h1',3:'w1'}, 'emb' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}},
                verbose=False, opset_version=12
            )
            print("<------")
class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
            h = out[0]
            emb = out[1]
            hs = out[2:]

            print("------>")
            from torch.autograd import Variable
            self.diffusion_model.forward = self.diffusion_model.forward2
            xx = (
                Variable(h), Variable(emb), Variable(cc), 
                Variable(hs[6]), Variable(hs[7]), Variable(hs[8]), Variable(hs[9]), 
                Variable(hs[10]), Variable(hs[11]))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_mid.onnx',
                input_names=[
                    "h", "emb", "context", "h6", "h7", "h8", "h9", "h10", "h11"],
                output_names=["out"],
                dynamic_axes={'h' : {0 : 'n', 2:'h4',3:'w4'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h6' : {0 : 'n', 2:'h3',3:'w3'}, 'h7' : {0 : 'n', 2:'h3',3:'w3'}, 'h8' : {0 : 'n', 2:'h3',3:'w3'}, 'h9' : {0 : 'n', 2:'h4',3:'w4'}, 'h10' : {0 : 'n', 2:'h4',3:'w4'}, 'h11' : {0 : 'n', 2:'h4',3:'w4'}, 'out' : {0 : 'n', 2:'h2',3:'w2'}},
                verbose=False, opset_version=12
            )
            print("<------")
class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
            h = out[0]
            emb = out[1]
            hs = out[2:]
            h = self.diffusion_model.forward2(
                h, emb, cc, 
                hs[6], hs[7], hs[8], hs[9], hs[10], hs[11])

            print("------>")
            from torch.autograd import Variable
            self.diffusion_model.forward = self.diffusion_model.forward3
            xx = (
                Variable(h), Variable(emb), Variable(cc), 
                Variable(hs[0]), Variable(hs[1]), Variable(hs[2]), Variable(hs[3]), 
                Variable(hs[4]), Variable(hs[5]))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_out.onnx',
                input_names=[
                    "h", "emb", "context", "h0", "h1", "h2", "h3", "h4", "h5"],
                output_names=["out"],
                dynamic_axes={'h' : {0 : 'n', 2:'h2',3:'w2'}, 'emb' : {0 : 'n'}, 'context' : {0 : 'n'}, 'h0' : {0 : 'n', 2:'h1',3:'w1'}, 'h1' : {0 : 'n', 2:'h1',3:'w1'}, 'h2' : {0 : 'n', 2:'h1',3:'w1'}, 'h3' : {0 : 'n', 2:'h2',3:'w2'}, 'h4' : {0 : 'n', 2:'h2',3:'w2'}, 'h5' : {0 : 'n', 2:'h2',3:'w2'}, 'out' : {0 : 'n', 2:'h',3:'w'}},
                verbose=False, opset_version=12
            )
            print("<------")

○ ldm/modules/diffusionmodules/openaimodel.py

class UNetModel(nn.Module):
    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        ...
        h = self.middle_block(h, emb, context)

class UNetModel(nn.Module):
    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        ...
        h = self.middle_block(h, emb, context)
        return h, emb, hs[0], hs[1], hs[2], hs[3], hs[4], hs[5], hs[6], hs[7], hs[8], hs[9], hs[10], hs[11]

    def forward2(self, h, emb, context, h6, h7, h8, h9, h10, h11):
        ...
        hs = [h6, h7, h8, h9, h10, h11]
        for i, module in enumerate(self.output_blocks[:6]):
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        return h

    def forward3(self, h, emb, context, h0, h1, h2, h3, h4, h5):
        hs = [h0, h1, h2, h3, h4, h5]
        for i, module in enumerate(self.output_blocks[6:]):
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)

        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                ...
            else:
                return self.first_stage_model.decode(z)

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                ...
            else:
                print("------>")
                self.first_stage_model.forward = self.first_stage_model.decode
                from torch.autograd import Variable
                x = Variable(z)
                torch.onnx.export(
                    self.first_stage_model, x, 'autoencoder.onnx',
                    input_names=["input"],
                    output_names=["output"],
                    dynamic_axes={'input' : {0 : 'n', 2:'h',3:'w'}, 'output' : {0 : 'n', 2:'ho',3:'wo'}},
                    verbose=False, opset_version=11
                )
                print("<------")
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

    with torch.no_grad():
        with model.ema_scope():
            for image, mask in tqdm(zip(images, masks)):
                ...
                c = model.cond_stage_model.encode(batch["masked_image"])

    with torch.no_grad():
        with model.ema_scope():
            for image, mask in tqdm(zip(images, masks)):
                ...
                print("------>")
                model.cond_stage_model.forward = model.cond_stage_model.encode
                from torch.autograd import Variable
                x = Variable(batch["masked_image"])
                torch.onnx.export(
                    model.cond_stage_model, x, 'cond_stage_model.onnx',
                    input_names=["masked_image"],
                    output_names=["out"],
                    dynamic_axes={'masked_image' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
                    verbose=False, opset_version=11
                )
                print("<------")
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            ..
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            ...
        else:
            if isinstance(self.first_stage_model, VQModelInterface):
                print("------>")
                from torch.autograd import Variable
                x = Variable(z)
                self.first_stage_model.forward = self.first_stage_model.decode
                torch.onnx.export(
                    self.first_stage_model, x, 'autoencoder.onnx',
                    input_names=["z"],
                    output_names=["dec"],
                    dynamic_axes={'z' : {2 : 'h', 3 : 'w'}, 'dec' : {2 : 'oh', 3 : 'ow'}},
                    verbose=False, opset_version=12
                )
                print("<------")
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(xc), Variable(t))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_model.onnx',
                input_names=["xc", "t"],
                output_names=["out"],
                dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
                verbose=False, opset_version=12
            )
            print("<------")
ooe1123 commented 2 years ago

○ ldm/models/diffusion/ddpm.py

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            if self.split_input_params["patch_distributed_vq"]:
                ...
                if isinstance(self.first_stage_model, VQModelInterface):
                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
                                                                 force_not_quantize=predict_cids or force_not_quantize)
                                   for i in range(z.shape[-1])]

class LatentDiffusion(DDPM):
    ...
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        ...
        if hasattr(self, "split_input_params"):
            if self.split_input_params["patch_distributed_vq"]:
                ...
                if isinstance(self.first_stage_model, VQModelInterface):
                    print("------>")
                    from torch.autograd import Variable
                    x = Variable(z[:, :, :, :, 0])
                    self.first_stage_model.forward = self.first_stage_model.decode
                    torch.onnx.export(
                        self.first_stage_model, x, 'first_stage_decode.onnx',
                        input_names=["x"],
                        output_names=["out"],
                        dynamic_axes={'x' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'oh', 3 : 'ow'}},
                        verbose=False, opset_version=12
                    )
                    print("<------")
ooe1123 commented 2 years ago

○ ldm/modules/diffusionmodules/openaimodel.py

class QKVAttentionLegacy(nn.Module):
    ...
    def forward(self, qkv):
        ...
        scale = 1 / math.sqrt(math.sqrt(ch))

class QKVAttentionLegacy(nn.Module):
    ...
    def forward(self, qkv):
        ...
        scale = 1 / ((ch**0.5)**0.5)

○ ldm/models/diffusion/ddpm.py

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)

class DiffusionWrapper(pl.LightningModule):
    ...
    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        ...
        elif self.conditioning_key == 'concat':
            print("------>")
            from torch.autograd import Variable
            xx = (Variable(xc), Variable(t))
            torch.onnx.export(
                self.diffusion_model, xx, 'diffusion_model.onnx',
                input_names=["xc", "t"],
                output_names=["out"],
                dynamic_axes={'xc' : {2 : 'h', 3 : 'w'}, 'out' : {2 : 'h', 3 : 'w'}},
                verbose=False, opset_version=12
            )
            print("<------")