Importing LSTM to Matlab

Hi,

I am training an LSTM for RL using Ray in Python. I would like to export this model using afterwards import it in Matlab for which I am using ONNX. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.

Minimum working example:

Python code to train LSTM:

import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig

# Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)

# train for 2 episodes

for i in range(2):
    result = algo.train()

# get policiy
ppo_policy = algo.get_policy()

# batch size
B=1

# initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)

# apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))

# save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)

Matlab code: (I inferred the shapes from the code documentation withitn the generated matlab code from importing the onnx model. I have also tried other shapes)

% Import model from where I saved it

net = importNetworkFromONNX('path/to/onnx-model');

B=1;
input_size = [B,4];
state_size = [B,256];
seq_lens_size = [B,256];

input = dlarray(randn(input_size), 'CB');
state = dlarray(randn(state_size), 'CB');
seq_lens = dlarray(randn(seq_lens_size), 'CB');

initialize(net, input, state, seq_lens);

Error message:

Any help welcome!

Best,
Andreas