Open amirhszd opened 1 year ago
Can you provide more details while debugging? I've never come across anything like this. The difference between the training/validation parts is the evaluation code starting here. I recommend debugging here first to check for errors like OOM issues.
it seems like it is not stalling it is extremely slow for some reason. I put a print(i)
, where i
is the validation data loader counter at the end of the loop. It is 9:22 AM at the moment and the validation portion started 1AM (8 hours and still running). I also put print statements in the model.forward()
before and that is where things are taking much longer.
I am using DALE dataset, turning off all the augmentation stuff (only 1 feature; intensity), introduced dummy class weights (all 1s), I am setting get_loop at 6 for this dataset (not sure what it does but went for default value). There is no memory issue, as I have access to 8 RTX 3090s. My first assumption was that this might be due to multiprocessing_distributed but I don't think that is the case.
I am changing the arguments by changing the default value here. Take a look please:
Author: Haoxi Ran
Date: 06/30/2022
import json
import os
import time
from functools import partial
import numpy as np
import argparse
from pathlib import Path
import datetime
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.multiprocessing as mp
import torch.distributed as dist
from tensorboardX import SummaryWriter
from modules.aug_utils import transform_point_cloud_coord, transform_point_cloud_rgb
from util.utils import AverageMeter, intersectionAndUnionGPU, find_free_port, get_dataset_description, get_optimizer, \
get_scheduler, get_loop, get_aug_args, get_loss, get_dataset_obj, get_rgb_stat, worker_init_fn
from util.data_util import collate_fn
from util.utils import get_model, get_class_weights, set_seed, main_process, get_logger
def parse_args():
parser = argparse.ArgumentParser('Model')
# for debugging set multiprocessing_distributed to False and workers to 1
# Basic
parser.add_argument('--log_dir', type=str, default='${log_dir}', help='experiment root')
parser.add_argument('--log_root', type=str, default='./log', help='log root dir')
parser.add_argument('--dataset', type=str, default='DALE', help='dataset name')
# available models are: 'pointnet2.pointnet2_ssg', 'repsurf.repsurf_umb_ssg', 'pointtransformer.pointtransformer'
parser.add_argument('--model', default='repsurf.repsurf_umb_ssg', help='model name [default: pointnet2.pointnet2_ssg]')
parser.add_argument('--gpus', nargs='+', type=int, default=[0], help='GPU IDs')
parser.add_argument('--seed', type=int, default=2000, help='Training Seed')
parser.add_argument('--world_size', type=int, default=1, help='Number of processes participating in the job')
parser.add_argument('--rank', type=int, default=0, help='Rank of the current process')
parser.add_argument('--multiprocessing_distributed', action='store_false', default=False,
help='Whether to use multiprocessing [default: True]')
parser.add_argument('--sync_bn', action='store_true', default=False,
help='Whether to use sync bn [default: False]')
# Training
parser.add_argument('--epoch', default=5, type=int, help='number of epoch in training [default: 100]')
parser.add_argument('--batch_size', type=int, default=1, help='batch size in training [default: 32]')
parser.add_argument('--workers', type=int, default=1, help='DataLoader Workers Number [default: 4]')
parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer for training [SGD, AdamW]')
parser.add_argument('--momentum', type=float, default=0.9, help='optimizer momentum [default: 0.9]')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='decay rate [default: 1e-2]')
parser.add_argument('--scheduler', type=str, default='step', help='scheduler for training [step]')
parser.add_argument('--learning_rate', default=0.006, type=float, help='init learning rate [default: 0.5]')
parser.add_argument('--lr_decay', type=float, default=0.1, help='decay rate [default: 0.1]')
parser.add_argument('--data_norm', type=str, default='mean', help='initializer for model [mean, min]')
parser.add_argument('--lr_decay_epochs', type=int, default=[60, 80], nargs='+',
help='for step scheduler. where to decay lr, can be a list')
parser.add_argument('--start_epoch', type=int, default=0, help='Start Training Epoch [default: 0]')
parser.add_argument('--train_freq', type=int, default=250, help='Training frequency [default: 250]')
parser.add_argument('--resume', type=str, default=None, help='Trained checkpoint path')
parser.add_argument('--pretrain', type=str, default=None, help='Pretrain model path')
# Evaluation
parser.add_argument('--batch_size_val', type=int, default=1, help='batch size in validation [default: 4]')
# TODO: min val is where we start looking at the validation data at epoch level
parser.add_argument('--min_val', type=int, default=0, help='Min val epoch [default: 60]')
parser.add_argument('--val_freq', type=int, default=1, help='Val frequency [default: 1]')
parser.add_argument('--test_area', type=int, default=4, help='Which area to use for test [default: 5]')
# Augmentation
#TODO: turning off all the augmentation stuff as we have only 1 feature
parser.add_argument('--aug_scale', action='store_true', default=False,
help='Whether to augment by scaling [default: False]')
parser.add_argument('--aug_rotate', type=str, default=None,
help='Type to augment by rotation [pert, pert_z, rot, rot_z]')
parser.add_argument('--aug_jitter', action='store_true', default=False,
help='Whether to augment by shifting [default: False]')
parser.add_argument('--aug_flip', action='store_true', default=False,
help='Whether to augment by flipping [default: False]')
parser.add_argument('--aug_shift', action='store_true', default=False,
help='Whether to augment by shifting [default: False]')
parser.add_argument('--color_contrast', action='store_true', default=False,
help='Whether to augment by RGB contrasting [default: False]')
parser.add_argument('--color_shift', action='store_true', default=False,
help='Whether to augment by RGB shifting [default: False]')
parser.add_argument('--color_jitter', action='store_true', default=False,
help='Whether to augment by RGB jittering [default: False]')
parser.add_argument('--hs_shift', action='store_true', default=False,
help='Whether to augment by HueSaturation shifting [default: False]')
parser.add_argument('--color_drop', action='store_true', default=False,
help='Whether to augment by RGB Dropout [default: False]')
# RepSurf
parser.add_argument('--group_size', type=int, default=8, help='Size of umbrella group [default: 8]')
parser.add_argument('--return_polar', action='store_true', default=False,
help='Whether to return polar coordinate in surface abstraction [default: False]')
parser.add_argument('--freeze_epoch', default=5, type=int,
help='number of epoch to freeze repsurf [default: 1e6]')
return parser.parse_args()
def main_worker(gpu, ngpus_per_node, argss):
global args, best_iou
args, best_iou = argss, 0
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
set_seed(args.seed + args.rank)
# model
model = get_model(args)
if main_process(args):
global logger, writer
logger = get_logger(args.log_dir, args.model)
writer = SummaryWriter(args.log_dir), indent=4, sort_keys=True)) # print args"=> creating models ...")"Classes: {}".format(args.num_class))
# print num of params
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)'Total Number of Parameters: {} M'.format(str(float(num_param) / 1e6)[:5]))
if args.distributed:
args.batch_size = int(args.batch_size / ngpus_per_node)
args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
if args.sync_bn:
# model parallel (Note: During DDP Training, enable 'find_unused_parameters' to freeze repsurf)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu],
find_unused_parameters='repsurf' in args.model)
# model
model = torch.nn.DataParallel(model)
coord_transform = transform_point_cloud_coord(args)
rgb_transform = transform_point_cloud_rgb(args)
dataset_obj = get_dataset_obj(args)
rgb_mean, rgb_std = get_rgb_stat(args)
# TODO: get rid of rgb mean and std and, or add DALE rgb stats and get rid of feature min, max
if 'trainval' not in args.dataset:
TRAIN_DATASET = dataset_obj(args, 'train', coord_transform, rgb_transform, rgb_mean, rgb_std, True)
VAL_DATASET = dataset_obj(args, 'val', None, None, rgb_mean, rgb_std, False, TRAIN_DATASET.features_min, TRAIN_DATASET.features_max)
VAL_DATASET.stop_aug = True
if main_process(args):"Totally {} samples in {} set.".format(len(TRAIN_DATASET) // args.loop, 'train'))"Totally {} samples in {} set.".format(len(VAL_DATASET) // args.loop, 'val'))
TRAIN_DATASET = dataset_obj(args, 'trainval', coord_transform, rgb_transform, rgb_mean, rgb_std, True)
if main_process(args):"Totally {} samples in {} set.".format(len(TRAIN_DATASET) // args.loop, 'trainval'))
train_sampler = if args.distributed else None
train_loader =, batch_size=args.batch_size, shuffle=train_sampler is None,
num_workers=args.workers, pin_memory=True, sampler=train_sampler,
drop_last=True, collate_fn=collate_fn,
worker_init_fn=partial(worker_init_fn, seed=args.seed))
if VAL_DATASET is not None:
val_sampler = if args.distributed else None
val_loader =, batch_size=args.batch_size_val, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler,
worker_init_fn=partial(worker_init_fn, seed=args.seed + 100))
# loss
# TODO: introduce dummy class weights
label_weight = get_class_weights(args.description).cuda()
criterion = get_loss(label_weight, args.ignore_label).cuda()
# optimizer
optimizer = get_optimizer(args, model)
# scheduler
scheduler = get_scheduler(args, optimizer)
if args.resume is not None:
if os.path.isfile(args.resume):
if main_process(args):"=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'], strict=True)
best_iou = checkpoint['best_iou']
if main_process(args):"=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
if main_process(args):"=> no checkpoint found at '{}'".format(args.resume))
if args.pretrain is not None:
if os.path.isfile(args.pretrain):
if main_process(args):"=> loading pretrained model '{}'".format(args.pretrain))
checkpoint = torch.load(args.pretrain, map_location=lambda storage, loc: storage.cuda())
model.load_state_dict(checkpoint['state_dict'], strict=True)
for epoch in range(args.start_epoch, args.epoch):
# train
if args.distributed:
loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch)
epoch_log = epoch + 1
if main_process(args):
writer.add_scalar('loss_train', loss_train, epoch_log)
writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
# validate
is_best = False
if args.min_val < epoch_log and (epoch_log % args.val_freq == 0) and args.is_eval:
loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion)
if main_process(args):
writer.add_scalar('loss_val', loss_val, epoch_log)
writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
is_best = mIoU_val > best_iou
best_iou = max(best_iou, mIoU_val)
# save model
if is_best and main_process(args):
filename = args.ckpt_dir + '/model_best.pth'
# save for publish{'state_dict': model.state_dict()}, filename)
# save for training
#{'epoch': epoch_log, 'state_dict': model.state_dict(),
# 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(),
# 'best_iou': best_iou, 'is_best': is_best}, filename)'Best validation mIoU updated to: {:.2f}'.format(best_iou * 100))
if main_process(args):
writer.close()'==>Training done!\nBest Iou: {:.2f}'.format(best_iou * 100, ))
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
# freeze weight
if args.freeze_epoch < epoch + 1:
# freeze params
for n, p in model.module.named_parameters():
if "surface_constructor" in n and p.requires_grad:
p.requires_grad = False
end = time.time()
max_iter = args.epoch * len(train_loader)
for i, (coord, feat, target, offset) in enumerate(train_loader): # [N, 3], [N, C], [N], [B]
data_time.update(time.time() - end)
coord, target, feat, offset = \
coord.cuda(non_blocking=True), target.cuda(non_blocking=True), feat.cuda(non_blocking=True), \
output = model([coord, feat, offset])
loss = criterion(output, target)
n = coord.size(0)
if args.multiprocessing_distributed:
loss *= n
count = target.new_tensor([n], dtype=torch.long)
dist.all_reduce(loss), dist.all_reduce(count)
n = count.item()
loss /= n
output = output[:, 1:].max(1)[1] + 1 if 'ScanNet' in args.dataset else output.max(1)[1] # remove unclassified label
intersection, union, target = intersectionAndUnionGPU(output, target, args.num_class, args.ignore_label)
if args.multiprocessing_distributed:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
loss_meter.update(loss.item(), n)
batch_time.update(time.time() - end)
end = time.time()
# calculate remain time
current_iter = epoch * len(train_loader) + i + 1
remain_iter = max_iter - current_iter
remain_time = remain_iter * batch_time.avg
t_m, t_s = divmod(remain_time, 60)
t_h, t_m = divmod(t_m, 60)
remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
if (i + 1) % args.train_freq == 0 and main_process(args):'Epoch: [{}/{}][{}/{}] '
'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Remain {remain_time} '
'Loss {loss_meter.val:.4f} '
'Accuracy {accuracy:.2f}'.format(epoch + 1, args.epoch, i + 1, len(train_loader),
batch_time=batch_time, remain_time=remain_time,
loss_meter=loss_meter, accuracy=accuracy * 100))
if main_process(args):
writer.add_scalar('loss_train_batch', loss_meter.val, current_iter)
writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
writer.add_scalar('allAcc_train_batch', accuracy, current_iter)
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
iou_class = iou_class[1:] if 'ScanNet' in args.dataset else iou_class
accuracy_class = accuracy_class[1:] if 'ScanNet' in args.dataset else accuracy_class
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
if main_process(args):
for class_idx, class_iou in enumerate(iou_class):
writer.add_scalar(f'class_{class_idx}_train_iou', class_iou, epoch)
if main_process(args):'Train result at epoch [{}/{}]: mIoU / mAcc / OA {:.2f} / {:.2f} / {:.2f}'.format(
epoch + 1, args.epoch, mIoU * 100, mAcc * 100, allAcc * 100))
return loss_meter.avg, mIoU, mAcc, allAcc
def validate(val_loader, model, criterion):
if main_process(args):'>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
batch_time = AverageMeter()
data_time = AverageMeter()
loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
end = time.time()
for i, (coord, feat, target, offset) in enumerate(val_loader):
data_time.update(time.time() - end)
coord, target, feat, offset = \
coord.cuda(non_blocking=True), target.cuda(non_blocking=True), feat.cuda(non_blocking=True), \
with torch.no_grad():
output = model([coord, feat, offset])
loss = criterion(output, target)
n = coord.size(0)
if args.multiprocessing_distributed:
loss *= n
count = target.new_tensor([n], dtype=torch.long)
dist.all_reduce(loss), dist.all_reduce(count)
n = count.item()
loss /= n
choice = output[:, 1:].max(1)[1] + 1 if 'ScanNet' in args.dataset else output.max(1)[1] # remove unclassified label
intersection, union, target = intersectionAndUnionGPU(choice, target, args.num_class, args.ignore_label)
if args.multiprocessing_distributed:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
loss_meter.update(loss.item(), n)
batch_time.update(time.time() - end)
end = time.time()
print(f"{i} done")
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
# remove unlabeled class
iou_class = iou_class[1:] if 'ScanNet' in args.dataset else iou_class
accuracy_class = accuracy_class[1:] if 'ScanNet' in args.dataset else accuracy_class
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
if main_process(args):'Val result: mIoU / mAcc / OA {:.2f} / {:.2f} / {:.2f}'.format(
mIoU * 100, mAcc * 100, allAcc * 100))'Val loss: {:.4f}'.format(loss_meter.avg))
for i in range(len(iou_class)):'Class_{} Result: IoU / Acc {:.2f}/{:.2f}'.format(
i, iou_class[i] * 100, accuracy_class[i] * 100))'<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
return loss_meter.avg, mIoU, mAcc, allAcc
if __name__ == '__main__':
import gc
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in args.gpus])
if 'A6000' in torch.cuda.get_device_name(0):
os.environ["NCCL_P2P_DISABLE"] = '1'
args.dist_url = 'tcp://localhost:8888'
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
args.ngpus_per_node = len(args.gpus)
if len(args.gpus) == 1:
args.sync_bn = False
args.distributed = False
args.multiprocessing_distributed = False
experiment_dir = Path(os.path.join(args.log_root, 'PointAnalysis', 'log', args.dataset.split('_')[0]))
if args.log_dir is None:
timestr = str('%Y-%m-%d_%H-%M'))
experiment_dir = experiment_dir.joinpath(timestr)
experiment_dir = experiment_dir.joinpath(args.log_dir)
checkpoints_dir = experiment_dir.joinpath('checkpoints/')
args.ckpt_dir = str(checkpoints_dir)
log_dir = experiment_dir.joinpath('logs/')
args.log_dir = str(log_dir)
dataset_obj = get_dataset_obj(args)
args.loop = get_loop(args)
#TODO: need to write something similar for DALES
if args.dataset == 'S3DIS':
args.num_class, args.voxel_max, args.voxel_size, args.in_channel, args.ignore_label = \
13, 80000, 0.04, 6, 255
args.data_dir = './data/S3DIS/trainval_fullarea'
dataset_obj(args, 'train')
dataset_obj(args, 'val')
elif args.dataset == 'DALE': # TODO in_channel here defines the model
args.num_class, args.voxel_max, args.voxel_size, args.in_channel, args.ignore_label = \
9, 10000, 0.04, 4, 0
args.data_dir = '/home/axhcis/Projects/NGA/data/DL/DALE_chunks/'
dataset_obj(args, 'train')
dataset_obj(args, 'val')
elif args.dataset == 'ScanNet_train':
args.num_class, args.voxel_max, args.voxel_size, args.in_channel, args.ignore_label = \
21, 120000, 0.02, 6, 0
args.data_dir = './data/ScanNet'
dataset_obj(args, 'train')
dataset_obj(args, 'val')
elif args.dataset == 'ScanNet_trainval':
args.num_class, args.voxel_max, args.voxel_size, args.in_channel, args.ignore_label = \
21, 120000, 0.02, 6, 0
args.data_dir = './data/ScanNet'
dataset_obj(args, 'trainval')
raise Exception('Not Impl. Dataset')
args.is_eval = 'trainval' not in args.dataset
args.aug_args = get_aug_args(args)
args.description = get_dataset_description(args)
print('Train Model on %s' % args.description)
if args.multiprocessing_distributed:
port = find_free_port()
args.dist_url = f"tcp://localhost:{port}"
args.world_size = args.ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args))
main_worker(args.gpus, args.ngpus_per_node, args)
this is the log output:
[2023-07-29 01:08:49,084 INFO line 135 2146007] Total Number of Parameters: 0.976 M
[2023-07-29 01:08:56,084 INFO line 165 2146007] Totally 342 samples in train set.
[2023-07-29 01:08:56,084 INFO line 166 2146007] Totally 129 samples in val set.
[2023-07-29 01:10:38,784 INFO line 330 2146007] Epoch: [1/5][250/2052] Batch 0.401 (0.411) Remain 01:08:32 Loss 0.2491 Accuracy 99.58
[2023-07-29 01:12:21,116 INFO line 330 2146007] Epoch: [1/5][500/2052] Batch 0.398 (0.410) Remain 01:06:42 Loss 0.9422 Accuracy 88.42
[2023-07-29 01:14:03,041 INFO line 330 2146007] Epoch: [1/5][750/2052] Batch 0.412 (0.409) Remain 01:04:52 Loss 0.0555 Accuracy 99.46
[2023-07-29 01:15:45,270 INFO line 330 2146007] Epoch: [1/5][1000/2052] Batch 0.412 (0.409) Remain 01:03:09 Loss 0.7146 Accuracy 97.51
[2023-07-29 01:17:27,351 INFO line 330 2146007] Epoch: [1/5][1250/2052] Batch 0.387 (0.409) Remain 01:01:25 Loss 1.3645 Accuracy 64.15
[2023-07-29 01:19:09,220 INFO line 330 2146007] Epoch: [1/5][1500/2052] Batch 0.418 (0.409) Remain 00:59:40 Loss 0.6405 Accuracy 97.29
[2023-07-29 01:20:51,436 INFO line 330 2146007] Epoch: [1/5][1750/2052] Batch 0.397 (0.409) Remain 00:57:58 Loss 0.9489 Accuracy 0.44
[2023-07-29 01:22:33,614 INFO line 330 2146007] Epoch: [1/5][2000/2052] Batch 0.380 (0.409) Remain 00:56:16 Loss 0.7478 Accuracy 80.15
[2023-07-29 01:22:54,613 INFO line 349 2146007] Train result at epoch [1/5]: mIoU / mAcc / OA 15.71 / 20.07 / 74.64
[2023-07-29 01:22:54,617 INFO line 355 2146007] >>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>
0 done
1 done
2 done
3 done
4 done
5 done
6 done
7 done
8 done
9 done
10 done
11 done
12 done
13 done
14 done
15 done
16 done
17 done
18 done
19 done
20 done
21 done
22 done
23 done
24 done
25 done
26 done
27 done
28 done
29 done
30 done
31 done
32 done
33 done
34 done
35 done
36 done
37 done
38 done
39 done
40 done
41 done
42 done
43 done
44 done
45 done
46 done
47 done
48 done
271 done
272 done
273 done
274 done
275 done
276 done
277 done
I figured this out, it seems like the voxelization exists for train_loader when fetching the data, but not for val_laoder, I am setting voxel_max = 10k so the data chunks are 10K but my validation split the chunks are 1 million each, which makes sense. And this get passed to the model and that's why it takes so much time. Any idea if data could be chopped up to smaller chunks in the workflow? I can make partitions of each file which I think would help a lot with pararellization.
still have the same problem even with smaller chunks. for 24 files of training it takes seconds, but for 24 files of validation it takes 25 minutes.
Thank you for the detailed info.
Here are two things I think that may cause this prblem:
A more possible one is about the manner of voxelization/grid-sampling. The main difference of preprocessing between training and val is the 'dropping points when overflow' operation here. By default, I disabled it during validation, as the input #points is acceptable considering the forward speed. However, a large number of input points can greatly affect its speed.
I recommend you turn down the input #points by either enabling the 'dropping points when overflow' operation if your data is dense enough (just change the code here by removing split != 'val'
), or chunking each scene into blocks if the scanned scenes are quite large with much more complex objects (like a city with many pedestrians, trees and houses) before running the code.
The key to this problem is to reduce the input #points as much as possible on the premise of ensuring no/little impact on the prediction results.
Thank you for the great code! I have implemented your code and I am using my own dataset on it. The validation part takes significantly Longer than the training and fails to return without any error (process gets terminated). For simplicity I am using 1 GPU, with multprocessing_distributed off and batch_number = 1. Any ideas?