ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.45k stars 5.67k forks source link

Fails restoring weights #41508

Open Finebouche opened 10 months ago

Finebouche commented 10 months ago

What happened + What you expected to happen

The code of examples/restore_1_of_n_agents_from_checkpoint.py seems to not be working (at least in my case).

The weight are not recovered but re-initialized. The way I see it is that instead of having the same policy reward means (in Wandb) as before I get reinitialized values.

Maybe the example is not up to date or maybe I am doing something wrong here. I am using tune.Tuner().fit() and not tune.train() as in the example. But not sure why this would fail...

Versions / Dependencies

Python 3.10 Ray 2.8

Reproduction script


from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.registry import get_trainable_cls

from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback

config =  (
   get_trainable_cls("PPO").get_default_config()
...
...
   .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
)

path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

def restore_weights(path, policy_type):
    checkpoint_path = os.path.join(path, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"predator": restore_weights(path_to_checkpoint, "predator")})
        algorithm.set_weights({"prey": restore_weights(path_to_checkpoint, "prey")})

config.callbacks(RestoreWeightsCallback)

ray.init()

# Define experiment    
tuner = tune.Tuner(
    "PPO",                                  
    param_space=config,                         
    run_config=train.RunConfig(         
        stop={                                    
            "training_iteration": 1,
            "timesteps_total": 20000,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(       
            project="ppo_marl", 
            group="PPO",
            api_key="blabla",
            log_config=True,
        )],
        checkpoint_config=train.CheckpointConfig(        
            checkpoint_at_end=True,
            checkpoint_frequency=1
        ),
    ),
)

# Run the experiment 
results = tuner.fit()

ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

Finebouche commented 10 months ago

I should add that I checked that the checkpoint were correctly saved. If I do use


path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

algo = Algorithm.from_checkpoint(path_to_checkpoint)

and then use algo.compute_single_action()/ run the environment for several steps and then visualize the agents. I get the correct output.

It's really when trying to keep training those previous policies using the method described above that it fails.

Finebouche commented 10 months ago

I fell like it might be due to me using tune.Tuner().fit()and not tune.run(). In this other example with train(), it seem to work for the person that tried. Is there a way that fit reinitialize the weights ? Can you actually prevent that ?

Finebouche commented 10 months ago

Related to #40777, #32751, #36761, #36830, #41290 and #40347 All on loading previously train Model/Policies.

Finebouche commented 10 months ago

The trick of passing the checkpoint via "start_from_checkpoint" parameter to tune.Tuner().param_space found here doesn't work either :/

Finebouche commented 10 months ago

I was able to use tune.run() instead of tune.Tuner().fit() but it stil seems to be not working. The way I asses that is by visualizing an episode run of 3 environement:

  1. The initial one I want to retrieve
  2. an environment after attempt to restore weights
  3. an environment after one step

And 2. and 3. have similar behavior, different from 1.

Side problem is that tune.run is absent from documentation. So I first thought that it was being deprecated. I finally found the info I needed in the function implementation in the repo but wasn't straightforward at all.

Questions still remains:

  1. Is tune.run absent from the docs because it's being deprecated ?
  2. The weight retrieval still doesn't work with tune.run and tune.Tuner().fit() + callbacks but works with Algorithm.from_checkpoint(path_to_checkpoint)
Finebouche commented 10 months ago

Also linked to : #40626 #40777 and #37515

Documentation should clearly explain how to do that