Setting up trajectory view correctly for repeated+non-repeated input

  • High: It blocks me to complete my task.

I am trying to use trajectory view for my model which is supposed to have an LSTM and a fully connected network in parallel. This is the workflow I have in mind:

recurrent obs = obs[“repeated”][-10:0] → LSTM → Fully
Connected → Output
non-recurrent = obs[“fixed”] ---------------> Network

I am thinking my best bet would be to implement a custom RNN.
So this is how I set up my obs space:

obs_space = spaces.Dict({
            "repeated": spaces.Box(low=np.array(([0] * 3 + [0, 0, 0, 0, 0, -1, 0, 0, 0, -np.inf])),
                                   high=np.array(([np.inf] * 8 + [1] + [np.inf] * 4)),
            "fixed": spaces.Box(low=
                                np.array([0, 0, -1, 0]),
                                high=np.array([1, 1, np.inf, np.inf]),

I then have this in my Policy.postprocess_trajectory function:

class new_policy(A3CTorchPolicy):
        def postprocess_trajectory(self, sample_batch: SampleBatch,
                                   other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
                                   episode: Optional[Episode] = None):
            sample_batch = super(new_policy, self).postprocess_trajectory(sample_batch,
            print("sample", sample_batch["obs"])
            sample_batch["fixed"] = np.array(sample_batch["obs"][:4])
            sample_batch["repeated"] = np.array(sample_batch["obs"][4:])
            return sample_batch

Here I use [“obs”][:4] instead of [“obs”][“fixed”] because obs seems to be a flattened numpy array at this point, and not a dict as I would have expected, but that’s fine.

Then, I set up my view requirement in my model

class my_model(TorchModelV2, torch.nn.Module):
    def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space,
                 num_outputs: int, model_config: ModelConfigDict, name: str):
          self.view_requirements["rep"] = ViewRequirement(data_col="repeated", shift="-10:0")

    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        print("Got input", input_dict["obs"], input_dict["obs_flat"], input_dict)

However, in my forward call, I don’t get the column “rep” at all! I also don’t see “repeated” or “fixed” columns. I suspect there is a flaw in my understanding of the trajectory view API and I would appreciate any insights into what is going wrong here and what would be the best way for me to do this.

Thanks in advance!