- High: It blocks me to complete my task.
I am trying to use trajectory view for my model which is supposed to have an LSTM and a fully connected network in parallel. This is the workflow I have in mind:
recurrent obs = obs[“repeated”][-10:0] → LSTM → Fully
Connected → Output
non-recurrent = obs[“fixed”] ---------------> Network
I am thinking my best bet would be to implement a custom RNN.
So this is how I set up my obs space:
obs_space = spaces.Dict({
"repeated": spaces.Box(low=np.array(([0] * 3 + [0, 0, 0, 0, 0, -1, 0, 0, 0, -np.inf])),
high=np.array(([np.inf] * 8 + [1] + [np.inf] * 4)),
dtype=np.float64),
"fixed": spaces.Box(low=
np.array([0, 0, -1, 0]),
high=np.array([1, 1, np.inf, np.inf]),
dtype=np.float64)
})
I then have this in my Policy.postprocess_trajectory function:
class new_policy(A3CTorchPolicy):
def postprocess_trajectory(self, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[Episode] = None):
sample_batch = super(new_policy, self).postprocess_trajectory(sample_batch,
other_agent_batches,
episode)
print("sample", sample_batch["obs"])
sample_batch["fixed"] = np.array(sample_batch["obs"][:4])
sample_batch["repeated"] = np.array(sample_batch["obs"][4:])
return sample_batch
Here I use [“obs”][:4] instead of [“obs”][“fixed”] because obs seems to be a flattened numpy array at this point, and not a dict as I would have expected, but that’s fine.
Then, I set up my view requirement in my model
class my_model(TorchModelV2, torch.nn.Module):
def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space,
num_outputs: int, model_config: ModelConfigDict, name: str):
self.view_requirements["rep"] = ViewRequirement(data_col="repeated", shift="-10:0")
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
print("Got input", input_dict["obs"], input_dict["obs_flat"], input_dict)
However, in my forward call, I don’t get the column “rep” at all! I also don’t see “repeated” or “fixed” columns. I suspect there is a flaw in my understanding of the trajectory view API and I would appreciate any insights into what is going wrong here and what would be the best way for me to do this.
Thanks in advance!