Issue with Custom PyTorch Model in Ray RLlib

  • High: It blocks me to complete my task.
    Hi,

I’m currently working on a custom neural network model for reinforcement learning using Ray RLlib with PyTorch. I’m facing an issue where my torch_dqn_model is not receiving the configuration parameters correctly, and I’m seeking assistance in troubleshooting this problem.

This is my neural network

from ray.rllib.utils.typing import ModelConfigDict
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from nn.models.structure2vec_model import Structure2VecModel as EmbedModel
from nn.models.action_eval_model import ActionEvalModel as EvalModel
import torch.nn.functional as F
from ray.rllib.utils.typing import TensorStructType

class MySequential(TorchModelV2):

def __init__(self, 
    # *args,
    # num_nodes: int = 4,
    # hidden_dim: int = 32,
    # embed_dim: int = 32,
    # **kwargs
     obs_space, action_space, num_outputs, model_config,name,**kwargs
    ):
     config = model_config["custom_model_config"]
     model_config["fcnet_hiddens"] = [config["hidden-layer-size"]]
     model_config["embed_dim"] = config["embed_dim"]
     TorchModelV2.__init__(self,obs_space=obs_space, action_space=action_space, num_outputs=self.num_outputs, model_config=model_config,name=name)
     self.embed_model = EmbedModel(num_layers=2,embed_dim=16,device=0,activation=F.relu,seed=11)#change seed
     self.eval_model = EvalModel(embed_dim=16,hidden_layer_size=8,device=0,activation=F.relu,seed=11)


def forward(self, input_dict, state, seq_lens): #inputs - 1*1 ,3*3
    self.num_outputs = 4
    input = self.convertInput(input_dict['obs'])
    embed = self.embed_model(input[1],input[0])
    eval = self.eval_model(embed)
    print(Fore.BLACK,eval)
    return eval,state # is used as state 

My config is

run_config = (
run_config
.training(lr=args.lr,
train_batch_size=1,
model={“custom_model”: custom_model ,
“custom_model_config”: custom_model_config,
}
)
.evaluation(evaluation_config={“explore”: True},
evaluation_interval=1,
evaluation_duration=100,
)
.debugging(log_level=args.log_level)
.framework(eager_tracing=True)
)