Build_for_inference() in env_runner_v2.py created empty state_out_1 and lead to failure of initiation

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I had a RNN model inheriting modelV2 that had worked well with ray 2.2.

In ray 2.9, I set the option as required:
config.experimental(_enable_new_api_stack=False).build()

The configuration can be built, but the problem occurred when I call
tuner = tune.Tuner(“PPO”, param_space=config, run_config=run_config, )

The error information is attached below here, and I tried to look into the functions that have been mentioned. I noticed that the “sample_batches_by_policy” did not contain “state_out_1” when running “ray/rllib/evaluation/env_runner_v2.py”. When calling the next function build_for_inference at line 326 of ray/rllib/connectors/agent/view_requirement.py, self.view_requirements created “state_in_1” with an empty list, which finally caused the IndexError.
self.view_requirements[‘state_in_1’] viewed in debug mode looks like this:

ViewRequirement(data_col='state_out_1', space=Box(-1.0, 1.0, (256,), float32), shift=-1, index=None, batch_repeat_value=20, used_for_compute_actions=True, used_for_training=True, shift_arr=array([-1]))

Please advise how to proceed! I am willing to provide more information.

2024-02-05 05:07:41,742	ERROR tune_controller.py:1374 -- Trial task failed for trial PPO_MultiAgentArena_v3_85c05_00000
Traceback (most recent call last):
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/worker.py", line 2624, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(IndexError): ray::PPO.train() (pid=987409, ip=10.47.57.189, actor_id=da518257234fa0c302d5fd4d01000000, repr=PPO)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 339, in train
    result = self.step()
             ^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 852, in step
    results, train_iter_ctx = self._run_one_training_iteration()
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 3042, in _run_one_training_iteration
    results = self.training_step()
              ^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 407, in training_step
    train_batch = synchronous_parallel_sample(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py", line 83, in synchronous_parallel_sample
    sample_batches = worker_set.foreach_worker(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 705, in foreach_worker
    handle_remote_call_result_errors(remote_results, self._ignore_worker_failures)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 78, in handle_remote_call_result_errors
    raise r.get()
ray.exceptions.RayTaskError(IndexError): ray::RolloutWorker.apply() (pid=987409, ip=10.47.57.189, actor_id=d64b201bd95cea973cd5da4701000000, repr=<ray.rllib.evaluation.rollout_worker._modify_class.<locals>.Class object at 0x7fd832842e10>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py", line 189, in apply
    raise e
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py", line 178, in apply
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py", line 84, in <lambda>
    lambda w: w.sample(), local_worker=False, healthy_only=True
              ^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 694, in sample
    batches = [self.input_reader.next()]
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 91, in next
    batches = [self.get_data()]
               ^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 276, in get_data
    item = next(self._env_runner)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 344, in run
    outputs = self.step()
              ^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 370, in step
    active_envs, to_eval, outputs = self._process_observations(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 637, in _process_observations
    processed = policy.agent_connectors(acd_list)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/agent/pipeline.py", line 41, in __call__
    ret = c(ret)
          ^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/connector.py", line 265, in __call__
    return [self.transform(d) for d in acd_list]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/connector.py", line 265, in <listcomp>
    return [self.transform(d) for d in acd_list]
            ^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/agent/view_requirement.py", line 118, in transform
    sample_batch = agent_collector.build_for_inference()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 366, in build_for_inference
    self._cache_in_np(np_data, data_col)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 613, in _cache_in_np
    cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 613, in <listcomp>
    cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 32, in _to_float_np_array
    if torch and torch.is_tensor(v[0]):
                                 ~^^^
IndexError: list index out of range

Here I can provide the code to replicate the error:

TestMultiAgentCartPole.py where I simply changed the model to use from the ray example.

import argparse
import os
import random
import importlib
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "torch"],
    default="torch",
    help="The DL framework specifier.",
)
parser.add_argument(
    "--as-test",
    action="store_true",
    help="Whether this script should be run as a test: --stop-reward must "
    "be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
    "--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
    "--stop-reward", type=float, default=300.0, help="Reward at which we stop training."
)
# os.environ["RLLIB_ENABLE_RL_MODULE"] = "False"
if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None, local_mode=True)

    # Register the models to use.
    # Each policy can have a different configuration (including custom model).


    def get_model(
        model_file,
        fc_size=200,
        rnn_hidden_size=256,
        max_seq_len=20,
        l2_curr=3,
        l2_inp=0,
        device="cuda",
        **_,
    ):
        md = importlib.import_module(model_file)
        myModel = getattr(md, "AnotherTorchRNNModel")
        modelName = "rnn_noFC"
        ModelCatalog.register_custom_model(modelName, myModel)
        print(f"Model Registered {model_file}.")
        model_dict = {
            "custom_model": modelName,
            "max_seq_len": max_seq_len,
            "custom_model_config": {
                "fc_size": fc_size,
                "rnn_hidden_size": rnn_hidden_size,
                "l2_lambda": l2_curr,
                "l2_lambda_inp": l2_inp,
                "device": device,  # or 'cuda'
            },
        }
        return model_dict

    model_dict = get_model("mySimpleRNN")
    def gen_policy(i):
        config = PPOConfig.overrides(
            model=model_dict,
            gamma=random.choice([0.95, 0.99]),
        )
        return PolicySpec(config=config)

    # Setup PPO with an ensemble of `num_policies` different policies.
    policies = {"policy_{}".format(i): gen_policy(i) for i in range(args.num_policies)}
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id, episode, worker, **kwargs):
        pol_id = random.choice(policy_ids)
        return pol_id

    config = (
        PPOConfig().experimental( _enable_new_api_stack=False)
        .environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents})
        .framework(args.framework)
        .training(num_sgd_iter=10)
        .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
    )

    # config.model.update(model_dict)
    config.experimental( _enable_new_api_stack=False).build()
    stop = {
        "episode_reward_mean": args.stop_reward,
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }
    checkpoint_config = air.CheckpointConfig(
        checkpoint_frequency=5,
        # num_to_keep=100,
        checkpoint_at_end=True,
    )

    results = tune.Tuner(
        "PPO",
        param_space=config.to_dict(),
        run_config=air.RunConfig(
            stop=stop,
            verbose=1,
            checkpoint_config=checkpoint_config,
            local_dir="/home/lime/Documents/CartPoleTest",
        ),
    ).fit()

    if args.as_test:
        check_learning_achieved(results, args.stop_reward)
    ray.shutdown()

mySimpleRNN.py which is the RNN model that I am using.

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from typing import Callable

torch, nn = try_import_torch()
class AnotherTorchRNNModel(RecurrentNetwork, nn.Module):
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            rnn_hidden_size=256,
            l2_lambda = 3,
            l2_lambda_inp=0,
            device="cuda"
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.rnn_hidden_size = model_config["custom_model_config"]["rnn_hidden_size"]
        self.l2_lambda = model_config["custom_model_config"]["l2_lambda"]
        self.l2_lambda_inp = model_config["custom_model_config"]["l2_lambda_inp"]

        # Build the Module from 0fc + RNN + 2xfc (action + value outs).
        # self.fc1 = nn.Linear(self.obs_size, self.fc_size)
        self.rnn = nn.RNN(self.obs_size, self.rnn_hidden_size, batch_first=True, nonlinearity='relu')
        self.action_branch = nn.Linear(self.rnn_hidden_size, num_outputs)
        self.value_branch = nn.Linear(self.rnn_hidden_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None
        self.l2_loss = None
        self.l2_loss_inp = None
        self.original_loss = None

        self.activations = {}
        self.hooks = []
        self.device = device

    @override(ModelV2)
    def get_initial_state(self):
        # Place hidden states on same device as model.
        h = [
            self.rnn.weight_ih_l0.new(1, self.rnn_hidden_size).zero_().squeeze(0),
            self.rnn.weight_ih_l0.new(1, self.rnn_hidden_size).zero_().squeeze(0),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(RecurrentNetwork)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.
        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).
        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        x = inputs
        y = torch.unsqueeze(state[0], 0)
        self._features, h = self.rnn(x, y)
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0)]

    @override(ModelV2)
    def custom_loss(self, policy_loss, loss_inputs):

        l2_lambda = self.l2_lambda
        l2_reg = torch.tensor(0.).to(self.device)
        # l2_reg += torch.norm(self.rnn.weight_hh_l0.data)
        l2_reg += torch.norm(self.rnn.weight_hh_l0).to(self.device)

        l2_lambda_inp = self.l2_lambda_inp
        l2_reg_inp = torch.tensor(0.).to(self.device)
        l2_reg_inp += torch.norm(self.rnn.weight_ih_l0).to(self.device)

        self.l2_loss = l2_lambda * l2_reg
        self.l2_loss_inp = l2_lambda_inp * l2_reg_inp
        self.original_loss = policy_loss

        assert self.l2_loss.requires_grad, "l2 loss no gradient"
        assert self.l2_loss_inp.requires_grad, "l2 loss no gradient"

        custom_loss = self.l2_loss + self.l2_loss_inp

        # depending on input add loss
        total_loss = [p_loss + custom_loss for p_loss in policy_loss]

        return total_loss

    def metrics(self):
        metrics = {
            "weight_loss": self.l2_loss.item(),
            # TODO Nguyen figure out if good or not
            "original_loss": self.original_loss[0].item(),
        }
        # you can print them to command line here. with Torch models its somehow not reported to the logger
        # print(metrics)