mle-infrastructure / mle-toolbox

Lightweight Tool to Manage Distributed ML Experiments đź› 
https://mle-infrastructure.github.io/mle_toolbox/toolbox/
MIT License
3 stars 1 forks source link

JAX environment variable setup helper #34

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Add a helper function in order to setup environment variables. This includes memory preallocation for GPUs, CPU cores for threading/mp, number of visible devices and TPU XLA bridge setup. So something along these lines:

def get_jax_os_ready(num_devices, mem_prealloc_bool, mem_prealloc_fraction, device_type):
    # Set number of devices (CPU cores/GPUs/TPUs) - on CPU this helps with pmap testing
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={num_devices}'
    # Set environment variables for memory preallocation on GPU
    if device_type == "GPU":
        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = mem_prealloc_bool
        os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = f'{memory_prealloc_fraction}'
    # Set environment variables for TPU usage
    elif device_type == "TPU":
        import jax.tools.colab_tpu
        jax.tools.colab_tpu.setup_tpu()
    return

Potentially also add visible device option here or separately. The jax setup function should in principle be executed whenever in log_config the model_type is set to JAX, but we can leave it also optional.

RobertTLange commented 3 years ago

Also test different ways to limit the threads/CPU usage of JAX. There seems to be an ongoing discussion how to this the best way. Something along these lines:

os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={num_cpus}"
os.environ["XLA_FLAGS"] += os.pathsep + ("--xla_cpu_multi_thread_eigen=false "
                              "intra_op_parallelism_threads=1")
RobertTLange commented 3 years ago

Addressed in 51d4233. Probably have to do some changes once I get to setup the GCP VM instances.