Hello, I have a question regarding offline rl training with custom action masking model
I have a historical dataset that I have already written into SampleBatch format and saved in json files. In my case, each line is a json object that contains full episode’s data, so it is not one timestep per line, but one episode per line. Inside the obs and new_obs field, it is lz4 encoded data that contains two fields observations
and action_mask
that works with the example action masking model in rllib example repo.
Below is my agent config, where I disabled the preprocessor since my obs is not flattened (ray/rllib/models/modelv2.py at fc630c9c94f835cb44f2297958e36ddf64dc5939 · ray-project/ray · GitHub).
The current code encounters error when reading in offline json data using JsonReader (ray/rllib/offline/json_reader.py at fc630c9c94f835cb44f2297958e36ddf64dc5939 · ray-project/ray · GitHub) where it by default treats each line as single timestep and tree.map_structure_up_to will raise KeyError: ‘action_mask’ not found, since I am sending in a list of dictionaries, not dictionary.
I did a fix in this method to accommodate episodic input, but realized that I could not do that when I deploy the code for training remotely, could anyone point me to the right direction to fix this issue more generally?
agent_config = (
MARWILConfig()
.framework(agent_config_yaml['framework'])
.rollouts(create_env_on_local_worker=True)
.rollouts(num_rollout_workers=0)
.debugging(seed=0, log_level="WARN")
.reporting(min_train_timesteps_per_iteration=1000, keep_per_episode_custom_metrics=True)
.environment(env=ENV, env_config=env_config)
.offline_data(
input_config={
"paths": train_input_path,
"format": "json",
},
input_=train_input_path
)
.experimental(
_disable_preprocessor_api=True,
)
.training(
model={
"fcnet_hiddens" : [256, 256],
"custom_model": "custom_marwil_model",
"custom_model_config": {
"input_files": train_input_path,
},
},
beta=agent_config_yaml.get('beta', 0.0), # 0.0 for behavior cloning,
)
.callbacks(CustomRLlibCallbacks)
)
Sample offline data looks like this:
{"type": "SampleBatch", "t": [0, 1], "eps_id": ["C", "C"], "obs": "BCJNGGhArAoAAAAAAADpsQIAAGGABZWhCgABAPEcjBVudW1weS5jb3JlLm11bHRpYXJyYXmUjAxfcmVjb25zdHJ1Y3SUk5SMBSkAUpSMB25kIwDwD5OUSwCFlEMBYpSHlFKUKEsBSwKFlGgDjAVkdHlwZTMAYgJPOJSJiB4A0QOMAXyUTk5OSv////8FAPEPSz90lGKJXZQofZQojAthY3Rpb25fbWFza5RdlChHngBhAAAARz/wCAAPCQA3Ak8AwAAAZYwMb2JzZXJ2YXoAGnN2AAAhAAACAAaIAAAOAAACAA8SAP8GDqgBAC4BAAIADzIBWlE/O8H6oC8CMmFZPAkAQO+bYOASAACMAAACAA+QAFoAcQAAAgAPdQCHBTIBAKcAAAIAD6sAPwVjABS/ngEPCQCYBb0AABwBAAIAD5MESB91DAULAH4AAAIAD4IANgBNAAACAA8MBSUF9QMARQAAAgAGPAIFGwAFCQAAIAAAAgAGvgAADgAAAgAPEgAAABcAAAIABlEAAA4AAAIADxIAkAXzAAUJAAC5AAACAA+9AABBPvu3r6kEACAAAAIABiQAXz7v3lMguwQVBS0AQT716TzxBEE/IG7tWgBQP0lVrmBIAFE/791twAkAMjdH0AkAEwUtAAB6AAACAFFHPsPqLC0AAA4AAAIABpAAUD9IpmSAPwAAFwAAAgAGGwAADgAAAgAHfgAcNVoADmwAADIAAAIAD0gAAAAXAAACAAYbAAVaAAAXAAACAAYbAAAOAAACAA8SAC0FYwBBPyBLrVEAUD9DLBVA8wAjPxXPAAUJAE8+1X9JngEEIz8jNgAARAAAAgAGmQAGlQExHjy/ngEUvrABDwkAGkE+2ksV5gEPNgAaBS0AQT7DSCG9AAUSAA8JABEFIAEA1AAAAgAH2AAPqAk0AFYAsAAAAABldWV0lGIuAAAAAA==", "actions": [1, 3], "rewards": [9.879999999999999, 0.321115074], "new_obs": "BCJNGGhArAoAAAAAAADpcQMAAGGABZWhCgABAPEcjBVudW1weS5jb3JlLm11bHRpYXJyYXmUjAxfcmVjb25zdHJ1Y3SUk5SMBSkAUpSMB25kIwDwD5OUSwCFlEMBYpSHlFKUKEsBSwKFlGgDjAVkdHlwZTMAYgJPOJSJiB4A0QOMAXyUTk5OSv////8FAPEPSz90lGKJXZQofZQojAthY3Rpb25fbWFza5RdlChHngA2AAAACQAhP/ARAA8JAC4CRgDAAABljAxvYnNlcnZhegAfc3YABAAqAAACAA+IAAAAFwAAAgAPGwAABawAACAAAAIADxIAEgApAAACAAZRAAAOAAACAA8SAJAF4QAFCQAAuQAAAgAPvQAAUD77t6+gzAEAIAAAAgAGJABfPu/eUyAbAAMAKQAAAgAPLQAAUD716TzgNgBBPyBu7VoAUT9JVa5gEgBB791twAkAMjdH0AkAEwUtAABNAAACAFFHPsPqLC0AAA4AAAIABmMAUD9IpmSAPwAAFwAAAgAGGwAADgAAAgAHfgAcNVoADmwAADIAAAIAD0gAAAAXAAACAAYbAAVaAAAXAAACAAYbAAAOAAACAA8SAC0FYwBBPyBLrVEAUD9DLBVA8wAjPxXPAAUJAE8+1X9JngEEIz8jNgAARAAAAgAGmQAGlQExHjy/ngEUvrABDwkAGkE+2ksV5gEPNgAaBS0AQT7DSCG9AAUSAA8JABEFIAEA1AAAAgAH2AAPkwQ9H3UMBR0AkAAAAgAPiwAkADsAAAIADwwFNwX+Aw4JAABpAAACAAasAAAOAAACAA8SAAAAFwAAAgAGbQEADgAAAgAPEgCHFL8PAg4JAAC5AAACAA+9AABBPvaSWpcEACAAAAIABiQAQT75UIEvAwAXAAACAAYbAAAOAAACAAYSAAUtAEE/AjNxGwBBPx2Tl1oAQj9FnRbRAzHv7N0CA0E/N+OE8QRKPvtNKWMAQT6wiuo/AAAFAAACAAZjAEE/Q1aoPwAAFwAAAgAGGwAADgAAAgA0Rz7ifgAxL9oFrQMAFwAAAgAzRz7yGwAADgAAAgAGLQAADgAAAgAPUQAAABcAAAIADy0AGwAyAAACAA9RAAkAIAAAAgAPJABaFL4gAQ8JAKEF3QEANwEAAgAGOwEF8AMAFwAAAgAG1QMFGwAPCQARADsAsAAAAABldWV0lGIuAAAAAA==", "truncateds": [false, true], "prev_actions": [0, 1], "prev_rewards": [0.0, 9.879999999999999], "terminateds": [false, true]}