I have a mixin which needs to override the compute_actions_from_input_dict
function in a policy.
I was previously doing this using compute_actions in version 1.1.0 but wanted to update this to use 1.2.0.
I’m using a method to create action masks which relies on choosing from a tree of possible actions during exploration phase and applying action masks accordingly to calculate logits/logps etc.
I’m passing the info dictionary to the compute_actions_from_input_dict… but It is getting stripped out of the dictionary because its not accessed in either the model or the post-processing.
My current workaround is to not use the trajectory view API and override compute_actions() with _use_trajectory_view_api=False (which in version 1.2.0 this throws another error about model.num_framestack which I can’t work out, so i ended up reverting to 1.1.0)
Is there something I am doing wrong?
I’ve included my custom compute function (up to the point where ‘infos’ is accessed)
@override(Policy)
def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs):
if not self.config['env_config'].get('invalid_action_masking', False):
raise RuntimeError('invalid_action_masking must be set to True in env_config to use this mixin')
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
# Calculate RNN sequence lengths.
seq_lens = np.array([1] * len(input_dict["obs"])) \
if state_batches else None
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
explore=explore, timestep=timestep)
dist_inputs, state_out = self.model(input_dict, state_batches,
seq_lens)
# Extract the tree from the info batch
valid_action_trees = []
infos = input_dict[SampleBatch.INFOS]
for info in infos:
if not isinstance(info, torch.Tensor) and 'valid_action_tree' in info:
valid_action_trees.append(info['valid_action_tree'])
else:
valid_action_trees.append({0: {0: {0: [0]}}})