damo-cv / TransReID-SSL

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification
MIT License
173 stars 20 forks source link

speed of training #5

Closed wengdunfang closed 2 years ago

wengdunfang commented 2 years ago

Hi ,when I using python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ --arch vit_small \ --data_path /my path/LUP \ --output_dir ./log/dino/lup/vit_small_full_lup \ --height 256 --width 128 \ --crop_height 128 --crop_width 64 \ --epochs 100 \

I found my code stuck at line 153 (main_dino. Py). Is this caused by loading the luperson dataset? It's been running for six hours.

michuanhaohao commented 2 years ago

Did you decode the LUPerson dataset to images (jpg/png)?

wengdunfang commented 2 years ago

Did you decode the LUPerson dataset to images (jpg/png)?

yes, I use the following code to read data.mdb and turn it into a picture. my dataset fold is : ── LUP └── images └──34_01_0341_00001865.jpg └── CFS_list.pkl

import os
import numpy as np
import lmdb, pickle, cv2
from PIL import Image
from torch.utils.data import Dataset
import re, time
from glob import glob
import torch

class PersonDataset(Dataset):
    def init(self, data_dir, key_path, transform=None):
              super(PersonDataset, self).init()
              self.data_dir = data_dir
              self.key_path = key_path
              self.transform = transform
              if not os.path.exists(self.data_dir):
                  raise IOError('dataset dir: {} is non-exist'.format(self.data_dir))
              self.load_dataset_infos()
              self.env = None

def load_dataset_infos(self):
    if not os.path.exists(self.key_path):
        raise IOError('key info file: {} is non-exist'.format(
                        self.key_path))
    with open(self.key_path, 'rb') as f:
        data = pickle.load(f)
    self.keys = data['keys']
    if 'pids' in data:
        self.labels = np.array(data['pids'], np.int)
    elif 'vids' in data:
        self.labels = np.array(data['vids'], np.int)
    else:
        self.labels = np.zeros(len(self.keys), np.int)
    self.num_cls = len(set(self.labels))

def __len__(self):
    return len(self.keys)

def _init_lmdb(self):
    self.env = lmdb.open(self.data_dir, readonly=True, lock=False, 
                    readahead=False, meminit=False)

def __getitem__(self, index):
    if self.env is None:
        self._init_lmdb()

    key = self.keys[index]
    label = self.labels[index]

    with self.env.begin(write=False) as txn:
        buf = txn.get(key.encode('ascii'))
    im = np.frombuffer(buf, dtype=np.uint8)
    im = cv2.imdecode(im, cv2.IMREAD_COLOR)
    # im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    # if self.transform is not None:
    #     im = Image.fromarray(im)
    #     im = self.transform(im)
    # else:
    #     im = im / 255.

    return im, label,key

def __repr__(self):
    format_string  = self.__class__.__name__ + '(num_imgs='
    format_string += '{:d}, num_cls={:d})'.format(len(self), self.num_cls)
    return format_string

data_dir = './lmdb'
key_path = './keys.pkl'
image_datasets = PersonDataset(data_dir, key_path)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=1,shuffle=True, num_workers=0)
index = 0
for i in dataloaders:
        im, label,key = i
        index += 1
        im = im.numpy()
        print(index, im.shape,im[0].shape, label,key,key[0])
        path = './img/'+str(key[0])+'.jpg'
        cv2.imwrite(path,im[0])
michuanhaohao commented 2 years ago

I have not tried to load data from the lmdb file. I decoded data from the lmdb file and save them into a new dir. The dataset path should be '/xxx/data/LUP/images/xxxx.jpg'. Did the log output "print(f"Data loaded: there are {len(dataset)} images.")"?

You can try to kill all pids and run again. If the error occurs again, please show the the error log and the config file. sudo ps aux|grep main|grep -v grep|awk '{print $2}'|xargs kill -9

wengdunfang commented 2 years ago

谢谢您的回复,我没有直接加载data.mdb文件,我也是采用上面我给出的代码将data.mdb转成图片,所有的图片都保存成了jpg格式,在数据LUP/images 目录下,共有4百万多张。然后我采用下面的命令运行,现在代码是卡在了main_dino.py文件下153行,也就是 : dataset = datasets.ImageFolder(args.data_path, transform=transform) 我确定的是args.data_path = /data/LUP LUP目录下有图片的文件夹images还有一个cfs_list.pkl文件。 我觉得有点奇怪,特意查了一下datasets.ImageFolder,好像它不能直接找到images文件里面的图片吧。 我特意将--data_path /data/LUP/images。可是遇到还是同样的结果。

python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \
--arch vit_small \
--data_path /data/LUP \
--output_dir ../log/dino/lup/vit_small_full_lup \
--height 256 --width 128 \
--crop_height 128 --crop_width 64 \
--epochs 3

下面是我的代码运行以后在shell中的打印的,因为代码在: dataset = datasets.ImageFolder(args.data_path, transform=transform) 卡住了,log并没有生成。 print(f"Data loaded: there are {len(dataset)} images.")这行代码也没有执行

arch: vit_small
batch_size_per_gpu: 64
clip_grad: 3.0
crop_height: 128
crop_width: 64
data_path: /data/LUP
dist_url: env://
epochs: 3
filter_path:
freeze_last_layer: 1
global_crops_scale: (0.4, 1.0)
gpu: 0
height: 256
keep_num: 1281167
local_crops_number: 8
local_crops_scale: (0.05, 0.4)
local_rank: 0
lr: 0.0005
min_lr: 1e-06
momentum_teacher: 0.996
norm_last_layer: True
num_workers: 16
optimizer: adamw
out_dim: 65536
output_dir: ../log/dino/lup/vit_small_full_lup
patch_size: 16
rank: 0
saveckp_freq: 5
seed: 0
teacher_temp: 0.04
use_bn_in_head: False
use_fp16: True
warmup_epochs: 10
warmup_teacher_temp: 0.04
warmup_teacher_temp_epochs: 0
weight_decay: 0.04
weight_decay_end: 0.4
width: 128
world_size: 8
wengdunfang commented 2 years ago

Can you add my WeChat? Account number: wengdunfang