- High: It blocks me to complete my task.
Hi,
I’m currently working on a custom neural network model for reinforcement learning using Ray RLlib with PyTorch. I’m facing an issue where my torch_dqn_model
is not receiving the configuration parameters correctly, and I’m seeking assistance in troubleshooting this problem.
This is my neural network
from ray.rllib.utils.typing import ModelConfigDict
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from nn.models.structure2vec_model import Structure2VecModel as EmbedModel
from nn.models.action_eval_model import ActionEvalModel as EvalModel
import torch.nn.functional as F
from ray.rllib.utils.typing import TensorStructType
class MySequential(TorchModelV2):
def __init__(self,
# *args,
# num_nodes: int = 4,
# hidden_dim: int = 32,
# embed_dim: int = 32,
# **kwargs
obs_space, action_space, num_outputs, model_config,name,**kwargs
):
config = model_config["custom_model_config"]
model_config["fcnet_hiddens"] = [config["hidden-layer-size"]]
model_config["embed_dim"] = config["embed_dim"]
TorchModelV2.__init__(self,obs_space=obs_space, action_space=action_space, num_outputs=self.num_outputs, model_config=model_config,name=name)
self.embed_model = EmbedModel(num_layers=2,embed_dim=16,device=0,activation=F.relu,seed=11)#change seed
self.eval_model = EvalModel(embed_dim=16,hidden_layer_size=8,device=0,activation=F.relu,seed=11)
def forward(self, input_dict, state, seq_lens): #inputs - 1*1 ,3*3
self.num_outputs = 4
input = self.convertInput(input_dict['obs'])
embed = self.embed_model(input[1],input[0])
eval = self.eval_model(embed)
print(Fore.BLACK,eval)
return eval,state # is used as state
My config is
run_config = (
run_config
.training(lr=args.lr,
train_batch_size=1,
model={“custom_model”: custom_model ,
“custom_model_config”: custom_model_config,
}
)
.evaluation(evaluation_config={“explore”: True},
evaluation_interval=1,
evaluation_duration=100,
)
.debugging(log_level=args.log_level)
.framework(eager_tracing=True)
)