Lightning-AI / lightning-Habana

Lightning support for Intel Habana accelerators.
Apache License 2.0
25 stars 8 forks source link

LightningCLI support for external accelerators #55

Open ankitgola005 opened 1 year ago

ankitgola005 commented 1 year ago

🚀 Feature

LightningCLI support for external accelerators

Motivation

LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.

Pitch

Extend support for external accelerators in LightningCLI.

Alternatives

Additional context

First mentioned in #54

To reproduce:

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.cli import LightningCLI
from lightning_habana import HPUAccelerator

class BMAccelerator(BoringModel):
    def on_fit_start(self):
        assert isinstance(self.trainer.accelerator, HPUAccelerator), self.trainer.accelerator

model = BMAccelerator
accelerator = HPUAccelerator()

if __name__ == "__main__":

    # Method 1, Passing supported accelerator class instance from an external library
    cli = LightningCLI(model, trainer_defaults={'accelerator': accelerator}

    # Method 2, passing accelerator as string
    cli = LightningCLI(model, trainer_defaults={'accelerator': 'hpu'}

Gives the following tracebacks:

Method 1, passing supported accelerator class instance from an external library

Traceback (most recent call last):
  File "temp.py", line 34, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': HPUAccelerator()})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
    self.strategy.setup_environment()
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/home/agola/lightning-habana-fork/src/lightning_habana/pytorch/accelerator.py", line 50, in setup_device
    raise MisconfigurationException(f"Device should be HPU, got {device} instead.")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be HPU, got cpu instead.

Method 2, passing accelerator as string

Traceback (most recent call last):
  File "temp.py", line 33, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': "hpu"})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 916, in _run
    call._call_lightning_module_hook(self, "on_fit_start")
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "temp.py", line 15, in on_fit_start
    assert isinstance(self.trainer.accelerator,
AssertionError: <lightning.pytorch.accelerators.hpu.HPUAccelerator object at 0x7f37f62917c0>

Env

lightning                     2.0.0
lightning-fabric              2.0.3
lightning-habana              1.0.0
lightning-utilities           0.9.0
pytorch-lightning             2.0.5
github-actions[bot] commented 1 year ago

Hi! thanks for your contribution!, great first issue!

jerome-habana commented 1 year ago

cc @Borda

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.