Open KangjianWu opened 4 months ago
def get_model(): model=models.resnet18(weights=True) for param in model.parameters(): #freezing the parameter param.requires_grad=False
model.avgpool=nn.AdaptiveAvgPool2d(output_size=(1,1))
model.fc=nn.Sequential(nn.Flatten(),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 1),
nn.Sigmoid())
loss_function=nn.BCELoss()
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
return model.to(device), loss_function, optimizer
Task Details:
Build and train the ResNet model for classification Save the trained model Files to Complete:
src/train_resnet.py src/resnet_model.py (if custom model architecture is required) Task Instructions:
Load the preprocessed data Train the ResNet model for the classification task Save the trained model to models/resnet/resnet_model.h5
Note: This task can begin after the data preprocessing task is completed.