MARL Custom RNN Model Batch Shape (batch, seq, feature)

Hey folks,

I have trouble to get a “train_batch” in the shape of [batch, seq, feature] for my custom MARL RNN model.
I thought I can just use the example RNN model given on the RAY repo and adjust some configs, but I didn’t find the proper configs. For the “worker steps” the data seems fine, but I don’t get why there is an extra dimension. For the “train_batch” is there any way to get it in shape [Batch_Size,Num_Agents,Agents_Ops]?
Right now the “train_batch” looks a bit confusing to me, and I didn’t find anything related in the docs (of course I found the “Implementing custom Recurrent Networks” capture, and I know that forward_rnn() takes batches with the time dimension added already).

Also, why is the first dim = 5 in the “train_batch”?

If you have any idea, don’t hesitate and tell me, I really have no clue what to do next :smiley:

class TorchRNNModel(TorchRNN, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name, fc_size=64, lstm_state_size=256):
    super().__init__(obs_space, action_space, num_outputs, model_config,name)

    self.obs_size = get_preprocessor(obs_space)(obs_space).size
    self.fc_size = fc_size
    self.lstm_state_size = lstm_state_size

    self.fc1 = nn.Linear(self.obs_size, self.fc_size)
    self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=True)
    self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
    self.value_branch = nn.Linear(self.lstm_state_size, 1)

    self._features = None

def forward_rnn(self, inputs, state, seq_lens):


    x = nn.functional.relu(self.fc1(inputs))
    self._features, [h, c] = self.lstm(x)
    action_out = self.action_branch(self._features)

    return action_out, state

def value_function(self):
    return torch.reshape(self.value_branch(self._features), [-1])

def get_initial_state(self):
    h = [, self.lstm_state_size).zero_().squeeze(0),, self.lstm_state_size).zero_().squeeze(0)]
    return h

policies = {"shared_policy": (None, obs_space, act_space, {})}
    stop={"episode_reward_mean": 200},
        # Enviroment specific
        "env": MABasicFactory,
        "env_config": env.env_config,

        # General
        "num_gpus": 0,
        "num_workers": 1,

        "shuffle_sequences": False,

        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": (lambda agent_id: "shared_policy"),
        "framework": "torch",
        "model": {
            "custom_model": "commnet",
            "custom_model_config": {},
            "max_seq_len": 33,

cc @amogkam @rliaw @kai

1 Like

cc @sven1977 for RLlib

1 Like

Is my problem understandable?

What are your settings for num_workers, envs_per_worker, max_seq_len, horizon, replay_sequence_length and batch_mode? Does your environment always run for a fixed set of steps or does it vary by episode?

1 Like

Hi @mannyv thx for your awnser! max_seq_len, num_workers I defined by me self. The other setting are just the default values, as you can see in the code example.

Maybe there is a problem with using PPO?

" Does your environment always run for a fixed set of steps or does it vary by episode?"
– it varys by episode

"max_seq_len": 33, # just for testing propose
"num_workers": 1,

"num_envs_per_worker": 1,
"horizon": None,
"batch_mode": "truncate_episodes"
replay_sequence_length: isnt used by ppo

Also, is the behavior I encounter a bug? Or have I just very specific requirements?

Hey @CodingBurmer , am I right assuming that you would like the model to get as input all agents’ observations? (I’m trying to understand the meaning of seq and num_agents mentioned by you above).
If yes, then you should probably take a look at our “centralized_critic” examples, where we postprocess the batch in a way that it contains the other agents’ observations as well.


1 Like

Hi @sven1977 here is a quick example of what I need.

In the env shown above, each agent get just his coordinates as his obs (x, y). So it is clear that the agents have to “talk” to each other to solve this problem. Here is some pseudocode of the env step.

# actions: Up=0, Down=1,Left=2 Right=3
agent_actions = {0:3, 1:0, 2:1, 3:1}

obs_dict = env.step(agent_actions)

obs_dict  --> {0:[2,1],  1:[5,2], 2:[5,0], 3:[5,0]}
obs_tensor --> [ [2,1], [5,2], [5,0], [5,0] ] 

Now, this obs_tensor can be seen as a sequence of features (observations) which I can feed to an RNN-network. So first we feed [2,1] (the first agent’s obs) to the RNN-layer, followed by [5,2] (the second agent’s obs) and so on. Now the thing behind this is that the RNN-layer change this hidden state after the first ops and “remembers” the position of the first agent, so when the next agent obs is feed to the RNN-network the Network knows the position of the first agent and can make a better decision. As you can read in many papers this is working very well. Right now I just try to implement my PyTorch model in RLlib to improve the performance.

In conclusion after a few steps i need my batch to look like this:

Batch_shape = (batch_size, seq, features)
 batch_size: number of seq
 seq: number of ops (here 4)
 features: shape of obs

Batch = [  
[[2,1], [5,2], [5,0],[5,0]] #Step v
[[x,y], [x,y], [x,y],[x,y]] #Step z
[[x,y], [x,y], [x,y],[x,y]] #Step y
[[x,y], [x,y], [x,y],[x,y]] #Step x

Where v,z,y,x are just random steps and don’t have to be in ascending order.

And yes you are right, I can solve this problem by easily adding a centralized critic, but I need the agents to be scalable,

Looks really interesting, were you able to solve the problem?

1 Like

@CodingBurmer one comment that might be helpful. The first few passes through the model you are seeing are probably not real data. They are data being pushed through as part of the trajectory_view code to determine view requirements. From some examples I have run I usually see 4 forward passes with dummy data as part of the setup before the training process actually starts.

As sven1977 was saying, you could look at the centralized_critic_2 for an example of how you could share (rewrite) agent observations to include the observations of other agents during training. You need not actually use a centralized critic. Just take the sharing observations part of the example.

Do you have a sample repo or a minimal sample you could share.

1 Like