shenghanlin / SeismicFoundationModel

Seismic Foundation Model
67 stars 16 forks source link

output size equal to input size #9

Closed ramdhan1989 closed 1 month ago

ramdhan1989 commented 1 month ago

how to get output size same as input size. my input 224 by 224 but the result gives me 196 by 768 `from functools import partial

import torch import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block

from util.pos_embed import get_2d_sincos_pos_embed

class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def init(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().init()

    # --------------------------------------------------------------------------
    # MAE encoder specifics
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
    num_patches = self.patch_embed.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding
    self.in_chans = in_chans
    self.blocks = nn.ModuleList([
        Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
        for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    # --------------------------------------------------------------------------

    # --------------------------------------------------------------------------
    # MAE decoder specifics
    self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

    self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

    self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

    self.decoder_blocks = nn.ModuleList([
        Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
        for i in range(decoder_depth)])

    self.decoder_norm = norm_layer(decoder_embed_dim)
    self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
    # --------------------------------------------------------------------------

    self.norm_pix_loss = norm_pix_loss

    self.initialize_weights()

def initialize_weights(self):
    # initialization
    # initialize (and freeze) pos_embed by sin-cos embedding
    pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
    self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
    self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

    # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
    w = self.patch_embed.proj.weight.data
    torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
    torch.nn.init.normal_(self.cls_token, std=.02)
    torch.nn.init.normal_(self.mask_token, std=.02)

    # initialize nn.Linear and nn.LayerNorm
    self.apply(self._init_weights)

def _init_weights(self, m):
    if isinstance(m, nn.Linear):
        # we use xavier_uniform following official JAX ViT:
        torch.nn.init.xavier_uniform_(m.weight)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

def patchify(self, imgs):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = self.patch_embed.patch_size[0]
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_chans))
    return x

def unpatchify(self, x):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = self.patch_embed.patch_size[0]
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]

    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs

def random_masking(self, x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

def forward_encoder(self, x, mask_ratio):
    # embed patches
    x = self.patch_embed(x)

    # add pos embed w/o cls token
    x = x + self.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    x, mask, ids_restore = self.random_masking(x, mask_ratio)

    # append cls token
    cls_token = self.cls_token + self.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # apply Transformer blocks
    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)

    return x, mask, ids_restore

def forward_decoder(self, x, ids_restore):
    # embed tokens
    x = self.decoder_embed(x)

    # append mask tokens to sequence
    mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

    # add pos embed
    x = x + self.decoder_pos_embed

    # apply Transformer blocks
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    # predictor projection
    x = self.decoder_pred(x)

    # remove cls token
    x = x[:, 1:, :]

    return x

def forward_loss(self, imgs, pred, mask):
    """
    imgs: [N, 3, H, W]
    pred: [N, L, p*p*3]
    mask: [N, L], 0 is keep, 1 is remove, 
    """
    target = self.patchify(imgs)
    if self.norm_pix_loss:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6)**.5

    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss

def forward(self, imgs, mask_ratio=0.75):
    latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
    pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
    loss = self.forward_loss(imgs, pred, mask)
    return loss, pred, mask

def mae_vit_small_patch16_dec512d8b(kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=6, num_heads=12, decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) return model

def mae_vit_base_patch16_dec512d8b(kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) return model

def mae_vit_large_patch16_dec512d8b(kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) return model

def mae_vit_huge_patch14_dec512d8b(kwargs): model = MaskedAutoencoderViT( patch_size=14, embed_dim=1280, depth=32, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) return model

def mae_vit_large_patch16_dec256d4b(kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) return model

set recommended archs

mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_large_patch16D4d256 = mae_vit_large_patch16_dec256d4b # decoder: 512 dim, 8 blocks if name == 'main': model = mae_vit_large_patch16() inputs = torch.randn(2, 3, 224,224) outputs = model(inputs) print(outputs[1].shape) print('Done')`

shenghanlin commented 1 month ago

You can try to read code in Finetune fold, after you extract features (196*768), you should reshape the feature and UpSampling them to get the same size.