Hey folks,
I have trouble to get a “train_batch” in the shape of [batch, seq, feature] for my custom MARL RNN model.
I thought I can just use the example RNN model given on the RAY repo and adjust some configs, but I didn’t find the proper configs. For the “worker steps” the data seems fine, but I don’t get why there is an extra dimension. For the “train_batch” is there any way to get it in shape [Batch_Size,Num_Agents,Agents_Ops]?
Right now the “train_batch” looks a bit confusing to me, and I didn’t find anything related in the docs (of course I found the “Implementing custom Recurrent Networks” capture, and I know that forward_rnn() takes batches with the time dimension added already).
Also, why is the first dim = 5 in the “train_batch”?
If you have any idea, don’t hesitate and tell me, I really have no clue what to do next
class TorchRNNModel(TorchRNN, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name, fc_size=64, lstm_state_size=256):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config,name)
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
self._features = None
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
print("_______________________________________________________________-")
print(inputs.size())
print(inputs)
print(seq_lens)
print("_______________________________________________________________-")
x = nn.functional.relu(self.fc1(inputs))
self._features, [h, c] = self.lstm(x)
action_out = self.action_branch(self._features)
return action_out, state
@override(ModelV2)
def value_function(self):
return torch.reshape(self.value_branch(self._features), [-1])
@override(ModelV2)
def get_initial_state(self):
h = [self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0), self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)]
return h
policies = {"shared_policy": (None, obs_space, act_space, {})}
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
checkpoint_freq=20,
config={
# Enviroment specific
"env": MABasicFactory,
"env_config": env.env_config,
# General
"num_gpus": 0,
"num_workers": 1,
"shuffle_sequences": False,
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: "shared_policy"),
},
"framework": "torch",
"model": {
"custom_model": "commnet",
"custom_model_config": {},
"max_seq_len": 33,
},
},
)