lanl / hippynn

python library for atomistic machine learning
https://lanl.github.io/hippynn/
Other
59 stars 22 forks source link

Proper device handling for restart #14

Closed tautomer closed 1 year ago

tautomer commented 1 year ago

@lubbersnick I guess it's still too early to merge for two reasons,

  1. More tests are needed. I haven't found any issue, but I can always miss something.
  2. Doc update is still missing.

However, I decided to open this PR, so at least you can review the code, and someone might be able to test it.

The core part is basically the same as we have discussed, but I split the codes into multiple functions.

Additionally, I took a look at experiment.routines.set_devices, and mimicked its behavior.

        structure["training_modules"].model.to(model_device)
        structure["training_modules"].loss.to(model_device)
        structure["training_modules"].evaluator.model_device = model_device
        structure["training_modules"].evaluator.model = structure["training_modules"].model

Not sure if the last two lines are important or not.

Tests using a model training on GPU 1. Here are the results

Options Expected behavior Actual behavior
load_checkpoint_from_cwd(map_location={"cuda:1": "cuda:0"}) model on GPU 0 model on GPU 0
load_checkpoint_from_cwd(map_location=torch.device(0)) failure because of rng_state failure because of rng_state
load_checkpoint_from_cwd(model_device=2) model on GPU 2 model on GPU 2
load_checkpoint_from_cwd(model_device="auto") model on GPU 0 model on GPU 0
load_checkpoint_from_cwd(model_device="cpu") model on CPU model on CPU
load_checkpoint_from_cwd() model on GPU 1 model on GPU 1

Similar tests were done for load_model_from_cwd as well.

One thing to note, restoring database always keeps the database on CPU, so my concern was totally unnecessary. Even doing map_location={"cuda:1": "cuda:0"} will keep the database originally on GPU 1 to GPU 0.

tautomer commented 1 year ago

Oops, there is a problem in actually restarting training.

I think the problem is that the exp_avg and exp_avg_sq tensors in controller.optimizer.state's values should be moved as well. The keys are tensors as well, not sure if they should be moved. Let me test it.

tautomer commented 1 year ago

Ok, fixed.

But the code looks really ugly.

for _, v in structure["controller"].optimizer.state_dict()["state"].items():
    v["exp_avg"] = v["exp_avg"].to(model_device)
    v["exp_avg_sq"] = v["exp_avg_sq"].to(model_device)

@lubbersnick Is there a better way to achieve this?

lubbersnick commented 1 year ago

This is totally fragile because those parameters are specific to Adam... why are only some of the items on the device? do you understand that part?

tautomer commented 1 year ago

Hi, Nick. This https://github.com/pytorch/pytorch/issues/8741#issuecomment-496907204 trick does work. It should work for all kinds of optimizers. Granted, I didn't check it explicitly though...

why are only some of the items on the device? do you understand that part?

Here is the full backtrace of the error. I guess it's because the grad tensor is on GPU, but the exp_avg tenor is CPU, so multiplication fails. Unfortunately, I'm not totally sure why others can be on CPU. No enough knowledge on torch itself.

RuntimeError                              Traceback (most recent call last)
Cell In [8], line 1
----> 1 metric_tracker = setup_and_train(
      2         training_modules=training_modules,
      3         database=database,
      4         setup_params=experiment_params,
      5     )

File ~/software/hippynn/hippynn/experiment/routines.py:128, in setup_and_train(training_modules, database, setup_params)
    123 training_modules, controller, metric_tracker = setup_training(
    124     training_modules=training_modules, setup_params=setup_params
    125 )
    127 # Actually do the training
--> 128 return train_model(
    129     training_modules=training_modules,
    130     database=database,
    131     controller=controller,
    132     metric_tracker=metric_tracker,
    133     callbacks=None,
    134     batch_callbacks=None,
    135 )

File ~/software/hippynn/hippynn/experiment/routines.py:301, in train_model(training_modules, database, controller, metric_tracker, callbacks, batch_callbacks, store_all_better, store_best, store_structure_file, store_metrics, quiet)
    298     serialization.create_structure_file(training_modules, database, controller)
    300 try:
--> 301     training_loop(
    302         training_modules=training_modules,
    303         database=database,
    304         controller=controller,
    305         metric_tracker=metric_tracker,
    306         callbacks=callbacks,
    307         batch_callbacks=batch_callbacks,
    308         store_best=store_best,
    309         store_all_better=store_all_better,
    310         quiet=quiet,
    311     )
    313 except KeyboardInterrupt:
    314     print("******* TRAINING INTERRUPTED *******")

File ~/software/hippynn/hippynn/experiment/routines.py:474, in training_loop(training_modules, database, controller, metric_tracker, callbacks, batch_callbacks, store_all_better, store_best, quiet)
    471 batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean()
    473 batch_train_loss.backward()
--> 474 optimizer.step()
    476 if batch_callbacks:
    477     for cb in batch_callbacks:

File ~/.conda/envs/hippynn/lib/python3.10/site-packages/torch/optim/optimizer.py:113, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    111 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
    112 with torch.autograd.profiler.record_function(profile_name):
--> 113     return func(*args, **kwargs)

File ~/.conda/envs/hippynn/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/.conda/envs/hippynn/lib/python3.10/site-packages/torch/optim/adam.py:157, in Adam.step(self, closure)
    153                 max_exp_avg_sqs.append(state['max_exp_avg_sq'])
    155             state_steps.append(state['step'])
--> 157     adam(params_with_grad,
    158          grads,
    159          exp_avgs,
    160          exp_avg_sqs,
    161          max_exp_avg_sqs,
    162          state_steps,
    163          amsgrad=group['amsgrad'],
    164          beta1=beta1,
    165          beta2=beta2,
    166          lr=group['lr'],
    167          weight_decay=group['weight_decay'],
    168          eps=group['eps'],
    169          maximize=group['maximize'],
    170          foreach=group['foreach'],
    171          capturable=group['capturable'])
    173 return loss

File ~/.conda/envs/hippynn/lib/python3.10/site-packages/torch/optim/adam.py:213, in adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)
    210 else:
    211     func = _single_tensor_adam
--> 213 func(params,
    214      grads,
    215      exp_avgs,
    216      exp_avg_sqs,
    217      max_exp_avg_sqs,
    218      state_steps,
    219      amsgrad=amsgrad,
    220      beta1=beta1,
    221      beta2=beta2,
    222      lr=lr,
    223      weight_decay=weight_decay,
    224      eps=eps,
    225      maximize=maximize,
    226      capturable=capturable)
...
--> 262 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    263 exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
    265 if capturable:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
lubbersnick commented 1 year ago

p.s. please update the changelog before PR is done

tautomer commented 1 year ago

p.s. please update the changelog before PR is done

Oops.. didn't do it, but I guess there is another round before the PR is really finalized.

# -- Project information -----------------------------------------------------

project = "hippynn"
copyright = "2019, Los Alamos National Laboratory"
author = "Nicholas Lubbers"

# The full version, including alpha/beta/rc tags
release = "0.0.1b2"

I think the year and release number should be changed as well. docs/source/conf.py

Updated the doc strings of everything in serialization.py, so everything function should have type hint for both input and output. Added a section in examples/restarting.rst. Do these make sense to you?

lubbersnick commented 1 year ago

Do not mess with the copyright year... LANL files copyright.

tautomer commented 1 year ago

Do not mess with the copyright year... LANL files copyright.

Really... I thought the year should be updated every year. Only bump the version number to 0.0.1b4 then.

Let me know if you have any comment on the PR itself. I will modify the changelog.

tautomer commented 1 year ago

Done. Change log update is missing, as I plan to add it during the GH pages PR anyway.

Fully works in my tests. Granted, these tests don't cover all situations. I guess we will only find out more bugs when we use the code.

The code below can be a semi-auto test. These assertions should cover all the objects that have to be on a certain device. As we haven't started building the test suite yet, I have no idea where to keep the code documented, so I decided to throw it here.

import torch
from hippynn.experiment.serialization import load_checkpoint_from_cwd

# has to be a torch.device for assertion
model_device = torch.device("cpu")
checkpoint = load_checkpoint_from_cwd(model_device=model_device)
training_modules = checkpoint["training_modules"]
controller = checkpoint["controller"]

# model should be on `device`
assert next(training_modules.model.parameters()).device == model_device

# evaluator.model should be on `device`
assert next(training_modules.evaluator.model.parameters()).device == model_device

# loss should be on `device`
assert next(training_modules.loss.parameters()).device == model_device

# evaluator.loss should always be on "cpu"
assert next(training_modules.evaluator.loss.parameters()).device == torch.device("cpu")

# training_modules.evaluator.model_device should be set properly
# otherwise some tensor will go to the original device assigned before restarting
assert training_modules.evaluator.model_device == model_device

# optimizer is trickier
# find the first non-scalar tensor and check its device
def check_optimizer(d: dict, target_device):
    for _, v in d.items():
        if isinstance(v, torch.Tensor):
            # ignore scalar values
            # they are on CPU by default
            if v.dim() == 0:
                continue
            assert v.device == target_device
            return
        elif isinstance(v, dict):
            return check_optimizer(v, target_device)
        else:
            raise ValueError("Not sure if this ever happens")

check_optimizer(controller.optimizer.state_dict(), model_device)