How can I get the value of a state from a trained PPO model with the new API?

I have a trained PPO model and now I want to give it a state observation and return the expected value of the state. I’m struggling to find how to do this with the new RLlib API.

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version:
  • Python version:
  • OS:
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected:
  • Actual:

To get the expected value (value function output) for a given state observation from a trained PPO model using the new RLlib API, you need to use the RLModule’s compute_values method. This requires access to the RLModule (e.g., via algo.get_module()), and you should pass the observation in the correct format (typically as a dictionary with key "obs" and a torch tensor value). The value function output is then returned as a tensor. For example:

import torch
obs = ...  # your observation as a numpy array
module = algo.get_module()
value = module.compute_values({"obs": torch.from_numpy(obs).unsqueeze(0)})

This approach is shown in train_w_bc_finetune_w_ppo.py and discussed in KeyError: ‘advantages’ on MARL.

Would you like a step-by-step breakdown or more details on handling batch/multi-agent cases?

Sources:

Hint: Mention @RunLLM in the post for followups.

Here’s a script I wrote for that:

targets = {
    1: {0:.25,1:-.25,2:.25,3:1},
    2: {0:.25,1:1,2:.25,3:-.25},
}

critic = algo.learner_group._learner._module[DEFAULT_POLICY_ID]

for opponent in [1,2]:
  print('---')
  for state in range(4):
      obs = torch.tensor([0,0,0,0]).float()
      obs[state]+=1
      batch = {
          Columns.OBS: obs.unsqueeze(0),
          SELECTED_OPPONENT: torch.tensor([opponent], dtype=torch.long), AUGMENTATION: torch.tensor([0]),
      }
      with torch.no_grad():
        value = critic.compute_values(batch).item()
      print(f'State {state}; Opponent {opponent}:\t{value:.02f}\t({targets[opponent][state]:.02f})')

This critic uses an augmentation (SELECTED_OPPONENT), you can ignore that part.