KangjianWu / BrainTumorDetection-MRI

0 stars 0 forks source link

ResNet Model Training #2

Open KangjianWu opened 4 months ago

KangjianWu commented 4 months ago

Task Details:

Build and train the ResNet model for classification Save the trained model Files to Complete:

src/train_resnet.py src/resnet_model.py (if custom model architecture is required) Task Instructions:

Load the preprocessed data Train the ResNet model for the classification task Save the trained model to models/resnet/resnet_model.h5

Note: This task can begin after the data preprocessing task is completed.

prijall commented 2 weeks ago

@ model building

def get_model(): model=models.resnet18(weights=True) for param in model.parameters(): #freezing the parameter param.requires_grad=False

model.avgpool=nn.AdaptiveAvgPool2d(output_size=(1,1))
model.fc=nn.Sequential(nn.Flatten(), 
                        nn.Linear(512, 128),
                        nn.ReLU(),
                        nn.Dropout(0.2),
                        nn.Linear(128, 1),
                        nn.Sigmoid())
                        loss_function=nn.BCELoss()
                        optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
                        return model.to(device), loss_function, optimizer