Hay,
I’m training a custom model using RLlib 3.5.0. The structure is like this:
class GATPolicyModelLSTM(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
...
def forward(self, input_dict, state, seq_lens):
...
return logits, state
Using this code:
config = ({
"env": "custom_env",
"env_config": env_config,
"framework": "torch",
"num_gpus": 1,
"num_workers": 5,
"sample_async": True,
"multiagent": {
"policies": {
"shared_policy": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": policy_mapping_fn,
},
"model": {
"custom_model": "stm_policy_model",
"custom_model_config": {
"use_lstm": True,
"max_seq_len": 20,
"lstm_cell_size": 256,
},
},
"lr": 5e-5,
"rollout_fragment_length": 'auto',
"train_batch_size": 4000,
"sgd_minibatch_size": 128,
"num_sgd_iter": 10,
"grad_clip": 0.5,
"clip_param": 0.2,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": [
[0, 0.2], # At iteration 0, entropy_coeff = 0.1
[2500000, 0.01], # At iteration 250, entropy_coeff = 0.01
],
"lambda": 0.95,
"vf_clip_param": 10.0,
"batch_mode": "truncate_episodes",
"log_level": "DEBUG",
"log_sys_usage": False,
"local_mode": True,
})
algo = ppo.PPO(env="custom_env", config=config)
But I’m receiving errors while using build-in LSTM.
Is there any template code that can guide me how could I use the LSTM wrapper in my custom model?