Hey!
RNN are currently not supported for offline algorithms. Indeed, this reproduction script (for MARWIL):
def test_marwil_rnn():
ModelCatalog.register_custom_model("rnn", RNNModel)
# This path may change depending on the location of this file (works for rllib.agents.marwil.tests)
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
config = marwil.DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["evaluation_num_workers"] = 0
# Evaluate on actual environment.
config["evaluation_config"] = {"input": "sampler"}
config["model"] = {
"custom_model": "rnn",
}
config["input_evaluation"] = [] # ["is", "wis"]
# Learn from offline data.
config["input"] = [data_file]
num_iterations = 10
min_reward = 70.0
frameworks = "tf"
for _ in framework_iterator(config, frameworks=frameworks):
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
learnt = False
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
eval_results = results.get("evaluation")
if eval_results:
print("iter={} R={} ".format(i, eval_results["episode_reward_mean"]))
# Learn until some reward is reached on an actual live env.
if eval_results["episode_reward_mean"] > min_reward:
print("learnt!")
learnt = True
break
if not learnt:
raise ValueError(
"MARWILTrainer did not reach {} reward from expert "
"offline data!".format(min_reward)
)
check_compute_single_action(trainer, include_prev_action_reward=True)
trainer.stop()
if __name__ == "__main__":
test_marwil_rnn()
leads to:
Traceback (most recent call last):
File "c:/Users/<username>/Documents/dev/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 155, in <module>
test_marwil_rnn()
File "c:/Users/<username>/Documents/dev/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 126, in test_marwil_rnn
results = trainer.train()
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\tune\trainable.py", line 315, in train
result = self.step()
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 982, in step
raise e
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 963, in step
step_attempt_results = self.step_attempt()
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 1042, in step_attempt
step_results = self._exec_plan_or_training_iteration_fn()
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 1966, in _exec_plan_or_training_iteration_fn
results = next(self.train_exec_impl)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 756, in __next__
return next(self.built_iterator)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 1075, in build_union
item = next(it)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 756, in __next__
return next(self.built_iterator)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 791, in apply_foreach
result = fn(item)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\execution\train_ops.py", line 114, in __call__
self.workers.local_worker().learn_on_batch(batch)
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 856, in learn_on_batch
to_fetch[pid] = policy._build_learn_on_batch(
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\tf_policy.py", line 1088, in _build_learn_on_batch
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\tf_policy.py", line 1140, in _get_loss_inputs_dict
train_batch[key],
File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\sample_batch.py", line 722, in __getitem__
value = dict.__getitem__(self, key)
KeyError: 'state_in_0'
The states and seq_lens keys are never set in SampleBatch for offline algos (if we do not set them using the JsonWriter).
I might be interested in helping to support RNN with offline algos but I am wondering what is the best way to achieve this. Where could I add the missing keys in the SampleBatch object?
Are there any other changes than the ones below in order to run my script without error?
- Pad RNN values in MARWIL/CQL as in PPO
- Fix off-policy estimators (IS,WIS) for RNN, as suggested in [Rllib][marwil] loss calculations not masking padded values when using RNN · Issue #15800 · ray-project/ray · GitHub
- Add states_in, states_out and seq_lens keys in SampleBatch in offline algos policies