Closed 18972441546 closed 2 years ago
Could you tell me how you instantiated the CRF? Thanks.
My code is shown below, but this is direct training, but loss is always constant.
import CRF import torch from unet_model import * device = torch.device('cuda' if torch.cuda.isavailable() else 'cpu') params = CRF.FrankWolfeParams(scheme='fixed', # constant stepsize stepsize=1.0, regularizer='l2', lambda=1.0, # regularization weight lambda_learnable=False, x0_weight=0.5, # useful for training, set to 0 if inference only x0_weight_learnable=False)
crf = CRF.DenseGaussianCRF(classes=1, alpha=160, beta=0.05, gamma=3.0, spatial_weight=1.0, bilateral_weight=1.0, compatibility=1.0, init='potts', solver='fw', iterations=5, params=params) class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf def forward(self, x): """ x is a batch of input images """ logits = self.cnn(x)#get the tensor of cnn logits = self.crf(x, logits) return logits
cnn
and crf
cnn=UNet() UnetCrfs = CNNCRF(cnn, crf).to(device)
if name == 'main': cnn=UNet(1,1)# 1 is symbol of the in_channel of cnn. 3 is symbol of the out_inchnnel of cnn model = CNNCRF(cnn, crf).to(device) data_path='./data' isbi_dataset = ISBI_Loader(data_path) train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset, batch_size=1, shuffle=False) for epoch in range(epochs): net.train() for image, label in train_loader: optimizer.zero_grad() image = image.to(device=device, dtype=torch.float32) label = label.to(device=device, dtype=torch.float32) pred = net(image) loss = criterion(pred, label)
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
loss.backward()
optimizer.step()
pbar.update(1)
Are you sure that training is successful without the CRF? Could you try replacing the line UnetCrfs = CNNCRF(cnn, crf).to(device)
with UnetCrfs = cnn.to(device)
and see what happens?
thanks,it is ok.
Great. Do not hesitate to let me know if you encounter any issues.
import torch
class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf
Create a CNN-CRF model from given
cnn
andcrf
This is a PyTorch module that can be used in a usual way
model = CNNCRF(cnn, crf)
First I train unET and save the model, then I load the trained UNET model and train UNET and CRFS. I found the loss stuck at 0.693147. Do you have any suggestions?