phamquiluan / ResidualMaskingNetwork

ICPR 2020: Facial Expression Recognition using Residual Masking Network
https://ieeexplore.ieee.org/document/9411919
MIT License
456 stars 93 forks source link

Resmasking Forward Function TypeError #53

Open tanzim10 opened 6 months ago

tanzim10 commented 6 months ago

import torch import torch.nn as nn from torchvision.models.utils import load_state_dict_from_url from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck

model_urls = { "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", }

class ResMasking(ResNet): def init(self, weight_path=""): super(ResMasking, self).init( block=BasicBlock, layers=[2, 2, 2, 2] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet18"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(512, 7)

    self.mask1 = self._masking(64, 64, depth=4)
    self.mask2 = self._masking(128, 128, depth=3)
    self.mask3 = self._masking(256, 256, depth=2)
    self.mask4 = self._masking(512, 512, depth=1)

def _masking(self, in_channels, out_channels, depth):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        *[
            nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ) for _ in range(depth - 1)
        ]
    )

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    m = self.mask1(x)
    x = x * (1 + m)

    x = self.layer2(x)
    m = self.mask2(x)
    x = x * (1 + m)

    x = self.layer3(x)
    m = self.mask3(x)
    x = x * (1 + m)

    x = self.layer4(x)
    m = self.mask4(x)
    x = x * (1 + m)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)

    x = self.fc(x)
    return x

class ResMasking50(ResNet): def init(self, weight_path=""): super(ResMasking50, self).init( block=Bottleneck, layers=[3, 4, 6, 3] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet50"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(2048, 7)

    self.mask1 = self._masking(256, 256, depth=4)
    self.mask2 = self._masking(512, 512, depth=3)
    self.mask3 = self._masking(1024, 1024, depth=2)
    self.mask4 = self._masking(2048, 2048, depth=1)

def _masking(self, in_channels, out_channels, depth):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        *[
            nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ) for _ in range(depth - 1)
        ]
    )

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    m = self.mask1(x)
    x = x * (1 + m)

    x = self.layer2(x)
    m = self.mask2(x)
    x = x * (1 + m)

    x = self.layer3(x)
    m = self.mask3(x)
    x = x * (1 + m)

    x = self.layer4(x)
    m = self.mask4(x)
    x = x * (1 + m)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)

    x = self.fc(x)
    return x

def resmasking(in_channels=3, num_classes=7, weight_path=""): return ResMasking(weight_path)

def resmasking50_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking50(weight_path) model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(2048, num_classes)) return model

def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Dropout(0.4), nn.Linear(512, num_classes) ) return model

def resmasking_dropout2(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(128, num_classes), ) return model

def resmasking_dropout3(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(), nn.Linear(128, num_classes), ) return model

TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'

main.py def main(config_path): """ This is the main function to make the training up

Parameters:
-----------
config_path : srt
    path to config file
"""
# load configs and set random seed
configs = json.load(open(config_path))
configs["cwd"] = os.getcwd()

# load model and data_loader
model = get_model(configs)

train_set, val_set, test_set = get_dataset(configs)

# init trainer and make a training
# from trainers.fer2013_trainer import FER2013Trainer

# from trainers.centerloss_trainer import FER2013Trainer
trainer = FER2013Trainer(model, train_set, val_set, test_set, configs)

if configs["distributed"] == 1:
    ngpus = torch.cuda.device_count()
    mp.spawn(trainer.train, nprocs=ngpus, args=())
else:
    trainer.train()

def get_model(configs):

Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'

if configs["arch"] == "resmasking_dropout3":
    # Directly return the imported model architecture
    model = resmasking_dropout3(

        num_classes=configs["num_classes"]
    )
    return model
else:
    # Handle case where 'arch' does not match
    raise ValueError(f"Model architecture {configs['arch']} is not supported.")

def get_dataset(configs): """ This function get raw dataset """

# todo: add transform
train_set = fer2013("train", configs)
val_set = fer2013("val", configs)
test_set = fer2013("test", configs, tta=True, tta_size=10)
return train_set, val_set, test_set

if name == "main": main("/content/drive/MyDrive/Resnet/fer2013_config.json")