DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.07k stars 1.7k forks source link

[Question] Can I load a RL model with Mac trained on windows platform? #2029

Closed Ssstirm closed 1 week ago

Ssstirm commented 1 week ago

❓ Question

I was trying to load a RL model on MacBook with apple chip but trained on windows platform with GPU. When I simply load the mode, I met the error"

---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
Cell In[3], line 1
----> 1 model = PPO.load("./model/test1", device='cpu')

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:680, in BaseAlgorithm.load(cls, path, env, device, custom_objects, print_system_info, force_reset, **kwargs)
    677     print("== CURRENT SYSTEM INFO ==")
    678     get_system_info()
--> 680 data, params, pytorch_variables = load_from_zip_file(
    681     path,
    682     device=device,
    683     custom_objects=custom_objects,
    684     print_system_info=print_system_info,
    685 )
    687 assert data is not None, "No data found in the saved file"
    688 assert params is not None, "No params found in the saved file"

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/save_util.py:451, in load_from_zip_file(load_path, load_data, custom_objects, device, verbose, print_system_info)
    447 file_content.seek(0)
    448 # Load the parameters with the right ``map_location``.
    449 # Remove ".pth" ending with splitext
    450 # Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
--> 451 th_object = th.load(file_content, map_location=device, weights_only=False)
    452 # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
    453 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
    454     # PyTorch variables (not state_dicts)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/serialization.py:1040, in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1038     except RuntimeError as e:
   1039         raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> 1040 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/serialization.py:1258, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1252 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
   1253     raise RuntimeError(
   1254         "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
   1255         f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
   1256         "functionality.")
-> 1258 magic_number = pickle_module.load(f, **pickle_load_args)
   1259 if magic_number != MAGIC_NUMBER:
   1260     raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: invalid load key, '\x00'.

Any solution to this? Thanks for any help

Checklist

araffin commented 1 week ago

Hello, it should be possible, but make sure to have the same env (python version, pytorch version, cloudpickle version).

EDIT: maybe related https://discuss.pytorch.org/t/unpicklingerror-invalid-load-key-x00/73083/4

Ssstirm commented 1 week ago

Hello, it should be possible, but make sure to have the same env (python version, pytorch version, cloudpickle version).

Hello, here is the system info: == CURRENT SYSTEM INFO ==

OS: macOS-15.1-arm64-arm-64bit Darwin Kernel Version 24.1.0: Sun Jul 14 12:07:15 PDT 2024; root:xnu-11215.0.165.0.4~64/RELEASE_ARM64_T8112 Python: 3.10.12 Stable-Baselines3: 2.3.2 PyTorch: 2.2.0 GPU Enabled: False Numpy: 1.23.5 Cloudpickle: 3.0.0 Gymnasium: 0.29.1 OpenAI Gym: 0.29.1 == SAVED MODEL SYSTEM INFO ==

OS: Windows-10-10.0.22631-SP0 10.0.22631 Python: 3.10.12 Stable-Baselines3: 2.3.2 PyTorch: 2.2.0+cu121 GPU Enabled: True Numpy: 1.23.5 Cloudpickle: 3.0.0 Gymnasium: 0.29.1 OpenAI Gym: 0.29.1 it should be alright but still meet the UnpicklingError above

araffin commented 1 week ago

mmh, apart from the GPU/CPU and the arm vs amd arch, looks fine. What you can do in the meantime is use other export method of SB3 like get_parameters() (see https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters and the example above).

But the issue might be related to pytorch itself (you can try upgrading pytorch on both machines).

Ssstirm commented 1 week ago

mmh, apart from the GPU/CPU and the arm vs amd arch, looks fine. What you can do in the meantime is use other export method of SB3 like get_parameters() (see https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters and the example above).

But the issue might be related to pytorch itself (you can try upgrading pytorch on both machines).

The issue above is because macOS generate a folder named macox with the same file name with .pth as an end. Remove them and get a new error during setting up as below:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 model = PPO.load("./model/test1", device='cpu',print_system_info=True)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:740, in BaseAlgorithm.load(cls, path, env, device, custom_objects, print_system_info, force_reset, **kwargs)
    738 model.__dict__.update(data)
    739 model.__dict__.update(kwargs)
--> 740 model._setup_model()
    742 try:
    743     # put state_dicts back in place
    744     model.set_parameters(params, exact_match=True, device=device)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/ppo/ppo.py:174, in PPO._setup_model(self)
    173 def _setup_model(self) -> None:
--> 174     super()._setup_model()
    176     # Initialize schedules for policy/value clipping
    177     self.clip_range = get_schedule_fn(self.clip_range)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py:134, in OnPolicyAlgorithm._setup_model(self)
    122         self.rollout_buffer_class = RolloutBuffer
    124 self.rollout_buffer = self.rollout_buffer_class(
    125     self.n_steps,
    126     self.observation_space,  # type: ignore[arg-type]
   (...)
    132     **self.rollout_buffer_kwargs,
    133 )
--> 134 self.policy = self.policy_class(  # type: ignore[assignment]
    135     self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
    136 )
    137 self.policy = self.policy.to(self.device)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/policies.py:535, in ActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    532 # Action distribution
    533 self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
--> 535 self._build(lr_schedule)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/stable_baselines3/common/policies.py:634, in ActorCriticPolicy._build(self, lr_schedule)
    631         module.apply(partial(self.init_weights, gain=gain))
    633 # Setup optimizer with initial learning rate
--> 634 self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/optim/adam.py:45, in Adam.__init__(self, params, lr, betas, eps, weight_decay, amsgrad, foreach, maximize, capturable, differentiable, fused)
     39     raise ValueError(f"Invalid weight_decay value: {weight_decay}")
     41 defaults = dict(lr=lr, betas=betas, eps=eps,
     42                 weight_decay=weight_decay, amsgrad=amsgrad,
     43                 maximize=maximize, foreach=foreach, capturable=capturable,
     44                 differentiable=differentiable, fused=fused)
---> 45 super().__init__(params, defaults)
     47 if fused:
     48     if differentiable:

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/optim/optimizer.py:278, in Optimizer.__init__(self, params, defaults)
    275     param_groups = [{'params': param_groups}]
    277 for param_group in param_groups:
--> 278     self.add_param_group(cast(dict, param_group))
    280 # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
    281 # which I don't think exists
    282 # https://github.com/pytorch/pytorch/issues/72948
    283 self._warned_capturable_if_run_uncaptured = True

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_compile.py:22, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     20 @functools.wraps(fn)
     21 def inner(*args, **kwargs):
---> 22     import torch._dynamo
     24     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/__init__.py:2
      1 import torch
----> 2 from . import allowed_functions, convert_frame, eval_frame, resume_execution
      3 from .backends.registry import list_backends, lookup_backend, register_backend
      4 from .code_context import code_context

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:45
     32 from .bytecode_transformation import (
     33     check_inst_exn_tab_entries_valid,
     34     Instruction,
   (...)
     37     transform_code_object,
     38 )
     39 from .cache_size import (
     40     CacheSizeRelevantForFrame,
     41     compute_cache_size,
     42     exceeds_cache_size_limit,
     43     is_recompilation,
     44 )
---> 45 from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher
     46 from .exc import (
     47     augment_exc_message,
     48     BackendCompilerFailed,
   (...)
     54     Unsupported,
     55 )
     56 from .guards import (
     57     CheckFunctionManager,
     58     get_and_maybe_log_recompilation_reason,
     59     GuardedCode,
     60 )

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:69
     66             continue
     67         globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
---> 69 from . import config, convert_frame, external_utils, skipfiles, utils
     70 from .code_context import code_context
     71 from .exc import CondOpArgsMismatchError, UserError, UserErrorType

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/skipfiles.py:39
     36 import torch.utils._content_store
     37 from .utils import getfile
---> 39 from .variables.functions import (
     40     NestedUserFunctionVariable,
     41     UserFunctionVariable,
     42     UserMethodVariable,
     43 )
     46 """
     47 A note on skipfiles:
     48 
   (...)
     86 you don't want to inline them.
     87 """
     90 BUILTIN_SKIPLIST = (
     91     abc,
     92     collections,
   (...)
    121     _weakrefset,
    122 )

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py:26
     14 from .dicts import (
     15     ConstDictVariable,
     16     CustomizedDictVariable,
   (...)
     19     SetVariable,
     20 )
     21 from .functions import (
     22     NestedUserFunctionVariable,
     23     UserFunctionVariable,
     24     UserMethodVariable,
     25 )
---> 26 from .higher_order_ops import TorchHigherOrderOperatorVariable
     27 from .iter import (
     28     CountIteratorVariable,
     29     CycleIteratorVariable,
     30     IteratorVariable,
     31     RepeatIteratorVariable,
     32 )
     33 from .lazy import LazyVariableTracker

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py:11
      9 import torch.fx
     10 import torch.nn
---> 11 import torch.onnx.operators
     12 from torch._dispatch.python import enable_python_dispatcher
     13 from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/onnx/__init__.py:46
     33 from .errors import CheckerError  # Backwards compatibility
     34 from .utils import (
     35     _optimize_graph,
     36     _run_symbolic_function,
   (...)
     43     unregister_custom_op_symbolic,
     44 )
---> 46 from ._internal.exporter import (  # usort:skip. needs to be last to avoid circular import
     47     DiagnosticOptions,
     48     ExportOptions,
     49     ONNXProgram,
     50     ONNXProgramSerializer,
     51     ONNXRuntimeOptions,
     52     InvalidExportOptionsError,
     53     OnnxExporterError,
     54     OnnxRegistry,
     55     dynamo_export,
     56     enable_fake_mode,
     57 )
     59 from ._internal.onnxruntime import (
     60     is_onnxrt_backend_supported,
     61     OrtBackend as _OrtBackend,
     62     OrtBackendOptions as _OrtBackendOptions,
     63     OrtExecutionProvider as _OrtExecutionProvider,
     64 )
     66 __all__ = [
     67     # Modules
     68     "symbolic_helper",
   (...)
    114     "is_onnxrt_backend_supported",
    115 ]

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/onnx/_internal/exporter/__init__.py:13
      1 __all__ = [
      2     "ONNXRegistry",
      3     "ONNXProgram",
   (...)
      9     "verification",
     10 ]
     12 from . import _testing as testing, _verification as verification
---> 13 from ._analysis import analyze
     14 from ._compat import export_compat
     15 from ._core import export, exported_program_to_ir

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_analysis.py:14
     11 from typing import TYPE_CHECKING
     13 import torch
---> 14 import torch._export.serde.schema
     15 from torch.export import graph_signature
     16 from torch.onnx._internal.exporter import _dispatching, _registration

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_export/__init__.py:69
     66 from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
     67 from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges
---> 69 from .exported_program import (
     70     _create_stateful_graph_module,
     71     _process_constraints,
     72     CallSpec,
     73 )
     74 from .passes.add_runtime_assertions_for_constraints_pass import (
     75     _AddRuntimeAssertionsForInlineConstraintsPass,
     76 )
     77 from .passes.lift_constant_tensor_pass import lift_constant_tensor_pass

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_export/exported_program.py:36
     25 from torch.export.graph_signature import (
     26     ExportBackwardSignature,
     27     ExportGraphSignature,
     28 )
     30 from torch.export.exported_program import (
     31     ExportedProgram,
     32     ModuleCallEntry,
     33     ModuleCallSignature,
     34 )
---> 36 from .utils import _check_input_constraints_pre_hook
     39 __all__ = [
     40     "ExportBackwardSignature",
     41     "ExportGraphSignature",
   (...)
     44     "ModuleCallSignature",
     45 ]
     48 # Information to maintain user calling/returning specs

File ~/miniconda3/envs/mlagents/lib/python3.10/site-packages/torch/_export/utils.py:24
      9 from torch.utils._pytree import (
     10     _register_pytree_node,
     11     Context,
   (...)
     17     UnflattenFunc,
     18 )
     21 SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
---> 24 @torch._dynamo.disable
     25 def _check_input_constraints_pre_hook(self, *args, **kwargs):
     26     flat_args, _ = tree_flatten(args)
     27     return _check_input_constraints_for_graph(
     28         self.graph,
     29         range_constraints=self.range_constraints,
     30         equality_constraints=self.equality_constraints,
     31     )(*flat_args)

AttributeError: partially initialized module 'torch._dynamo' has no attribute 'disable' (most likely due to a circular import)

This could be something wrong with PyTorch itself I guess

Ssstirm commented 1 week ago

mmh, apart from the GPU/CPU and the arm vs amd arch, looks fine. What you can do in the meantime is use other export method of SB3 like get_parameters() (see https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters and the example above).

But the issue might be related to pytorch itself (you can try upgrading pytorch on both machines).

After upgrade the PyTorch version on macOS, it seems ok now. (No error at least)

thanks for your help