RLPredictor with attention net

  • High: It blocks me to complete my task.

I have this code

 num_transformers = 1
 attention_dim = 64
 memory = 50
 state = [
   np.zeros([memory, attention_dim], np.float32)
   for _ in range(num_transformers)

predictor = RLPredictor.from_checkpoint(Checkpoint.from_directory(checkpoint_path))
action = predictor.predict(obs, state)

I get this error:

TypeError: predict() takes 2 positional arguments but 3 were given

@evo11x predict() is defined in the Predictor class. It takes only a single argument.

With a single argument, only the observation I get an error with invalid seq lens

The observations need to have a further dimension for the sequence length. If you have a sequence length of 4 you need a further dimension and along that you stack 4 observation tensors.

What 4 observations tensors? Do you have an example?


I think what is needed is the time dimension: (BATCH_SIZE, TIME_DIM: SEQ_LEN, OBS_DIM_1, OBS_DIM_2, etc.)

1 Like