Hi!
I have implemented a custom model and a custom action distribution that I am training with PPO.
Everything works flawlessly but when sample_batch = compute_bootstrap_value(sample_batch, policy)
is called, I get value = dict.__getitem__(self, key) KeyError: 'vf_preds'
so I am unable to run a training step. Why are the value predictions not in the batch?
My custom model has a value_function() defined that is called and works…
This is my algo config:
algo = (
PPOConfig()
.rollouts(
num_rollout_workers=4,
num_envs_per_worker=1,
create_env_on_local_worker=False,
rollout_fragment_length=512,
batch_mode="truncate_episodes",
)
.resources(
num_gpus=0
)
.environment(
env="env_name"
)
.experimental(
_disable_initialize_loss_from_dummy_batch=True,
)
.training(
model={
"custom_model": "CustomModel",
"custom_action_dist": "CustomActionDist",
},
# ppo args
gamma=0.99,
lr=0.0001,
train_batch_size=2048,
use_critic=True,
use_gae=True,
lambda_=0.95,
kl_coeff=0.0,
sgd_minibatch_size=64,
num_sgd_iter=3,
clip_param=0.1,
entropy_coeff=0.01,
)
.multi_agent(
policies={
"player0": (None, env.observation_space["player0"], env.action_space["player0"], {}),
"player1": (None, env.observation_space["player1"], env.action_space["player1"], {})
},
policy_mapping_fn=(lambda agent_id, *args, **kwargs: "player0" if agent_id == "player0" else "player1"),
policies_to_train=["player0", "player1"]
)
.build()
)
and this is my custom model:
class CustomModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
nn.Module.__init__(self)
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
)
# Define network
self.encoder = #my CNN
# Define actor
self.actor1 = # actor for dim 1 of the dict action space
self.actor2 = # actor for dim 2 of the dict action space
# Define critic
self.critic = # single critic
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
self.features_ = self.encoder(input_dict)
return self.features_, state
@override(TorchModelV2)
def value_function(self):
value = torch.reshape(self.critic(self.features_), [-1])
return value
Note that the model only returns the context (features) and my custom action distribution handles this to sample actions from each actor, following the implementation in ray/rllib/examples/autoregressive_action_dist.py at master · ray-project/ray · GitHub
Thank you!