ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

Deterministic mode is not set on remote worker when `Trainer` is set to `deterministic` #213

Open MarkusSpanring opened 1 year ago

MarkusSpanring commented 1 year ago

🐛 Bug

Deterministic mode is not set on all workers when Trainer is set to deterministic=True.

To Reproduce

The script is divided in two parts. test.py and model.py to show that torch.backends.cudnn.deterministic is set on every worker and that the initial value is False

#############
test.py
#############
import torch

import pytorch_lightning as pl
from pytorch_lightning.strategies import Strategy
from pytorch_lightning import LightningModule, Trainer
from ray_lightning import RayStrategy
from model import BoringModel

def get_trainer(dir,
                strategy: Strategy,
                gpus=None,
                max_epochs: int = 1,
                limit_train_batches: int = 10,
                limit_val_batches: int = 10,
                **trainer_kwargs) -> Trainer:
    """Returns a Pytorch Lightning Trainer with the provided arguments."""

    trainer = pl.Trainer(
        default_root_dir=dir,
        gpus=gpus,
        strategy=strategy,
        max_epochs=max_epochs,
        deterministic=True,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        enable_progress_bar=False,
        **trainer_kwargs)
    return trainer

def train_test(trainer: Trainer, model: LightningModule):
    """Checks if training the provided model updates its weights."""
    initial_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model)
    post_train_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])
    assert trainer.state.finished, f"Trainer failed with {trainer.state}"
    # Check that the model is actually changed post-training.
    assert torch.norm(initial_values - post_train_values) > 0.1

def test_ray_train(tmpdir, num_workers):
    """Tests if training modifies model weights."""
    model = BoringModel()
    strategy = RayStrategy(num_workers=num_workers, use_gpu=True)
    trainer = get_trainer(tmpdir, strategy=strategy)
    train_test(trainer, model)

if __name__ == '__main__':
    test_ray_train("test", 1)
#############
model.py
#############
import torch
from torch.utils.data import Dataset
from pytorch_lightning import LightningModule

print("Deterministic:", torch.backends.cudnn.deterministic)
torch.backends.cudnn.deterministic = True
print("Deterministic:", torch.backends.cudnn.deterministic)

class RandomDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index: int):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_epoch = 0

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # Arbitrary loss to have a loss that updates the model weights
        # during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction,
                                            torch.ones_like(prediction))

    def step(self, x):
        x = self(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64))

This leads to the following output

MODEL Deterministic: False
MODEL Deterministic: True
2022-09-15 09:39:20,204 INFO worker.py:1518 -- Started a local Ray instance.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
(RayExecutor pid=1819295) MODEL Deterministic: False
(RayExecutor pid=1819295) MODEL Deterministic: True
(RayExecutor pid=1819295) /scratch/markus.spanring/conda/envs/storch/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:53: LightningDeprecationWarning: pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6 and will be removed in v1.8. Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.
(RayExecutor pid=1819295)   new_rank_zero_deprecation(
(RayExecutor pid=1819295) /scratch/markus.spanring/conda/envs/storch/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:58: LightningDeprecationWarning: ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8.
(RayExecutor pid=1819295)   return new_rank_zero_deprecation(*args, **kwargs)
(RayExecutor pid=1819295) Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
(RayExecutor pid=1819295) ----------------------------------------------------------------------------------------------------
(RayExecutor pid=1819295) distributed_backend=nccl
(RayExecutor pid=1819295) All distributed processes registered. Starting with 1 processes
(RayExecutor pid=1819295) ----------------------------------------------------------------------------------------------------
(RayExecutor pid=1819295) 
(RayExecutor pid=1819295) GPU available: True (cuda), used: True (Please ignore the previous info [GPU used: False]).
(RayExecutor pid=1819295) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
(RayExecutor pid=1819295) 
(RayExecutor pid=1819295)   | Name  | Type   | Params
(RayExecutor pid=1819295) ---------------------------------
(RayExecutor pid=1819295) 0 | layer | Linear | 66    
(RayExecutor pid=1819295) ---------------------------------
(RayExecutor pid=1819295) 66        Trainable params
(RayExecutor pid=1819295) 0         Non-trainable params
(RayExecutor pid=1819295) 66        Total params
(RayExecutor pid=1819295) 0.000     Total estimated model params size (MB)

From this one can see that model.py is loaded twice and that torch.backends.cudnn.deterministic is always false at the beginning.

Expected behavior

Deterministic mode is set on all workers when running in distributed mode.

I know that PTL uses torch.use_deterministic_algorithms to set the deterministic mode and that this does not set torch.backends.cudnn.deterministic. I stumbled over this behavior when I tried to compare the checkpoints of two DCGANs (lot of non deterministic layers) that should have been equal. So I am positive that deterministic mode is not set on the remote worker.

As a first workaround I have added

if trainer._accelerator_connector.deterministic:
    trainer._accelerator_connector._init_deterministic(True)

here

I am not sure if this is enough or if the accelerator needs to be initialized properly on each worker.

Environment

``` * CUDA: - GPU: - NVIDIA A100-SXM4-80GB - NVIDIA A100-SXM4-80GB - NVIDIA A100-SXM4-80GB - NVIDIA A100-SXM4-80GB - available: True - version: 11.6 * Lightning: - lightning-bolts: 0.5.0 - pytorch-lightning: 1.6.5 - ray-lightning: 0.3.0 - torch: 1.12.1 - torch-fidelity: 0.3.0 - torchdata: 0.4.1 - torchmetrics: 0.9.3 - torchtext: 0.13.1 - torchvision: 0.13.1 * Packages: - absl-py: 1.2.0 - aiohttp: 3.8.1 - aiosignal: 1.2.0 - anyio: 3.6.1 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - asttokens: 2.0.8 - async-timeout: 4.0.2 - attrs: 22.1.0 - babel: 2.10.3 - backcall: 0.2.0 - beautifulsoup4: 4.11.1 - bleach: 5.0.1 - cachetools: 5.2.0 - certifi: 2022.6.15 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 2.1.1 - click: 8.0.4 - configobj: 5.0.6 - coverage: 6.4.4 - cycler: 0.11.0 - debugpy: 1.6.3 - decorator: 5.1.1 - defusedxml: 0.7.1 - distlib: 0.3.6 - entrypoints: 0.4 - executing: 0.10.0 - fastjsonschema: 2.16.1 - filelock: 3.8.0 - flake8: 5.0.4 - flake8-docstrings: 1.5.0 - flatten-json: 0.1.13 - fonttools: 4.36.0 - frozenlist: 1.3.1 - fsspec: 2022.7.1 - gitdb: 4.0.9 - gitpython: 3.1.27 - google-auth: 2.11.0 - google-auth-oauthlib: 0.4.6 - gputil: 1.4.0 - grpcio: 1.43.0 - identify: 2.5.3 - idna: 3.3 - importlib-metadata: 4.12.0 - iniconfig: 1.1.1 - ipykernel: 6.15.1 - ipympl: 0.9.1 - ipython: 8.4.0 - ipython-genutils: 0.2.0 - ipywidgets: 7.7.1 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.1.0 - json5: 0.9.10 - jsonschema: 4.12.1 - jupyter-client: 7.3.4 - jupyter-core: 4.11.1 - jupyter-server: 1.18.1 - jupyterlab: 3.4.5 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.15.1 - jupyterlab-widgets: 3.0.1 - kiwisolver: 1.4.4 - lightning-bolts: 0.5.0 - lxml: 4.9.1 - markdown: 3.4.1 - markupsafe: 2.1.1 - matplotlib: 3.5.3 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mistune: 0.8.4 - mock: 4.0.3 - msgpack: 1.0.4 - multidict: 6.0.2 - mypy: 0.971 - mypy-extensions: 0.4.3 - nbclassic: 0.4.3 - nbclient: 0.6.6 - nbconvert: 6.5.3 - nbformat: 5.4.0 - nest-asyncio: 1.5.5 - nodeenv: 1.7.0 - notebook: 6.4.12 - notebook-shim: 0.1.0 - numpy: 1.22.4 - oauthlib: 3.2.0 - overrides: 6.2.0 - packaging: 21.3 - pandas: 1.4.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.2.0 - pip: 22.1.2 - platformdirs: 2.5.2 - pluggy: 1.0.0 - portalocker: 2.5.1 - pre-commit: 2.20.0 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.30 - protobuf: 3.19.4 - psutil: 5.9.1 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - py-cpuinfo: 8.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.9.1 - pycparser: 2.21 - pydeprecate: 0.3.2 - pydocstyle: 6.1.1 - pyflakes: 2.5.0 - pygments: 2.13.0 - pyparsing: 3.0.9 - pyrsistent: 0.18.1 - pytest: 7.1.2 - pytest-benchmark: 3.4.1 - pytest-cov: 3.0.0 - pytest-mock: 3.8.2 - pytest-mypy: 0.9.1 - python-dateutil: 2.8.2 - pytorch-lightning: 1.6.5 - pytz: 2022.2.1 - pyyaml: 6.0 - pyzmq: 23.2.1 - ray: 2.0.0 - ray-lightning: 0.3.0 - regex: 2022.8.17 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - rsa: 4.9 - sacremoses: 0.0.41 - scipy: 1.7.3 - send2trash: 1.8.0 - setuptools: 63.4.1 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.2.0 - snowballstemmer: 2.2.0 - soupsieve: 2.3.2.post1 - stack-data: 0.4.0 - tabulate: 0.8.10 - tensorboard: 2.10.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - terminado: 0.15.0 - tinycss2: 1.1.1 - toml: 0.10.2 - tomli: 2.0.1 - torch: 1.12.1 - torch-fidelity: 0.3.0 - torchdata: 0.4.1 - torchmetrics: 0.9.3 - torchtext: 0.13.1 - torchvision: 0.13.1 - tornado: 6.2 - tqdm: 4.64.0 - traitlets: 5.3.0 - typing-extensions: 4.3.0 - urllib3: 1.26.12 - virtualenv: 20.16.3 - vulture: 2.5 - wcwidth: 0.2.5 - webencodings: 0.5.1 - websocket-client: 1.4.0 - werkzeug: 2.2.2 - wheel: 0.37.1 - widgetsnbextension: 3.6.1 - yarl: 1.8.1 - zipp: 3.8.1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.12 - version: #49-Ubuntu SMP Thu Aug 4 18:03:25 UTC 2022 ```