MARWIL with gymnasium Dict as action Space

It seems likely that the main issues are (1) ensuring your dataset is written with only native Python types (no numpy arrays) for all nested Dict fields, and (2) reading the dataset in a way compatible with RLlib’s MARWIL offline pipeline. Below is a minimal, end-to-end example for both create_dataset.py and train_marwil.py that should work for a Dict action space with Discrete and Box subspaces.


create_dataset.py

import numpy as np
import ray

def to_serializable_action(action):
    return {
        "rotate": int(action["rotate"]),  # Discrete as int
        "thrust": [float(x) for x in action["thrust"]]  # Box as list of floats
    }

def generate_parquet_dataset(num_samples=100, filename="offline_data.parquet"):
    data = []
    for _ in range(num_samples):
        obs = np.random.rand(4).astype(np.float32).tolist()
        action = {
            "rotate": int(np.random.randint(0, 3)),
            "thrust": [float(np.random.uniform(-1, 1))]
        }
        reward = float(np.random.randn())
        new_obs = np.random.rand(4).astype(np.float32).tolist()
        done = bool(np.random.choice([False, True], p=[0.9, 0.1]))
        data.append({
            "obs": obs,
            "actions": to_serializable_action(action),
            "rewards": reward,
            "new_obs": new_obs,
            "dones": done,
        })
    ray.init(ignore_reinit_error=True)
    ds = ray.data.from_items(data)
    ds.write_parquet(filename)
    ray.shutdown()

if __name__ == "__main__":
    generate_parquet_dataset()

train_marwil.py

import gymnasium as gym
import numpy as np
from ray.rllib.algorithms.marwil import MARWILConfig
from ray.tune.registry import register_env

# Minimal custom env for Dict obs/action space
class DummyDictEnv(gym.Env):
    def __init__(self, config=None):
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
        self.action_space = gym.spaces.Dict({
            "rotate": gym.spaces.Discrete(3),
            "thrust": gym.spaces.Box(-1, 1, (1,), dtype=np.float32)
        })
        self.state = self.observation_space.sample()
        self.steps = 0

    def reset(self, *, seed=None, options=None):
        self.state = self.observation_space.sample()
        self.steps = 0
        return self.state, {}

    def step(self, action):
        self.state = self.observation_space.sample()
        reward = float(np.random.randn())
        done = self.steps > 10
        self.steps += 1
        return self.state, reward, done, False, {}

register_env("DummyDictEnv", lambda config: DummyDictEnv(config))

config = (
    MARWILConfig()
    .environment(env="DummyDictEnv")
    .offline_data(
        input_="offline_data.parquet",
        input_read_episodes=False,  # For tabular data
    )
    .training(train_batch_size_per_learner=32)
)

if __name__ == "__main__":
    algo = config.build()
    for i in range(5):
        result = algo.train()
        print(f"Iteration {i}: reward={result['episode_reward_mean']}")
    algo.stop()

This setup should avoid serialization issues and allow MARWIL to process your offline data. If you use a different observation/action space, adjust the env and dataset accordingly.

Would you like more detail on customizing the environment or dataset schema?

Sources:

Hint: Mention @RunLLM in the post for followups.