How max_seq_len param impacts custom LSTM implementation

I am creating a custom LSTM model for PPO to handle a time sequence of say 10 observations. Doing that as advised in the docs as per below:

from_ = observation_time_sequence_len - 1
self.view_requirements[SampleBatch.OBS].shift =
“-{}:0”.format(from_)
self.view_requirements[SampleBatch.OBS].shift_from = -from_
self.view_requirements[SampleBatch.OBS].shift_to = 0

Doing that I do not have to add a time dimension in the forward method.

How does the RLLIb max_seq_len param impacts this? I do not quite understand how my observation_time_sequence_len could override or at least work in concert with the max_seq_len.

Seeing some different results changing the max_seq_len value.

I didnt get what max_seq_len does and how can i tune that. hope someone help.

2 Likes

Hey @hossein836 and @rob65 , the max_seq_len in your model’s config dict specifies how many timesteps are being “pushed” through the LSTM as a coherent sequence (with one initial internal state input and one internal state output at the end).
So e.g. setting this to 20 would basically chop up your episodes into chunks of length 20, then keep these chunks in-tact (no shuffling of single timesteps) and feed these chunks to the LSTM (observation input shape [B, 20, ...]). For each of the B chunks, you would need one initial state tensor (either zero-initialized if you are at the beginning of an episode or a previous-chunk internal state output from the LSTM).

# obs-space=Discrete(100)  # assume observations are the timesteps in the episode
# single internal state shape = (1,)  <- keep it simple for this example
# max-seq-len=5

observations = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
# add time dim:
observations = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]
# LSTM states:
internal_states = [[0.0 <- all zeros as episode starts], [x <- output of the LSTM during action computations for timestep 4], [y <- output of LSTM during action computation for timestep 9]]
output, out_states = my_lstm(observations, internal_states)
# out_states.shape = (3, 1)  # 3=B; 1=internal state shape
2 Likes

I got that thanks :ok_hand: