atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

added multithreading to OTPlanSampler for "exact" solver #131

Closed yashizhang closed 3 months ago

yashizhang commented 3 months ago

What does this PR do?

This PR adds support for OpenMP multithreading via the OTPlanSampler on initialization for method == "exact". Here's the link for Python OT's documentation of the pot.emd method.

On my HPC (AMD EPYC 7742 CPU, A100 GPU), running num_threads=2 or num_threads=4 gave up to 2x speedups. This is hardware dependent so the default value will be num_threads=1, as it has been.

If you would like to test threading performance yourself, here's the code that I used:

import argparse 
import time 
import torch 
import numpy as np 
import multiprocessing
from functools import partial
from torchcfm.optimal_transport import OTPlanSampler 
from torchcfm.utils import sample_8gaussians, sample_moons
from tqdm import tqdm 

def parse_args():
   parser = argparse.ArgumentParser()
   parser.add_argument("--num_samples", type=int, default=500)
   parser.add_argument("--num_rounds", type=int, default=50)
   return parser.parse_args()

if __name__ == "__main__":
   args = parse_args()

   samplers = {}
   for num_threads in [1, 2, 4, 8, 16, 32, 64]:
       samplers[num_threads] = OTPlanSampler(method="exact", num_threads=num_threads)
   #samplers[-1] = OTPlanSampler(method="exact", num_threads="max")

   torch.manual_seed(42)
   np.random.seed(42)

   samples = []
   for _ in range(args.num_rounds):
       x0 = sample_8gaussians(args.num_samples)
       x1 = sample_moons(args.num_samples)
       samples.append((x0, x1))

   # Test with POT implementation
   times = {}
   for num_threads, sampler in samplers.items():
       start = time.time()
       for x0, x1 in tqdm(samples):
           sampler.get_map(x0, x1)
       times[num_threads] = (time.time() - start)
   print("="*50)
   print(f"Using POT Implementation, Num. Samples: {args.num_samples}, Num. Rounds: {args.num_rounds}") 
   print("="*50)
   for num_threads, _time in times.items():
       print(f"Num. Threads: {num_threads}, Time: {_time:.2f}s")

Before submitting

yashizhang commented 3 months ago

I forgot that Python versions < 3.10 do not support type | type, so my type hint int | str needs to be modified. Added from typing import Union and changed to Union[int, str]

atong01 commented 3 months ago

LGTM thank you for the PR.

Just going to put the timing information here on your system for reference.

AMD EPYC 7742 CPU, A100 GPU

Using POT Implementation, Num. Samples: 500, Num. Rounds: 50

Num. Threads: 1, Time: 14.23s Num. Threads: 2, Time: 7.58s Num. Threads: 4, Time: 10.67s Num. Threads: 8, Time: 16.80s Num. Threads: 16, Time: 23.88s Num. Threads: 32, Time: 42.76s Num. Threads: 64, Time: 108.10s