Modifying network weights with new API

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello, on my new API migration journey I have run into another roadblock. In my successful PPO training setup prior to the migration I was using a custom callback for on_algorithm_init() in order to modify the way the network weights were being initialized. Now that I have made the API migration this no longer works.

The callback code I’m trying to use is:

class CustomCallbacks (RLlibCallback):
    def __init__(self,
                 legacy_callbacks_dict: Dict[str, callable]      = None,    #required by RLlib
                ):
        super().__init__()

    def on_algorithm_init(self, *, algorithm, **kwargs) -> None:
        """Called when a new algorithm instance has finished its setup() but before training begins.
            We will use it to initialize NN weights.
        """

        # Get the initial weights from the newly created NN
        policy_dict = algorithm.get_weights(["default_policy"])["default_policy"]

        # Re-initialize the weights in policy_dict here...

        # Stuff the modified weights into the newly created NN
        to_algo = {"default_policy": policy_dict} 
        algorithm.set_weights(to_algo)

The error I get is:

Traceback (most recent call last):
  File "/home/starkj/projects/day_trader/staging/train.py", line 385, in <module>
    main(sys.argv)
  File "/home/starkj/projects/day_trader/staging/train.py", line 182, in main
    algo = cfg.build_algo()
           ^^^^^^^^^^^^^^^^
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm_config.py", line 957, in build_algo
    return algo_class(
           ^^^^^^^^^^^
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 590, in __init__
    super().__init__(
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 158, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 955, in setup
    make_callback(
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/callbacks/utils.py", line 32, in make_callback
    getattr(callback_obj, callback_name)(*(args or ()), **(kwargs or {}))
  File "/home/starkj/projects/day_trader/staging/custom_callbacks.py", line 43, in on_algorithm_init
    algorithm.set_weights(to_algo)
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 2170, in set_weights
    self.env_runner_group.local_env_runner.set_weights(weights)
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/utils/deprecation.py", line 121, in _ctor
    deprecation_warning(
  File "/home/starkj/miniconda3/envs/trader3/lib/python3.12/site-packages/ray/rllib/utils/deprecation.py", line 48, in deprecation_warning
    raise ValueError(msg)
ValueError: `set_weights` has been deprecated. Use `SingleAgentEnvRunner.set_state()` instead.

I have been unable to find any Ray documentation that describes how to do this with the new API, including the API migration guide. Looking at the source code for the RLlib Algorithm class, it appears to be doing exactly what the deprecation message wants, which makes the message more baffling. I figured out how to extract the local env runner from the Algorithm object, but code in the SingleAgentEnvRunner.set_state() method is structured in a way that will require a lot of time to track down how to invoke this method with a meaningful input. In fact, current documentation (and source code) for the Algorithm class gives every indication that my code should work as is. I have been unable to find where this deprecation message is generated, thus get to the bottom of this problem.

As a final frustration, I feel like the Ray team is starting to use the term deprecated rather loosely, and imposing draconian measures for such items. Normally, deprecated indicates something that is no longer desirable and will be unavailable in the future (a warning to start migrating away from it). Whereas this usage generates a full-blown ValueError exception, and there is no choice but to move away from it immediately (even though it appears it should work).

Has anyone else successfully modified weights of a network using the new API?