How severe does this issue affect your experience of using Ray?
High: It blocks me to complete my task.
Hello, I am trying to train a CRR model for an environment with a discrete observation_space and a discrete action_space.
There seems to be a mismatch between the input data from my offline dataset (SampleBatch) and the input layer of the CRR model. This is my code :
import gymnasium as gym
from gymnasium.spaces import Discrete
import ray
from ray.rllib.algorithms.crr import CRRConfig
BATCH_SIZE = 8
#There is a matrix multiplication (BATCH_SIZEx1) with (NB_DISCRETE,ACTOR_HIDDENS)
ray.init()
config = CRRConfig()
config = config.offline_data(
input_ = "offline/output-2023-07-13_15-11-26_worker-0_0.json"
)
config = config.environment(
observation_space = Discrete(13),
action_space = Discrete(12),
)
config = config.training(
train_batch_size=BATCH_SIZE,
actor_hiddens=[256, 256],
)
# Create the Algorithm used for Policy serving.
algo = config.build()
for _ in range(10):
algo.train()
ray.shutdown()
and this is the error I am getting :
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x1 and 13x256)
It seems the Discrete(13) is creating an input layer expecting a vector of size 13 (I imagine One Hot encoded) but the data in the SampleBatch does not go through OHE. I am not sure what format should I use is the SampleBatch to make it work. I read about preprocessor but could not make it work either.
Data is SampleBatch :
{"type": "SampleBatch", "obs": [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "new_obs": [[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "actions": [[11]], "rewards": [1.0], "dones": [false], "agent_index": [0], "eps_id": ["P0"], "unroll_id": [0], "weights": [1.0]}
{"type": "SampleBatch", "obs": [[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "new_obs": [[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "actions": [[1]], "rewards": [1.0], "dones": [false], "agent_index": [0], "eps_id": ["P0"], "unroll_id": [1], "weights": [1.0]}
Would appreciate some help on this.