Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.15k stars 3.37k forks source link

Support saving and loading from remote paths in Fabric #19113

Open claudio-alanaai opened 10 months ago

claudio-alanaai commented 10 months ago

Bug description

Hello everyone,

I am training a model using FSDP with Fabric.

When saving the model to an S3 bucket calling the following function:

    fabric.save(
        "s3://model-training-us/models/experiment/checkpoint",
        state
    )

I get the following error:

───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /vision-language-model/vlm/train_vlm_fabric.py:520 in <module>               │
│                                                                              │
│   517 │                                                                      │
│   518 │   cfg = Config(args)                                                 │
│   519 │                                                                      │
│ ❱ 520 │   main(cfg)                                                          │
│   521                                                                        │
│                                                                              │
│ /vision-language-model/vlm/train_vlm_fabric.py:387 in main                   │
│                                                                              │
│   384 │   │   "step_count": 0,                                               │
│   385 │   }                                                                  │
│   386 │                                                                      │
│ ❱ 387 │   fabric.save(                                                       │
│   388 │   │   "s3://model-training-us/models/experiment1/t │
│   389 │   │   state,                                                         │
│   390 │   )                                                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/fabric.py:738 in save        │
│                                                                              │
│    735 │   │   │   for k, v in filter.items():                               │
│    736 │   │   │   │   if not callable(v):                                   │
│    737 │   │   │   │   │   raise TypeError(f"Expected `fabric.save(filter=.. │
│ ❱  738 │   │   self._strategy.save_checkpoint(path=path, state=_unwrap_objec │
│    739 │   │   self.barrier()                                                │
│    740 │                                                                     │
│    741 │   def load(                                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/strategies/fsdp.py:498 in    │
│ save_checkpoint                                                              │
│                                                                              │
│   495 │   │   │   │   │   _apply_filter(key, filter or {}, converted, full_s │
│   496 │   │   │                                                              │
│   497 │   │   │   if self.global_rank == 0:                                  │
│ ❱ 498 │   │   │   │   torch.save(full_state, path)                           │
│   499 │   │   else:                                                          │
│   500 │   │   │   raise ValueError(f"Unknown state_dict_type: {self._state_d │
│   501                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:618 in save            │
│                                                                              │
│    615 │   _check_save_filelike(f)                                           │
│    616 │                                                                     │
│    617 │   if _use_new_zipfile_serialization:                                │
│ ❱  618 │   │   with _open_zipfile_writer(f) as opened_zipfile:               │
│    619 │   │   │   _save(obj, opened_zipfile, pickle_module, pickle_protocol │
│    620 │   │   │   return                                                    │
│    621 │   else:                                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:492 in                 │
│ _open_zipfile_writer                                                         │
│                                                                              │
│    489 │   │   container = _open_zipfile_writer_file                         │
│    490 │   else:                                                             │
│    491 │   │   container = _open_zipfile_writer_buffer                       │
│ ❱  492 │   return container(name_or_buffer)                                  │
│    493                                                                       │
│    494                                                                       │
│    495 def _is_compressed_file(f) -> bool:                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:463 in __init__        │
│                                                                              │
│    460 │   │   │   self.file_stream = io.FileIO(self.name, mode='w')         │
│    461 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.file_str │
│    462 │   │   else:                                                         │
│ ❱  463 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.name))   │
│    464 │                                                                     │
│    465 │   def __exit__(self, *args) -> None:                                │
│    466 │   │   self.file_like.write_end_of_file()                            │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Parent directory
s3:/model-training-us/models/experiment does not exist.

The path s3://model-training-us/models/ exists.

Isn't saving to an S3 bucket supported by fabric.save? Or am I encountering a weird bug?

Thank you in advance for your attention.

Kind regards, Claudio

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /vision-language-model/vlm/train_vlm_fabric.py:520 in <module>               │
│                                                                              │
│   517 │                                                                      │
│   518 │   cfg = Config(args)                                                 │
│   519 │                                                                      │
│ ❱ 520 │   main(cfg)                                                          │
│   521                                                                        │
│                                                                              │
│ /vision-language-model/vlm/train_vlm_fabric.py:387 in main                   │
│                                                                              │
│   384 │   │   "step_count": 0,                                               │
│   385 │   }                                                                  │
│   386 │                                                                      │
│ ❱ 387 │   fabric.save(                                                       │
│   388 │   │   "s3://model-training-us/models/experiment1/t │
│   389 │   │   state,                                                         │
│   390 │   )                                                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/fabric.py:738 in save        │
│                                                                              │
│    735 │   │   │   for k, v in filter.items():                               │
│    736 │   │   │   │   if not callable(v):                                   │
│    737 │   │   │   │   │   raise TypeError(f"Expected `fabric.save(filter=.. │
│ ❱  738 │   │   self._strategy.save_checkpoint(path=path, state=_unwrap_objec │
│    739 │   │   self.barrier()                                                │
│    740 │                                                                     │
│    741 │   def load(                                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/strategies/fsdp.py:498 in    │
│ save_checkpoint                                                              │
│                                                                              │
│   495 │   │   │   │   │   _apply_filter(key, filter or {}, converted, full_s │
│   496 │   │   │                                                              │
│   497 │   │   │   if self.global_rank == 0:                                  │
│ ❱ 498 │   │   │   │   torch.save(full_state, path)                           │
│   499 │   │   else:                                                          │
│   500 │   │   │   raise ValueError(f"Unknown state_dict_type: {self._state_d │
│   501                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:618 in save            │
│                                                                              │
│    615 │   _check_save_filelike(f)                                           │
│    616 │                                                                     │
│    617 │   if _use_new_zipfile_serialization:                                │
│ ❱  618 │   │   with _open_zipfile_writer(f) as opened_zipfile:               │
│    619 │   │   │   _save(obj, opened_zipfile, pickle_module, pickle_protocol │
│    620 │   │   │   return                                                    │
│    621 │   else:                                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:492 in                 │
│ _open_zipfile_writer                                                         │
│                                                                              │
│    489 │   │   container = _open_zipfile_writer_file                         │
│    490 │   else:                                                             │
│    491 │   │   container = _open_zipfile_writer_buffer                       │
│ ❱  492 │   return container(name_or_buffer)                                  │
│    493                                                                       │
│    494                                                                       │
│    495 def _is_compressed_file(f) -> bool:                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:463 in __init__        │
│                                                                              │
│    460 │   │   │   self.file_stream = io.FileIO(self.name, mode='w')         │
│    461 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.file_str │
│    462 │   │   else:                                                         │
│ ❱  463 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.name))   │
│    464 │                                                                     │
│    465 │   def __exit__(self, *args) -> None:                                │
│    466 │   │   self.file_like.write_end_of_file()                            │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Parent directory
s3:/model-training-us/models/experiment does not exist.

Environment

More info

No response

cc @borda @awaelchli @carmocca @justusschock

awaelchli commented 10 months ago

Hey @claudio-alanaai

It is just not supported / implemented. So therefore it's not really a bug. Generally if it's not documented, you can assume it is not supported. It's definitely something we want, someone just has to do it :)

(duplicate of #18786)