import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
self.mask1 = self._masking(64, 64, depth=4)
self.mask2 = self._masking(128, 128, depth=3)
self.mask3 = self._masking(256, 256, depth=2)
self.mask4 = self._masking(512, 512, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
*[
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
) for _ in range(depth - 1)
]
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
self.mask1 = self._masking(256, 256, depth=4)
self.mask2 = self._masking(512, 512, depth=3)
self.mask3 = self._masking(1024, 1024, depth=2)
self.mask4 = self._masking(2048, 2048, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
*[
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
) for _ in range(depth - 1)
]
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'
main.py
def main(config_path):
"""
This is the main function to make the training up
Parameters:
-----------
config_path : srt
path to config file
"""
# load configs and set random seed
configs = json.load(open(config_path))
configs["cwd"] = os.getcwd()
# load model and data_loader
model = get_model(configs)
train_set, val_set, test_set = get_dataset(configs)
# init trainer and make a training
# from trainers.fer2013_trainer import FER2013Trainer
# from trainers.centerloss_trainer import FER2013Trainer
trainer = FER2013Trainer(model, train_set, val_set, test_set, configs)
if configs["distributed"] == 1:
ngpus = torch.cuda.device_count()
mp.spawn(trainer.train, nprocs=ngpus, args=())
else:
trainer.train()
def get_model(configs):
Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'
if configs["arch"] == "resmasking_dropout3":
# Directly return the imported model architecture
model = resmasking_dropout3(
num_classes=configs["num_classes"]
)
return model
else:
# Handle case where 'arch' does not match
raise ValueError(f"Model architecture {configs['arch']} is not supported.")
def get_dataset(configs):
"""
This function get raw dataset
"""
import torch import torch.nn as nn from torchvision.models.utils import load_state_dict_from_url from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
model_urls = { "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", }
class ResMasking(ResNet): def init(self, weight_path=""): super(ResMasking, self).init( block=BasicBlock, layers=[2, 2, 2, 2] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet18"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(512, 7)
class ResMasking50(ResNet): def init(self, weight_path=""): super(ResMasking50, self).init( block=Bottleneck, layers=[3, 4, 6, 3] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet50"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(2048, 7)
def resmasking(in_channels=3, num_classes=7, weight_path=""): return ResMasking(weight_path)
def resmasking50_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking50(weight_path) model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(2048, num_classes)) return model
def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Dropout(0.4), nn.Linear(512, num_classes) ) return model
def resmasking_dropout2(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(128, num_classes), ) return model
def resmasking_dropout3(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(), nn.Linear(128, num_classes), ) return model
TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'
main.py def main(config_path): """ This is the main function to make the training up
def get_model(configs):
Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'
def get_dataset(configs): """ This function get raw dataset """
if name == "main": main("/content/drive/MyDrive/Resnet/fer2013_config.json")