Dear Ray-Team,
I´m trying to use the rnnsac algorithm with a continuous action space environment.
From ray.rllib.examples I took the rnnsac_stateless_cartpole.py example and changed the environment from StatelessCartPole to MountainCarContinuous-v0 to test the algorithm for my custom environment.
As a result, I get an error message about incorrect dimensions, with which I can not proceed.
Can you help me with this problem?
I’m relatively new to this area, maybe I overlooked or ignored something related to using SAC with RNN and contiunous action spaces?
Thank you very much in advance!!
My code (same as ray.rllib.examples.rnnsac_stateless_cartpole.py, only changed the env.):
import json
from pathlib import Path
import ray
from ray import tune
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
config = {
"name": "RNNSAC_example",
"local_dir": str(Path(__file__).parent / "example_out"),
"checkpoint_freq": 1,
"keep_checkpoints_num": 1,
"checkpoint_score_attr": "episode_reward_mean",
"stop": {
"episode_reward_mean": 65.0,
"timesteps_total": 100000,
},
"metric": "episode_reward_mean",
"mode": "max",
"verbose": 2,
"config": {
"framework": "torch",
"num_workers": 4,
"num_envs_per_worker": 1,
"num_cpus_per_worker": 1,
"log_level": "INFO",
# "env": envs["RepeatAfterMeEnv"],
# "env": envs["StatelessCartPole"],
"env" : "MountainCarContinuous-v0",
"horizon": 1000,
"gamma": 0.95,
"batch_mode": "complete_episodes",
"prioritized_replay": False,
"buffer_size": 100000,
"learning_starts": 1000,
"train_batch_size": 480,
"target_network_update_freq": 480,
"tau": 0.3,
"burn_in": 4,
"zero_init_states": False,
"optimization": {
"actor_learning_rate": 0.005,
"critic_learning_rate": 0.005,
"entropy_learning_rate": 0.0001
},
"model": {
"max_seq_len": 20,
},
"policy_model": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
"Q_model": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
},
}
if __name__ == "__main__":
# INIT
ray.init(num_cpus=5)
# TRAIN
results = tune.run("RNNSAC", **config)
# TEST
best_checkpoint = results.best_checkpoint
print("Loading checkpoint: {}".format(best_checkpoint))
checkpoint_config_path = str(
Path(best_checkpoint).parent.parent / "params.json")
with open(checkpoint_config_path, "rb") as f:
checkpoint_config = json.load(f)
checkpoint_config["explore"] = False
agent = get_trainer_class("RNNSAC")(
env=config["config"]["env"], config=checkpoint_config)
agent.restore(best_checkpoint)
env = agent.env_creator({})
state = agent.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
eps = 0
ep_reward = 0
while eps < 10:
action, state, info_trainer = agent.compute_action(
obs,
state=state,
prev_action=prev_action,
prev_reward=prev_reward,
full_fetch=True)
obs, reward, done, info = env.step(action)
prev_action = action
prev_reward = reward
ep_reward += reward
try:
env.render()
except (NotImplementedError, ImportError):
pass
if done:
eps += 1
print("Episode {}: {}".format(eps, ep_reward))
ep_reward = 0
state = agent.get_policy().get_initial_state()
prev_action = 0
prev_reward = 0
obs = env.reset()
ray.shutdown()
Full error message:
2021-11-29 18:27:50,134 ERROR worker.py:425 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RNNSACTrainer.__init__() (pid=20300, ip=127.0.0.1)
(pid=20300) File "python\ray\_raylet.pyx", line 565, in ray._raylet.execute_task
(pid=20300) File "python\ray\_raylet.pyx", line 569, in ray._raylet.execute_task
(pid=20300) File "python\ray\_raylet.pyx", line 519, in ray._raylet.execute_task.function_executor
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\_private\function_manager.py", line 576, in actor_method_executor
(pid=20300) return method(__ray_actor, *args, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 137, in __init__
(pid=20300) Trainer.__init__(self, config, env, logger_creator)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 623, in __init__
(pid=20300) super().__init__(config, logger_creator)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\tune\trainable.py", line 107, in __init__
(pid=20300) self.setup(copy.deepcopy(self.config))
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 147, in setup
(pid=20300) super().setup(config)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 776, in setup
(pid=20300) self._init(self.config, self.env_creator)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer_template.py", line 176, in _init
(pid=20300) num_workers=self.config["num_workers"])
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\trainer.py", line 864, in _make_workers
(pid=20300) logdir=self.logdir)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 89, in __init__
(pid=20300) lambda p, pid: (pid, p.observation_space, p.action_space)))
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\_private\client_mode_hook.py", line 105, in wrapper
(pid=20300) return func(*args, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\worker.py", line 1627, in get
(pid=20300) raise value
(pid=20300) ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=24004, ip=127.0.0.1)
(pid=20300) File "python\ray\_raylet.pyx", line 565, in ray._raylet.execute_task
(pid=20300) File "python\ray\_raylet.pyx", line 569, in ray._raylet.execute_task
(pid=20300) File "python\ray\_raylet.pyx", line 519, in ray._raylet.execute_task.function_executor
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\_private\function_manager.py", line 576, in actor_method_executor
(pid=20300) return method(__ray_actor, *args, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 588, in __init__
(pid=20300) seed=seed)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
(pid=20300) return method(self, *_args, **_kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1385, in _build_policy_map
(pid=20300) conf, merged_conf)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy_map.py", line 144, in create_policy
(pid=20300) merged_config)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy_template.py", line 282, in __init__
(pid=20300) stats_fn=None if self.config["in_evaluation"] else stats_fn,
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\policy\policy.py", line 732, in _initialize_loss_from_dummy_batch
(pid=20300) self._dummy_batch, explore=False)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\policy\torch_policy.py", line 303, in compute_actions_from_input_dict
(pid=20300) seq_lens, explore, timestep)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
(pid=20300) return func(self, *a, **k)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\policy\torch_policy.py", line 348, in _compute_action_helper
(pid=20300) is_training=False)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_policy.py", line 175, in action_distribution_fn
(pid=20300) _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_model.py", line 101, in get_q_values
(pid=20300) seq_lens)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\agents\sac\rnnsac_torch_model.py", line 91, in _get_q_value
(pid=20300) out, state_out = net(model_out, state_in, seq_lens)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
(pid=20300) res = self.forward(restored, state or [], seq_lens)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\recurrent_net.py", line 187, in forward
(pid=20300) wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\fcnet.py", line 124, in forward
(pid=20300) self._features = self._hidden_layers(self._last_flat_in)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300) return forward_call(*input, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
(pid=20300) input = module(input)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300) return forward_call(*input, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\ray\rllib\models\torch\misc.py", line 160, in forward
(pid=20300) return self._model(x)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300) return forward_call(*input, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
(pid=20300) input = module(input)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
(pid=20300) return forward_call(*input, **kwargs)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
(pid=20300) return F.linear(input, self.weight, self.bias)
(pid=20300) File ".\WorkingENV\lib\site-packages\torch\nn\functional.py", line 1848, in linear
(pid=20300) return torch._C._nn.linear(input, weight, bias)
(pid=20300) RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x2 and 3x64)