Offline learning with MARWIL with LSTM

Hey!

I am trying to launch an offline learning using MARWIL (ray 1.9) and a custom LSTM model and I encounter an error:

Failure # 1 (occurred at 2021-12-09_14-21-27)
Traceback (most recent call last):
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\tune\trial_runner.py", line 924, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\tune\ray_trial_executor.py", line 787, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\_private\client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): e[36mray::MARWIL.train()e[39m (pid=20736, ip=127.0.0.1, repr=MARWIL)
  File "python\ray\_raylet.pyx", line 625, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 629, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 578, in ray._raylet.execute_task.function_executor
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\_private\function_manager.py", line 609, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\tune\trainable.py", line 314, in train
    result = self.step()
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\agents\trainer.py", line 880, in step
    raise e
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\agents\trainer.py", line 867, in step
    result = self.step_attempt()
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\agents\trainer.py", line 920, in step_attempt
    step_results = next(self.train_exec_impl)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 1075, in build_union
    item = next(it)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  [Previous line repeated 1 more time]
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 471, in base_iterator
    yield ray.get(futures, timeout=timeout)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\_private\client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): e[36mray::RolloutWorker.par_iter_next()e[39m (pid=21924, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x0000025A6412C940>)
  File "python\ray\_raylet.pyx", line 625, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 629, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 578, in ray._raylet.execute_task.function_executor
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\_private\function_manager.py", line 609, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\iter.py", line 1151, in par_iter_next
    return next(self.local_it)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 381, in gen_rollouts
    yield self.sample()
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 789, in sample
    estimator.process(sub_batch)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\offline\off_policy_estimator.py", line 121, in process
    self.new_estimates.append(self.estimate(batch))
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\offline\is_estimator.py", line 17, in estimate
    new_prob = self.action_log_likelihood(batch)
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\offline\off_policy_estimator.py", line 99, in action_log_likelihood
    log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
  File "C:\Users\<USERNAME>2\Miniconda3\envs\<CONDA_ENV_NAME>\lib\site-packages\ray\rllib\policy\tf_policy.py", line 380, in compute_log_likelihoods
    raise ValueError(
ValueError: Must pass in RNN state batches for placeholders [<tf.Tensor 'default_policy_wk1/Placeholder:0' shape=(?, 200) dtype=float32>, <tf.Tensor 'default_policy_wk1/Placeholder_1:0' shape=(?, 200) dtype=float32>], got []

I generated my data using the example from the doc:

batch_builder.add_values(
                t=t,
                eps_id=eps_id,
                agent_index=0,
                obs=prep.transform(obs),
                actions=action,
                action_prob=1.0,  # put the true action probability here
                action_logp=0.0,
                rewards=rew,
                prev_actions=prev_action,
                prev_rewards=prev_reward,
                dones=done,
                infos=info,
                new_obs=prep.transform(new_obs),
            )

What should I do to solve my error? Should I add the last state info in batch_builder.add_values or is it done automatically somewhere? I did not find anything related to offline learning + lstm in the doc or in the examples.

Thanks!

Hi @Fabien-Couthouis,

Marwil does not currently support RNNs

Hi @mannyv, thanks for your quick answer!
It seems like the current documentation is misleading as model support for MARWIL contains RNN in the algorithms section.

Hi @Fabien-Couthouis,

Yeha the documentation is wrong and has been for a long time.