Closed tautomer closed 1 year ago
Two more questions related to this problem.
Why the "explicit" way of mapping works? Because it only transfers what is originally on GPU 1 to GPU 0, and what is on the CPU will stay?
Will training be faster if we transfer the RNG to GPU? I guess no significant difference?
I just realize this fix will cause other problems... Maybe I should say that map_location=torch.device(0)
should not be allowed in our package in the first place. For example, if the database is originally on a GPU, then this mapping will fine. However, if the database is originally on CPU, map_location=torch.device(0)
will obvious cause problems as well.
Currently all the 3 torch tensors saved by hippynn more or less relies on torch.device
, so I think the best option might be writing a separated file (torch_device, for example) when the better_model
is true for the first time.
If user gives map_location
in the argument, if it's
Per our conversation let's do this a different way. Rather than manually assuming the device for RNG should be CPU, let's add an optional device argument when de-serializing that invokes map_location to CPU first (capturing whatever device things are on) and then moves the other pieces to the device argument. If both map_location and the device are given then we raise an error.
Per our conversation let's do this a different way. Rather than manually assuming the device for RNG should be CPU, let's add an optional device argument when de-serializing that invokes map_location to CPU first (capturing whatever device things are on) and then moves the other pieces to the device argument. If both map_location and the device are given then we raise an error.
Sure, but to point out one thing. The RNG tensor is not assumed to be on CPU, but it has to. The function we use right now torch.set_rng_state
only works on CPU, according to the source code.
def set_rng_state(new_state: torch.Tensor) -> None:
r"""Sets the random number generator state.
.. note: This function only works for CPU. For CUDA, please use
torch.manual_seed(seed), which works for both CPU and CUDA.
Args:
new_state (torch.ByteTensor): The desired state
"""
default_generator.set_state(new_state)
torch.maunal_seed
works for both CPU and CUDA, but I don't know how you can restart with this function. It only takes an integer (so original seed). I got different RNG state between set_rng_state(reload_tensor)
and manual_seed(orig_seed)
.
I'm not sure how do you view this .cpu()
thing. If we force transfer the state tensor to CPU, restarting should always "work" in the sense that it doesn't throw an error. Especially if you google "map_location", most results will tell you simply do map_location=torch.device(0)
. This is exactly how I discovered this bug. Of course, this is at the risk of a huge DB on VRAM which can cause memory issues later. There might be some problems as well which don't foresee yet.
I think you view this kind of behavior as unsafe, as it allows ambiguous actions behind the scenes. However, I still think we should do something towards this error if we don't do state["torch_rng_state"].cpu()
. One possibility is a customized error message rather than the original torch one which doesn't indicate anything unless you actually look into it.
try:
torch.random.set_rng_state(state["torch_rng_state"])
except TypeError:
Raise RuntimeError("Don't use 'map_location=torch.device(device_var)' please. Check the maunal for a proper way of restarting across devices.")
For the auto-handling of devices, I thought about it again last night.
If we do map_location="cpu"
, is there a way to know which device the tensors were on? Let me know if you are aware of a way to "inspect" the .pt files.
I think torch.load
will just load the tensor without saying anything or throw an error and exit. With map_location="cpu"
, it never fails, so we can't know the original device. However, if we leave the transfer problem to users, things will very simple.
Possible approaches in my mind:
map_location={"old": "new"}
is given (maybe together with an argument like no_atuo_device
). In this case, we don't have handle devices at all. It's user's responsibility to get everything right.move_db_to_gpu
, we can handle reloading automatically. If user uses atuo_device
and move_db_to_gpu
, we can
move_db_to_gpu
.so, here's my proposal
Does that make sense? It does give me the thought that we could try and mark DBs as non-restartable, but that's a separate question.
edit: update proposed keyword to "model_device"
I see why I was confused now. If we don't try to auto-handle the database locations, there will be no ambiguities. We can then easily handle all situations.
The proposed load_checkpoint
will look like below
from ..tools import device_fallback
def load_checkpoint(
structure_fname,
state_fname,
restore_db=True,
map_location=None,
model_device=None,
**kwargs
):
# TODO: this if can be a standalone function as loading model alone need this as well
# if both map_location and model_device are given
if map_location != None:
if model_device != None:
raise KeyError(
"map_location option is conflict with model_device option. Use either one of them."
)
else:
print(f"Tensors will be loaded with option map_location={map_location}.")
# if map_location = None and model_device != None, we will handle the devices
# if both are none, no transfer across device happens, directly pass map_location (which is None) to torch.load
elif model_device != None:
if model_device == "auto":
model_device = device_fallback()
map_location = "cpu"
with open(structure_fname, "rb") as pfile:
structure = torch.load(pfile, map_location=map_location, **kwargs)
with open(state_fname, "rb") as pfile:
state = torch.load(pfile, map_location=map_location, **kwargs)
# transfer stuff back to model_device
structure = restore_checkpoint(structure, state, restore_db=restore_db)
# no transfer happens in both case, as the tensors are on the target devices already
if model_device == "cpu" or map_location != None:
return structure
else:
# FIXME: what else should be transferred?
# FIXME: is it `.send_to_device` or `.cuda`?
structure["training_modules"].model.send_to_device(model_device)
return structure
Load model should be handled in a similar way, but much simpler.
Does this make sense to you?
Still, this block looks nice to me as the above implementation won't catch the error from map_location=torch.device(0)
. Loading will be fine, but users can't know much from reading the original error message.
try:
torch.random.set_rng_state(state["torch_rng_state"])
except TypeError:
Raise RuntimeError("Don't use 'map_location=torch.device(device_var)' please. Check the maunal for a proper way of restarting across devices.")
I can prepare an example of different ways of restarting and corresponding doc. This is very much necessary. Many people (including me not too long ago) don't know how to properly restart.
This si the gist, yes. The check for existing map_location keywords can be moved inside the second set of if statements to make the code simpler. to send the model to a device it's just model.to(device)
.
Documentation improvements would be really helpful. With proper documentation I don't think we need to catch and transmute the error for map_location, because now the usage of map_location will be drastically less necessary.
The check for existing map_location keywords can be moved inside the second set of if statements to make the code simpler.
I didn't do it this way because I thought we'd better quit asap if there is a conflict than spending (possibly quite a bit of) time to load the tensors and only to find we would exit right after that. Also, we can wrap that if into a function, since load_model
will need exactly the same function. If do that, we won't see the ugly if 😂
to send the model to a device it's just
model.to(device)
.
I see. Let me change this snippet. But we might want to work on this (how to move tensors) a little bit. Either a section of docs or unify the name of method?
Also, I'm still not 100% sure on all the restarted quantities. | name | device |
---|---|---|
model | GPU if possible | |
metric_tracker | CPU? | |
rng_state | CPU | |
training_modules | how about stuff other than model? | |
controller | CPU | |
database | we don't control it here |
what I mean for reordering the if statements:
if model_device is not None:
if "map_location" in kwargs:
raise TypeError("Passing map_location explicitly and the model device are incompatible")
if model_device == "auto":
model_device = device_fallback()
kwargs['map_location']='cpu'
Basically you have two checks on whether model device is None, they can be coalesced to create simpler control flow.
training modules: model -> device, loss-> device, evaluator -> cpu metric tracker: we shouldn't have problems leaving this on CPU. It's not actually a torch module although it does have the best model stored on it. rng state: CPU controller: does not store tensors? database: we don't control it here.
Basically you have two checks on whether model device is None, they can be coalesced to create simpler control flow.
Oops.. I get what you were saying now. Not sure what's wrong with my mind.
, loss-> device
Let me move loss to device as well.
Also, what's your opinion on defining the function with map_location=None, model_device=None, **kwargs
vs model_device=None, **kwargs
(map_location will be in the dict). I did the first one because for people heavily rely on auto-completion (aka me...), it will be very useful to have an explicitly argument there. Doc strings will also be clearer. I can surely remove it if you don't like.
When a model is reloaded to a different CUDA device, an error of
TypeError: RNG state must be a torch.ByteTensor
might be thrown.For example, originally the model was trained on GPU 1 and now to load it onto GPU 0, per torch docs, you can do
load_checkpoint_from_cwd(map_location={'cuda:1':'cuda:0'})
, which works fine. Unfortunately, this way requires users to know which GPU was used, and this is obviously a problem to automation.However, using
load_checkpoint_from_cwd(map_location=torch.device(0))
orload_checkpoint_from_cwd(map_location=lambda storage, loc: storage.cuda(0))
will throw an error ofTypeError: RNG state must be a torch.ByteTensor
.The codes related are
hippynn.experiment.serialization.restore_checkpoint
and torch.random.set_rng_state
state["torch_rng_state"]
will be a tensor on GPU 0 for the latter two ways ofmap_location
but it will stay on CPU with
{'cuda:1':'cuda:0'}
Forcing the tensor to be transferred to CPU solves the problem, and RNG is originally on CPU anyway.