1. Severity of the issue: (select one)
High: Completely blocks me.
Under the new API interface, the forward
function needs to return the advantage; otherwise, a KeyError
will be reported. If a placeholder is used to return the advantage, the policy loss is always 0. How can this problem be solved? The code is as follows:
combined = torch.cat([z, coord_feat], dim=1).unsqueeze(1).unsqueeze(1)
action_embed = self.fusion(combined).squeeze()
action_logits = self._pi_head(action_embed)
# action_logits += torch.log(obs["action_mask"].float().to(device) + 1e-10)
action_logits = torch.where(
obs["action_mask"].bool(),
action_logits,
torch.tensor(-1e10, device=device)
)
values = self._vf_head(action_embed).squeeze(-1)
action_dist = self.action_dist_cls(logits=action_logits)
# action = torch.argmax(action_dist.logits, dim=-1)
action = action_dist.sample()
output = {
Columns.ACTIONS: action,
Columns.ACTION_DIST_INPUTS: action_logits,
Columns.VF_PREDS: values,
Columns.ACTION_LOGP: action_dist.logp(action),
# Columns.ADVANTAGES:
# Columns.VALUE_TARGETS:
Columns.EMBEDDINGS: action_embed,
}
return output