Custom pytorch policy network

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi all,
I am trying to implement a PyTorch custom network, but I get the following error.

2022-06-29 16:54:45,575	INFO services.py:1250 -- View the Ray dashboard at http://127.0.0.1:8265
:actor_name:RolloutWorker
/home/rdbt/ETHZ/dbt_python/housing_design_making-general-env/agents_floorplan/simple_rllib_agent.py:36: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.float)
/home/rdbt/ETHZ/dbt_python/housing_design_making-general-env/agents_floorplan/simple_rllib_agent.py:36: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.float)
[2022-06-29 16:54:47,098 E 64123 64123] core_worker.cc:1561: Pushed Error with JobID: 01000000 of type: task with message: ray::RolloutWorker.__init__() (pid=64123, ip=129.132.204.244, repr=<ray.rllib.evaluation.rollout_worker.modify_class.<locals>.Class object at 0x7fae8447deb0>)
TypeError: get_distribution_inputs_and_class() missing 1 required positional argument: 'obs_batch'

During handling of the above exception, another exception occurred:

ray::RolloutWorker.__init__() (pid=64123, ip=129.132.204.244, repr=<ray.rllib.evaluation.rollout_worker.modify_class.<locals>.Class object at 0x7fae8447deb0>)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 583, in __init__
    self._build_policy_map(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1382, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
    self[policy_id] = class_(observation_space, action_space,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 279, in __init__
    self._initialize_loss_from_dummy_batch(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 751, in _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 299, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 352, in _compute_action_helper
    self.action_distribution_fn(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py", line 212, in get_distribution_inputs_and_class
    q_vals = compute_q_values(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py", line 344, in compute_q_values
    model_out, state = model(input_dict, state_batches or [], seq_lens)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/rdbt/ETHZ/dbt_python/housing_design_making-general-env/agents_floorplan/simple_rllib_agent.py", line 137, in forward
    network_output = self.network(obs_transformed)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: expected scalar type Byte but found Float at time: 1.65651e+09
2022-06-29 16:54:47,100	ERROR actor.py:750 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=64123, ip=129.132.204.244)
TypeError: get_distribution_inputs_and_class() missing 1 required positional argument: 'obs_batch'

During handling of the above exception, another exception occurred:

ray::RolloutWorker.__init__() (pid=64123, ip=129.132.204.244)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 583, in __init__
    self._build_policy_map(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1382, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
    self[policy_id] = class_(observation_space, action_space,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 279, in __init__
    self._initialize_loss_from_dummy_batch(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 751, in _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 299, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 352, in _compute_action_helper
    self.action_distribution_fn(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py", line 212, in get_distribution_inputs_and_class
    q_vals = compute_q_values(
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py", line 344, in compute_q_values
    model_out, state = model(input_dict, state_batches or [], seq_lens)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/rdbt/ETHZ/dbt_python/housing_design_making-general-env/agents_floorplan/simple_rllib_agent.py", line 137, in forward
    network_output = self.network(obs_transformed)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/rdbt/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: expected scalar type Byte but found Float
[2022-06-29 16:54:47,105 E 64123 64123] core_worker.cc:1561: Pushed Error with JobID: 01000000 of type: task with message: ray::RolloutWorker.foreach_env()::Exiting (pid=64123, ip=129.132.204.244, repr=<ray.rllib.evaluation.rollout_worker.modify_class.<locals>.Class object at 0x7fae8447deb0>)
  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1009, in foreach_env
    if self.async_env is None:
AttributeError: 'RolloutWorker' object has no attribute 'async_env' at time: 1.65651e+09
Traceback (most recent call last):

  File "/home/rdbt/ETHZ/dbt_python/housing_design_making-general-env/agents_floorplan/simple_rllib_agent.py", line 188, in <module>
    trainer = dqn.DQNTrainer(config=config, env=env_name)  # this is for only train

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 137, in __init__
    Trainer.__init__(self, config, env, logger_creator)

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 611, in __init__
    super().__init__(config, logger_creator)

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/tune/trainable.py", line 106, in __init__
    self.setup(copy.deepcopy(self.config))

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 147, in setup
    super().setup(config)

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 764, in setup
    self._init(self.config, self.env_creator)

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 171, in _init
    self.workers = self._make_workers(

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 846, in _make_workers
    return WorkerSet(

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 103, in __init__
    self._local_worker = self._make_worker(

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 399, in _make_worker
    worker = cls(

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 538, in __init__
    policy_dict = _determine_spaces_for_multi_agent_dict(

  File "/home/rdbt/anaconda3/envs/rlb/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1467, in _determine_spaces_for_multi_agent_dict
    raise ValueError(

ValueError: `observation_space` not provided in PolicySpec for default_policy and env does not have an observation space OR no spaces received from other workers' env(s) OR no `observation_space` specified in config!

I provided a simple runnable code below. I wonder if anyone could figure out what the issue is?

# %% Imports
import os

import gym
import numpy as np

import ray
from ray import tune
from ray.rllib.agents import ppo, dqn
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.tune.logger import pretty_print

from abc import ABC
from torch import nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog


# %% Env class
class SimpleEnv(gym.Env):
    def __init__(self, env_config={'env_name': 'simple_cnn_env'}):
        self.env_name = env_config['env_name']
        self.action_space = gym.spaces.Discrete(2)
        observation_space_fc = gym.spaces.Box(0.0*np.ones(2, dtype=float),
                                              1.0*np.ones(2, dtype=float),
                                              shape=(2,), 
                                              dtype=np.float)
        
        observation_space_cnn = gym.spaces.Box(low=0, 
                                               high=255,
                                               shape=(21, 21, 3),
                                               dtype=np.uint8)
        if self.env_name == 'simple_fc_env':
            self.observation_space = observation_space_fc
        elif self.env_name == 'simple_cnn_env':
            self.observation_space = observation_space_cnn
        elif self.env_name == 'simple_fccnn_env':
            self.observation_space = gym.spaces.Tuple((observation_space_fc, 
                                                observation_space_cnn))
        self.initial_observation = self.reset()
        

    def reset(self):
        if self.env_name == 'simple_fc_env':
            observation = np.array([0, 1])
        elif self.env_name == 'simple_cnn_env':
            observation = np.zeros((21, 21, 3), dtype=np.uint8)
        elif self.env_name == 'simple_fccnn_env':
            observation = (np.array([0, 1]), np.zeros((21, 21, 3), dtype=np.uint8))
        self.timestep = 0
        return observation


    def _get_action(self, obs):
        return np.random.randint(self.action_space.n)
    
    
    def _take_action(self, action):
        next_observation = self.initial_observation
        done = False if self.timestep <=3 else True
        reward = 1 if done else 0
        return next_observation, reward, done
        

    def step(self, action):
        self.timestep += 1
        observation, reward, done = self._take_action(action)
        return observation, reward, done, {}        


    def seed(self, seed: int = None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]
    
 # %% Env registeration   
def _register_env(env_name):
    if 'simple' in env_name:
        # from simple_env import SimpleEnv
        
        env_config = {'env_name': env_name}
        env = SimpleEnv(env_config)
        
        ray.tune.register_env(env_name, lambda config: SimpleEnv(env_config))
        
    elif env_name == 'master_env':
        raise NotImplementedError


# %% Network
class MySimpleCnn(TorchModelV2, nn.Module, ABC):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self._in_channels = obs_space.shape[2]
        self._num_actions = num_outputs

        self.network = nn.Sequential(
            nn.Conv2d(self._in_channels, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 256, 11, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256*11*11, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU())

        self._actor_head = nn.Sequential(
            nn.Linear(in_features=256, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=self._num_actions))

        self._critic_head = nn.Sequential(
            nn.Linear(in_features=256, out_features=1))

    def forward(self, input_dict, state, seq_lens):
        obs_transformed = input_dict['obs'].permute(0, 3, 1, 2)
        network_output = self.network(obs_transformed)
        value = self._critic_head(network_output)
        self._value = value.reshape(-1)
        logits = self._actor_head(network_output)
        return logits, state

    def value_function(self):
        return self._value


# %% Main
if __name__ == "__main__":
    ray.init(local_mode=True)
    
    env_name = 'simple_cnn_env' # master_env, simple_fc_env simple_cnn_env simple_fccnn_env
    _register_env(env_name)
    
    agent_name = 'dqn'
    learner_name = 'trainer' # trainer tunner 
    custom_model_flag = True
    
    _config = {
            "env": env_name,  # or "corridor" if registered above
            "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
            "num_workers": 1,  # parallelism
            "framework": 'torch',
            "lr":1e-3,
            "model": {"dim": 21,
                      "conv_filters": [[16, [3, 3], 1],
                                      [32, [5, 5], 2],
                                      [512, [11, 11], 1]],
                }
        }
    
    if custom_model_flag:
        ModelCatalog.register_custom_model("MySimpleCnn", MySimpleCnn)
        _config.update({
            "model": {
                "custom_model": "MySimpleCnn",
                "vf_share_layers": True,
            },
        })
    
    if agent_name == 'ppo':
        config = ppo.DEFAULT_CONFIG.copy()
        config.update(_config)
        trainer = ppo.PPOTrainer(config=config, env=env_name) # this is for only train
        agent = PPOTrainer # this is for only tune
    elif agent_name == 'dqn':
        config = dqn.DEFAULT_CONFIG.copy()
        config.update(_config)
        trainer = dqn.DQNTrainer(config=config, env=env_name)  # this is for only train
        agent = DQNTrainer  # this is for only tune
    
    
    ## training/tunning
    stop = {"training_iteration": 2,
            "timesteps_total": 1000,
            "episode_reward_mean": 6}
    
    simple_storage_dir = "storage/simples/simple"
    
    if learner_name == 'trainer':
        for _ in range(stop['training_iteration']):
            result = trainer.train()
            print(pretty_print(result))
            if result["timesteps_total"] >= stop['timesteps_total'] or \
                    result["episode_reward_mean"] >= stop['episode_reward_mean']:
                break
            
    elif learner_name == 'tunner':
        results = tune.run(agent, 
                           local_dir=simple_storage_dir,
                           config=config, 
                           stop=stop,
                           checkpoint_freq=10,
                           checkpoint_at_end=True,
                           )

    ray.shutdown()

I am aware that the Rllib has a vision model than can be used in my code. But the reason I am trying to use my own cnn is that later I am going to build up a more complex model.

Thanks!

Thanks to Reinforcement Learning with RLLib — Griddly 1.3.9 documentation, I finally fixed the bug. Here is the clean and working code in case you are interested:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 23 13:32:45 2021

@author: Reza Kakooee
"""

# %% Imports
import os

import gym
import numpy as np

import ray
from ray import tune
from ray.rllib.agents import ppo, dqn
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.tune.logger import pretty_print

from abc import ABC
from torch import nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog


# %% Env class
class SimpleEnv(gym.Env):
    def __init__(self, env_config={'env_name': 'simple_cnn_env'}):
        self.env_name = env_config['env_name']
        self.action_space = gym.spaces.Discrete(2)
        observation_space_fc = gym.spaces.Box(0.0*np.ones(2, dtype=float),
                                              1.0*np.ones(2, dtype=float),
                                              shape=(2,), 
                                              dtype=np.float)
        
        observation_space_cnn = gym.spaces.Box(low=0, 
                                               high=255,
                                               shape=(21, 21, 3),
                                               dtype=np.uint8)
        if self.env_name == 'simple_fc_env':
            self.observation_space = observation_space_fc
        elif self.env_name == 'simple_cnn_env':
            self.observation_space = observation_space_cnn
        elif self.env_name == 'simple_fccnn_env':
            self.observation_space = gym.spaces.Tuple((observation_space_fc, 
                                                observation_space_cnn))
        self.initial_observation = self.reset()
        

    def reset(self):
        if self.env_name == 'simple_fc_env':
            observation = np.array([0, 1])
        elif self.env_name == 'simple_cnn_env':
            observation = np.zeros((21, 21, 3), dtype=np.uint8)
        elif self.env_name == 'simple_fccnn_env':
            observation = (np.array([0, 1]), np.zeros((21, 21, 3), dtype=np.uint8))
        self.timestep = 0
        return observation


    def _get_action(self, obs):
        return np.random.randint(self.action_space.n)
    
    
    def _take_action(self, action):
        next_observation = self.initial_observation
        done = False if self.timestep <=3 else True
        reward = 1 if done else 0
        return next_observation, reward, done
        

    def step(self, action):
        self.timestep += 1
        observation, reward, done = self._take_action(action)
        return observation, reward, done, {}        


    def seed(self, seed: int = None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]
    
 # %% Env registeration   
def _register_env(env_name):
    if 'simple' in env_name:
        # from simple_env import SimpleEnv
        
        env_config = {'env_name': env_name}
        env = SimpleEnv(env_config)
        
        ray.tune.register_env(env_name, lambda config: SimpleEnv(env_config))
        
    elif env_name == 'master_env':
        raise NotImplementedError


# %% Network
class SimpleConv(TorchModelV2, nn.Module, ABC):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self._in_channels = obs_space.shape[2]
        self._num_actions = num_outputs#action_space.n

        self.network = nn.Sequential(
            nn.Conv2d(self._in_channels, 16, 3),#, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 5),#, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 256, 11),#, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256*5*5, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU())

        self._actor_head = nn.Sequential(
            nn.Linear(in_features=256, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=self._num_actions))

        self._critic_head = nn.Sequential(
            nn.Linear(in_features=256, out_features=1))

    def forward(self, input_dict, state, seq_lens):
        obs_transformed = input_dict['obs'].permute(0, 3, 1, 2)
        network_output = self.network(obs_transformed.float())
        value = self._critic_head(network_output)
        self._value = value.reshape(-1)
        logits = self._actor_head(network_output)
        return logits, state

    def value_function(self):
        return self._value


# %% Main
if __name__ == "__main__":
    ray.init(local_mode=True)
    
    env_name = 'simple_cnn_env' # master_env, simple_fc_env simple_cnn_env simple_fccnn_env
    _register_env(env_name)
    
    agent_name = 'dqn'
    learner_name = 'trainer' # trainer tunner 
    custom_model_flag = True
    
    _config = {
            "env": env_name,  # or "corridor" if registered above
            "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
            "num_workers": 1,  # parallelism
            "framework": 'torch',
            "lr":1e-3,
            "model": {"dim": 21,
                      "conv_filters": [[16, [3, 3], 1],
                                      [32, [5, 5], 2],
                                      [512, [11, 11], 1]],
                }
        }
    
    if custom_model_flag:
        ModelCatalog.register_custom_model("SimpleConv", SimpleConv)
        _config.update({
            "model": {
                "custom_model": "SimpleConv",
                "vf_share_layers": True,
            },
        })
    
    if agent_name == 'ppo':
        config = ppo.DEFAULT_CONFIG.copy()
        config.update(_config)
        trainer = ppo.PPOTrainer(config=config, env=env_name) # this is for only train
        agent = PPOTrainer # this is for only tune
    elif agent_name == 'dqn':
        config = dqn.DEFAULT_CONFIG.copy()
        config.update(_config)
        trainer = dqn.DQNTrainer(config=config, env=env_name)  # this is for only train
        agent = DQNTrainer  # this is for only tune
    
    
    ## training/tunning
    stop = {"training_iteration": 2,
            "timesteps_total": 1000,
            "episode_reward_mean": 6}
    
    simple_storage_dir = "storage/simples/simple"
    
    if learner_name == 'trainer':
        for _ in range(stop['training_iteration']):
            result = trainer.train()
            print(pretty_print(result))
            if result["timesteps_total"] >= stop['timesteps_total'] or \
                    result["episode_reward_mean"] >= stop['episode_reward_mean']:
                break
            
    elif learner_name == 'tunner':
        results = tune.run(agent, 
                           local_dir=simple_storage_dir,
                           config=config, 
                           stop=stop,
                           checkpoint_freq=10,
                           checkpoint_at_end=True,
                           )

    ray.shutdown()
1 Like

@deepgravity can you shed some light on what your bug was?

Sure, i casted the obs_transformed datatype to float, and adapted the nn dimension to match my obs dimension.