lanl / hippynn

python library for atomistic machine learning
https://lanl.github.io/hippynn/
Other
59 stars 22 forks source link

Fix a problem with torch RNG when restarting #12

Closed tautomer closed 1 year ago

tautomer commented 1 year ago

When a model is reloaded to a different CUDA device, an error of TypeError: RNG state must be a torch.ByteTensor might be thrown.

For example, originally the model was trained on GPU 1 and now to load it onto GPU 0, per torch docs, you can do load_checkpoint_from_cwd(map_location={'cuda:1':'cuda:0'}), which works fine. Unfortunately, this way requires users to know which GPU was used, and this is obviously a problem to automation.

However, using load_checkpoint_from_cwd(map_location=torch.device(0)) or load_checkpoint_from_cwd(map_location=lambda storage, loc: storage.cuda(0)) will throw an error of TypeError: RNG state must be a torch.ByteTensor.

The codes related are

hippynn.experiment.serialization.restore_checkpoint

torch.random.set_rng_state(state["torch_rng_state"])

and torch.random.set_rng_state

def set_rng_state(new_state: torch.Tensor) -> None:
    r"""Sets the random number generator state.

    .. note: This function only works for CPU. For CUDA, please use
             torch.manual_seed(seed), which works for both CPU and CUDA.

    Args:
        new_state (torch.ByteTensor): The desired state
    """
    default_generator.set_state(new_state)

state["torch_rng_state"] will be a tensor on GPU 0 for the latter two ways of map_location

tensor([13, 78, 72,  ...,  0,  0,  0], device=‘cuda:0’, dtype=torch.uint8

but it will stay on CPU with {'cuda:1':'cuda:0'}

tensor([13, 78, 72,  ...,  0,  0,  0], dtype=torch.uint8

Forcing the tensor to be transferred to CPU solves the problem, and RNG is originally on CPU anyway.

tautomer commented 1 year ago

Two more questions related to this problem.

  1. Why the "explicit" way of mapping works? Because it only transfers what is originally on GPU 1 to GPU 0, and what is on the CPU will stay?

  2. Will training be faster if we transfer the RNG to GPU? I guess no significant difference?

tautomer commented 1 year ago

I just realize this fix will cause other problems... Maybe I should say that map_location=torch.device(0) should not be allowed in our package in the first place. For example, if the database is originally on a GPU, then this mapping will fine. However, if the database is originally on CPU, map_location=torch.device(0) will obvious cause problems as well.

Currently all the 3 torch tensors saved by hippynn more or less relies on torch.device, so I think the best option might be writing a separated file (torch_device, for example) when the better_model is true for the first time.

If user gives map_location in the argument, if it's

  1. a dict: we assume it's right
  2. a function: we can't know the mapped device, skip
  3. everything else: convert it to the dict based on the logged device name.
lubbersnick commented 1 year ago

Per our conversation let's do this a different way. Rather than manually assuming the device for RNG should be CPU, let's add an optional device argument when de-serializing that invokes map_location to CPU first (capturing whatever device things are on) and then moves the other pieces to the device argument. If both map_location and the device are given then we raise an error.

tautomer commented 1 year ago

Per our conversation let's do this a different way. Rather than manually assuming the device for RNG should be CPU, let's add an optional device argument when de-serializing that invokes map_location to CPU first (capturing whatever device things are on) and then moves the other pieces to the device argument. If both map_location and the device are given then we raise an error.

Sure, but to point out one thing. The RNG tensor is not assumed to be on CPU, but it has to. The function we use right now torch.set_rng_state only works on CPU, according to the source code.

def set_rng_state(new_state: torch.Tensor) -> None:
    r"""Sets the random number generator state.
    .. note: This function only works for CPU. For CUDA, please use
             torch.manual_seed(seed), which works for both CPU and CUDA.
    Args:
        new_state (torch.ByteTensor): The desired state
    """
    default_generator.set_state(new_state)

torch.maunal_seed works for both CPU and CUDA, but I don't know how you can restart with this function. It only takes an integer (so original seed). I got different RNG state between set_rng_state(reload_tensor) and manual_seed(orig_seed).

I'm not sure how do you view this .cpu() thing. If we force transfer the state tensor to CPU, restarting should always "work" in the sense that it doesn't throw an error. Especially if you google "map_location", most results will tell you simply do map_location=torch.device(0). This is exactly how I discovered this bug. Of course, this is at the risk of a huge DB on VRAM which can cause memory issues later. There might be some problems as well which don't foresee yet.

I think you view this kind of behavior as unsafe, as it allows ambiguous actions behind the scenes. However, I still think we should do something towards this error if we don't do state["torch_rng_state"].cpu(). One possibility is a customized error message rather than the original torch one which doesn't indicate anything unless you actually look into it.

try:
    torch.random.set_rng_state(state["torch_rng_state"])
except TypeError:
    Raise RuntimeError("Don't use 'map_location=torch.device(device_var)' please. Check the maunal for a proper way of restarting across devices.")

For the auto-handling of devices, I thought about it again last night.

If we do map_location="cpu", is there a way to know which device the tensors were on? Let me know if you are aware of a way to "inspect" the .pt files. I think torch.load will just load the tensor without saying anything or throw an error and exit. With map_location="cpu", it never fails, so we can't know the original device. However, if we leave the transfer problem to users, things will very simple.

Possible approaches in my mind:

  1. User knows exactly the old and new device, so map_location={"old": "new"} is given (maybe together with an argument like no_atuo_device). In this case, we don't have handle devices at all. It's user's responsibility to get everything right.
  2. Move everything to CPU, which should always be safe. Then user is responsible to move some tensors back in the script.
  3. I think out of all tensors, only database has the CPU vs GPU scenario? Adding one argument like move_db_to_gpu, we can handle reloading automatically. If user uses atuo_device and move_db_to_gpu, we can
    1. Load everything to CPU.
    2. Move things that should 100% be on GPU to new GPU, like the model.
    3. Move of keep DB based on move_db_to_gpu.
    4. Done?
lubbersnick commented 1 year ago

so, here's my proposal

Does that make sense? It does give me the thought that we could try and mark DBs as non-restartable, but that's a separate question.

edit: update proposed keyword to "model_device"

tautomer commented 1 year ago

I see why I was confused now. If we don't try to auto-handle the database locations, there will be no ambiguities. We can then easily handle all situations.

The proposed load_checkpoint will look like below

from ..tools import device_fallback
def load_checkpoint(
    structure_fname,
    state_fname,
    restore_db=True,
    map_location=None,
    model_device=None,
    **kwargs
):
    # TODO: this if can be a standalone function as loading model alone need this as well
    # if both map_location and model_device are given
    if map_location != None:
        if model_device != None:
            raise KeyError(
                "map_location option is conflict with model_device option. Use either one of them."
            )
        else:
            print(f"Tensors will be loaded with option map_location={map_location}.")
    # if map_location = None and model_device != None, we will handle the devices
    # if both are none, no transfer across device happens, directly pass map_location (which is None) to torch.load
    elif model_device != None:
        if model_device == "auto":
            model_device = device_fallback()
        map_location = "cpu"

    with open(structure_fname, "rb") as pfile:
        structure = torch.load(pfile, map_location=map_location, **kwargs)

    with open(state_fname, "rb") as pfile:
        state = torch.load(pfile, map_location=map_location, **kwargs)

    # transfer stuff back to model_device
    structure = restore_checkpoint(structure, state, restore_db=restore_db)
    # no transfer happens in both case, as the tensors are on the target devices already
    if model_device == "cpu" or map_location != None:
        return structure
    else:
        # FIXME: what else should be transferred?
        # FIXME: is it `.send_to_device` or `.cuda`?
        structure["training_modules"].model.send_to_device(model_device)
        return structure

Load model should be handled in a similar way, but much simpler.

Does this make sense to you?

Still, this block looks nice to me as the above implementation won't catch the error from map_location=torch.device(0). Loading will be fine, but users can't know much from reading the original error message.

try:
    torch.random.set_rng_state(state["torch_rng_state"])
except TypeError:
    Raise RuntimeError("Don't use 'map_location=torch.device(device_var)' please. Check the maunal for a proper way of restarting across devices.")

I can prepare an example of different ways of restarting and corresponding doc. This is very much necessary. Many people (including me not too long ago) don't know how to properly restart.

lubbersnick commented 1 year ago

This si the gist, yes. The check for existing map_location keywords can be moved inside the second set of if statements to make the code simpler. to send the model to a device it's just model.to(device).

Documentation improvements would be really helpful. With proper documentation I don't think we need to catch and transmute the error for map_location, because now the usage of map_location will be drastically less necessary.

tautomer commented 1 year ago

The check for existing map_location keywords can be moved inside the second set of if statements to make the code simpler.

I didn't do it this way because I thought we'd better quit asap if there is a conflict than spending (possibly quite a bit of) time to load the tensors and only to find we would exit right after that. Also, we can wrap that if into a function, since load_model will need exactly the same function. If do that, we won't see the ugly if 😂

to send the model to a device it's just model.to(device).

I see. Let me change this snippet. But we might want to work on this (how to move tensors) a little bit. Either a section of docs or unify the name of method?

Also, I'm still not 100% sure on all the restarted quantities. name device
model GPU if possible
metric_tracker CPU?
rng_state CPU
training_modules how about stuff other than model?
controller CPU
database we don't control it here
lubbersnick commented 1 year ago

what I mean for reordering the if statements:

if model_device is not None:
   if "map_location" in kwargs:
        raise TypeError("Passing map_location explicitly and the model device are incompatible")
   if model_device == "auto":
        model_device = device_fallback()
    kwargs['map_location']='cpu'

Basically you have two checks on whether model device is None, they can be coalesced to create simpler control flow.

lubbersnick commented 1 year ago

training modules: model -> device, loss-> device, evaluator -> cpu metric tracker: we shouldn't have problems leaving this on CPU. It's not actually a torch module although it does have the best model stored on it. rng state: CPU controller: does not store tensors? database: we don't control it here.

tautomer commented 1 year ago

Basically you have two checks on whether model device is None, they can be coalesced to create simpler control flow.

Oops.. I get what you were saying now. Not sure what's wrong with my mind.

, loss-> device

Let me move loss to device as well.

Also, what's your opinion on defining the function with map_location=None, model_device=None, **kwargs vs model_device=None, **kwargs (map_location will be in the dict). I did the first one because for people heavily rely on auto-completion (aka me...), it will be very useful to have an explicitly argument there. Doc strings will also be clearer. I can surely remove it if you don't like.