netw0rkf10w / CRF

Conditional Random Fields
Apache License 2.0
23 stars 2 forks source link

how to deal with the size of x and logits #6

Closed 18972441546 closed 2 years ago

18972441546 commented 2 years ago

import torch

class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf

def forward(self, x):
    """
    x is a batch of input images
    """
    logits = self.cnn(x)
    logits = self.crf(x, logits)
    return logits

Create a CNN-CRF model from given cnn and crf

This is a PyTorch module that can be used in a usual way

model = CNNCRF(cnn, crf)

According to your usage, if my image input size is (1,3,512,512), the label shape is (1,2,512,512) where 2 is the category containing the background. The output logits shape after CNN processing is also (1,2,512,512). At this point, (1,3,512,512) and (1,2,512,512) will be input self.crf. I wonder if I can.

netw0rkf10w commented 2 years ago

The output logits shape after CNN processing is also (1,2,512,512). At this point, (1,3,512,512) and (1,2,512,512) will be input self.crf. I wonder if I can.

Yes the input shapes should be good for self.crf.