RNNSAC with continous action space

Dear Ray-Team,
I´m trying to use the rnnsac algorithm with a continuous action space environment.

From ray.rllib.examples I took the rnnsac_stateless_cartpole.py example and changed the environment from StatelessCartPole to MountainCarContinuous-v0 to test the algorithm for my custom environment.

As a result, I get an error message about incorrect dimensions, with which I can not proceed.

Can you help me with this problem?
I’m relatively new to this area, maybe I overlooked or ignored something related to using SAC with RNN and contiunous action spaces?
Thank you very much in advance!!

My code (same as ray.rllib.examples.rnnsac_stateless_cartpole.py, only changed the env.):

import json
from pathlib import Path

import ray
from ray import tune
from ray.rllib.agents.registry import get_trainer_class

from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole


config = {
    "name": "RNNSAC_example",
    "local_dir": str(Path(__file__).parent / "example_out"),
    "checkpoint_freq": 1,
    "keep_checkpoints_num": 1,
    "checkpoint_score_attr": "episode_reward_mean",
    "stop": {
        "episode_reward_mean": 65.0,
        "timesteps_total": 100000,
    },
    "metric": "episode_reward_mean",
    "mode": "max",
    "verbose": 2,
    "config": {
        "framework": "torch",
        "num_workers": 4,
        "num_envs_per_worker": 1,
        "num_cpus_per_worker": 1,
        "log_level": "INFO",

        # "env": envs["RepeatAfterMeEnv"],
        # "env": envs["StatelessCartPole"],
        "env" : "MountainCarContinuous-v0",
        "horizon": 1000,
        "gamma": 0.95,
        "batch_mode": "complete_episodes",
        "prioritized_replay": False,
        "buffer_size": 100000,
        "learning_starts": 1000,
        "train_batch_size": 480,
        "target_network_update_freq": 480,
        "tau": 0.3,
        "burn_in": 4,
        "zero_init_states": False,
        "optimization": {
            "actor_learning_rate": 0.005,
            "critic_learning_rate": 0.005,
            "entropy_learning_rate": 0.0001
        },
        "model": {
            "max_seq_len": 20,
        },
        "policy_model": {
            "use_lstm": True,
            "lstm_cell_size": 64,
            "fcnet_hiddens": [64, 64],
            "lstm_use_prev_action": True,
            "lstm_use_prev_reward": True,
        },
        "Q_model": {
            "use_lstm": True,
            "lstm_cell_size": 64,
            "fcnet_hiddens": [64, 64],
            "lstm_use_prev_action": True,
            "lstm_use_prev_reward": True,
        },
    },
}

if __name__ == "__main__":
    # INIT
    ray.init(num_cpus=5)

    # TRAIN
    results = tune.run("RNNSAC", **config)

    # TEST
    best_checkpoint = results.best_checkpoint
    print("Loading checkpoint: {}".format(best_checkpoint))
    checkpoint_config_path = str(
        Path(best_checkpoint).parent.parent / "params.json")
    with open(checkpoint_config_path, "rb") as f:
        checkpoint_config = json.load(f)

    checkpoint_config["explore"] = False

    agent = get_trainer_class("RNNSAC")(
        env=config["config"]["env"], config=checkpoint_config)
    agent.restore(best_checkpoint)

    env = agent.env_creator({})
    state = agent.get_policy().get_initial_state()
    prev_action = 0
    prev_reward = 0
    obs = env.reset()

    eps = 0
    ep_reward = 0
    while eps < 10:
        action, state, info_trainer = agent.compute_action(
            obs,
            state=state,
            prev_action=prev_action,
            prev_reward=prev_reward,
            full_fetch=True)
        obs, reward, done, info = env.step(action)
        prev_action = action
        prev_reward = reward
        ep_reward += reward
        try:
            env.render()
        except (NotImplementedError, ImportError):
            pass
        if done:
            eps += 1
            print("Episode {}: {}".format(eps, ep_reward))
            ep_reward = 0
            state = agent.get_policy().get_initial_state()
            prev_action = 0
            prev_reward = 0
            obs = env.reset()
    ray.shutdown()

Full error message:

2021-11-29 18:27:50,134	ERROR worker.py:425 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RNNSACTrainer.__init__() (pid=20300, ip=127.0.0.1)
(pid=20300)   File "python\ray\_raylet.pyx", line 565, in ray._raylet.execute_task
(pid=20300)   File "python\ray\_raylet.pyx", line 569, in ray._raylet.execute_task
(pid=20300)   File "python\ray\_raylet.pyx", line 519, in ray._raylet.execute_task.function_executor
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\_private\function_manager.py", line 576, in actor_method_executor
(pid=20300)     return method(__ray_actor, *args, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 137, in __init__
(pid=20300)     Trainer.__init__(self, config, env, logger_creator)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 623, in __init__
(pid=20300)     super().__init__(config, logger_creator)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\tune\trainable.py", line 107, in __init__
(pid=20300)     self.setup(copy.deepcopy(self.config))
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 147, in setup
(pid=20300)     super().setup(config)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 776, in setup
(pid=20300)     self._init(self.config, self.env_creator)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 176, in _init
(pid=20300)     num_workers=self.config["num_workers"])
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 864, in _make_workers
(pid=20300)     logdir=self.logdir)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 89, in __init__
(pid=20300)     lambda p, pid: (pid, p.observation_space, p.action_space)))
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\_private\client_mode_hook.py", line 105, in wrapper
(pid=20300)     return func(*args, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\worker.py", line 1627, in get
(pid=20300)     raise value
(pid=20300) ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=24004, ip=127.0.0.1)
(pid=20300)   File "python\ray\_raylet.pyx", line 565, in ray._raylet.execute_task
(pid=20300)   File "python\ray\_raylet.pyx", line 569, in ray._raylet.execute_task
(pid=20300)   File "python\ray\_raylet.pyx", line 519, in ray._raylet.execute_task.function_executor
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\_private\function_manager.py", line 576, in actor_method_executor
(pid=20300)     return method(__ray_actor, *args, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 588, in __init__
(pid=20300)     seed=seed)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300)     return method(self, *_args, **_kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1385, in _build_policy_map
(pid=20300)     conf, merged_conf)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy_map.py", line 144, in create_policy
(pid=20300)     merged_config)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy_template.py", line 282, in __init__
(pid=20300)     stats_fn=None if self.config["in_evaluation"] else stats_fn,
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy.py", line 732, in _initialize_loss_from_dummy_batch
(pid=20300)     self._dummy_batch, explore=False)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\policy\torch_policy.py", line 303, in compute_actions_from_input_dict
(pid=20300)     seq_lens, explore, timestep)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
(pid=20300)     return func(self, *a, **k)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\policy\torch_policy.py", line 348, in _compute_action_helper
(pid=20300)     is_training=False)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_policy.py", line 175, in action_distribution_fn
(pid=20300)     _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_model.py", line 101, in get_q_values
(pid=20300)     seq_lens)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_model.py", line 91, in _get_q_value
(pid=20300)     out, state_out = net(model_out, state_in, seq_lens)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
(pid=20300)     res = self.forward(restored, state or [], seq_lens)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\recurrent_net.py", line 187, in forward
(pid=20300)     wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\fcnet.py", line 124, in forward
(pid=20300)     self._features = self._hidden_layers(self._last_flat_in)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300)     return forward_call(*input, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
(pid=20300)     input = module(input)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300)     return forward_call(*input, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\misc.py", line 160, in forward
(pid=20300)     return self._model(x)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300)     return forward_call(*input, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
(pid=20300)     input = module(input)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300)     return forward_call(*input, **kwargs)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
(pid=20300)     return F.linear(input, self.weight, self.bias)
(pid=20300)   File ".\WorkingENV\lib\site-packages\torch\nn\functional.py", line 1848, in linear
(pid=20300)     return torch._C._nn.linear(input, weight, bias)
(pid=20300) RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x2 and 3x64)

I have the same problem.
@sven1977 do you have any idea how to solve it?

hi also i have the same problem, and i created an issue on github.

Thank you in advance @sven1977