Vf_preds not in SampleBatch (for PPO)

Hi!

I have implemented a custom model and a custom action distribution that I am training with PPO.
Everything works flawlessly but when sample_batch = compute_bootstrap_value(sample_batch, policy) is called, I get value = dict.__getitem__(self, key) KeyError: 'vf_preds' so I am unable to run a training step. Why are the value predictions not in the batch?

My custom model has a value_function() defined that is called and works…

This is my algo config:

algo = (
    PPOConfig()
    .rollouts(
        num_rollout_workers=4,
        num_envs_per_worker=1,
        create_env_on_local_worker=False,
        rollout_fragment_length=512,
        batch_mode="truncate_episodes", 
    )
    .resources(
        num_gpus=0
    )
    .environment(
        env="env_name"
    )
    .experimental(
        _disable_initialize_loss_from_dummy_batch=True,
    )
    .training(
        model={
            "custom_model": "CustomModel",
            "custom_action_dist": "CustomActionDist",
        },
        # ppo args
        gamma=0.99,
        lr=0.0001,
        train_batch_size=2048,
        use_critic=True,
        use_gae=True,
        lambda_=0.95,
        kl_coeff=0.0,
        sgd_minibatch_size=64,
        num_sgd_iter=3,
        clip_param=0.1,
        entropy_coeff=0.01,
    )
    .multi_agent(
        policies={
            "player0": (None, env.observation_space["player0"], env.action_space["player0"], {}),
            "player1": (None, env.observation_space["player1"], env.action_space["player1"], {})
        },
        policy_mapping_fn=(lambda agent_id, *args, **kwargs: "player0" if agent_id == "player0" else "player1"),
        policies_to_train=["player0", "player1"]
    )
    .build()
)

and this is my custom model:

class CustomModel(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        nn.Module.__init__(self)
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )

        # Define network
        self.encoder =  #my CNN

        # Define actor
        self.actor1 = # actor for dim 1 of the dict action space
        self.actor2 = # actor for dim 2 of the dict action space

        # Define critic
        self.critic = # single critic
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        self.features_ = self.encoder(input_dict)
        return self.features_,  state
    
    @override(TorchModelV2)
    def value_function(self):
        value = torch.reshape(self.critic(self.features_), [-1])
        return value

Note that the model only returns the context (features) and my custom action distribution handles this to sample actions from each actor, following the implementation in ray/rllib/examples/autoregressive_action_dist.py at master · ray-project/ray · GitHub

Thank you! :slight_smile:

1 Like

I am seeing a similar issue - vf_preds not in SampleBatch for PPO. I am getting the following error:

File "/export/home/xxx/xxx/miniconda3/envs/xxx/lib/python3.10/site-packages/ray/rllib/evaluation/postprocessing.py", line 113, in compute_advantages
  SampleBatch.VF_PREDS in rollout or not use_critic
AssertionError: use_critic=True but values not found

Hi @omsrisagar,

It has been my experience that this error is usually a follow on error from some other more fundamental error.

Are there any other errors or can you share of the full backtrack?

Sorry for my long delay in getting back to you as I was caught up with work and stopped working on this project. I came back here to post the fix I used (if it benefits anyone).

I fixed this issue by adding the following line to my torch model:
self.view_requirements[SampleBatch.VF_PREDS] = ViewRequirement(SampleBatch.VF_PREDS)

With full context:

class Module(TorchModelV2, Base):

    def __init__(self,
                 obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 num_outputs: int,
                 model_config: ModelConfigDict,
                 name: str,
                 args: Namespace
        ):
        '''
        Seq2Seq agent
        '''
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name)
        args.num_actions_low = action_space['action_low'].n
        Base.__init__(self, args)

        # Add vf_preds to view requirements as it is neeeded for computing GAEs in PPO
        self.view_requirements[SampleBatch.VF_PREDS] = ViewRequirement(SampleBatch.VF_PREDS)
        self.view_requirements[SampleBatch.ACTION_DIST_INPUTS] = ViewRequirement(SampleBatch.ACTION_DIST_INPUTS)
        self.view_requirements[SampleBatch.ACTION_LOGP] = ViewRequirement(SampleBatch.ACTION_LOGP)
        ...

Without this line, I am getting the above error. As per your suggestion, the full backtrack is as follows (I am using Ray 2.5.1 and training PPO algorithm):

Traceback (most recent calls WITHOUT Sacred internals):
  File "/export/home/xxx/yyy/Projects/rl_training/train/rllib_client.py", line 519, in main
    result = run(params)
  File "/export/home/xxx/yyy/Projects/rl_training/train/rllib_client.py", line 475, in run
    result = trainer.train()
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 389, in train
    raise skipped from exception_cause(skipped)
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 386, in train
    result = self.step()
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 803, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 2853, in _run_one_training_iteration
    results = self.training_step()
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 403, in training_step
    train_batch = synchronous_parallel_sample(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/execution/rollout_ops.py", line 82, in synchronous_parallel_sample
    sample_batches = [worker_set.local_worker().sample()]
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/rollout_worker.py", line 915, in sample
    batches = [self.input_reader.next()]
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 92, in next
    batches = [self.get_data()]
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 277, in get_data
    item = next(self._env_runner)
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 617, in _env_runner
    active_envs, to_eval, outputs = _process_observations(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 955, in _process_observations
    ma_sample_batch = sample_collector.postprocess_episode(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py", line 513, in postprocess_episode
    post_batches[agent_id] = policy.postprocess_trajectory(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 215, in postprocess_trajectory
    return compute_gae_for_sample_batch(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/postprocessing.py", line 222, in compute_gae_for_sample_batch
    batch = compute_advantages(
  File "/export/home/xxx/yyy/miniconda3/envs/ai2thor/lib/python3.10/site-packages/ray/rllib/evaluation/postprocessing.py", line 113, in compute_advantages
    SampleBatch.VF_PREDS in rollout or not use_critic
AssertionError: use_critic=True but values not found