RNN support for offline algorithms

Hey!

RNN are currently not supported for offline algorithms. Indeed, this reproduction script (for MARWIL):

def test_marwil_rnn():
    ModelCatalog.register_custom_model("rnn", RNNModel)
    # This path may change depending on the location of this file (works for rllib.agents.marwil.tests)
    rllib_dir = Path(__file__).parent.parent.parent.parent
    print("rllib dir={}".format(rllib_dir))
    data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json")
    print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))

    config = marwil.DEFAULT_CONFIG.copy()
    config["num_workers"] = 0
    config["evaluation_num_workers"] = 0
    # Evaluate on actual environment.
    config["evaluation_config"] = {"input": "sampler"}
    config["model"] = {
        "custom_model": "rnn",
    }
    config["input_evaluation"] = []  # ["is", "wis"]
    # Learn from offline data.
    config["input"] = [data_file]
    num_iterations = 10
    min_reward = 70.0

    frameworks = "tf"
    for _ in framework_iterator(config, frameworks=frameworks):
        trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
        learnt = False
        for i in range(num_iterations):
            results = trainer.train()
            check_train_results(results)

            eval_results = results.get("evaluation")
            if eval_results:
                print("iter={} R={} ".format(i, eval_results["episode_reward_mean"]))
                # Learn until some reward is reached on an actual live env.
                if eval_results["episode_reward_mean"] > min_reward:
                    print("learnt!")
                    learnt = True
                    break

        if not learnt:
            raise ValueError(
                "MARWILTrainer did not reach {} reward from expert "
                "offline data!".format(min_reward)
            )

        check_compute_single_action(trainer, include_prev_action_reward=True)

        trainer.stop()

if __name__ == "__main__":
    test_marwil_rnn()

leads to:

Traceback (most recent call last):
  File "c:/Users/<username>/Documents/dev/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 155, in <module>
    test_marwil_rnn()
  File "c:/Users/<username>/Documents/dev/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 126, in test_marwil_rnn
    results = trainer.train()
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\tune\trainable.py", line 315, in train
    result = self.step()
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 982, in step
    raise e
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 963, in step
    step_attempt_results = self.step_attempt()
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 1042, in step_attempt
    step_results = self._exec_plan_or_training_iteration_fn()
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\agents\trainer.py", line 1966, in _exec_plan_or_training_iteration_fn
    results = next(self.train_exec_impl)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 1075, in build_union
    item = next(it)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\util\iter.py", line 791, in apply_foreach
    result = fn(item)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\execution\train_ops.py", line 114, in __call__
    self.workers.local_worker().learn_on_batch(batch)
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 856, in learn_on_batch
    to_fetch[pid] = policy._build_learn_on_batch(
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\tf_policy.py", line 1088, in _build_learn_on_batch
    self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\tf_policy.py", line 1140, in _get_loss_inputs_dict
    train_batch[key],
  File "C:\Users\<username>\Miniconda3\envs\<env_name>\lib\site-packages\ray\rllib\policy\sample_batch.py", line 722, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'state_in_0'

The states and seq_lens keys are never set in SampleBatch for offline algos (if we do not set them using the JsonWriter).

I might be interested in helping to support RNN with offline algos but I am wondering what is the best way to achieve this. Where could I add the missing keys in the SampleBatch object?

Are there any other changes than the ones below in order to run my script without error?

Hey @Fabien-Couthouis , yeah, would be great to have this! We could probably borrow some of the buffer magic that R2D2 and RNNSAC are currently using (setting the replay buffers to return sequences instead of single timesteps) for this to work.

1 Like

Happy to help out with the PR. Let me know, if you would like to give it a shot.

Hello!

I tried to set the ‘replay_sequence_length’ config > 1 for MARWIL. I also added the seq_lens key to the batch in the JSonReader class but the state keys are also missing in the batch while trying to apply learn_on_loaded_batch:

Traceback (most recent call last):
  File "<ray_path>/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 85, in <module>
    test_marwil_rnn()
  File "<ray_path>/ray/rllib/agents/marwil/tests/test_marwil_rnn.py", line 61, in test_marwil_rnn
    results = trainer.train()
  File "<ray_path>\ray\tune\trainable.py", line 315, in train
    result = self.step()
  File "<ray_path>\ray\rllib\agents\trainer.py", line 991, in step
    raise e
  File "<ray_path>\ray\rllib\agents\trainer.py", line 972, in step
    step_attempt_results = self.step_attempt()
  File "<ray_path>\ray\rllib\agents\trainer.py", line 1051, in step_attempt
    step_results = self._exec_plan_or_training_iteration_fn()
  File "<ray_path>\ray\rllib\agents\trainer.py", line 1993, in _exec_plan_or_training_iteration_fn
    results = next(self.train_exec_impl)
  File "<ray_path>\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "<ray_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 1075, in build_union
    item = next(it)
  File "<ray_path>\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "<ray_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<ray_path>\ray\util\iter.py", line 791, in apply_foreach
    result = fn(item)
  File "<ray_path>\ray\rllib\execution\train_ops.py", line 329, in __call__
    results = policy.learn_on_loaded_batch(
  File "<ray_path>\ray\rllib\policy\dynamic_tf_policy.py", line 549, in learn_on_loaded_batch
    return self.learn_on_batch(sliced_batch)
  File "<ray_path>\ray\rllib\policy\tf_policy.py", line 419, in learn_on_batch
    fetches = self._build_learn_on_batch(builder, postprocessed_batch)
  File "<ray_path>\ray\rllib\policy\tf_policy.py", line 1088, in _build_learn_on_batch
    self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
  File "<ray_path>\ray\rllib\policy\tf_policy.py", line 1144, in _get_loss_inputs_dict
    train_batch[key],
  File "<ray_path>\ray\rllib\policy\sample_batch.py", line 722, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'state_in_0'

Where can I set the state_in_i keys in the batch? As far as I understand, the states are set into the batch in SimpleListCollector for online algos but offline algos do not use this class for training. Should I do it in the InputReader (JSonReader / D4RLReader) or is there a possibility to add the states elsewhere?

Here are the changes I have made so far: