BUTSpeechFIT / EEND

70 stars 10 forks source link

Extensive memory consumption in infer.py #2

Open Jamiroquai88 opened 2 years ago

Jamiroquai88 commented 2 years ago

Hey, I have been using the codebase for quite some time and I noticed extensive memory consumption of infer.py script. In my case it consumes around 120GB of RAM, which is a huge deal - my machine has 2T of RAM but I would need to run more of them in parallel.

I tried a simple RAM profiler on per-line basis and this is the output:


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   207  411.480 MiB  411.480 MiB           1   @profile
   208                                         def main():
   209  411.594 MiB    0.113 MiB           1       args = parse_arguments()
   210
   211                                             # For reproducibility
   212  411.602 MiB    0.008 MiB           1       torch.manual_seed(args.seed)
   213  411.602 MiB    0.000 MiB           1       torch.cuda.manual_seed(args.seed)
   214  411.602 MiB    0.000 MiB           1       torch.cuda.manual_seed_all(args.seed)  # if you are using multi-GPU.
   215  411.602 MiB    0.000 MiB           1       np.random.seed(args.seed)  # Numpy module.
   216  411.602 MiB    0.000 MiB           1       random.seed(args.seed)  # Python random module.
   217  411.602 MiB    0.000 MiB           1       torch.manual_seed(args.seed)
   218  411.602 MiB    0.000 MiB           1       torch.backends.cudnn.benchmark = False
   219  411.602 MiB    0.000 MiB           1       torch.backends.cudnn.deterministic = True
   220  411.602 MiB    0.000 MiB           1       os.environ['PYTHONHASHSEED'] = str(args.seed)
   221
   222  411.602 MiB    0.000 MiB           1       logging.info(args)
   223
   224  483.664 MiB  483.664 MiB           1       infer_loader = get_infer_dataloader(args)
   225
   226  483.664 MiB    0.000 MiB           1       if args.gpu >= 1:
   227                                                 gpuid = use_single_gpu(args.gpu)
   228                                                 logging.info('GPU device {} is used'.format(gpuid))
   229                                                 args.device = torch.device("cuda")
   230                                             else:
   231  483.664 MiB    0.000 MiB           1           gpuid = -1
   232  483.664 MiB    0.000 MiB           1           args.device = torch.device("cpu")
   233
   234  483.664 MiB    0.000 MiB           1       assert args.estimate_spk_qty_thr != -1 or \
   235                                                 args.estimate_spk_qty != -1, \
   236                                                 ("Either 'estimate_spk_qty_thr' or 'estimate_spk_qty' "
   237                                                  "arguments have to be defined.")
   238  483.664 MiB    0.000 MiB           1       if args.estimate_spk_qty != -1:
   239  483.664 MiB    0.000 MiB           3           out_dir = join(args.rttms_dir, f"spkqty{args.estimate_spk_qty}_\
   240  483.664 MiB    0.000 MiB           2               thr{args.threshold}_median{args.median_window_length}")
   241                                             elif args.estimate_spk_qty_thr != -1:
   242                                                 out_dir = join(args.rttms_dir, f"spkqtythr{args.estimate_spk_qty_thr}_\
   243                                                     thr{args.threshold}_median{args.median_window_length}")
   244
   245  510.508 MiB   26.844 MiB           1       model = get_model(args)
   246
   247  812.148 MiB  301.641 MiB           2       model = average_checkpoints(
   248  510.508 MiB    0.000 MiB           1           args.device, model, args.models_path, args.epochs)
   249  812.148 MiB    0.000 MiB           1       model.eval()
   250
   251  812.148 MiB    0.000 MiB           2       out_dir = join(
   252  812.148 MiB    0.000 MiB           1           args.rttms_dir,
   253  812.148 MiB    0.000 MiB           1           f"epochs{args.epochs}",
   254  812.148 MiB    0.000 MiB           1           f"timeshuffle{args.time_shuffle}",
   255  812.148 MiB    0.000 MiB           1           (f"spk_qty{args.estimate_spk_qty}_"
   256                                                     f"spk_qty_thr{args.estimate_spk_qty_thr}"),
   257  812.148 MiB    0.000 MiB           1           f"detection_thr{args.threshold}",
   258  812.148 MiB    0.000 MiB           1           f"median{args.median_window_length}",
   259  812.148 MiB    0.000 MiB           1           "rttms"
   260                                             )
   261  812.297 MiB    0.148 MiB           1       Path(out_dir).mkdir(parents=True, exist_ok=True)
   262
   263 34718.992 MiB -32000.816 MiB           4       for i, batch in enumerate(infer_loader):
   264 34689.719 MiB   48.301 MiB           3           input = torch.stack(batch['xs']).to(args.device)
   265 34689.719 MiB    0.000 MiB           3           name = batch['names'][0]
   266 34689.719 MiB    0.000 MiB           3           with torch.no_grad():
   267 34749.977 MiB 1166.988 MiB           3               y_pred = model.estimate_sequential(input, args)[0]
   268 34749.977 MiB 13027.621 MiB           6           post_y = postprocess_output(
   269 34749.977 MiB -32767.594 MiB           3               y_pred, args.subsampling,
   270 34749.977 MiB -32767.594 MiB           3               args.threshold, args.median_window_length)
   271 34749.977 MiB -32767.594 MiB           3           rttm_filename = join(out_dir, f"{name}.rttm")
   272 34749.977 MiB -32767.594 MiB           3           with open(rttm_filename, 'w') as rttm_file:
   273 34749.977 MiB 45795.215 MiB           3               hard_labels_to_rttm(post_y, name, rttm_file)

So it looks like the larger consumption is coming out of infer_loader. Any ideas on how we could improve this?

I am running on pretty long audios ~1 hour.

Jamiroquai88 commented 2 years ago

Ok so after some debugging it looks like memory consumption comes from

y_pred = model.estimate_sequential(input, args)[0]

I guess because the audio is quite long. Is it possible to do some kind of chunking?

fnlandini commented 2 years ago

Hi, thank you for the detailed information and sorry for the delay. I was aware that the inference could take quite some memory but had not run any profiler, thank you for that. I would need some extra time to understand where the main problem is and how to solve it (if possible) and at the moment I do not have that time. I will keep the issue open to take a look at it in the future.

In the meantime, the idea of some kind of chunking is something that has been explored: Advances in integration of end-to-end neural and clustering-based diarization for real conversational speech and TOWARDS NEURAL DIARIZATION FOR UNLIMITED NUMBERS OF SPEAKERS USING GLOBAL AND LOCAL ATTRACTORS but that is not implemented here and would require some extra effort. Folks at NTT have an implementation and perhaps you could try something in that line.

With the current version that we have here, a compromise solution could be to increase the subsampling parameter. You will lose granularity in the output but perhaps it will not affect so much the final result depending on the scenario. We did some analysis between 50ms and 100ms in Table 4 and you can see that there is quite some impact when using fine-tuning but that was when evaluating with collar 0ms. Maybe, in a more forgiving setup, the difference would not be that large.