netw0rkf10w / CRF

Conditional Random Fields
Apache License 2.0
23 stars 2 forks source link

The losses remain the same #8

Closed 18972441546 closed 2 years ago

18972441546 commented 2 years ago

import torch

class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf

def forward(self, x):
    """
    x is a batch of input images
    """
    logits = self.cnn(x)
    logits = self.crf(x, logits)
    return logits

Create a CNN-CRF model from given cnn and crf

This is a PyTorch module that can be used in a usual way

model = CNNCRF(cnn, crf)

First I train unET and save the model, then I load the trained UNET model and train UNET and CRFS. I found the loss stuck at 0.693147. Do you have any suggestions?

netw0rkf10w commented 2 years ago

Could you tell me how you instantiated the CRF? Thanks.

18972441546 commented 2 years ago

My code is shown below, but this is direct training, but loss is always constant.

import CRF import torch from unet_model import * device = torch.device('cuda' if torch.cuda.isavailable() else 'cpu') params = CRF.FrankWolfeParams(scheme='fixed', # constant stepsize stepsize=1.0, regularizer='l2', lambda=1.0, # regularization weight lambda_learnable=False, x0_weight=0.5, # useful for training, set to 0 if inference only x0_weight_learnable=False)

crf = CRF.DenseGaussianCRF(classes=1, alpha=160, beta=0.05, gamma=3.0, spatial_weight=1.0, bilateral_weight=1.0, compatibility=1.0, init='potts', solver='fw', iterations=5, params=params) class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf def forward(self, x): """ x is a batch of input images """ logits = self.cnn(x)#get the tensor of cnn logits = self.crf(x, logits) return logits

Create a CNN-CRF model from given cnn and crf

This is a PyTorch module that can be used in a usual way

cnn=UNet() UnetCrfs = CNNCRF(cnn, crf).to(device)

if name == 'main': cnn=UNet(1,1)# 1 is symbol of the in_channel of cnn. 3 is symbol of the out_inchnnel of cnn model = CNNCRF(cnn, crf).to(device) data_path='./data' isbi_dataset = ISBI_Loader(data_path) train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset, batch_size=1, shuffle=False) for epoch in range(epochs): net.train() for image, label in train_loader: optimizer.zero_grad() image = image.to(device=device, dtype=torch.float32) label = label.to(device=device, dtype=torch.float32) pred = net(image) loss = criterion(pred, label)

print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())

        if loss < best_loss:
            best_loss = loss
            torch.save(net.state_dict(), 'best_model.pth')
        loss.backward()
        optimizer.step()
        pbar.update(1)
netw0rkf10w commented 2 years ago

Are you sure that training is successful without the CRF? Could you try replacing the line UnetCrfs = CNNCRF(cnn, crf).to(device) with UnetCrfs = cnn.to(device) and see what happens?

18972441546 commented 2 years ago

thanks,it is ok.

netw0rkf10w commented 2 years ago

Great. Do not hesitate to let me know if you encounter any issues.