Fails restoring weights #41508

I am publishing here as well as a copie of #41508 as I am not sure where my problem belong.

The code of examples/restore_1_of_n_agents_from_checkpoint.py seems to not be working (at least in my case).

The weight are not recovered but re-initialized. The way I see it is that instead of having the same policy reward means (in Wandb) as before I get reinitialized values.

Maybe the example is not up to date or maybe I am doing something wrong here. I am using tune.Tuner().fit() and not tune.train() as in the example. But not sure why this would fail…

Versions / Dependencies

Python 3.10
Ray 2.8

Reproduction script


from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.registry import get_trainable_cls

from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback

config =  (
   get_trainable_cls("PPO").get_default_config()
...
...
   .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
)

path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

def restore_weights(path, policy_type):
    checkpoint_path = os.path.join(path, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"predator": restore_weights(path_to_checkpoint, "predator")})
        algorithm.set_weights({"prey": restore_weights(path_to_checkpoint, "prey")})

config.callbacks(RestoreWeightsCallback)

ray.init()

# Define experiment    
tuner = tune.Tuner(
    "PPO",                                  
    param_space=config,                         
    run_config=train.RunConfig(         
        stop={                                    
            "training_iteration": 1,
            "timesteps_total": 20000,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(       
            project="ppo_marl", 
            group="PPO",
            api_key="blabla",
            log_config=True,
        )],
        checkpoint_config=train.CheckpointConfig(        
            checkpoint_at_end=True,
            checkpoint_frequency=1
        ),
    ),
)

# Run the experiment 
results = tuner.fit()

ray.shutdown()

I checked that the checkpoint were correctly saved. If I do use


path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

algo = Algorithm.from_checkpoint(path_to_checkpoint)

and then use algo.compute_single_action()/ run the environment for several steps and then visualize the agents. I get the correct output.

So it’s really when trying to keep training those previous policies using the method described above that it fails.

I was able to use tune.run() instead of tune.Tuner().fit() but it stil seems to be not working. The way I asses that is by visualizing an episode run of 3 environement:

  1. The initial one I want to retrieve
  2. an environment after attempt to restore weights
  3. an environment after one step

And 2. and 3. have similar behavior, different from 1.

This is linked to this :

but i need to use ray tune for Wandb integration and those solution use PPOTrainer directly

I want to know too… I have a Checkpoint made by tuner.fit() and i want to expand the model and use the pretrained data.