The context for this post is the same as this one: https://discuss.ray.io/t/registering-custom-env-that-passes-an-argument-in-1-13-0/8783. The code has been modified to look like the following:
def env_creator(env_config):
return ind_set(env_config)
tune.register_env("myenv", env_creator)
G = nx.dodecahedral_graph()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-workers", default=6, type=int)
parser.add_argument("--training-iteration", default=100, type=int) # change depending on graph size
parser.add_argument("--ray-num-cpus", default=7, type=int)
args = parser.parse_args()
ray.init(num_cpus=args.ray_num_cpus)
ModelCatalog.register_custom_model("dense_model", DenseModel)
tune.run(
"contrib/AlphaZero",
stop={"training_iteration": args.training_iteration},
max_failures=0,
config={
"env": "myenv",
"env_config": {"graph": G},
"num_workers": args.num_workers,
"rollout_fragment_length": 10,
"train_batch_size": 50,
"sgd_minibatch_size": 8,
"lr": 1e-4,
"num_sgd_iter": 1,
"mcts_config": {
"puct_coefficient": 1.5,
"num_simulations": 5,
"temperature": 1.0,
"dirichlet_epsilon": 0.20,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},
"ranked_rewards": {
"enable": True,
},
"model": {
"custom_model": "dense_model",
},
},
)
Now, I get the error
(RolloutWorker pid=8796) File "/home/IncompleteOmega/.local/lib/python3.8/site-packages/networkx/utils/decorators.py", line 86, in _not_implemented_for
(RolloutWorker pid=8796) dval is None or dval == g.is_directed()
(RolloutWorker pid=8796) AttributeError: 'EnvContext' object has no attribute 'is_directed'
Note that I am using the networkx library (imported as nx). I suspect this is an issue with the way I have written the code and not the networkx library because of two reasons:
- My Networkx Module is updated to the newest version
- When I separately run
G = nx.dodecahedral_graph()
and then askG.is_directed()
, the result is the expectedFalse
ā so this is functional in the correct way.
It would be great if someone could let me know how I can modify my code to get around this.