zhiyuanyou / UniAD

[NeurIPS 2022 Spotlight] A Unified Model for Multi-class Anomaly Detection
Apache License 2.0
250 stars 28 forks source link

Who can help me to export this model to onnx? #28

Closed Kaze816 closed 1 year ago

Kaze816 commented 1 year ago

when i try to use "torch.onnx.export(model, input, "test.onnx")", I get nothing but lots of errors. Who can show me the right codes? And, if it is possible to let the input just a n3224*224 tensor, without mask when infering ?

zhiyuanyou commented 1 year ago

The neighbor-masked attention could not be export to onnx.

You can remove this part and train the model again.

Kaze816 commented 10 months ago

I try to do this work like flowing, it can run, but still fail to export model as onnx.

#########################uniad.py where remove marked by ZFD-- #####################################

import copy import math import os import random from typing import Optional

import numpy as np import torch import torch.nn.functional as F from einops import rearrange from models.initializer import initialize_from_cfg from torch import Tensor, nn

class UniAD(nn.Module): def init( self, inplanes, instrides, feature_size, feature_jitter,

neighbor_mask,

    hidden_dim,
    pos_embed_type,
    save_recon,
    initializer,
    **kwargs,
):
    super().__init__()
    assert isinstance(inplanes, list) and len(inplanes) == 1
    assert isinstance(instrides, list) and len(instrides) == 1
    self.feature_size = feature_size
    self.num_queries = feature_size[0] * feature_size[1]
    self.feature_jitter = feature_jitter
    self.pos_embed = build_position_embedding(
        pos_embed_type, feature_size, hidden_dim
    )
    self.save_recon = save_recon

    # self.transformer = Transformer(hidden_dim, feature_size, neighbor_mask, **kwargs)
    self.transformer = Transformer(hidden_dim, feature_size, **kwargs)

    self.input_proj = nn.Linear(inplanes[0], hidden_dim)
    self.output_proj = nn.Linear(hidden_dim, inplanes[0])

    self.upsample = nn.UpsamplingBilinear2d(scale_factor=instrides[0])

    initialize_from_cfg(self, initializer)

def add_jitter(self, feature_tokens, scale, prob):
    if random.uniform(0, 1) <= prob:
        num_tokens, batch_size, dim_channel = feature_tokens.shape
        feature_norms = (
            feature_tokens.norm(dim=2).unsqueeze(2) / dim_channel
        )  # (H x W) x B x 1
        jitter = torch.randn((num_tokens, batch_size, dim_channel)).cuda()
        jitter = jitter * feature_norms * scale
        feature_tokens = feature_tokens + jitter
    return feature_tokens

def forward(self, input):
    feature_align = input["feature_align"]  # B x C X H x W
    feature_tokens = rearrange(
        feature_align, "b c h w -> (h w) b c"
    )  # (H x W) x B x C
    if self.training and self.feature_jitter:
        feature_tokens = self.add_jitter(
            feature_tokens, self.feature_jitter.scale, self.feature_jitter.prob
        )
    feature_tokens = self.input_proj(feature_tokens)  # (H x W) x B x C
    pos_embed = self.pos_embed(feature_tokens)  # (H x W) x C
    output_decoder, _ = self.transformer(
        feature_tokens, pos_embed
    )  # (H x W) x B x C
    feature_rec_tokens = self.output_proj(output_decoder)  # (H x W) x B x C
    feature_rec = rearrange(
        feature_rec_tokens, "(h w) b c -> b c h w", h=self.feature_size[0]
    )  # B x C X H x W

    if not self.training and self.save_recon:
        clsnames = input["clsname"]
        filenames = input["filename"]
        for clsname, filename, feat_rec in zip(clsnames, filenames, feature_rec):
            filedir, filename = os.path.split(filename)
            _, defename = os.path.split(filedir)
            filename_, _ = os.path.splitext(filename)
            save_dir = os.path.join(self.save_recon.save_dir, clsname, defename)
            os.makedirs(save_dir, exist_ok=True)
            feature_rec_np = feat_rec.detach().cpu().numpy()
            np.save(os.path.join(save_dir, filename_ + ".npy"), feature_rec_np)

    pred = torch.sqrt(
        torch.sum((feature_rec - feature_align) ** 2, dim=1, keepdim=True)
    )  # B x 1 x H x W
    pred = self.upsample(pred)  # B x 1 x H x W
    return {
        "feature_rec": feature_rec,
        "feature_align": feature_align,
        "pred": pred,
    }

class Transformer(nn.Module): def init( self, hidden_dim, feature_size,

neighbor_mask,

    nhead,
    num_encoder_layers,
    num_decoder_layers,
    dim_feedforward,
    dropout=0.1,
    activation="relu",
    normalize_before=False,
    return_intermediate_dec=False,
):
    super().__init__()
    self.feature_size = feature_size
    # self.neighbor_mask = neighbor_mask

    encoder_layer = TransformerEncoderLayer(
        hidden_dim, nhead, dim_feedforward, dropout, activation, normalize_before
    )
    encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
    self.encoder = TransformerEncoder(
        encoder_layer, num_encoder_layers, encoder_norm
    )

    decoder_layer = TransformerDecoderLayer(
        hidden_dim,
        feature_size,
        nhead,
        dim_feedforward,
        dropout,
        activation,
        normalize_before,
    )
    decoder_norm = nn.LayerNorm(hidden_dim)
    self.decoder = TransformerDecoder(
        decoder_layer,
        num_decoder_layers,
        decoder_norm,
        return_intermediate=return_intermediate_dec,
    )

    self.hidden_dim = hidden_dim
    self.nhead = nhead

def generate_mask(self, feature_size, neighbor_size):
    """
    Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    h, w = feature_size
    hm, wm = neighbor_size
    mask = torch.ones(h, w, h, w)
    for idx_h1 in range(h):
        for idx_w1 in range(w):
            idx_h2_start = max(idx_h1 - hm // 2, 0)
            idx_h2_end = min(idx_h1 + hm // 2 + 1, h)
            idx_w2_start = max(idx_w1 - wm // 2, 0)
            idx_w2_end = min(idx_w1 + wm // 2 + 1, w)
            mask[
                idx_h1, idx_w1, idx_h2_start:idx_h2_end, idx_w2_start:idx_w2_end
            ] = 0
    mask = mask.view(h * w, h * w)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
        .cuda()
    )
    return mask

def forward(self, src, pos_embed):
    _, batch_size, _ = src.shape
    pos_embed = torch.cat(
        [pos_embed.unsqueeze(1)] * batch_size, dim=1
    )  # (H X W) x B x C

    # if self.neighbor_mask:
    #     mask = self.generate_mask(
    #         self.feature_size, self.neighbor_mask.neighbor_size
    #     )
    #     mask_enc = mask if self.neighbor_mask.mask[0] else None
    #     mask_dec1 = mask if self.neighbor_mask.mask[1] else None
    #     mask_dec2 = mask if self.neighbor_mask.mask[2] else None
    # else:
    #     mask_enc = mask_dec1 = mask_dec2 = None

    # output_encoder = self.encoder(src, mask=mask_enc, pos=pos_embed)  # (H X W) x B x C
    output_encoder = self.encoder(src, pos=pos_embed)  # (H X W) x B x C
    # output_decoder = self.decoder(
    #     output_encoder,
    #     tgt_mask=mask_dec1,
    #     memory_mask=mask_dec2,
    #     pos=pos_embed,
    # )  # (H X W) x B x C
    output_decoder = self.decoder(output_encoder,pos=pos_embed,)  # (H X W) x B x C

    return output_decoder, output_encoder

class TransformerEncoder(nn.Module): def init(self, encoder_layer, num_layers, norm=None): super().init() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm

def forward(
    self,
    src,
    mask: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    output = src

    for layer in self.layers:
        output = layer(
            output,
            src_mask=mask,
            src_key_padding_mask=src_key_padding_mask,
            pos=pos,
        )

    if self.norm is not None:
        output = self.norm(output)

    return output

class TransformerDecoder(nn.Module): def init(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().init() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate

def forward(
    self,
    memory,
    tgt_mask: Optional[Tensor] = None,
    memory_mask: Optional[Tensor] = None,
    tgt_key_padding_mask: Optional[Tensor] = None,
    memory_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    output = memory

    intermediate = []

    for layer in self.layers:
        output = layer(
            output,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            pos=pos,
        )
        if self.return_intermediate:
            intermediate.append(self.norm(output))

    if self.norm is not None:
        output = self.norm(output)
        if self.return_intermediate:
            intermediate.pop()
            intermediate.append(output)

    if self.return_intermediate:
        return torch.stack(intermediate)

    return output

class TransformerEncoderLayer(nn.Module): def init( self, hidden_dim, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().init()

self.self_attn = nn.MultiheadAttention(hidden_dim, nhead, dropout=dropout) # 难道就是这里的注意力机制? ZFD

    # Implementation of Feedforward model
    self.linear1 = nn.Linear(hidden_dim, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, hidden_dim)

    self.norm1 = nn.LayerNorm(hidden_dim)
    self.norm2 = nn.LayerNorm(hidden_dim)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
    self.normalize_before = normalize_before

def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    return tensor if pos is None else tensor + pos

def forward_post(
    self,
    src,
    src_mask: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    # ZFD --
    # q = k = self.with_pos_embed(src, pos)
    # src2 = self.self_attn(
    #     q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
    # )[0]
    # src = src + self.dropout1(src2)
    # src = self.norm1(src)
    # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
    # src = src + self.dropout2(src2)
    src = self.norm2(src)
    return src

def forward_pre(
    self,
    src,
    src_mask: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    # ZFD --
    # src2 = self.norm1(src)
    # q = k = self.with_pos_embed(src2, pos)
    # src2 = self.self_attn(
    #     q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
    # )[0]
    # src = src + self.dropout1(src2)
    # src2 = self.norm2(src)
    # src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
    # src = src + self.dropout2(src2)
    return src

def forward(
    self,
    src,
    src_mask: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    if self.normalize_before:
        return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
    return self.forward_post(src, src_mask, src_key_padding_mask, pos)

class TransformerDecoderLayer(nn.Module): def init( self, hidden_dim, feature_size, nhead, dim_feedforward, dropout=0.1, activation="relu", normalize_before=False, ): super().init() num_queries = feature_size[0] * feature_size[1] self.learned_embed = nn.Embedding(num_queries, hidden_dim) # (H x W) x C

    # ZFD --
    # self.self_attn = nn.MultiheadAttention(hidden_dim, nhead, dropout=dropout)
    # self.multihead_attn = nn.MultiheadAttention(hidden_dim, nhead, dropout=dropout)

    # Implementation of Feedforward model
    self.linear1 = nn.Linear(hidden_dim, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, hidden_dim)

    self.norm1 = nn.LayerNorm(hidden_dim)
    self.norm2 = nn.LayerNorm(hidden_dim)
    self.norm3 = nn.LayerNorm(hidden_dim)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
    self.normalize_before = normalize_before

def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    return tensor if pos is None else tensor + pos

def forward_post(
    self,
    out,
    memory,
    tgt_mask: Optional[Tensor] = None,
    memory_mask: Optional[Tensor] = None,
    tgt_key_padding_mask: Optional[Tensor] = None,
    memory_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    _, batch_size, _ = memory.shape
    tgt = self.learned_embed.weight
    tgt = torch.cat([tgt.unsqueeze(1)] * batch_size, dim=1)  # (H X W) x B x C

    # ZFD --
    # tgt2 = self.self_attn(
    #     query=self.with_pos_embed(tgt, pos),
    #     key=self.with_pos_embed(memory, pos),
    #     value=memory,
    #     attn_mask=tgt_mask,
    #     key_padding_mask=tgt_key_padding_mask,
    # )[0]
    # tgt = tgt + self.dropout1(tgt2)
    # tgt = self.norm1(tgt)

    # tgt2 = self.multihead_attn(
    #     query=self.with_pos_embed(tgt, pos),
    #     key=self.with_pos_embed(out, pos),
    #     value=out,
    #     attn_mask=memory_mask,
    #     key_padding_mask=memory_key_padding_mask,
    # )[0]
    # tgt = tgt + self.dropout2(tgt2)
    # tgt = self.norm2(tgt)

    # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
    # tgt = tgt + self.dropout3(tgt2)
    tgt = self.norm3(tgt)
    return tgt

def forward_pre(
    self,
    out,
    memory,
    tgt_mask: Optional[Tensor] = None,
    memory_mask: Optional[Tensor] = None,
    tgt_key_padding_mask: Optional[Tensor] = None,
    memory_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    _, batch_size, _ = memory.shape
    tgt = self.learned_embed.weight
    tgt = torch.cat([tgt.unsqueeze(1)] * batch_size, dim=1)  # (H X W) x B x C

    tgt = self.norm3(tgt) # 以一抵多
    # ZFD --
    # tgt2 = self.norm1(tgt)
    # tgt2 = self.self_attn(
    #     query=self.with_pos_embed(tgt2, pos),
    #     key=self.with_pos_embed(memory, pos),
    #     value=memory,
    #     attn_mask=tgt_mask,
    #     key_padding_mask=tgt_key_padding_mask,
    # )[0]
    # tgt = tgt + self.dropout1(tgt2)

    # tgt2 = self.norm2(tgt)
    # tgt2 = self.multihead_attn(
    #     query=self.with_pos_embed(tgt2, pos),
    #     key=self.with_pos_embed(out, pos),
    #     value=out,
    #     attn_mask=memory_mask,
    #     key_padding_mask=memory_key_padding_mask,
    # )[0]
    # tgt = tgt + self.dropout2(tgt2)

    # tgt2 = self.norm3(tgt)
    # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
    # tgt = tgt + self.dropout3(tgt2)
    return tgt

def forward(
    self,
    out,
    memory,
    tgt_mask: Optional[Tensor] = None,
    memory_mask: Optional[Tensor] = None,
    tgt_key_padding_mask: Optional[Tensor] = None,
    memory_key_padding_mask: Optional[Tensor] = None,
    pos: Optional[Tensor] = None,
):
    if self.normalize_before:
        return self.forward_pre(
            out,
            memory,
            tgt_mask,
            memory_mask,
            tgt_key_padding_mask,
            memory_key_padding_mask,
            pos,
        )
    return self.forward_post(
        out,
        memory,
        tgt_mask,
        memory_mask,
        tgt_key_padding_mask,
        memory_key_padding_mask,
        pos,
    )

def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """

def __init__(
    self,
    feature_size,
    num_pos_feats=128,
    temperature=10000,
    normalize=False,
    scale=None,
):
    super().__init__()
    self.feature_size = feature_size
    self.num_pos_feats = num_pos_feats
    self.temperature = temperature
    self.normalize = normalize
    if scale is not None and normalize is False:
        raise ValueError("normalize should be True if scale is passed")
    if scale is None:
        scale = 2 * math.pi
    self.scale = scale

def forward(self, tensor):
    not_mask = torch.ones((self.feature_size[0], self.feature_size[1]))  # H x W
    y_embed = not_mask.cumsum(0, dtype=torch.float32)
    x_embed = not_mask.cumsum(1, dtype=torch.float32)
    if self.normalize:
        eps = 1e-6
        y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale

    dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32)
    dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack(
        (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
    ).flatten(2)
    pos_y = torch.stack(
        (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
    ).flatten(2)
    pos = torch.cat((pos_y, pos_x), dim=2).flatten(0, 1)  # (H X W) X C
    return pos.to(tensor.device)

class PositionEmbeddingLearned(nn.Module): """ Absolute pos embedding, learned. """

def __init__(self, feature_size, num_pos_feats=128):
    super().__init__()
    self.feature_size = feature_size  # H, W
    self.row_embed = nn.Embedding(feature_size[0], num_pos_feats)
    self.col_embed = nn.Embedding(feature_size[1], num_pos_feats)
    self.reset_parameters()

def reset_parameters(self):
    nn.init.uniform_(self.row_embed.weight)
    nn.init.uniform_(self.col_embed.weight)

def forward(self, tensor):
    i = torch.arange(self.feature_size[1], device=tensor.device)  # W
    j = torch.arange(self.feature_size[0], device=tensor.device)  # H
    x_emb = self.col_embed(i)  # W x C // 2
    y_emb = self.row_embed(j)  # H x C // 2
    pos = torch.cat(
        [
            torch.cat(
                [x_emb.unsqueeze(0)] * self.feature_size[0], dim=0
            ),  # H x W x C // 2
            torch.cat(
                [y_emb.unsqueeze(1)] * self.feature_size[1], dim=1
            ),  # H x W x C // 2
        ],
        dim=-1,
    ).flatten(
        0, 1
    )  # (H X W) X C
    return pos

def build_position_embedding(pos_embed_type, feature_size, hidden_dim): if pos_embed_type in ("v2", "sine"):

TODO find a better way of exposing other arguments

    pos_embed = PositionEmbeddingSine(feature_size, hidden_dim // 2, normalize=True)
elif pos_embed_type in ("v3", "learned"):
    pos_embed = PositionEmbeddingLearned(feature_size, hidden_dim // 2)
else:
    raise ValueError(f"not supported {pos_embed_type}")
return pos_embed

########################### the error:

Kaze816 commented 10 months ago

torch.onnx.export(model, input, "test.onnx") Traceback (most recent call last): File "", line 1, in File "d:\pyenv\Torch116gpu\lib\site-packages\torch\onnx\utils.py", line 504, in export _export( File "d:\pyenv\Torch116gpu\lib\site-packages\torch\onnx\utils.py", line 1529, in _export graph, params_dict, torch_out = _model_to_graph( File "d:\pyenv\Torch116gpu\lib\site-packages\torch\onnx\utils.py", line 1111, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\onnx\utils.py", line 987, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\onnx\utils.py", line 891, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "d:\pyenv\Torch116gpu\lib\site-packages\torch\jit_trace.py", line 1184, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, kwargs) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\jit_trace.py", line 127, in forward graph, out = torch._C._create_graph_by_tracing( File "d:\pyenv\Torch116gpu\lib\site-packages\torch\jit_trace.py", line 118, in wrapper outs.append(self.inner(trace_inputs)) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\nn\modules\module.py", line 1182, in _slow_forward result = self.forward(*input, kwargs) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\nn\parallel\distributed.py", line 1040, in forward output = self._run_ddp_forward(*inputs, *kwargs) File "d:\pyenv\Torch116gpu\lib\site-packages\torch\nn\parallel\distributed.py", line 1000, in _run_ddp_forward return module_to_run(inputs[0], **kwargs[0]) IndexError: tuple index out of range