How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hello,
I am implementing a graph convolutional network as a custom model to estimate the Q-values with a DQN. Essentially, my GCN has 3 convolutional layers, and a readout layer followed by a fully connected neural network which estimates the Q-values associated with each action.
However, I am having some trouble outputting the Q-values estimated by the final layer of the network as, apparently RLlib tries to pass it to another FC layer with 256 inputs.
(RolloutWorker pid=89182) File "/home/malin/anaconda3/envs/marl/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
(RolloutWorker pid=89182) return F.linear(input, self.weight, self.bias)
(RolloutWorker pid=89182) RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x5 and 256x256)
Currently, in the simulation there are 2 agents, and this is the algorithm configuration:
ModelCatalog.register_custom_model("GCN_V1_1", HetGCN)
ModelCatalog.register_custom_model("GCN_V1_2", HetGCN)
def gen_policy(i):
config = {
"model": {
"custom_model": ["GCN_V1_1", "GCN_V1_2"][i % 2],
"custom_model_config": {
"params_s" : params_s,
"params_t" : params_t,
},
},
"no_final_linear": True,
"gamma": [0.95, 0.99],
}
return PolicySpec(None, observation_space=obs_space,action_space=act_space,
config=config)
policies = {"policy_{}".format(i): gen_policy(i) for i in range(2)}
policy_ids = list(policies.keys())
env_name = "uam_v1"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator()))
stop = {
"timesteps_total": 2000e3,
}
learning_rate_start = 2e-4
learning_rate_end = 1e-5
n_timesteps = 1000*args.num_episodes
config={
"env": env_name,
"log_level": "DEBUG",
"framework": "torch",
"num_gpus": args.num_gpus,
"seed": seed,
"num_workers": args.num_workers,
"num_envs_per_worker": 1,
"batch_mode": "truncate_episodes",
"_disable_preprocessor_api": True,
"gamma": 0.99,
"rollout_fragment_length": 512,
"train_batch_size": 20,
"lr":learning_rate_start,
"model": {"custom_model": "GCN_V1_1",
"custom_model_config": {
"params_s" : params_s,
"params_t" : params_t,
},
},
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: policy_ids[agent_id]),
},
}
results = tune.run(
"DQN",
name="DQN",
stop=stop,
checkpoint_freq=10,
local_dir="~/ray_results/" + env_name,
config=config
)
Especially, I tried using "no_final_linear": True
but it seems have no effect. I can see q-values are being computed by the GCN, but how can I stop it from being forwarded to another FC network by RLlib? Thanks in advance!