Setting up trajectory view correctly for repeated+non-repeated input

  • High: It blocks me to complete my task.

I am trying to use trajectory view for my model which is supposed to have an LSTM and a fully connected network in parallel. This is the workflow I have in mind:

recurrent obs = obs[“repeated”][-10:0] → LSTM → Fully
Connected → Output
non-recurrent = obs[“fixed”] ---------------> Network

I am thinking my best bet would be to implement a custom RNN.
So this is how I set up my obs space:

obs_space = spaces.Dict({
            "repeated": spaces.Box(low=np.array(([0] * 3 + [0, 0, 0, 0, 0, -1, 0, 0, 0, -np.inf])),
                                   high=np.array(([np.inf] * 8 + [1] + [np.inf] * 4)),
            "fixed": spaces.Box(low=
                                np.array([0, 0, -1, 0]),
                                high=np.array([1, 1, np.inf, np.inf]),

I then have this in my Policy.postprocess_trajectory function:

class new_policy(A3CTorchPolicy):
        def postprocess_trajectory(self, sample_batch: SampleBatch,
                                   other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
                                   episode: Optional[Episode] = None):
            sample_batch = super(new_policy, self).postprocess_trajectory(sample_batch,
            print("sample", sample_batch["obs"])
            sample_batch["fixed"] = np.array(sample_batch["obs"][:4])
            sample_batch["repeated"] = np.array(sample_batch["obs"][4:])
            return sample_batch

Here I use [“obs”][:4] instead of [“obs”][“fixed”] because obs seems to be a flattened numpy array at this point, and not a dict as I would have expected, but that’s fine.

Then, I set up my view requirement in my model

class my_model(TorchModelV2, torch.nn.Module):
    def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space,
                 num_outputs: int, model_config: ModelConfigDict, name: str):
          self.view_requirements["rep"] = ViewRequirement(data_col="repeated", shift="-10:0")

    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        print("Got input", input_dict["obs"], input_dict["obs_flat"], input_dict)

However, in my forward call, I don’t get the column “rep” at all! I also don’t see “repeated” or “fixed” columns. I suspect there is a flaw in my understanding of the trajectory view API and I would appreciate any insights into what is going wrong here and what would be the best way for me to do this.

Thanks in advance!

Hey @utkarshp,

Have you tried setting _disable_preprocessor_api to False? Setting this flag, will tell RLlib to not flatten the data structure of the observation.

Trajectory view API is responsible for dictating the constraints for trajectory construction and dependencies of certain keys on others across time. So here you need to first get access to repeated and fixed key under each obs and then using TrajectoryViewAPI decide how to contextualize your input state as a function of these keys.