Vf_preds not in SampleBatch (for PPO)

Hi!

I have implemented a custom model and a custom action distribution that I am training with PPO.
Everything works flawlessly but when sample_batch = compute_bootstrap_value(sample_batch, policy) is called, I get value = dict.__getitem__(self, key) KeyError: 'vf_preds' so I am unable to run a training step. Why are the value predictions not in the batch?

My custom model has a value_function() defined that is called and works…

This is my algo config:

algo = (
    PPOConfig()
    .rollouts(
        num_rollout_workers=4,
        num_envs_per_worker=1,
        create_env_on_local_worker=False,
        rollout_fragment_length=512,
        batch_mode="truncate_episodes", 
    )
    .resources(
        num_gpus=0
    )
    .environment(
        env="env_name"
    )
    .experimental(
        _disable_initialize_loss_from_dummy_batch=True,
    )
    .training(
        model={
            "custom_model": "CustomModel",
            "custom_action_dist": "CustomActionDist",
        },
        # ppo args
        gamma=0.99,
        lr=0.0001,
        train_batch_size=2048,
        use_critic=True,
        use_gae=True,
        lambda_=0.95,
        kl_coeff=0.0,
        sgd_minibatch_size=64,
        num_sgd_iter=3,
        clip_param=0.1,
        entropy_coeff=0.01,
    )
    .multi_agent(
        policies={
            "player0": (None, env.observation_space["player0"], env.action_space["player0"], {}),
            "player1": (None, env.observation_space["player1"], env.action_space["player1"], {})
        },
        policy_mapping_fn=(lambda agent_id, *args, **kwargs: "player0" if agent_id == "player0" else "player1"),
        policies_to_train=["player0", "player1"]
    )
    .build()
)

and this is my custom model:

class CustomModel(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        nn.Module.__init__(self)
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )

        # Define network
        self.encoder =  #my CNN

        # Define actor
        self.actor1 = # actor for dim 1 of the dict action space
        self.actor2 = # actor for dim 2 of the dict action space

        # Define critic
        self.critic = # single critic
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        self.features_ = self.encoder(input_dict)
        return self.features_,  state
    
    @override(TorchModelV2)
    def value_function(self):
        value = torch.reshape(self.critic(self.features_), [-1])
        return value

Note that the model only returns the context (features) and my custom action distribution handles this to sample actions from each actor, following the implementation in ray/rllib/examples/autoregressive_action_dist.py at master · ray-project/ray · GitHub

Thank you! :slight_smile: