I execute the following code from section “Converting tabular data to RLlib’s episode format” (link: Working with offline data — Ray 2.42.1) in the user guide " Working with offline data":
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import (
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
from ray.rllib.core.rl_module import RLModuleSpec
# Set up a path for the tabular data records.
tabular_data_path = "tmp/rllib_offline_recording_tabular"
# Configure the algorithm for recording.
config = (
PPOConfig()
# The environment needs to be specified.
.environment(
env="CartPole-v1",
)
# Make sure to sample complete episodes because
# you want to record RLlib's episode objects.
.env_runners(
batch_mode="complete_episodes",
)
# Set up 5 evaluation `EnvRunners` for recording.
# Sample 50 episodes in each evaluation rollout.
.evaluation(
evaluation_num_env_runners=5,
evaluation_duration=500,
)
# Use the checkpointed expert policy from the preceding PPO training.
# Note, we have to use the same `model_config` as
# the one with which the expert policy was trained, otherwise
# the module state can't be loaded.
.rl_module(
model_config=DefaultModelConfig(
fcnet_hiddens=[32],
fcnet_activation="linear",
# Share encoder layers between value network
# and policy.
vf_share_layers=True,
),
)
# Define the output path and format. In this example you
# want to store data directly in RLlib's episode objects.
.offline_data(
output=tabular_data_path,
# You want to store for this example tabular data.
output_write_episodes=False,
)
)
#<my_checkpoint_path> holds the checkpoint of a pretrained cartpole algorithm
best_checkpoint = <my_checkpoint_path>
# Build the algorithm.
algo = config.build()
# Load the PPO-trained `RLModule` to use in recording.
algo.restore_from_path(
best_checkpoint,
# Load only the `RLModule` component here.
component=COMPONENT_RL_MODULE,
)
# Run 10 evaluation iterations and record the data.
for i in range(1):
print(f"Iteration {i + 1}")
res_eval = algo.evaluate()
print(res_eval)
# Stop the algorithm. Note, this is important for when
# defining `output_max_rows_per_file`. Otherwise,
# remaining episodes in the `EnvRunner`s buffer isn't written to disk.
algo.stop()
from ray import data
# Read the tabular data into a Ray dataset.
ds = data.read_parquet(tabular_data_path)
# Now, print its schema.
print("Tabular data schema of expert experiences:\n")
print(ds.schema())
Now the output of the line “print(ds.schema())” is given by
Column Type
------ ----
eps_id string
agent_id null
module_id null
obs string
actions int32
rewards double
new_obs string
terminateds bool
truncateds bool
action_dist_inputs numpy.ndarray(shape=(2,), dtype=float)
action_logp float
weights_seq_no int64
and NOT the output from the tutorial given by
# Column Type
# ------ ----
# eps_id string
# agent_id null
# module_id null
# obs numpy.ndarray(shape=(4,), dtype=float)
# actions int32
# rewards double
# new_obs numpy.ndarray(shape=(4,), dtype=float)
# terminateds bool
# truncateds bool
# action_dist_inputs numpy.ndarray(shape=(2,), dtype=float)
# action_logp float
# weights_seq_no int64
So, when I load the .parquet datasets, then for some reason my observations seem to be still in a serialized format as strings, which leads to outputs of “ds.take_batch(batch_size=1)” that look similar to this:
'obs': array(['BCJNGGhAjwAAAAAAAAAdigAAAFKABZWEAAEA8hqME251bXB5Ll9jb3JlLm51bWVyaWOUjAtfZnJvbWJ1ZmZlcpSTlCiWEC8A8QUAo0B5vNrqRL2j96U8ttAkvZSMBUEA8RaUjAVkdHlwZZSTlIwCZjSUiYiHlFKUKEsDjAE8lE5OTkr/////BQDwA0sAdJRiSwSFlIwBQ5R0lFKULgAAAAA=']
My Python version is 3.12, Ray version is 2.42.1 and the OS is linux. I don’t know if the problem arises during parquet saving or loading, but is there any adjustment that I can make to obtain the data in the usual dtypes that the “Cartpole-v1” env uses, e.g. “numpy.ndarray(shape=(4,), dtype=float)” for the observations? Sadly I am not very experienced in parquet and have no idea how I could convert the strings into the desired dtypes. However, I want to use offline data to work with Decision Mamba and for that I need to be able to obtain the necessary data in the correct dtypes.
Thanks in advance for your help