ByungKwanLee / Causal-Unsupervised-Segmentation

Official PyTorch Implementation code for realizing the technical part of Causal Unsupervised Semantic sEgmentation (CAUSE) to improve performance of unsupervised semantic segmentation. (Under Review)
8 stars 1 forks source link

Calculation of loss_linear #1

Closed applefl closed 11 months ago

applefl commented 11 months ago

The calculation of loss_linear needs to get the label. But the unsupervised training process doesn't have label. How can I train the model with no label data?

ByungKwanLee commented 11 months ago
# linear probe loss
linear_logits = segment.linear(seg_feat_ema)
linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes)
flat_label = label.reshape(-1)
flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)
loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask])

This is just for Linear Probe Loss to validate competitive dense representation quality learned in unsupervised manners as described in Table 2 of CAUSE Main Paper.

Not only my work but also previous unsupervised segmentation works of STEGO and HP all did it. In addition, self-supervised learning frameworks such as DINO all did it despite proposing unsupervised manners.

STEGO Code Link

linear_logits = self.linear_probe(detached_code)
linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
linear_loss = self.linear_probe_loss_fn(linear_logits[mask], flat_label[mask]).mean()
loss += linear_loss
self.log('loss/linear', linear_loss, **log_args)

HP Code Link

with torch.cuda.amp.autocast(enabled=True):
    linear_output = linear_model(detached_code)
    cluster_output = cluster_model(detached_code, None, is_direct=False)

    loss, loss_dict, corr_dict = criterion(model_input=model_input,
                                           model_output=model_output,
                                           linear_output=linear_output,
                                           cluster_output=cluster_output
                                           )

    loss = loss + loss_supcon + loss_consistency*opt["alpha"]

DINO Code Link

# forward
with torch.no_grad():
    if "vit" in args.arch:
        intermediate_output = model.get_intermediate_layers(inp, n)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if avgpool:
            output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
    else:
        output = model(inp)
output = linear_classifier(output)

# compute cross entropy loss
loss = nn.CrossEntropyLoss()(output, target)

If you do not want to compute linear probe loss, then you can easily remove it.

applefl commented 11 months ago

Thank you. I tried to removed it, but the trainning process of fine_tuning_tr got an error at line 83, "No inf checks were recorded for this optimizer."

ByungKwanLee commented 11 months ago

It would be better to run the provided bash file without the change, since the Loss of Linear Probe has a different purpose (validating dense representative quality).

xouyang0079 commented 10 months ago

Hi,

I carefully checked the STEGO and HP implementation, and they all detached the code (i.e., the output of the segmentation head) from the computation graph that lead through the "net". Hence, any operation done on the tensor after the detachment can no longer influence the "net" gradients. However, I didn't see the seg_feat_ema was detached in your implementation. Please clarify me if I am wrong.

Thanks

ByungKwanLee commented 10 months ago

You can find the following line in train_front_door_tr.py [Link]

# Bank and EMA
cluster.bank_init()
ema_init(segment.head, segment.head_ema)
ema_init(segment.projection_head, segment.projection_head_ema

Here, the funcion of ema_init is defined in modules/segment_module.py [Link]

def ema_init(x, x_ema):
    for param, param_ema in zip(x.parameters(), x_ema.parameters()): param_ema.data = param.data; param_ema.requires_grad = False
xouyang0079 commented 10 months ago

You can find the following line in train_front_door_tr.py [Link]

# Bank and EMA
cluster.bank_init()
ema_init(segment.head, segment.head_ema)
ema_init(segment.projection_head, segment.projection_head_ema

Here, the funcion of ema_init is defined in modules/segment_module.py [Link]

def ema_init(x, x_ema):
    for param, param_ema in zip(x.parameters(), x_ema.parameters()): param_ema.data = param.data; param_ema.requires_grad = False

Thanks very much for your clarification!