'infos' automatically stripped if they are accessed in mixin

I have a mixin which needs to override the compute_actions_from_input_dict function in a policy.

I was previously doing this using compute_actions in version 1.1.0 but wanted to update this to use 1.2.0.

I’m using a method to create action masks which relies on choosing from a tree of possible actions during exploration phase and applying action masks accordingly to calculate logits/logps etc.

I’m passing the info dictionary to the compute_actions_from_input_dict… but It is getting stripped out of the dictionary because its not accessed in either the model or the post-processing.

My current workaround is to not use the trajectory view API and override compute_actions() with _use_trajectory_view_api=False (which in version 1.2.0 this throws another error about model.num_framestack which I can’t work out, so i ended up reverting to 1.1.0)

Is there something I am doing wrong?

I’ve included my custom compute function (up to the point where ‘infos’ is accessed)

@override(Policy)
    def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs):

        if not self.config['env_config'].get('invalid_action_masking', False):
            raise RuntimeError('invalid_action_masking must be set to True in env_config to use this mixin')

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = np.array([1] * len(input_dict["obs"])) \
                if state_batches else None

            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(
                explore=explore, timestep=timestep)

            dist_inputs, state_out = self.model(input_dict, state_batches,
                                                seq_lens)
            # Extract the tree from the info batch
            valid_action_trees = []
            infos = input_dict[SampleBatch.INFOS]
            for info in infos:
                if not isinstance(info, torch.Tensor) and 'valid_action_tree' in info:
                    valid_action_trees.append(info['valid_action_tree'])
                else:
                    valid_action_trees.append({0: {0: {0: [0]}}})

Hey @Bam4d . Thanks for filing this. I can reproduce the error.
Taking a look right now …

1 Like

Here is the issue and the PR for the fix:

Issue: [RLlib] `Policy.compute_actions_from_input_dict` does not properly track accessed fields for the Policy's view-requirements. · Issue #14385 · ray-project/ray · GitHub
PR: [RLlib] Issue 14385: `Policy.compute_actions_from_input_dict` does not properly track accessed fields for Policy's view requirements. by sven1977 · Pull Request #14386 · ray-project/ray · GitHub

1 Like