fountaindream / DLAW

Code for "Dual Level Adaptive Weighting for Cloth-changing Person Re-identification"
7 stars 1 forks source link

can't run the code #1

Open 1024AILab opened 1 year ago

1024AILab commented 1 year ago

I can't run the code, when train on the PRCC dataset

fountaindream commented 1 year ago

Maybe you can use the main_prcc.py instead.

I can't run the code, when train on the PRCC dataset

1024AILab commented 1 year ago

I use this "!python main_prcc.py --mode train --data_path /root/Simple-CCReID-main/datas/prcc/rgb", but it throws error: /root/miniconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn( Traceback (most recent call last): File "main_prcc.py", line 176, in data = Data() File "/root/DLAW/data_prcc.py", line 35, in init self.trainset = Market1501(train_transform, 'train', opt.data_path) File "/root/DLAW/data_prcc.py", line 72, in init self.parsings = [path for path in self.list_pictures(self.parsing_path) if self.id(path) != -1] File "/root/DLAW/data_prcc.py", line 187, in list_pictures assert os.path.isdir(directory), 'dataset is not exists!{}'.format(directory) AssertionError: dataset is not exists!/root/Simple-CCReID-main/datas/prcc/rgbnew_train_parsing

1024AILab commented 1 year ago

I download the dataset from google drive. Then run the code, but it seems don't work.

1024AILab commented 1 year ago

There is also an error in matching_vector.py: """ from future import absolute_import

import torch import torch.nn.functional as F from torch import nn from .clothes_detector import ClothesDetector from .TripletLoss import TripletLoss """ In util folder don't have the clothes_detector file and ClothesDetector funcation.

fountaindream commented 1 year ago

I download the dataset from google drive. Then run the code, but it seems don't work.

It seems like you should check the data path.

1024AILab commented 1 year ago

I change the path, not only in the "!python main_prcc.py --mode train --data_path /root/Simple-CCReID-main/datas/prcc/rgb" but also in the opt.

fountaindream commented 1 year ago

There is also an error in matching_vector.py: """ from future import absolute_import

import torch import torch.nn.functional as F from torch import nn from .clothes_detector import ClothesDetector from .TripletLoss import TripletLoss """ In util folder don't have the clothes_detector file and ClothesDetector funcation.

It can be deleted.

1024AILab commented 1 year ago

And the line "from data import Data" in main_prcc.py should be "from data_prcc import Data" right ? """ import os import numpy as np from scipy.spatial.distance import cdist from tqdm import tqdm import matplotlib

matplotlib.use('agg') import matplotlib.pyplot as plt

import torch from torch.optim import lr_scheduler

from opt import opt from data import Data from network import MGN from loss import Loss from utils.get_optimizer import get_optimizer from utils.extract_feature import extract_feature """

fountaindream commented 1 year ago

And the line "from data import Data" in main_prcc.py should be "from data_prcc import Data" right ? """ import os import numpy as np from scipy.spatial.distance import cdist from tqdm import tqdm import matplotlib

matplotlib.use('agg') import matplotlib.pyplot as plt

import torch from torch.optim import lr_scheduler

from opt import opt from data import Data from network import MGN from loss import Loss from utils.get_optimizer import get_optimizer from utils.extract_feature import extract_feature """

Yep, thank you for checking this.

1024AILab commented 1 year ago

I don't konw how to fix this /root/miniconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn( Traceback (most recent call last): File "main_prcc.py", line 176, in data = Data() File "/root/DLAW/data_prcc.py", line 35, in init self.trainset = Market1501(train_transform, 'train', opt.data_path) File "/root/DLAW/data_prcc.py", line 75, in init self._cam2label = {_id: idx for idx, _id in enumerate(self.unique_cams)} File "/root/DLAW/data_prcc.py", line 181, in unique_cams return sorted(set(self.cameras))
File "/root/DLAW/data_prcc.py", line 174, in cameras return [self.camera(path) for path in self.imgs] File "/root/DLAW/data_prcc.py", line 174, in return [self.camera(path) for path in self.imgs] File "/root/DLAW/data_prcc.py", line 139, in camera return int(filepath.split('/')[-1].split('')[1])-1 ValueError: invalid literal for int() with base 10: 'cropped'

1024AILab commented 1 year ago

import os import numpy as np from scipy.spatial.distance import cdist from tqdm import tqdm import matplotlib

matplotlib.use('agg') import matplotlib.pyplot as plt

import torch from torch.optim import lr_scheduler

from opt import opt from data_prcc import Data from network import MGN from loss import Loss from utils.get_optimizer import get_optimizer from utils.extract_feature import extract_feature

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1)

class Main(): def init(self, model, loss, data): self.train_loader = data.train_loader self.test_loader = data.test_loader self.query_loader = data.query_loader self.testset = data.testset self.queryset = data.queryset

    self.model = model.cuda()
    self.loss = loss
    self.optimizer = get_optimizer(model)
    self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=opt.lr_scheduler, gamma=0.1)

def train(self):

    self.scheduler.step()

    self.model.train()
    for batch, (inputs, labels, cams, clos, parsings) in enumerate(self.train_loader):
        inputs = inputs.cuda()
        labels = labels.cuda()
        clos = clos.cuda()
        parsings = parsings.cuda()
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = self.loss(outputs, labels, clos, parsings)
        loss.backward()
        self.optimizer.step()

def evaluate(self):

    self.model.eval()

    def rank(qf,ql,qc,qh,gf,gl,gc,gh):
        query = qf.view(-1,1)
        # print(query.shape)
        score = torch.mm(gf,query)
        score = score.squeeze(1).cpu()
        score = score.numpy()
        # predict index
        index = np.argsort(score)  #from small to large
        index = index[::-1]
        # index = index[0:2000]
        # good index
        query_index = np.argwhere(gl==ql)
        camera_index = np.argwhere(gc==qc)
        cloth_index = np.argwhere(gh==qh)             
        junk_index = np.argwhere(gl==-1)

        ap_tmp, CMC_tmp = compute_mAP(index, query_index, junk_index)
        return ap_tmp, CMC_tmp

    def compute_mAP(index, good_index, junk_index):
        ap = 0
        cmc = torch.IntTensor(len(index)).zero_()
        if good_index.size==0:   # if empty
            cmc[0] = -1
            return ap,cmc

        # remove junk_index
        mask = np.in1d(index, junk_index, invert=True)
        index = index[mask]

        # find good_index index
        ngood = len(good_index)
        mask = np.in1d(index, good_index)
        rows_good = np.argwhere(mask==True)
        rows_good = rows_good.flatten()

        cmc[rows_good[0]:] = 1
        for i in range(ngood):
            d_recall = 1.0/ngood
            precision = (i+1)*1.0/(rows_good[i]+1)
            if rows_good[i]!=0:
                old_precision = i*1.0/rows_good[i]
            else:
                old_precision=1.0
            ap = ap + d_recall*(old_precision + precision)/2

        return ap, cmc

    query_feature, query_label, query_cam, query_cloth = extract_feature(self.model, tqdm(self.query_loader))
    gallery_feature, gallery_label, gallery_cam, gallery_cloth = extract_feature(self.model, tqdm(self.test_loader))        
    # query_feature = query_feature.cuda()
    # gallery_feature = gallery_feature.cuda()

    #print(query_feature.shape)      
    CMC = torch.IntTensor(len(gallery_label)).zero_()
    ap = 0.0
    #print(query_label)
    for i in range(len(query_label)):
        ap_tmp, CMC_tmp = rank(query_feature[i],query_label[i],query_cam[i],query_cloth[i],gallery_feature,gallery_label,gallery_cam,gallery_cloth)
        if CMC_tmp[0]==-1:
            continue
        CMC = CMC + CMC_tmp
        ap += ap_tmp
        #print(i, CMC_tmp[0])

    CMC = CMC.float()
    CMC = CMC/len(query_label) #average CMC
    print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))       

def vis(self):

    self.model.eval()

    gallery_path = data.testset.imgs
    gallery_label = data.testset.ids

    # Extract feature
    print('extract features, this may take a few minutes')
    query_feature = extract_feature(model, tqdm([(torch.unsqueeze(data.query_image, 0), 1)]))
    gallery_feature = extract_feature(model, tqdm(data.test_loader))

    # sort images
    query_feature = query_feature.view(-1, 1)
    score = torch.mm(gallery_feature, query_feature)
    score = score.squeeze(1).cpu()
    score = score.numpy()

    index = np.argsort(score)  # from small to large
    index = index[::-1]  # from large to small

    # # Remove junk images
    # junk_index = np.argwhere(gallery_label == -1)
    # mask = np.in1d(index, junk_index, invert=True)
    # index = index[mask]

    # Visualize the rank result
    fig = plt.figure(figsize=(16, 4))

    ax = plt.subplot(1, 11, 1)
    ax.axis('off')
    plt.imshow(plt.imread(opt.query_image))
    ax.set_title('query')

    print('Top 10 images are as follow:')

    for i in range(10):
        img_path = gallery_path[index[i]]
        print(img_path)

        ax = plt.subplot(1, 11, i + 2)
        ax.axis('off')
        plt.imshow(plt.imread(img_path))
        ax.set_title(img_path.split('/')[-1][:9])

    fig.savefig("show.png")
    print('result saved to show.png')

if name == 'main':

data = Data()
model = MGN()
model = torch.nn.DataParallel(model)
# model = model.cuda()
# model = torch.nn.parallel.DistributedDataParallel(model)
loss = Loss()
main = Main(model, loss, data)

if opt.mode == 'train':

    for epoch in range(1, opt.epoch + 1):
        print('\nepoch', epoch)
        main.train()
        if epoch % 50 == 0:
            print('\nstart evaluate')
            main.evaluate()
            os.makedirs('weights/PRCC/', exist_ok=True)
            torch.save(model.state_dict(), ('weights/PRCC/model_{}.pt'.format(epoch)))

if opt.mode == 'evaluate':
    print('start evaluate')
    model.load_state_dict(torch.load(opt.weight))
    main.evaluate()

if opt.mode == 'vis':
    print('visualize')
    model.load_state_dict(torch.load(opt.weight))
    main.vis()
1024AILab commented 1 year ago

In loss.py, there is also an error in the line "Weighted_Triplet_Loss.data.cpu().numpy()" you need to add a "'," after the sentence.

1024AILab commented 1 year ago

I download the newest dataset and change the path, but there is still an error: /root/miniconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn( Traceback (most recent call last): File "main_prcc.py", line 176, in data = Data() File "/root/DLAW/data_prcc.py", line 35, in init self.trainset = Market1501(train_transform, 'train', opt.data_path) File "/root/DLAW/data_prcc.py", line 72, in init self.parsings = [path for path in self.list_pictures(self.parsing_path) if self.id(path) != -1] File "/root/DLAW/data_prcc.py", line 187, in list_pictures assert os.path.isdir(directory), 'dataset is not exists!{}'.format(directory) AssertionError: dataset is not exists!/root/DLAW/data/prcc/rgb/new_train_parsing

fountaindream commented 1 year ago

I download the newest dataset and change the path, but there is still an error: /root/miniconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn( Traceback (most recent call last): File "main_prcc.py", line 176, in data = Data() File "/root/DLAW/data_prcc.py", line 35, in init self.trainset = Market1501(train_transform, 'train', opt.data_path) File "/root/DLAW/data_prcc.py", line 72, in init self.parsings = [path for path in self.list_pictures(self.parsing_path) if self.id(path) != -1] File "/root/DLAW/data_prcc.py", line 187, in list_pictures assert os.path.isdir(directory), 'dataset is not exists!{}'.format(directory) AssertionError: dataset is not exists!/root/DLAW/data/prcc/rgb/new_train_parsing

please check the path for the parsing result

1024AILab commented 1 year ago

I get it, i put the "PRCC_train_parsing folder" to the prcc folder and renamed it "new_train_parsing"

1024AILab commented 1 year ago

Is that correct ?

fountaindream commented 1 year ago

Is that correct ?

yep

1024AILab commented 1 year ago

there is another error: /root/miniconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn( Traceback (most recent call last): File "main_prcc.py", line 177, in model = MGN() File "/root/DLAW/network.py", line 42, in init self.Matching_Vector = MatchingVector(num_classes) File "/root/DLAW/utils/matching_vector.py", line 45, in init self.base = ClothesDetector() NameError: name 'ClothesDetector' is not defined

1024AILab commented 1 year ago

just delete the "ClothesDetector"?

fountaindream commented 1 year ago

just delete the "ClothesDetector"?

yes

1024AILab commented 1 year ago

Traceback (most recent call last): File "main_prcc.py", line 188, in main.train() File "main_prcc.py", line 41, in train for batch, (inputs, labels, cams, clos, parsings) in enumerate(self.train_loader): File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 530, in next data = self._next_data() File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1224, in _next_data return self._process_data(data) File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1250, in _process_data data.reraise() File "/root/miniconda3/lib/python3.8/site-packages/torch/_utils.py", line 457, in reraise raise exception _pickle.UnpicklingError: Caught UnpicklingError in DataLoader worker process 0. Original Traceback (most recent call last): File "/root/miniconda3/lib/python3.8/site-packages/numpy/lib/npyio.py", line 438, in load return pickle.load(fid, **pickle_kwargs) _pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/root/DLAW/data_prcc.py", line 93, in getitem parsing_label, label_parsing1 = self.parsing(parsing_path) File "/root/DLAW/data_prcc.py", line 112, in parsing label_parsing1 = np.load(file_path, allow_pickle=True) File "/root/miniconda3/lib/python3.8/site-packages/numpy/lib/npyio.py", line 440, in load raise pickle.UnpicklingError( _pickle.UnpicklingError: Failed to interpret file '/root/DLAW/data/prcc/rgb/new_train_parsing/213_1_cropped_rgb358.png' as a pickle

1024AILab commented 1 year ago

i don't know how to fix it

1024AILab commented 1 year ago

image here are the folders in prcc/rgb

fountaindream commented 1 year ago

Please download the parsing result later, using the npy files instead.

1024AILab commented 1 year ago

download through the baidudrive? Thank you~

1024AILab commented 1 year ago

Traceback (most recent call last): File "main_prcc.py", line 188, in main.train() File "main_prcc.py", line 47, in train outputs = self.model(inputs) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward return self.module(*inputs[0], *kwargs[0]) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/root/DLAW/network.py", line 135, in forward part_pd_score, matching_vector = self.Matching_Vector(x) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/root/DLAW/utils/matching_vector.py", line 77, in forward y_part, y_global, y_fore, clustering_feat_map, part_pd_score = self.base(x) # (b, 2048, 1, 1) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'MatchingVector' object has no attribute 'base'

1024AILab commented 1 year ago

AttributeError: 'MatchingVector' object has no attribute 'base' but we delete the # self.base = ClothesDetector() I am very curious about the experiments in the paper

1024AILab commented 1 year ago

How does the experiments in the paper carry out ?

fountaindream commented 1 year ago

AttributeError: 'MatchingVector' object has no attribute 'base' but we delete the # self.base = ClothesDetector() I am very curious about the experiments in the paper

I've unploded the clothes_detetor file.

1024AILab commented 1 year ago

image image Hello I finished the training process. But I don't know use which folder to vis

1024AILab commented 1 year ago

image How to see the Cloth-changing rank@1 and mAP ? Thank you~