facebookresearch / nougat

Implementation of Nougat Neural Optical Understanding for Academic Documents
https://facebookresearch.github.io/nougat/
MIT License
8.81k stars 560 forks source link

Fine-tuning base model fails with KeyError "pytorch-lightning_version" #165

Closed OrianeN closed 11 months ago

OrianeN commented 11 months ago

I'm trying to run a dummy fine-tuning with a dataset I've created using only your ArXiv paper as sample (train = val, this is just for trying out the training pipeline).

I've created a config YAML file that looks like:

resume_from_checkpoint_path: "/path/to/nougat/models/0.1.0-base/pytorch_model.bin" 
result_path: "result_dummy_paper"
model_path: null
dataset_paths: [
  "/path/to/nougat-exp/dataset-generation-1sample/train.jsonl",  
  "/path/to/nougat-exp/dataset-generation-1sample/validation.jsonl"
]
tokenizer: "/path/to/nougat/models/0.1.0-base/tokenizer.json"
exp_name: "nougat_dummy_paper"
train_batch_sizes: [1]
num_workers: 4
val_batch_sizes: [1]
val_batches: 1
input_size: [896, 672]
max_length: 4096
max_position_embeddings: 4096
accumulate_grad_batches: 1
window_size: 7
patch_size: 4
embed_dim: 128
hidden_dimension: 1024
num_heads: [4, 8, 16, 32]
encoder_layer: [2, 2, 14, 2]
decoder_layer: 10
align_long_axis: False
num_nodes: 1
seed: 25
lr: 5e-5
min_lr: 7.5e-6
lr_step: 16
gamma: 0.9996
warmup_steps: 3
num_training_samples_per_epoch: 10
max_epochs: 1
max_steps: -1
val_check_interval: null
check_val_every_n_epoch: 1
gradient_clip_val: 0.5
verbose: True

Then I've launched python3 nougat/train.py --config train_nougat.yaml --debug , but I got a KeyError: 'pytorch-lightning_version' :

resume_from_checkpoint_path: /path/to/nougat/models/0.1.0-base/pytorch_model.bin 
result_path: result_dummy_paper
model_path: None
dataset_paths: 
  - /path/to/nougat-exp/dataset-generation-1sample/train.jsonl
  - /path/to/nougat-exp/dataset-generation-1sample/validation.jsonl

tokenizer: /path/to/nougat/models/0.1.0-base/tokenizer.json
exp_name: nougat_dummy_paper
train_batch_sizes: 
  - 1
num_workers: 4
val_batch_sizes: 
  - 1
val_batches: 1
input_size: 
  - 896
  - 672
max_length: 4096
max_position_embeddings: 4096
accumulate_grad_batches: 1
window_size: 7
patch_size: 4
embed_dim: 128
hidden_dimension: 1024
num_heads: 
  - 4
  - 8
  - 16
  - 32
encoder_layer: 
  - 2
  - 2
  - 14
  - 2
decoder_layer: 10
align_long_axis: False
num_nodes: 1
seed: 25
lr: 5e-05
min_lr: 7.5e-06
lr_step: 16
gamma: 0.9996
warmup_steps: 3
num_training_samples_per_epoch: 10
max_epochs: 1
max_steps: -1
val_check_interval: None
check_val_every_n_epoch: 1
gradient_clip_val: 0.5
verbose: True
debug: True
job: None
exp_version: 20231027_164844
Config is saved at result_dummy_paper/nougat_dummy_paper/20231027_164844/config.yaml
Global seed set to 25
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
[rank: 0] Global seed set to 25
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
[rank: 3] Global seed set to 25
[rank: 2] Global seed set to 25
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[rank: 1] Global seed set to 25
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[rank: 1] Global seed set to 25
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
[rank: 3] Global seed set to 25
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
[rank: 2] Global seed set to 25
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:617: UserWarning: Checkpoint directory /nas-labs/OCR/experiments/nougat-exp/result_dummy_paper/nougat_dummy_paper/20231027_164844 exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Restoring states from the checkpoint path at /path/to/nougat/models/0.1.0-base/pytorch_model.bin

Traceback (most recent call last):
  File "/.../nougat/train.py", line 238, in <module> 
    train(config)
  File "/.../nougat/train.py", line 208, in train 
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
Traceback (most recent call last):
  File "/.../nougat/train.py", line 238, in <module> 
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    train(config)
  File "/.../nougat/train.py", line 208, in train 
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 399, in _restore_modules_and_callbacks
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 399, in _restore_modules_and_callbacks
Traceback (most recent call last):
  File "/.../nougat/train.py", line 238, in <module> 
    train(config)
    self.resume_start(checkpoint_path)
    self.resume_start(checkpoint_path)
  File "/.../nougat/train.py", line 208, in train 
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 84, in resume_start
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 84, in resume_start
    self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
    self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 142, in _pl_migrate_checkpoint
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 142, in _pl_migrate_checkpoint
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    old_version = _get_version(checkpoint)
    old_version = _get_version(checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 163, in _get_version
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 163, in _get_version
    return checkpoint["pytorch-lightning_version"]
    return checkpoint["pytorch-lightning_version"]
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 399, in _restore_modules_and_callbacks
KeyError: 'pytorch-lightning_version'
KeyError: 'pytorch-lightning_version'
    self.resume_start(checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 84, in resume_start
    self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 142, in _pl_migrate_checkpoint
    old_version = _get_version(checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 163, in _get_version
    return checkpoint["pytorch-lightning_version"]
KeyError: 'pytorch-lightning_version'
Traceback (most recent call last):
  File "/.../nougat/train.py", line 238, in <module> 
    train(config)
  File "/.../nougat/train.py", line 208, in train 
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 399, in _restore_modules_and_callbacks
    self.resume_start(checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 84, in resume_start
    self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 142, in _pl_migrate_checkpoint
    old_version = _get_version(checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 163, in _get_version
    return checkpoint["pytorch-lightning_version"]
KeyError: 'pytorch-lightning_version'

I have pytorch-lightning version "2.0.9.post0", yet I've tried with other versions as well (e.g. 2.0.0) and still got the same issue. Nougat-ocr version: 0.1.17

I've tried to add a version in the state_dict of the checkpoint, following this PyTorch discussion, but I still get the same error. Here's the modified function in train.py:

def load_checkpoint(self, path, storage_options=None):
        """
        Load a checkpoint from the specified path.

        Args:
            `path` (str): The path from which the checkpoint will be loaded.
            `storage_options` (dict, optional): Additional storage options.
        """
        path = Path(path)

        if path.is_file():
            print("path:", path, path.is_dir())
            ckpt = torch.load(path)
            if not "state_dict" in ckpt:
                ckpt["state_dict"] = {
                    "model." + key: value
                    for key, value in torch.load(
                        path.parent / "pytorch_model.bin"
                    ).items()
                }
        else:
            ckpt = torch.load(path / "artifacts.ckpt")
            state_dict = torch.load(path / "pytorch_model.bin")
            ckpt["state_dict"] = {
                "model." + key: value for key, value in state_dict.items()
            }

        print("Custom loaded ckpt, adding pl version")
        ckpt["state_dict"]["pytorch-lightning_version"] = pl.__version__

        return ckpt
OrianeN commented 11 months ago

It seems that the CustomCheckpointIO instance custom_ckpt was not used. Pytorch-lightning documentation) suggests to pass it to the Trainer constructor:

custom_ckpt = CustomCheckpointIO()
[...]
trainer = pl.Trainer(
    [...]
    callbacks=[
        lr_callback,
        grad_norm_callback,
        checkpoint_callback,
        GradientAccumulationScheduler({0: config.accumulate_grad_batches}),
    ],
    plugins=[custom_ckpt]
)

Adding that to train.py changed the output slightly, as the prints of the custom loading function are now printed:

[...]
Restoring states from the checkpoint path at /nas-labs/OCR/experiments/nougat-exp/nougat/models/0.1.0-base/pytorch_model.bin
path: /.../nougat/models/0.1.0-base/pytorch_model.bin False
Custom loaded ckpt, adding pl version
Traceback (most recent call last):
  File "/.../nougat/train.py", line 239, in <module>
    train(config)
  File "/.../nougat/train.py", line 209, in train
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 399, in _restore_modules_and_callbacks
    self.resume_start(checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 84, in resume_start
    self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 142, in _pl_migrate_checkpoint
    old_version = _get_version(checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/migration/utils.py", line 163, in _get_version
    return checkpoint["pytorch-lightning_version"]
KeyError: 'pytorch-lightning_version'
OrianeN commented 11 months ago

Update: specifying the version with ckpt["pytorch-lightning_version"] = pl.__version__ in CustomCheckpointIO.load_checkpoint() worked to load the model, yet afterwards I got this error:

KeyError: 'Trying to restore optimizer state but checkpoint contains only the model. This is probably due to ModelCheckpoint.save_weights_only being set to True.'

I guess the released model isn't meant to be fine-tuned, so I'll close this issue.