1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.
2. Environment:
-
Ray version: 2.44.1
-
Python version: 3.10
-
OS: Ubuntu
-
Cloud/Infrastructure: GCP
-
Other libs/tools (if relevant): Pytorch
3. What happened vs. what you expected:
-
Expected: Creating an inference endpoint with RLlib PPO checkpoint, is the approach and code correct? This is first time doing in ray
-
Actual: NA
Code
import os
import numpy as np
from typing import Any, Dict, List, Union
import ray
from ray import serve
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/app/checkpoint")
PROJECT_ROOT = os.environ.get("PROJECT_ROOT", "/app")
ObsType = Union[List[float], List[int], Dict[str, Any]]
def _to_numpy_obs(obs: Any) -> Any:
if isinstance(obs, dict):
return {k: _to_numpy_obs(v) for k, v in obs.items()}
if isinstance(obs, (list, tuple)):
return np.asarray(obs, dtype=np.float32)
return obs
@serve.deployment(
name="PPOInference",
num_replicas=int(os.environ.get("NUM_REPLICAS", "1")),
ray_actor_options={"num_cpus": float(os.environ.get("NUM_CPUS", "1"))},
)
class PPOInference:
def __init__(self, checkpoint_dir: str = CHECKPOINT_DIR):
# Ensure custom modules referenced by checkpoint pickle are importable
if PROJECT_ROOT not in os.sys.path:
os.sys.path.insert(0, PROJECT_ROOT)
env_name = os.environ.get("ENV_NAME", "CartPole-v1")
# Explicit config is REQUIRED in your case
cfg = (
PPOConfig()
.environment(env=env_name)
.framework("torch")
)
print(f"[init] Restoring from checkpoint: {checkpoint_dir}")
self.algo = Algorithm.from_checkpoint(checkpoint_dir, config=cfg)
# New API stack: module exists
module = self.algo.get_module()
print("[init] Restore OK")
print("[init] Module:", type(module))
async def __call__(self, request):
# --- NEW: enforce POST /predict (but keep POST / too) ---
method = request.method.upper()
path = request.url.path
if method != "POST" or path not in ("/predict", "/"):
return {"error": "Use POST /predict"}
payload = await request.json()
explore = bool(payload.get("explore", False))
obs = payload.get("obs", None)
if obs is None:
return {
"error": "Missing field 'obs'.",
"example": {"obs": [0.1, 0.0, 0.2, -0.1], "explore": False},
}
# Batch
if isinstance(obs, list) and len(obs) > 0 and isinstance(obs[0], (list, dict)):
obs_batch = [_to_numpy_obs(o) for o in obs]
actions, _, _ = self.algo.compute_actions(obs_batch, explore=explore)
actions_list = actions.tolist() if hasattr(actions, "tolist") else list(actions)
return {"actions": actions_list}
# Single
obs_np = _to_numpy_obs(obs)
action = self.algo.compute_single_action(obs_np, explore=explore)
if isinstance(action, tuple):
action = action[0]
return {"actions": [action]}
def main():
ray.init(ignore_reinit_error=True)
serve.start(
http_options={
"host": "0.0.0.0",
"port": int(os.environ.get("PORT", "8000")),
}
)
serve.run(PPOInference.bind(CHECKPOINT_DIR))
if __name__ == "__main__":
main()