Building a inference endpoint from RLlib PPO checkpoint

1. Severity of the issue: (select one)

None: I’m just curious or want clarification.

Low: Annoying but doesn’t hinder my work.

Medium: Significantly affects my productivity but can find a workaround.

High: Completely blocks me.

2. Environment:

  • Ray version: 2.44.1

  • Python version: 3.10

  • OS: Ubuntu

  • Cloud/Infrastructure: GCP

  • Other libs/tools (if relevant): Pytorch

3. What happened vs. what you expected:

  • Expected: Creating an inference endpoint with RLlib PPO checkpoint, is the approach and code correct? This is first time doing in ray

  • Actual: NA

Code

import os
import numpy as np
from typing import Any, Dict, List, Union

import ray
from ray import serve
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm

CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/app/checkpoint")
PROJECT_ROOT = os.environ.get("PROJECT_ROOT", "/app")

ObsType = Union[List[float], List[int], Dict[str, Any]]


def _to_numpy_obs(obs: Any) -> Any:
    if isinstance(obs, dict):
        return {k: _to_numpy_obs(v) for k, v in obs.items()}
    if isinstance(obs, (list, tuple)):
        return np.asarray(obs, dtype=np.float32)
    return obs


@serve.deployment(
    name="PPOInference",
    num_replicas=int(os.environ.get("NUM_REPLICAS", "1")),
    ray_actor_options={"num_cpus": float(os.environ.get("NUM_CPUS", "1"))},
)
class PPOInference:
    def __init__(self, checkpoint_dir: str = CHECKPOINT_DIR):
        # Ensure custom modules referenced by checkpoint pickle are importable
        if PROJECT_ROOT not in os.sys.path:
            os.sys.path.insert(0, PROJECT_ROOT)

        env_name = os.environ.get("ENV_NAME", "CartPole-v1")

        # Explicit config is REQUIRED in your case
        cfg = (
            PPOConfig()
            .environment(env=env_name)
            .framework("torch")
        )

        print(f"[init] Restoring from checkpoint: {checkpoint_dir}")
        self.algo = Algorithm.from_checkpoint(checkpoint_dir, config=cfg)

        # New API stack: module exists
        module = self.algo.get_module()
        print("[init] Restore OK")
        print("[init] Module:", type(module))

    async def __call__(self, request):
        # --- NEW: enforce POST /predict (but keep POST / too) ---
        method = request.method.upper()
        path = request.url.path

        if method != "POST" or path not in ("/predict", "/"):
            return {"error": "Use POST /predict"}

        payload = await request.json()
        explore = bool(payload.get("explore", False))
        obs = payload.get("obs", None)

        if obs is None:
            return {
                "error": "Missing field 'obs'.",
                "example": {"obs": [0.1, 0.0, 0.2, -0.1], "explore": False},
            }

        # Batch
        if isinstance(obs, list) and len(obs) > 0 and isinstance(obs[0], (list, dict)):
            obs_batch = [_to_numpy_obs(o) for o in obs]
            actions, _, _ = self.algo.compute_actions(obs_batch, explore=explore)
            actions_list = actions.tolist() if hasattr(actions, "tolist") else list(actions)
            return {"actions": actions_list}

        # Single
        obs_np = _to_numpy_obs(obs)
        action = self.algo.compute_single_action(obs_np, explore=explore)
        if isinstance(action, tuple):
            action = action[0]
        return {"actions": [action]}


def main():
    ray.init(ignore_reinit_error=True)

    serve.start(
        http_options={
            "host": "0.0.0.0",
            "port": int(os.environ.get("PORT", "8000")),
        }
    )

    serve.run(PPOInference.bind(CHECKPOINT_DIR))


if __name__ == "__main__":
    main()

Your approach is mostly correct for serving a PPO checkpoint with Ray Serve, but there are some caveats. With Ray 2.44.1 and the new RLlib API stack, Algorithm.compute_single_action and compute_actions are deprecated and may not work as expected; instead, you should use the RLModule directly for inference. The recommended pattern is to load the RLModule from the checkpoint and use its forward_inference method, as shown in policy_inference_after_training_w_connector.py and ray_serve_with_rllib.py.

For production inference, extract the RLModule from the checkpoint and call forward_inference on batched observations. See the official Ray Serve + RLlib example and deployment class for a minimal, robust pattern. You should also ensure that your observation preprocessing matches what was used during training.

Would you like a step-by-step code example for the recommended approach?

Sources:

Hint: Mention @RunLLM in the post for followups.