I am running the following example with PDB:
import argparse
import os
import numpy as np
import gym
from gym.spaces import Box, Discrete, MultiDiscrete
from typing import Dict, List, Union
from gym.envs.classic_control import CartPoleEnv
import ray
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
class StatelessCartPole(CartPoleEnv):
"""Partially observable variant of the CartPole gym environment.
We delete the x- and angular velocity components of the state, so that it
can only be solved by a memory enhanced model (policy).
def __init__(self, config=None):
# Fix our observation-space (remove 2 velocity components).
high = np.array(
self.x_threshold * 2,
self.theta_threshold_radians * 2,
self.observation_space = Box(low=-high, high=high, dtype=np.float32)
def step(self, action):
next_obs, reward, done, info = super().step(action)
# next_obs is [x-pos, x-veloc, angle, angle-veloc]
return np.array([next_obs[0], next_obs[2]]), reward, done, info
def reset(self):
init_obs = super().reset()
# init_obs is [x-pos, x-veloc, angle, angle-veloc]
return np.array([init_obs[0], init_obs[2]])
class TorchRNNModel(TorchRNN, nn.Module):
def __init__(self,
super().__init__(obs_space, action_space, num_outputs, model_config,
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size
# Build the Module from fc + LSTM + 2xfc (action + value outs).
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.lstm = nn.LSTM(
self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Holds the current "base" output (before logits layer).
self._features = None
def get_initial_state(self):
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
# View API is supported across all of RLlib.
# Place hidden states on same device as model.
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
return h
def value_function(self):
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
NN Outputs (B x T x ...) as sequence.
The state batches as a List of two items (c- and h-states).
import pdb; pdb.set_trace()
x = nn.functional.relu(self.fc1(inputs))
self._features, [h, c] = self.lstm(
x, [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(self._features)
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
ModelCatalog.register_custom_model("rnn", TorchRNNModel)
register_env("StatelessPendulum", lambda _: StatelessCartPole())
config = {
"env": "StatelessPendulum",
"env_config": {
"repeat_delay": 2,
"gamma": 0.9,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"simple_optimizer": True,
"sgd_minibatch_size": 128,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
"custom_model_config": {
"cell_size": 32,
"framework": "torch",
stop = {
"training_iteration": 10,
"timesteps_total": 1e6,
"episode_reward_mean": 500.,
results = tune.run('PPO', config=config, stop=stop, verbose=1)
the following is the result of the first cycles:
as you can see at the same breakpoint I am getting:
torch.Size([32, 1, 2])
torch.Size([1, 1, 2])
torch.Size([4, 8, 2])
torch.Size([20, 1, 2])
torch.Size([20, 1, 2])
So what is the reason of this different sizes?
And also since _time_major is false by default I would expect [bacth,timesteps,features] format and that should be [32,20,2].
The version I am using is 2.0.0dev installed today.