huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.31k stars 870 forks source link

notebook_launcher on kaggle tpu #2893

Open lhiqwj173 opened 5 days ago

lhiqwj173 commented 5 days ago

The current version is not detecting the TPU environment on Kaggle.

example: https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb

with run: notebook_launcher(training_loop, args, num_processes=8)

outputs: Launching training on 8 GPUs. with error

RuntimeError: An issue was found when launching the training: 

-- Process 5 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 75, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.10/site-packages/accelerate/utils/launch.py", line 626, in __call__
    self.launcher(*args)
  File "/tmp/ipykernel_13/1582854104.py", line 2, in training_loop
    set_seed(seed)
  File "/usr/local/lib/python3.10/site-packages/accelerate/utils/random.py", line 58, in set_seed
    xm.set_rng_state(seed)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1396, in set_rng_state
    device = torch_xla._XLAC._xla_get_default_device()
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: go/debugonly    device_path: "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance

The following two code modifications should fix

def notebook_launcher(
    function,
    args=(),
    num_processes=None,
    mixed_precision="no",
    use_port="29500",
    master_addr="127.0.0.1",
    node_rank=0,
    num_nodes=1,
    rdzv_backend="static",
    rdzv_endpoint="",
    rdzv_conf=None,
    rdzv_id="none",
    max_restarts=0,
    monitor_interval=0.1,
):
    # Are we in a google colab or a Kaggle Kernel?
    in_colab = False
    in_kaggle = False
    if any(key.startswith("KAGGLE") for key in os.environ.keys()):
        in_kaggle = True
    elif "IPython" in sys.modules:
        in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())

    try:
        mixed_precision = PrecisionType(mixed_precision.lower())
    except ValueError:
        raise ValueError(
            f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
        )

  - if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
  + if (in_colab or in_kaggle) and (os.environ.get("TPU_WORKER_ID", None) is not None):
        # TPU launch
        import torch_xla.distributed.xla_multiprocessing as xmp

        if len(AcceleratorState._shared_state) > 0:
            raise ValueError(
                "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
                "your training function. Restart your notebook and make sure no cells initializes an "
                "`Accelerator`."
            )
        if num_processes is None:
            num_processes = 8

      - launcher = PrepareForLaunch(function, distributed_type="TPU")
      + launcher = PrepareForLaunch(function, distributed_type="XLA")
        print(f"Launching a training on {num_processes} TPU cores.")
        xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")

with cell run

import os
os.environ.pop('CLOUD_TPU_TASK_ID')
os.environ.pop('TPU_PROCESS_ADDRESSES')
SunMarc commented 4 days ago

Hi @lhiqwj173, thanks for the report. distributed_type="TPU" is indeed deprecated in the latest accelerate. Would you like to submit a PR with the fix you proposed ? Otherwise, I can do it ! One question I have is why do we need to change TPU_NAME to TPU_WORKER_ID ?

lhiqwj173 commented 4 days ago

I'm not sure about the previous situation, but currently in the Kaggle environment with TPU enabled, there is no 'TPU_NAME' field. Therefore, using 'TPU_NAME' to determine if TPU is enabled is not feasible and will always return false. So, I have opted for a new field that can effectively differentiate.

environ{'TPU_WORKER_ID': '0',
        'HOSTNAME': '3b1dce008685',
        'PYTHON_VERSION': '3.10.14',
        'KAGGLE_DATA_PROXY_TOKEN': 'eyJhbGciOiJBMTI4S1ciLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0.__I5goXD7XlUpcC6XAP0iV65xdQiLShDiBDfaAGsd4UsuI9awPWvGA.YqAmfYXUOxsj-xzY0W7dDA.z8YGL7aIYKBX-DCdLFXZ_0cTgDOo2dVtJQ_GAAcpYRzh6bZVsdSja-IqsIQIwmcRFMYcFw_rS2dcsF_a-V5T9bWWv9-DDEucNoIKcoULBhNGRujPsqI4jzbLHCuL_czalBrPe_E6jNP3Ri3LpSHurU5mp14NBA0sLM0Gg2_sF2VDyEUSaJOHatuR170QLKjLz-CEpcpXfVr2V2J0Jtw-WRka3TQdrvlxOvIauaqvwjP8U-4KH2UqY0KWzU44XJ9v1sjRDCy2aqRFVpMWDnGsYmj811mJ0tAdwlvolBJ6lHpSDR_-wHQiF_hIBIRpGYm0.qj_80T3PCKkHk7akeFw8SQ',
        'TF_CPP_MIN_LOG_LEVEL': '2',
        'KAGGLE_URL_BASE': 'https://www.kaggle.com/',
        'TPU_SKIP_MDS_QUERY': '1',
        'KAGGLE_DOCKER_IMAGE': 'gcr.io/kaggle-gpu-images/python-tpuvm@sha256:56d8cb1f88608771922efeb2b2a2038cb87d0b57ffa689f9c8a6b58f53c0dbd3',
        'KAGGLE_KERNEL_INTEGRATIONS': '',
        'PWD': '/kaggle/working',
        'TPU_CHIPS_PER_HOST_BOUNDS': '2,2,1',
        'BUILD_DATE': '20240603-164143',
        'PYTHON_SETUPTOOLS_VERSION': '65.5.1',
        'HOME': '/root',
        'LANG': 'C.UTF-8',
        'TPU_ACCELERATOR_TYPE': 'v3-8',
        'GPG_KEY': 'A035C8C19219BA821ECEA86B64E628F8D684696D',
        'TPU_RUNTIME_METRICS_PORTS': '8431',
        'ISTPUVM': '1',
        'SLICE_BUILDER_WORKER_IPS': '127.0.0.1',
        'PYTHONPATH': '/kaggle/lib/kagglegym:/kaggle/lib',
        'KAGGLE_DATA_PROXY_PROJECT': 'kaggle-161607',
        'KAGGLE_USER_SECRETS_TOKEN': 'eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..7YgGZAeJYiEPhT6y73o7jg.IgpVOx3jTa6P5TqE2NKbWC0pJ1TqNEsqThUdXKxyWiI99dhnCjfHOSZOgtENLd_NDxkpfrTHt5xA7V6X7ps_JE1O8rfAfINyX2odUfHb5gKBRLdqiww_AChqjzwh7eHJqVzLNASxAKLqBPpMdH30RQ.gAwpBxvM3WjB2erYCHq9RA',
        'TPU_HOST_BOUNDS': '1,1,1',
        'SHLVL': '0',
        'PJRT_DEVICE': 'TPU',
        'KAGGLE_KERNEL_RUN_TYPE': 'Interactive',
        'PYTHON_PIP_VERSION': '23.0.1',
        'XRT_TPU_CONFIG': 'localservice;0;localhost:51011',
        'KAGGLE_GCP_ZONE': 'us-east1-d',
        'PYTHON_GET_PIP_SHA256': 'dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9',
        'TPU_WORKER_HOSTNAMES': 'localhost',
        'GIT_COMMIT': 'c493f6f4bae6881bf1665547702e9eddef42107c',
        'PYTHON_GET_PIP_URL': 'https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py',
        'KAGGLE_CONTAINER_NAME': 'kaggle_cSnr0rQEoqcq8nQeHS45x3aTDDRAE4o3yEcM1ScqS6w-185646110-webtier',
        'PATH': '/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin',
        'PYTHONUSERBASE': '/root/.local',
        'KAGGLE_GRPC_DATA_PROXY_URL': 'dp.kaggle.net:443',
        'KAGGLE_DATA_PROXY_URL': 'https://dp.kaggle.net/',
        '_': '/usr/local/bin/jupyter',
        'JPY_PARENT_PID': '1',
        'PYDEVD_USE_FRAME_EVAL': 'NO',
        'TERM': 'xterm-color',
        'CLICOLOR': '1',
        'FORCE_COLOR': '1',
        'CLICOLOR_FORCE': '1',
        'PAGER': 'cat',
        'GIT_PAGER': 'cat',
        'MPLBACKEND': 'module://matplotlib_inline.backend_inline',
        'GRPC_VERBOSITY': 'ERROR',
        'LIBTPU_INIT_ARGS': ' --xla_latency_hiding_scheduler_rerun=1 --xla_tpu_prefer_async_allgather_to_allreduce=true',
        'ALLOW_MULTIPLE_LIBTPU_LOAD': '1',
        'TPU_ML_PLATFORM': 'PyTorch/XLA',
        'XLA_FLAGS': ' --xla_cpu_enable_fast_math=false --xla_gpu_simplify_all_fp_conversions=false --xla_gpu_force_compilation_parallelism=8',
        'PTXLA_TPU_LIBRARY_PATH': '/usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so'}
SunMarc commented 3 days ago

Thanks for the report ! Seems like Kaggle updated their environment. I'll try to reproduce in kaggle and check if the same happens in colab cc @muellerzr