Open jinmang2 opened 1 year ago
Note that: 아래 코드 분석은 lightning 2.0.7 버전을 기준으로 분석한 결과이기 때문에 차후 경로 및 소스코드에 변경이 있을 수 있음.
보통 학습이 종료된 이후 pl.Trainer
의 save_checkpoint
method를 활용하여 저장하거나 callbacks에 pl.callbacks.ModelCheckpoint
를 활용해서 저장하게 된다.
pl.callbacks.ModelCheckpoint
에선 last/non-monitor/update-best 등 세팅된 인자를 기반으로 어느 시점이든 checkpoint를 저장하게 되는데, 전부 공통적으로 아래 method를 활용한다.
# lightning.pytorch.callbacks.model_checkpoint.py
class ModelCheckpoint(Checkpoint):
...
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)
self._last_global_step_saved = trainer.global_step
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
즉, 중요한 것은 pl.Trainer
의 save_checkpoint
method. 이 또한 뜯어보면,
# lightning.pytorch.trainer.trainer.py
class Trainer:
...
def save_checkpoint(
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
) -> None:
r"""Runs routine to create a checkpoint.
Args:
filepath: Path where checkpoint is saved.
weights_only: If ``True``, will only save the model weights.
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
Raises:
AttributeError:
If the model is not attached to the Trainer before calling this method.
"""
if self.model is None:
raise AttributeError(
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
)
checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
self.strategy.barrier("Trainer.save_checkpoint")
여기서 self.strategy
는 ddp 등의 분산학습 전략을 의미하는 것으로 보이고 실제로 checkpoint를 받아오는 것은 self._checkpoint_connector
의 dump_checkpoint
method로 확인된다. 해당 코드에서도 중요한 부분만 추출하여 분석하면,
# lightning.pytorch.trainer.trainer.py
class Trainer:
@_defaults_from_env_vars
def __init__(
self,
...
) -> None:
...
self._checkpoint_connector = _CheckpointConnector(self)
...
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
...
def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Args:
weights_only: saving model weights only
Return:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
'state_dict': Model's state_dict (e.g. network weights)
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
CHECKPOINT_HYPER_PARAMS_NAME:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__qualname__: pl DataModule's state
}
"""
trainer = self.trainer
model = trainer.lightning_module
datamodule = trainer.datamodule
checkpoint = {
# the epoch and global step are saved for compatibility but they are not relevant for restoration
"epoch": trainer.current_epoch,
"global_step": trainer.global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
}
if not weights_only:
# dump callbacks
checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer)
optimizer_states = []
for i, optimizer in enumerate(trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = trainer.strategy.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint["optimizer_states"] = optimizer_states
# dump lr schedulers
lr_schedulers = []
for config in trainer.lr_scheduler_configs:
lr_schedulers.append(config.scheduler.state_dict())
checkpoint["lr_schedulers"] = lr_schedulers
# precision plugin
prec_plugin = trainer.precision_plugin
prec_plugin_state_dict = prec_plugin.state_dict()
if prec_plugin_state_dict:
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
prec_plugin.on_save_checkpoint(checkpoint)
# dump hyper-parameters
for obj in (model, datamodule):
if obj and obj.hparams:
if hasattr(obj, "_hparams_name"):
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name
# dump arguments
if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container):
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams)
else:
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams)
# dump stateful datamodule
if datamodule is not None:
datamodule_state_dict = call._call_lightning_datamodule_hook(trainer, "state_dict")
if datamodule_state_dict:
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict
# on_save_checkpoint hooks
if not weights_only:
# if state is returned from callback's on_save_checkpoint
# it overrides the returned state from callback's state_dict
# support for returning state in on_save_checkpoint
# will be removed in v1.8
call._call_callbacks_on_save_checkpoint(trainer, checkpoint)
call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint)
return checkpoint
위 코드를 분석하여 아래의 사실을 확인할 수 있었다.
epoch
, global_step
, pytorch-lightning_version
, state_dict
, loops
가 계산된다.callbacks
, optimizer_states
, lr_schedulers
, precision_plugin 등 또한 checkpoint dict에 넘겨진다.pl.LightningModule
, pl.LightningDataModule
의 hparams또한 해당 object의 CHECKPOINT_HYPER_PARAMS_{}.fotmat(NAME | KEY | TYPE)
의 key에 기록된다lightning.pytorch.trainer.call.py
의 _call_callbacks_on_save_checkpoint
로 trainer에 등록된 checkpoint들의 on_save_checkpoint
mehtod를 전부 수행해준다.lightning.pytorch.trainer.call.py
의 _call_lightning_module_hook
으로 lightningmodule에서 CheckpointHooks
를 상속받아서 가지고 있는 on_save_checkpoint
method을 override해서 저장할 내역을 작성했다면 해당 custom 함수를 실행시켜준다.현재까지의 분석으로
_CheckpointConnector
가 기본적으로 가지고 있는 _get_lightning_module_state_dict
method로 얻어온다.
_CheckpointConnector는
pl.Trainer`의 인자로 줄 수 없다. 상속해서 생성자에서 건드려야한다.on_save_checkpoint
는 제일 마지막에 수행된다. 유사하게 callbacks들의 on_save_checkpoint
또한 마지막에 수행된다.이를 hf-style로 구현할 방법으로는,
pl.Trainer
를 상속한 HuggingfaceTrainer
를 만들고 _CheckpointConnector
의 dump_checkpoint를 수정한 객체를 생성자에 주입한다.pl.Trainer
의 self._checkpoint_connector
의 dump_checkpoint
를 wrapping하여 원하는 동작을 수행하도록 수정한다.on_save_checkpoint
method에서 기존 checkpoint의 state_dict
를 pop하고 self.model.save_pretrained
를 수행해준다. checkpoint에는 path만 넘겨준다.어떤 방식으로 구현할지는 load_checkpoint
를 어떻게 수행하는지에 달렸다.
chatgpt에게 lightning에서 checkpoint를 loading하는 방법들에 대해 물어봤다.
load_from_checkpoint
클래스 메서드 사용
LightningModule
에서 제공하는 클래스 메서드로, 체크포인트 파일 경로를 직접 지정하여 모델을 로드할 수 있음.
model = MyModel.load_from_checkpoint(checkpoint_path="path/to/checkpoint.ckpt")
Trainer
의 resume_from_checkpoint
매개변수 사용
Trainer
객체를 생성할 때, resume_from_checkpoint
매개변수를 사용하여 checkpoint file path를 지정할 수 있음
trainer = pl.Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
trainer.fit(model, dataloader)
ModelCheckpoint
콜백과 함께 사용
ModelCheckpoint
콜백을 사용하면, 최상의 체크포인트나 최근 체크포인트 등을 자동으로 관리할 수 있으며, 이를 통해 나중에 쉽게 로드할 수 있음
checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, dataloader)
best_model_path = checkpoint_callback.best_model_path model = MyModel.load_from_checkpoint(best_model_path)
4. 수동으로 `torch.load` 사용
- PyTorch Lightning의 체크포인트는 사전 형식으로 저장되므로, 필요한 경우 PyTorch의 기본 `torch.load` 함수를 사용하여 수동으로 로드할 수도 있음
```python
checkpoint = torch.load("path/to/checkpoint.ckpt")
model.load_state_dict(checkpoint['state_dict'])
재차 chatgpt에게 물어본 결과, pl 1.x 버전에서 동작하는 방법들이라고 언급했기에 2.x에서의 사용 여부와 docs를 추가적으로 살펴보며 분석하고자 한다.
우선, 3번의 resume_from_checkpoint
의 경우, lightning의 issue 9501에서 deprecated되었고 fit의 ckpt_path
를 통해 제어하도록 수정된 것을 확인할 수 있었다.
Issue 9006에서 Trainer의 arguments를 최소화하고자 수정을 한 것으로 보인다.
위를 다시 정리하면, 1. naive하게 torch.load
를 사용해서 직접 custom 2. LightningModule
의 load_from_checkpoint
메서드 활용 3. pl.Trainer
의 fit
method의 ckpt_path
인자를 넣어서. 이렇게 총 세 가지의 방식을 활용할 수 있다.
본 레포는 lightning
에 대해 분석하고 활용할 수 있는 능력을 기르는 것에도 목적이 있기에 torch.load
를 활용해서 직접 구현하는 방법은 지양하고 2와 3에 대해 source code 동작을 뜯어보고자 한다.
load_from_checkpoint
pl.LightningModule
의 load_from_checkpoint
classmethod는 매우 단순하다.
# lightning.pytorch.core.module.py
...
from typing import cast
...
from typing_extensions import Self
...
from lightning.pytorch.core.saving import _load_from_checkpoint
...
class LightningModule(
_DeviceDtypeModuleMixin,
HyperparametersMixin,
ModelHooks,
DataHooks,
CheckpointHooks,
Module,
):
...
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: bool = True,
**kwargs: Any,
) -> Self:
loaded = _load_from_checkpoint(
cls,
checkpoint_path,
map_location,
hparams_file,
strict,
**kwargs,
)
return cast(Self, loaded)
_load_from_checkpoint
는 아래와 같다.
# lightning.pytorch.core.saving.py
def _load_from_checkpoint(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
map_location = map_location or _default_map_location
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)
# convert legacy checkpoints to the new format
checkpoint = _pl_migrate_checkpoint(
checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
)
if hparams_file is not None:
extension = str(hparams_file).split(".")[-1]
if extension.lower() == "csv":
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ("yml", "yaml"):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
# TODO: make this a migration:
# for past checkpoint need to add the new key
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
if issubclass(cls, pl.LightningDataModule):
return _load_state(cls, checkpoint, **kwargs)
if issubclass(cls, pl.LightningModule):
model = _load_state(cls, checkpoint, strict=strict, **kwargs)
state_dict = checkpoint["state_dict"]
if not state_dict:
rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
return model
device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
return model.to(device)
raise NotImplementedError(f"Unsupported {cls}")
코드가 길어보이지만 실상은 첫 줄의 pl_load
함수로 checkpoint를 호출하고(pl_legacy_patch
로 이전 버전에서 저장된 checkpoint를 현재 버전에서도 load할 수 있도록하는 patch를 적용) legacy ckpt일 경우를 대비해 new format으로 바꿔주는 _pl_migrate_checkpoint
함수를 적용한 다음, 만일 hparams_file
이 입력으로 들어왔다면 이를 읽고 checkpoint에 추가, default key를 설정해주고 _load_state
함수를 사용하여 datamodule/lightningmodule을 instantiate하고 이를 반환해주는 단순한 코드이다.
즉, 첫 줄의 pl_load
가 제일 중요한 부분이다. (실제로 checkpoint가 불러져오는 부분임)
하나씩 분석하자. 우선 pl_load
에서 어떻게 checkpoint file을 불러오는지 확인하자
# lightning.fabric.utilities.cloud_io.py
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
) -> Any:
"""Loads a checkpoint.
Args:
path_or_url: Path or URL of the checkpoint.
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
"""
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similar
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location) # type: ignore[arg-type]
torch.load
는 f로 file-like object(read
, readline
, tell
, seek
method가 구현되어있는) 혹은 str/os.PathLike object(file name을 포함하는)을 받는다. pl_load
는 file-like object인 경우 바로 torch.load
에 태워주고 hub에서 받아오도록 path가 http로 시작하면 torch.hub.load_state_dict_from_url
메서드로 state_dict를 받아온다. 마지막으로 fsspec의 url_to_fs
함수를 사용하여 protocol을 판별하고 이를 context manager 역할을 수행할 수 있는 filesystem class로 반환하고 해당 path의 file을 열고 torch.load
로 state_dict를 호출해준다. pl_legact_patch
는 위에서 언급했듯이 old checkpoint에는 존재하나 현재는 사용하지 않는 legacy artifacts를 잠시 register하고 load가 끝난 후에는 해제해주는 context manager이다.
언급했듯 _load
는 torch.load
를 수행하는 그 이상도 이하도 아니며 pl_legacy_checkpoint
를 활용하여 이전 버전의 checkpoint도 문제없이 호출할 수 있도록 코드를 수행했지만 현재 사용하는 version의 format으로 맞춰줄 필요가 있다. 이를 migrate_checkpoint
메서드가 수행해준다. (lightning.pytorch.utilities.migration.utils.py 참고)
torch.load
로 in-memory에 checkpoint 파일을 부르고 LightningModule
로 checkpoint 중 state_dict던지 필요한 부분을 불러야할 필요가 있다. 이를 lightning에서는 _load_state
함수에서 처리하고 있고 아래와 같이 동작한다.
# lightning.pytorch.core.saving.py
def _load_state(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: Optional[bool] = None,
**cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))
cls_kwargs_loaded = {}
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
if issubclass(cls, pl.LightningModule):
# TODO: make this a migration:
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
# 2. Try to restore model hparams from checkpoint using the new key
cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
if args_name and args_name in cls_init_args_name:
cls_kwargs_loaded = {args_name: cls_kwargs_loaded}
_cls_kwargs = {}
_cls_kwargs.update(cls_kwargs_loaded)
_cls_kwargs.update(cls_kwargs_new)
if not cls_spec.varkw:
# filter kwargs according to class init unless it allows any argument via kwargs
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
obj = cls(**_cls_kwargs)
if isinstance(obj, pl.LightningModule):
# give model a chance to load something
obj.on_load_checkpoint(checkpoint)
if isinstance(obj, pl.LightningDataModule):
if obj.__class__.__qualname__ in checkpoint:
obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
return obj
# load the state_dict on the model automatically
assert strict is not None
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
if not strict:
if keys.missing_keys:
rank_zero_warn(
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
)
if keys.unexpected_keys:
rank_zero_warn(
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
)
return obj
LightningModule
혹은 LightningDataModule
의 생성자의 arguments들을 파이썬의 기본 내장 inspect
라이브러리를 활용하여 cls_init_args_name
을 얻는다.
LightningModule
은 HyperparametersMixin
을 상속받아 save_hyperparameters
method를 사용할 수 있다. 생성자에 주어진 args/kwargs를 self.hparams에 등록하는데 이는 자동으로 checkpoint에 저장되고 CHECKPOINT_HYPER_PARAMS_{}.format(KEY|TYPE|NAME)
을 통해 접근할 수 있다. 중간 부분은 해당 부분을 불러오는 부분이다. (상세 동작 추후 확인). 이는 LightningModule
을 호출할 때 필요한 인자이기 때문에 _cls_kwargs에 할당한 다음 object를 instantiate한다. LightningDataModule
의 경우 해당 객체의 load_state_dict
메서드를 수행하고 끝나고 현재 관심사가 아니기 때문에 LightningModule
만 살펴본다.
LightningModule
object를 생성하고 _load_state
에선 아래 두 method를 수행한다.
# give model a chance to load something
obj.on_load_checkpoint(checkpoint)
# load the state_dict on the model automatically
assert strict is not None
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
on_load_checkpoint
는 user custom 함수로 checkpoint 파일에서 받아올 부분을 직접 작성하여 어떤 key를 받아와서 객체에 저장할 지 결정할 수 있다.
load_state_dict
는 LightningModule
이 nn.Module
을 상속받기 때문에 torch의 Module의 load_state_dict
method를 그대로 수행해서 model에 넘겨준다.
최종적으로 _load_from_checkpoint
함수에서는 checkpoint의 state_dict의 device로 lightning module의 device를 할당해주고 반환한다.
lightning.pytorch.Trainer(...).fit(model=model, ckpt_path={YOUR_CKPT_PATH})
pl.Trainer.fit
의 동작은 추상화되어있어 참으로 단순해보인다.
# lightning.pytorch.trainer.trainer.py
class Trainer(...):
...
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
앞의 세 줄은 model/strategy를 wrapping하고 verify하는 부분이고 실제로 중요한 부분은 마지막 줄이다.
lightning.pytorch.trainer.call.py
의 _call_and_handle_interrupt
함수는 pl.Trainer
의 main entry point인 fit
, validate
, test
, predict
함수에 대한 error handling을 위해 설계된 함수이다. trainer.strategy.launcher
가 있으면 해당 launcher를 사용하여 trainer_fn을 실행하고 그렇지 않으면 trainer_fn을 직접 호출한다. 에러는 아래의 3 종류에 따라 처리한다.
_TunerExitException
: teardown 호출 후 trainer status를 FINISHED로 설정KeyboardInterrupt
: 사용자가 process를 중단하려고 하면 graceful shutdown 시도. 이후 trainer status를 INTERRUPTED로 설정
on_exception
을 호출하여 사용자가 중단한 경우에 설정을 저장하거나 등을 수행이 가능한 것으로 보임BaseException
: trainer status를 INTERRUPTED로 설정하고 logger로 failed 설정즉, 중요한 부분은 self._fit_impl
.
# lightning.pytorch.trainer.trainer.py
class Trainer(...):
...
def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
log.debug(f"{self.__class__.__name__}: trainer fit stage")
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
)
# links data to the trainer
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=ckpt_path)
assert self.state.stopped
self.training = False
return
코드를 보면 training setup을 수행하고 self._run
을 통해 실질적인 동작을 수행하는 것을 확인할 수 있다. _run
을 확인하기 전에 _select_ckpt_path
가 어떤 동작을 수행하는지 확인해보자.
# lightning.pytorch.trainer.trainer.py
class Trainer(...):
@_defaults_from_env_vars
def __init__(
self,
...
) -> None:
...
self._checkpoint_connector = _CheckpointConnector(self)
...
...
@property
def ckpt_path(self) -> Optional[_PATH]:
"""Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`,
:meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`,
:meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or
:meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
return self._checkpoint_connector._ckpt_path
@ckpt_path.setter
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
"""Allows you to manage which checkpoint is loaded statefully.
.. code-block:: python
trainer = Trainer()
trainer.ckpt_path = "my/checkpoint/file.ckpt"
trainer.fit(model)
...
# you will be in charge of resetting this
trainer.ckpt_path = None
trainer.test(model)
"""
self._checkpoint_connector._ckpt_path = ckpt_path
self._checkpoint_connector._user_managed = bool(ckpt_path)
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
...
def _select_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Called by the ``Trainer`` to select the checkpoint path source."""
if self._user_managed:
if ckpt_path:
rank_zero_warn(
f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
)
# reset the previous path
self._ckpt_path = None
self._user_managed = False
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
else:
ckpt_path = self._ckpt_path
else:
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
return ckpt_path
def _parse_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
configuration."""
if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
ckpt_path = "hpc"
from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
ft_checkpoints = [cb for cb in self.trainer.callbacks if isinstance(cb, OnExceptionCheckpoint)]
fn = state_fn.value
if ckpt_path is None and ft_checkpoints and self.trainer.state.fn == TrainerFn.FITTING:
ckpt_path = "last"
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model."
" The last model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
f" `{fn}(ckpt_path='last')` to use the last model."
" If you pass a value, this warning will be silenced."
)
if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return None
if model_connected and ckpt_path is None:
ckpt_path = "best"
ft_tip = (
" There is also an on-exception checkpoint available, however it is used by default only when fitting."
if ft_checkpoints
else ""
)
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model."
" The best model of the previous `fit` call will be used."
+ ft_tip
+ f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
f" `.{fn}(ckpt_path='last')` to use the last model."
" If you pass a value, this warning will be silenced."
)
if ckpt_path == "best":
if len(self.trainer.checkpoint_callbacks) > 1:
rank_zero_warn(
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
" callbacks. It will use the best checkpoint path from first checkpoint callback."
)
if not self.trainer.checkpoint_callback:
raise ValueError(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.')
has_best_model_path = self.trainer.checkpoint_callback.best_model_path
if hasattr(self.trainer.checkpoint_callback, "best_model_path") and not has_best_model_path:
if self.trainer.fast_dev_run:
raise ValueError(
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
raise ValueError(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None)
elif ckpt_path == "last":
candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints}
for callback in self.trainer.checkpoint_callbacks:
if isinstance(callback, ModelCheckpoint):
candidates |= callback._find_last_checkpoints(self.trainer)
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
if not candidates_ts:
# not an error so it can be set and forget before the first `fit` run
rank_zero_warn(
f'.{fn}(ckpt_path="last") is set, but there is no last checkpoint available.'
" No checkpoint will be loaded."
)
return None
ckpt_path = max(candidates_ts, key=candidates_ts.get) # type: ignore[arg-type]
elif ckpt_path == "hpc":
if not self._hpc_resume_path:
raise ValueError(
f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
ckpt_path = self._hpc_resume_path
if not ckpt_path:
raise ValueError(
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
)
return ckpt_path
trainer._checkpoint_connector
의 _user_managed
attr이 무엇인지 궁금했는데 pl.Trainer
에서 용도를 알 수 있었다. trainer.ckpt_path = {YOUR_CKPT_PATH}
와 같이 할당할 경우 user가 직접 ckpt path를 관리한다는 의미로 _user_managed
가 True로 설정됨과 동시에 _checkpoint_connector
의 ckpt_path에 할당한다. _user_managed
옵션이 True인데도 self.fit
함수에 ckpt_path
가 입력됐을 경우 원 상태로 복구하고 warning을 띄우고 _parse_ckpt_path
method를 실행한다. 아닐 경우에도 _parse_ckpt_path
를 실행하여 ckpt_path를 select하고 만일 trainer.fit
에 ckpt_path가 입력되지 않았고 _user_managed
옵션도 False일 경우 trainer._ckpt_path
를 select한다.
_parse_ckpt_path
는 ckpt_path
를 trainer configuration에 depending된 actual filepath로 변환하는 CheckpointConnector의 method. 만일 특정 ckpt_path를 입력했다면 아무런 동작도 하지않고 그대로 반환해주며 ckpt_path
가 None이면 trainer의 status, callbacks 등에 따라 hpc
, last
, best
를 할당하고 이 후 각자 설정에 맞는 checkpoint path를 가져온다.
trainer._run
은 짧게 필요한 부분만 담고자 한다.
# lightning.pytorch.trainer.trainer.py
class Trainer(...):
...
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
...
# ----------------------------
# SET UP THE TRAINER
# ----------------------------
...
# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
... setup ...
if self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
...
return results
trainer.strategy.restore_checkpoint_after_setup
attr에 따라 setup 전에 restore checkpoint를 수행하냐 혹은 후에 하느냐 차이.
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
...
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
2. from fault-tolerant auto-saved checkpoint if found
3. from `checkpoint_path` file if provided
4. don't restore
"""
self._ckpt_path = checkpoint_path
if not checkpoint_path:
log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.")
return
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
with pl_legacy_patch():
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
...
def restore_model(self) -> None:
"""Restores a model's weights from a PyTorch Lightning checkpoint.
Hooks are called first to give the LightningModule a chance to modify the contents, then finally the model gets
updated with the loaded weights.
"""
if not self._loaded_checkpoint:
return
trainer = self.trainer
# hook: give user access to checkpoint if needed.
call._call_lightning_module_hook(trainer, "on_load_checkpoint", self._loaded_checkpoint)
# restore model state_dict
trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
...
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self.resume_start(checkpoint_path)
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()
# lightning.pytorch.strategies.strategy.py
class Strategy(ABC):
...
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
assert self.lightning_module is not None
self.lightning_module.load_state_dict(checkpoint["state_dict"])
동작은 앞서 load_from_checkpoint
에서 본 과정과 동일하나 차이점은 trainer.strategy
가 가지고 있는 load_checkpoint
와 load_model_state_dict
method를 사용한다는 점이다. 이는 예를 들어 deepspeed strategy의 경우 아래와 같이 override하기에 이렇게 수정한 것으로 보인다.
# lightning.pytorch.strategies.deepspeed.py
class DeepSpeedStrategy(DDPStrategy):
strategy_name = "deepspeed"
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
...
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
return super().load_checkpoint(checkpoint_path)
_validate_checkpoint_directory(checkpoint_path)
# Rely on deepspeed to load the checkpoint and necessary information
assert self.lightning_module is not None
from lightning.pytorch.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
_, client_state = self.deepspeed_engine.load_checkpoint(
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=False
)
if client_state is None:
raise MisconfigurationException(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
"or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
)
return client_state
...
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)
목적