jinmang2 / anomaly_detection_on_video

Implementation for Anomaly Detection on Video
3 stars 1 forks source link

lightning checkpointing #9

Open jinmang2 opened 1 year ago

jinmang2 commented 1 year ago

목적

jinmang2 commented 1 year ago

How to save checkpoint in lightning?

Note that: 아래 코드 분석은 lightning 2.0.7 버전을 기준으로 분석한 결과이기 때문에 차후 경로 및 소스코드에 변경이 있을 수 있음.

보통 학습이 종료된 이후 pl.Trainersave_checkpoint method를 활용하여 저장하거나 callbacks에 pl.callbacks.ModelCheckpoint를 활용해서 저장하게 된다.

pl.callbacks.ModelCheckpoint에선 last/non-monitor/update-best 등 세팅된 인자를 기반으로 어느 시점이든 checkpoint를 저장하게 되는데, 전부 공통적으로 아래 method를 활용한다.

# lightning.pytorch.callbacks.model_checkpoint.py
class ModelCheckpoint(Checkpoint):
...
    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        trainer.save_checkpoint(filepath, self.save_weights_only)

        self._last_global_step_saved = trainer.global_step

        # notify loggers
        if trainer.is_global_zero:
            for logger in trainer.loggers:
                logger.after_save_checkpoint(proxy(self))

즉, 중요한 것은 pl.Trainersave_checkpoint method. 이 또한 뜯어보면,

# lightning.pytorch.trainer.trainer.py
class Trainer:
...
    def save_checkpoint(
        self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
    ) -> None:
        r"""Runs routine to create a checkpoint.

        Args:
            filepath: Path where checkpoint is saved.
            weights_only: If ``True``, will only save the model weights.
            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin

        Raises:
            AttributeError:
                If the model is not attached to the Trainer before calling this method.

        """
        if self.model is None:
            raise AttributeError(
                "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
                " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
            )
        checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
        self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
        self.strategy.barrier("Trainer.save_checkpoint")

여기서 self.strategy는 ddp 등의 분산학습 전략을 의미하는 것으로 보이고 실제로 checkpoint를 받아오는 것은 self._checkpoint_connectordump_checkpoint method로 확인된다. 해당 코드에서도 중요한 부분만 추출하여 분석하면,

# lightning.pytorch.trainer.trainer.py
class Trainer:
    @_defaults_from_env_vars
    def __init__(
        self,
        ...
    ) -> None:
        ...
        self._checkpoint_connector = _CheckpointConnector(self)
        ...
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def dump_checkpoint(self, weights_only: bool = False) -> dict:
        """Creating a model checkpoint dictionary object from various component states.

        Args:
            weights_only: saving model weights only
        Return:
            structured dictionary: {
                'epoch':                     training epoch
                'global_step':               training global step
                'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
                'callbacks':                 "callback specific state"[] # if not weights_only
                'optimizer_states':          "PT optim's state_dict"[]   # if not weights_only
                'lr_schedulers':             "PT sched's state_dict"[]   # if not weights_only
                'state_dict':                Model's state_dict (e.g. network weights)
                precision_plugin.__class__.__qualname__:  precision plugin state_dict # if not weights_only
                CHECKPOINT_HYPER_PARAMS_NAME:
                CHECKPOINT_HYPER_PARAMS_KEY:
                CHECKPOINT_HYPER_PARAMS_TYPE:
                something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
                LightningDataModule.__class__.__qualname__: pl DataModule's state
            }

        """
        trainer = self.trainer
        model = trainer.lightning_module
        datamodule = trainer.datamodule

        checkpoint = {
            # the epoch and global step are saved for compatibility but they are not relevant for restoration
            "epoch": trainer.current_epoch,
            "global_step": trainer.global_step,
            "pytorch-lightning_version": pl.__version__,
            "state_dict": self._get_lightning_module_state_dict(),
            "loops": self._get_loops_state_dict(),
        }

        if not weights_only:
            # dump callbacks
            checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer)

            optimizer_states = []
            for i, optimizer in enumerate(trainer.optimizers):
                # Rely on accelerator to dump optimizer state
                optimizer_state = trainer.strategy.optimizer_state(optimizer)
                optimizer_states.append(optimizer_state)

            checkpoint["optimizer_states"] = optimizer_states

            # dump lr schedulers
            lr_schedulers = []
            for config in trainer.lr_scheduler_configs:
                lr_schedulers.append(config.scheduler.state_dict())
            checkpoint["lr_schedulers"] = lr_schedulers

            # precision plugin
            prec_plugin = trainer.precision_plugin
            prec_plugin_state_dict = prec_plugin.state_dict()
            if prec_plugin_state_dict:
                checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
            prec_plugin.on_save_checkpoint(checkpoint)

        # dump hyper-parameters
        for obj in (model, datamodule):
            if obj and obj.hparams:
                if hasattr(obj, "_hparams_name"):
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name
                # dump arguments
                if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container):
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams)
                else:
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams)

        # dump stateful datamodule
        if datamodule is not None:
            datamodule_state_dict = call._call_lightning_datamodule_hook(trainer, "state_dict")
            if datamodule_state_dict:
                checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict

        # on_save_checkpoint hooks
        if not weights_only:
            # if state is returned from callback's on_save_checkpoint
            # it overrides the returned state from callback's state_dict
            # support for returning state in on_save_checkpoint
            # will be removed in v1.8
            call._call_callbacks_on_save_checkpoint(trainer, checkpoint)
        call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint)
        return checkpoint

위 코드를 분석하여 아래의 사실을 확인할 수 있었다.

  1. 코드 제일 위쪽에서 checkpoint dict가 선언되면서 epoch, global_step, pytorch-lightning_version, state_dict, loops가 계산된다.
  2. weights_only option이 False일 경우 callbacks, optimizer_states, lr_schedulers, precision_plugin 등 또한 checkpoint dict에 넘겨진다.
  3. pl.LightningModule, pl.LightningDataModule의 hparams또한 해당 object의 CHECKPOINT_HYPER_PARAMS_{}.fotmat(NAME | KEY | TYPE)의 key에 기록된다
  4. lightning.pytorch.trainer.call.py_call_callbacks_on_save_checkpoint로 trainer에 등록된 checkpoint들의 on_save_checkpoint mehtod를 전부 수행해준다.
  5. lightning.pytorch.trainer.call.py_call_lightning_module_hook으로 lightningmodule에서 CheckpointHooks를 상속받아서 가지고 있는 on_save_checkpoint method을 override해서 저장할 내역을 작성했다면 해당 custom 함수를 실행시켜준다.

현재까지의 분석으로

  1. ckpt의 state_dict는 _CheckpointConnector가 기본적으로 가지고 있는 _get_lightning_module_state_dict method로 얻어온다.
    • _CheckpointConnector는pl.Trainer`의 인자로 줄 수 없다. 상속해서 생성자에서 건드려야한다.
  2. lightningmodule에서 정의한 on_save_checkpoint는 제일 마지막에 수행된다. 유사하게 callbacks들의 on_save_checkpoint또한 마지막에 수행된다.

이를 hf-style로 구현할 방법으로는,

  1. pl.Trainer를 상속한 HuggingfaceTrainer를 만들고 _CheckpointConnector의 dump_checkpoint를 수정한 객체를 생성자에 주입한다.
  2. pl.Trainerself._checkpoint_connectordump_checkpoint를 wrapping하여 원하는 동작을 수행하도록 수정한다.
  3. lightningmodule의 on_save_checkpoint method에서 기존 checkpoint의 state_dict를 pop하고 self.model.save_pretrained를 수행해준다. checkpoint에는 path만 넘겨준다.

어떤 방식으로 구현할지는 load_checkpoint를 어떻게 수행하는지에 달렸다.

jinmang2 commented 1 year ago

How to load checkpoint in lightning?

chatgpt에게 lightning에서 checkpoint를 loading하는 방법들에 대해 물어봤다.

  1. load_from_checkpoint 클래스 메서드 사용
    • 이 방법은 LightningModule에서 제공하는 클래스 메서드로, 체크포인트 파일 경로를 직접 지정하여 모델을 로드할 수 있음.
      model = MyModel.load_from_checkpoint(checkpoint_path="path/to/checkpoint.ckpt")
  2. Trainerresume_from_checkpoint 매개변수 사용
    • 학습을 이어서 진행하려는 경우, Trainer 객체를 생성할 때, resume_from_checkpoint 매개변수를 사용하여 checkpoint file path를 지정할 수 있음
      trainer = pl.Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
      trainer.fit(model, dataloader)
  3. ModelCheckpoint 콜백과 함께 사용
    • ModelCheckpoint 콜백을 사용하면, 최상의 체크포인트나 최근 체크포인트 등을 자동으로 관리할 수 있으며, 이를 통해 나중에 쉽게 로드할 수 있음
      
      checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
      trainer = pl.Trainer(callbacks=[checkpoint_callback])
      trainer.fit(model, dataloader)

나중에 최상의 체크포인트 로드

best_model_path = checkpoint_callback.best_model_path model = MyModel.load_from_checkpoint(best_model_path)

4. 수동으로 `torch.load` 사용
  - PyTorch Lightning의 체크포인트는 사전 형식으로 저장되므로, 필요한 경우 PyTorch의 기본 `torch.load` 함수를 사용하여 수동으로 로드할 수도 있음
```python
checkpoint = torch.load("path/to/checkpoint.ckpt")
model.load_state_dict(checkpoint['state_dict'])

재차 chatgpt에게 물어본 결과, pl 1.x 버전에서 동작하는 방법들이라고 언급했기에 2.x에서의 사용 여부와 docs를 추가적으로 살펴보며 분석하고자 한다.

jinmang2 commented 1 year ago

우선, 3번의 resume_from_checkpoint의 경우, lightning의 issue 9501에서 deprecated되었고 fit의 ckpt_path를 통해 제어하도록 수정된 것을 확인할 수 있었다.

Issue 9006에서 Trainer의 arguments를 최소화하고자 수정을 한 것으로 보인다.

위를 다시 정리하면, 1. naive하게 torch.load를 사용해서 직접 custom 2. LightningModuleload_from_checkpoint 메서드 활용 3. pl.Trainerfit method의 ckpt_path 인자를 넣어서. 이렇게 총 세 가지의 방식을 활용할 수 있다.

본 레포는 lightning에 대해 분석하고 활용할 수 있는 능력을 기르는 것에도 목적이 있기에 torch.load를 활용해서 직접 구현하는 방법은 지양하고 2와 3에 대해 source code 동작을 뜯어보고자 한다.

jinmang2 commented 1 year ago

load_from_checkpoint

pl.LightningModuleload_from_checkpoint classmethod는 매우 단순하다.

# lightning.pytorch.core.module.py
...
from typing import cast
...
from typing_extensions import Self
...
from lightning.pytorch.core.saving import _load_from_checkpoint
...

class LightningModule(
    _DeviceDtypeModuleMixin,
    HyperparametersMixin,
    ModelHooks,
    DataHooks,
    CheckpointHooks,
    Module,
):
    ...
    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        hparams_file: Optional[_PATH] = None,
        strict: bool = True,
        **kwargs: Any,
    ) -> Self:
        loaded = _load_from_checkpoint(
            cls,
            checkpoint_path,
            map_location,
            hparams_file,
            strict,
            **kwargs,
        )
        return cast(Self, loaded)

_load_from_checkpoint는 아래와 같다.

# lightning.pytorch.core.saving.py
def _load_from_checkpoint(
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
    checkpoint_path: Union[_PATH, IO],
    map_location: _MAP_LOCATION_TYPE = None,
    hparams_file: Optional[_PATH] = None,
    strict: Optional[bool] = None,
    **kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
    map_location = map_location or _default_map_location
    with pl_legacy_patch():
        checkpoint = pl_load(checkpoint_path, map_location=map_location)

    # convert legacy checkpoints to the new format
    checkpoint = _pl_migrate_checkpoint(
        checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
    )

    if hparams_file is not None:
        extension = str(hparams_file).split(".")[-1]
        if extension.lower() == "csv":
            hparams = load_hparams_from_tags_csv(hparams_file)
        elif extension.lower() in ("yml", "yaml"):
            hparams = load_hparams_from_yaml(hparams_file)
        else:
            raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")

        # overwrite hparams by the given file
        checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

    # TODO: make this a migration:
    # for past checkpoint need to add the new key
    checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
    # override the hparams with values that were passed in
    checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)

    if issubclass(cls, pl.LightningDataModule):
        return _load_state(cls, checkpoint, **kwargs)
    if issubclass(cls, pl.LightningModule):
        model = _load_state(cls, checkpoint, strict=strict, **kwargs)
        state_dict = checkpoint["state_dict"]
        if not state_dict:
            rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
            return model

        device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
        assert isinstance(model, pl.LightningModule)
        return model.to(device)

    raise NotImplementedError(f"Unsupported {cls}")

코드가 길어보이지만 실상은 첫 줄의 pl_load함수로 checkpoint를 호출하고(pl_legacy_patch로 이전 버전에서 저장된 checkpoint를 현재 버전에서도 load할 수 있도록하는 patch를 적용) legacy ckpt일 경우를 대비해 new format으로 바꿔주는 _pl_migrate_checkpoint 함수를 적용한 다음, 만일 hparams_file이 입력으로 들어왔다면 이를 읽고 checkpoint에 추가, default key를 설정해주고 _load_state 함수를 사용하여 datamodule/lightningmodule을 instantiate하고 이를 반환해주는 단순한 코드이다.

즉, 첫 줄의 pl_load가 제일 중요한 부분이다. (실제로 checkpoint가 불러져오는 부분임)

하나씩 분석하자. 우선 pl_load에서 어떻게 checkpoint file을 불러오는지 확인하자

# lightning.fabric.utilities.cloud_io.py
def _load(
    path_or_url: Union[IO, _PATH],
    map_location: _MAP_LOCATION_TYPE = None,
) -> Any:
    """Loads a checkpoint.

    Args:
        path_or_url: Path or URL of the checkpoint.
        map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.

    """
    if not isinstance(path_or_url, (str, Path)):
        # any sort of BytesIO or similar
        return torch.load(
            path_or_url,
            map_location=map_location,  # type: ignore[arg-type] # upstream annotation is not correct
        )
    if str(path_or_url).startswith("http"):
        return torch.hub.load_state_dict_from_url(
            str(path_or_url),
            map_location=map_location,  # type: ignore[arg-type]
        )
    fs = get_filesystem(path_or_url)
    with fs.open(path_or_url, "rb") as f:
        return torch.load(f, map_location=map_location)  # type: ignore[arg-type]

torch.load는 f로 file-like object(read, readline, tell, seek method가 구현되어있는) 혹은 str/os.PathLike object(file name을 포함하는)을 받는다. pl_load는 file-like object인 경우 바로 torch.load에 태워주고 hub에서 받아오도록 path가 http로 시작하면 torch.hub.load_state_dict_from_url 메서드로 state_dict를 받아온다. 마지막으로 fsspec의 url_to_fs 함수를 사용하여 protocol을 판별하고 이를 context manager 역할을 수행할 수 있는 filesystem class로 반환하고 해당 path의 file을 열고 torch.load로 state_dict를 호출해준다. pl_legact_patch는 위에서 언급했듯이 old checkpoint에는 존재하나 현재는 사용하지 않는 legacy artifacts를 잠시 register하고 load가 끝난 후에는 해제해주는 context manager이다.

언급했듯 _loadtorch.load를 수행하는 그 이상도 이하도 아니며 pl_legacy_checkpoint를 활용하여 이전 버전의 checkpoint도 문제없이 호출할 수 있도록 코드를 수행했지만 현재 사용하는 version의 format으로 맞춰줄 필요가 있다. 이를 migrate_checkpoint 메서드가 수행해준다. (lightning.pytorch.utilities.migration.utils.py 참고)

torch.load로 in-memory에 checkpoint 파일을 부르고 LightningModule로 checkpoint 중 state_dict던지 필요한 부분을 불러야할 필요가 있다. 이를 lightning에서는 _load_state 함수에서 처리하고 있고 아래와 같이 동작한다.

# lightning.pytorch.core.saving.py
def _load_state(
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
    checkpoint: Dict[str, Any],
    strict: Optional[bool] = None,
    **cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
    cls_spec = inspect.getfullargspec(cls.__init__)
    cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()

    self_var, args_var, kwargs_var = parse_class_init_keys(cls)
    drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
    cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))

    cls_kwargs_loaded = {}
    # pass in the values we saved automatically
    if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
        if issubclass(cls, pl.LightningModule):
            # TODO: make this a migration:
            # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
            for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
                cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

        # 2. Try to restore model hparams from checkpoint using the new key
        cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))

        # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
        cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

        # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
        args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
        if args_name and args_name in cls_init_args_name:
            cls_kwargs_loaded = {args_name: cls_kwargs_loaded}

    _cls_kwargs = {}
    _cls_kwargs.update(cls_kwargs_loaded)
    _cls_kwargs.update(cls_kwargs_new)

    if not cls_spec.varkw:
        # filter kwargs according to class init unless it allows any argument via kwargs
        _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}

    obj = cls(**_cls_kwargs)

    if isinstance(obj, pl.LightningModule):
        # give model a chance to load something
        obj.on_load_checkpoint(checkpoint)

    if isinstance(obj, pl.LightningDataModule):
        if obj.__class__.__qualname__ in checkpoint:
            obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
        return obj

    # load the state_dict on the model automatically
    assert strict is not None
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)

    if not strict:
        if keys.missing_keys:
            rank_zero_warn(
                f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
            )
        if keys.unexpected_keys:
            rank_zero_warn(
                f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
            )

    return obj

LightningModule 혹은 LightningDataModule의 생성자의 arguments들을 파이썬의 기본 내장 inspect 라이브러리를 활용하여 cls_init_args_name을 얻는다.

LightningModuleHyperparametersMixin을 상속받아 save_hyperparameters method를 사용할 수 있다. 생성자에 주어진 args/kwargs를 self.hparams에 등록하는데 이는 자동으로 checkpoint에 저장되고 CHECKPOINT_HYPER_PARAMS_{}.format(KEY|TYPE|NAME)을 통해 접근할 수 있다. 중간 부분은 해당 부분을 불러오는 부분이다. (상세 동작 추후 확인). 이는 LightningModule을 호출할 때 필요한 인자이기 때문에 _cls_kwargs에 할당한 다음 object를 instantiate한다. LightningDataModule의 경우 해당 객체의 load_state_dict 메서드를 수행하고 끝나고 현재 관심사가 아니기 때문에 LightningModule만 살펴본다.

LightningModule object를 생성하고 _load_state에선 아래 두 method를 수행한다.

# give model a chance to load something
obj.on_load_checkpoint(checkpoint)

# load the state_dict on the model automatically
assert strict is not None
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)

on_load_checkpoint는 user custom 함수로 checkpoint 파일에서 받아올 부분을 직접 작성하여 어떤 key를 받아와서 객체에 저장할 지 결정할 수 있다. load_state_dictLightningModulenn.Module을 상속받기 때문에 torch의 Module의 load_state_dict method를 그대로 수행해서 model에 넘겨준다.

최종적으로 _load_from_checkpoint 함수에서는 checkpoint의 state_dict의 device로 lightning module의 device를 할당해주고 반환한다.

jinmang2 commented 1 year ago

lightning.pytorch.Trainer(...).fit(model=model, ckpt_path={YOUR_CKPT_PATH})

pl.Trainer.fit의 동작은 추상화되어있어 참으로 단순해보인다.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def fit(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        model = _maybe_unwrap_optimized(model)
        self.strategy._lightning_module = model
        _verify_strategy_supports_compile(model, self.strategy)
        call._call_and_handle_interrupt(
            self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        )

앞의 세 줄은 model/strategy를 wrapping하고 verify하는 부분이고 실제로 중요한 부분은 마지막 줄이다.

lightning.pytorch.trainer.call.py_call_and_handle_interrupt 함수는 pl.Trainer의 main entry point인 fit, validate, test, predict 함수에 대한 error handling을 위해 설계된 함수이다. trainer.strategy.launcher가 있으면 해당 launcher를 사용하여 trainer_fn을 실행하고 그렇지 않으면 trainer_fn을 직접 호출한다. 에러는 아래의 3 종류에 따라 처리한다.

즉, 중요한 부분은 self._fit_impl.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def _fit_impl(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        log.debug(f"{self.__class__.__name__}: trainer fit stage")

        self.state.fn = TrainerFn.FITTING
        self.state.status = TrainerStatus.RUNNING
        self.training = True

        # if a datamodule comes in as the second arg, then fix it for the user
        if isinstance(train_dataloaders, LightningDataModule):
            datamodule = train_dataloaders
            train_dataloaders = None
        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
            raise MisconfigurationException(
                "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
            )

        # links data to the trainer
        self._data_connector.attach_data(
            model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
        )

        ckpt_path = self._checkpoint_connector._select_ckpt_path(
            self.state.fn,
            ckpt_path,
            model_provided=True,
            model_connected=self.lightning_module is not None,
        )
        self._run(model, ckpt_path=ckpt_path)

        assert self.state.stopped
        self.training = False
        return

코드를 보면 training setup을 수행하고 self._run을 통해 실질적인 동작을 수행하는 것을 확인할 수 있다. _run을 확인하기 전에 _select_ckpt_path가 어떤 동작을 수행하는지 확인해보자.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    @_defaults_from_env_vars
    def __init__(
        self,
        ...
    ) -> None:
        ...
        self._checkpoint_connector = _CheckpointConnector(self)
        ...
    ...
    @property
    def ckpt_path(self) -> Optional[_PATH]:
        """Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`,
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`,
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
        return self._checkpoint_connector._ckpt_path

    @ckpt_path.setter
    def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
        """Allows you to manage which checkpoint is loaded statefully.

        .. code-block:: python

            trainer = Trainer()
            trainer.ckpt_path = "my/checkpoint/file.ckpt"
            trainer.fit(model)
            ...

            # you will be in charge of resetting this
            trainer.ckpt_path = None
            trainer.test(model)

        """
        self._checkpoint_connector._ckpt_path = ckpt_path
        self._checkpoint_connector._user_managed = bool(ckpt_path)
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def _select_ckpt_path(
        self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
    ) -> Optional[_PATH]:
        """Called by the ``Trainer`` to select the checkpoint path source."""
        if self._user_managed:
            if ckpt_path:
                rank_zero_warn(
                    f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
                    f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
                )
                # reset the previous path
                self._ckpt_path = None
                self._user_managed = False
                ckpt_path = self._parse_ckpt_path(
                    state_fn,
                    ckpt_path,
                    model_provided=model_provided,
                    model_connected=model_connected,
                )
            else:
                ckpt_path = self._ckpt_path
        else:
            ckpt_path = self._parse_ckpt_path(
                state_fn,
                ckpt_path,
                model_provided=model_provided,
                model_connected=model_connected,
            )
        return ckpt_path

    def _parse_ckpt_path(
        self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
    ) -> Optional[_PATH]:
        """Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
        configuration."""
        if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
            ckpt_path = "hpc"

        from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint

        ft_checkpoints = [cb for cb in self.trainer.callbacks if isinstance(cb, OnExceptionCheckpoint)]
        fn = state_fn.value
        if ckpt_path is None and ft_checkpoints and self.trainer.state.fn == TrainerFn.FITTING:
            ckpt_path = "last"
            rank_zero_warn(
                f"`.{fn}(ckpt_path=None)` was called without a model."
                " The last model of the previous `fit` call will be used."
                f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
                f" `{fn}(ckpt_path='last')` to use the last model."
                " If you pass a value, this warning will be silenced."
            )

        if model_provided and ckpt_path is None:
            # use passed model to function without loading weights
            return None

        if model_connected and ckpt_path is None:
            ckpt_path = "best"
            ft_tip = (
                " There is also an on-exception checkpoint available, however it is used by default only when fitting."
                if ft_checkpoints
                else ""
            )
            rank_zero_warn(
                f"`.{fn}(ckpt_path=None)` was called without a model."
                " The best model of the previous `fit` call will be used."
                + ft_tip
                + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
                f" `.{fn}(ckpt_path='last')` to use the last model."
                " If you pass a value, this warning will be silenced."
            )

        if ckpt_path == "best":
            if len(self.trainer.checkpoint_callbacks) > 1:
                rank_zero_warn(
                    f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
                    " callbacks. It will use the best checkpoint path from first checkpoint callback."
                )

            if not self.trainer.checkpoint_callback:
                raise ValueError(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.')

            has_best_model_path = self.trainer.checkpoint_callback.best_model_path
            if hasattr(self.trainer.checkpoint_callback, "best_model_path") and not has_best_model_path:
                if self.trainer.fast_dev_run:
                    raise ValueError(
                        f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
                        f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
                    )
                raise ValueError(
                    f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
                )
            # load best weights
            ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None)

        elif ckpt_path == "last":
            candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints}
            for callback in self.trainer.checkpoint_callbacks:
                if isinstance(callback, ModelCheckpoint):
                    candidates |= callback._find_last_checkpoints(self.trainer)
            candidates_fs = {path: get_filesystem(path) for path in candidates if path}
            candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
            if not candidates_ts:
                # not an error so it can be set and forget before the first `fit` run
                rank_zero_warn(
                    f'.{fn}(ckpt_path="last") is set, but there is no last checkpoint available.'
                    " No checkpoint will be loaded."
                )
                return None
            ckpt_path = max(candidates_ts, key=candidates_ts.get)  # type: ignore[arg-type]

        elif ckpt_path == "hpc":
            if not self._hpc_resume_path:
                raise ValueError(
                    f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
                    " Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
                )
            ckpt_path = self._hpc_resume_path

        if not ckpt_path:
            raise ValueError(
                f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
                f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
            )
        return ckpt_path

trainer._checkpoint_connector_user_managed attr이 무엇인지 궁금했는데 pl.Trainer에서 용도를 알 수 있었다. trainer.ckpt_path = {YOUR_CKPT_PATH}와 같이 할당할 경우 user가 직접 ckpt path를 관리한다는 의미로 _user_managed가 True로 설정됨과 동시에 _checkpoint_connector의 ckpt_path에 할당한다. _user_managed 옵션이 True인데도 self.fit 함수에 ckpt_path가 입력됐을 경우 원 상태로 복구하고 warning을 띄우고 _parse_ckpt_path method를 실행한다. 아닐 경우에도 _parse_ckpt_path를 실행하여 ckpt_path를 select하고 만일 trainer.fit에 ckpt_path가 입력되지 않았고 _user_managed 옵션도 False일 경우 trainer._ckpt_path를 select한다.

_parse_ckpt_pathckpt_path를 trainer configuration에 depending된 actual filepath로 변환하는 CheckpointConnector의 method. 만일 특정 ckpt_path를 입력했다면 아무런 동작도 하지않고 그대로 반환해주며 ckpt_path가 None이면 trainer의 status, callbacks 등에 따라 hpc, last, best를 할당하고 이 후 각자 설정에 맞는 checkpoint path를 가져온다.

trainer._run은 짧게 필요한 부분만 담고자 한다.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def _run(
        self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
    ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
        ...
        # ----------------------------
        # SET UP THE TRAINER
        # ----------------------------
        ...
        # check if we should delay restoring checkpoint till later
        if not self.strategy.restore_checkpoint_after_setup:
            log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
            self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
        ... setup ...
        if self.strategy.restore_checkpoint_after_setup:
            log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
            self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
        ...
        return results

trainer.strategy.restore_checkpoint_after_setup attr에 따라 setup 전에 restore checkpoint를 수행하냐 혹은 후에 하느냐 차이.

# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
        """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

        1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
        2. from fault-tolerant auto-saved checkpoint if found
        3. from `checkpoint_path` file if provided
        4. don't restore

        """
        self._ckpt_path = checkpoint_path
        if not checkpoint_path:
            log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.")
            return

        rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
        with pl_legacy_patch():
            loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
        self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
    ...
    def restore_model(self) -> None:
        """Restores a model's weights from a PyTorch Lightning checkpoint.

        Hooks are called first to give the LightningModule a chance to modify the contents, then finally the model gets
        updated with the loaded weights.

        """
        if not self._loaded_checkpoint:
            return

        trainer = self.trainer
        # hook: give user access to checkpoint if needed.
        call._call_lightning_module_hook(trainer, "on_load_checkpoint", self._loaded_checkpoint)

        # restore model state_dict
        trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
    ...
    def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
        # restore modules after setup
        self.resume_start(checkpoint_path)
        self.restore_model()
        self.restore_datamodule()
        if self.trainer.state.fn == TrainerFn.FITTING:
            # restore callback states
            self.restore_callbacks()

# lightning.pytorch.strategies.strategy.py
class Strategy(ABC):
    ...
    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
        torch.cuda.empty_cache()
        return self.checkpoint_io.load_checkpoint(checkpoint_path)

    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        assert self.lightning_module is not None
        self.lightning_module.load_state_dict(checkpoint["state_dict"])

동작은 앞서 load_from_checkpoint에서 본 과정과 동일하나 차이점은 trainer.strategy가 가지고 있는 load_checkpointload_model_state_dict method를 사용한다는 점이다. 이는 예를 들어 deepspeed strategy의 경우 아래와 같이 override하기에 이렇게 수정한 것으로 보인다.

# lightning.pytorch.strategies.deepspeed.py
class DeepSpeedStrategy(DDPStrategy):
    strategy_name = "deepspeed"
    DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
    ...
    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
        if self.load_full_weights and self.zero_stage_3:
            # Broadcast to ensure we load from the rank 0 checkpoint
            # This doesn't have to be the case when using deepspeed sharded checkpointing
            checkpoint_path = self.broadcast(checkpoint_path)
            return super().load_checkpoint(checkpoint_path)

        _validate_checkpoint_directory(checkpoint_path)

        # Rely on deepspeed to load the checkpoint and necessary information
        assert self.lightning_module is not None

        from lightning.pytorch.trainer.states import TrainerFn

        is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING

        _, client_state = self.deepspeed_engine.load_checkpoint(
            checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=False
        )
        if client_state is None:
            raise MisconfigurationException(
                "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
                "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
            )
        return client_state
    ...
    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
        if self.load_full_weights and self.zero_stage_3:
            self.model_to_device()
            self._restore_zero_state(checkpoint)