Yunfan-Li / Contrastive-Clustering

Code for the paper "Contrastive Clustering" (AAAI 2021)
MIT License
300 stars 92 forks source link

CC适配deepcluster时遇到的问题 #56

Open showmeyourcodehaha opened 5 months ago

showmeyourcodehaha commented 5 months ago

您好!我尝试使用您的CC聚类方法去替换传统deepcluster中K-means,我主体上使用的是您的代码,只是在cluster.py中去增加了一些部分,增加的部分是使用deepcluster的方式去训练。同时我使用了Market数据集,我之前已经在这个Market数据集上以CC的算法对模型进行预训练了(1000epoch),代码如下: P.S.:被#框起来的部分是我新加进去的。 但是这段代码并不work。loss并不下降,并且ACC甚至会降低。 我只是一个研一的新生,代码能力有限,实在是不明白问题出在了那里,请您指点一二。


```python
import os
import argparse
import torch
import torchvision
import numpy as np
from utils import yaml_config_hook
from modules import resnet, network, transform
from evaluation import evaluation
from torch.utils import data
import copy
###############################################################################
from CLIPReID.bases import ImageDataset
from CLIPReID.market1501 import Market1501
from utils import yaml_config_hook, save_model
###############################################################################

def inference(loader, model, device):
    model.eval()
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            c = model.forward_cluster(x)
        c = c.detach()
        feature_vector.extend(c.cpu().detach().numpy())
        labels_vector.extend(y.numpy())
        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")
    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector

###############################################################################
# 获取伪标签
def generate_pseudo_labels(loader, model):
    pseudo_labels = []
    for step, (x, _) in enumerate(loader):
        x = x.to(device)
        pseudo_labels_batch = model.forward_cluster(x)
        pseudo_labels.append(pseudo_labels_batch)
    return pseudo_labels

# 训练
def train(loader, model, pseudo_labels, opt):
    model.train()
    total_loss = 0
    for step, (x, _) in enumerate(loader):

        x = x.to(device)
        pred = model.forward_compute(x)
        loss = torch.nn.functional.cross_entropy(pred, pseudo_labels[step])
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item()
    print("total_loss:" + str(total_loss))

# 评估
def eval(eval_data_loader, model, device):
    print("### Creating features from model ###")
    X, Y = inference(eval_data_loader, model, device)
    if args.dataset == "CIFAR-100":  # super-class
        super_label = [
            [72, 4, 95, 30, 55],
            [73, 32, 67, 91, 1],
            [92, 70, 82, 54, 62],
            [16, 61, 9, 10, 28],
            [51, 0, 53, 57, 83],
            [40, 39, 22, 87, 86],
            [20, 25, 94, 84, 5],
            [14, 24, 6, 7, 18],
            [43, 97, 42, 3, 88],
            [37, 17, 76, 12, 68],
            [49, 33, 71, 23, 60],
            [15, 21, 19, 31, 38],
            [75, 63, 66, 64, 34],
            [77, 26, 45, 99, 79],
            [11, 2, 35, 46, 98],
            [29, 93, 27, 78, 44],
            [65, 50, 74, 36, 80],
            [56, 52, 47, 59, 96],
            [8, 58, 90, 13, 48],
            [81, 69, 41, 89, 85],
        ]
        Y_copy = copy.copy(Y)
        for i in range(20):
            for j in super_label[i]:
                Y[Y_copy == j] = i
    nmi, ari, f, acc = evaluation.evaluate(Y, X)
    print('NMI = {:.4f} ARI = {:.4f} F = {:.4f} ACC = {:.4f}'.format(nmi, ari, f, acc))
###############################################################################

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###############################################################################
    __factory = {
        'market1501': Market1501,
    }
    raw_dataset = __factory["market1501"](
        root=r"../Market-1501-v15.09.15")
    train_set_normal = ImageDataset(raw_dataset.train, transform.Transforms(size=args.image_size).test_transform)
###############################################################################

    if args.dataset == "CIFAR-10":
        train_dataset = torchvision.datasets.CIFAR10(
            root=args.dataset_dir,
            train=True,
            download=True,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root=args.dataset_dir,
            train=False,
            download=True,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        dataset = data.ConcatDataset([train_dataset, test_dataset])
        class_num = 10
    elif args.dataset == "CIFAR-100":
        train_dataset = torchvision.datasets.CIFAR100(
            root=args.dataset_dir,
            download=True,
            train=True,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root=args.dataset_dir,
            download=True,
            train=False,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        dataset = data.ConcatDataset([train_dataset, test_dataset])
        class_num = 20
    elif args.dataset == "STL-10":
        train_dataset = torchvision.datasets.STL10(
            root=args.dataset_dir,
            split="train",
            download=True,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        test_dataset = torchvision.datasets.STL10(
            root=args.dataset_dir,
            split="test",
            download=True,
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
        class_num = 10
    elif args.dataset == "ImageNet-10":
        dataset = torchvision.datasets.ImageFolder(
            root='datasets/imagenet-10',
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        class_num = 10
    elif args.dataset == "ImageNet-dogs":
        dataset = torchvision.datasets.ImageFolder(
            root='datasets/imagenet-dogs',
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        class_num = 15
    elif args.dataset == "tiny-ImageNet":
        dataset = torchvision.datasets.ImageFolder(
            root='datasets/tiny-imagenet-200/train',
            transform=transform.Transforms(size=args.image_size).test_transform,
        )
        class_num = 200

###############################################################################
    elif args.dataset == "market1501":
        dataset = train_set_normal
        class_num = 751
###############################################################################

    else:
        raise NotImplementedError
    # 定义评估用的dataloader
    eval_data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=500,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

###############################################################################
    # 定义训练用的dataloader
    train_data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=256,
        num_workers=args.workers,
        shuffle=True,
        pin_memory=True,
    )
###############################################################################

    # 定义模型
    res = resnet.get_resnet(args.resnet)
    model = network.Network(res, args.feature_dim, class_num)
    model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.start_epoch))
    model.load_state_dict(torch.load(model_fp, map_location=device.type)['net'])
    model.to(device)

###############################################################################
    # 定义优化器
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.001,
        momentum=0.9,
        weight_decay=0.0005,
    )

    # 训练循环
    for epoch in range(100):
        print("epoch:" + str(epoch))
        pseudo_labels = generate_pseudo_labels(train_data_loader, model)
        train(train_data_loader, model, pseudo_labels, optimizer)
        if (epoch + 1) % 10 == 0:
            save_model(args, model, optimizer, epoch)
            eval(eval_data_loader, model, device)
###############################################################################
Yunfan-Li commented 5 months ago

您好,可以看一下预训练的CC效果是否合理,对于类别个数较大的数据集,可以考虑增大batch size来保证cluster-head contrast的有效性

showmeyourcodehaha commented 5 months ago

您好,可以看一下预训练的CC效果是否合理,对于类别个数较大的数据集,可以考虑增大batch size来保证cluster-head contrast的有效性

非常感谢您的答复,您说增大batchsize是为了保证cluster-head的有效性,那么对instance-head呢?因为在我比较有限的理解中,针对对比学习增大batchsize就是增加了负样本的数量,从而提升性能,而您说的“保证cluster-head的有效性”具体是什么含义呢? 其次就是您说到类比较多的时候,增大batchsize,那么要想达到一定效果,cluster_num / batchsize是否应该是个定值呢?您有没有做过相关实验呢?而具体有时是什么原因造成的这种结果呢?

Yunfan-Li commented 5 months ago

Cluster head是在列空间进行类别级对比学习,对于其中的某一列,需要在当前batch中存在该列对应类别的样本,那么该列才能够有效的描述类别。Instance head由于在行空间进行实例级的对比,对于batch size的要求相对低一些。对于这个具体的比例我没有做过分析试验,根据经验只要保证平均每个类有一定数量的样本即可。