Broadcast Issue when using customenv with custom policy

Hi I’m new to Ray and RLib,

I’m trying to prototype a workflow where I use PPO to train a transformer based embeddings model. I keep running into the same issue, where the build method complains about some output not being broadcastable. Everything I’ve tried to resole it doesn’t seem to work.

Note that the mismatch always seems to be half of the actual model dimension. For instance, if I use a model with embedding dim of size 768, it then complains about mismatch between torch.Size([32, 768]) and torch.Size([32, 384])

Here’s the stack trace

    algo = config.build()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm_config.py", line 1071, in build
    return algo_class(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm.py", line 475, in __init__
    super().__init__(
  File "/usr/local/lib/python3.8/dist-packages/ray/tune/trainable/trainable.py", line 170, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/algorithm.py", line 601, in setup
    self.workers = WorkerSet(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 172, in __init__
    self._setup(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 262, in _setup
    self._local_worker = self._make_worker(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/worker_set.py", line 967, in _make_worker
    worker = cls(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 738, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
    self._build_policy_map(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/utils/policy.py", line 142, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/policy/policy.py", line 1489, in _initialize_loss_from_dummy_batch
    self.loss(self.model, self.dist_class, train_batch)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 112, in loss
    curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/models/torch/torch_action_dist.py", line 268, in logp
    return super().logp(actions).sum(-1)
  File "/usr/local/lib/python3.8/dist-packages/ray/rllib/models/torch/torch_action_dist.py", line 37, in logp
    return self.dist.log_prob(actions)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributions/normal.py", line 79, in log_prob
    self._validate_sample(value)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py", line 288, in _validate_sample
    raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 384]) vs torch.Size([32, 192]).

and here’s the dummy prototype code that’s producing it.

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from scipy.spatial.distance import cosine

import ray
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import get_trainable_cls
from ray.rllib.utils.annotations import override
import os
from ray.rllib.utils.framework import try_import_torch
import pdb
torch, nn = try_import_torch()


embedding_model_name  = "sentence-transformers/all-MiniLM-L6-v2"
embedding_dim = 384


class CustomEnv(gym.Env):
    def __init__(self, env_config):
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        self.model = AutoModel.from_pretrained(embedding_model_name)
        self.vocab_size = self.tokenizer.vocab_size
        self.max_length = 512
        self.observation_space = spaces.Box(low=0, high=self.vocab_size-1, shape=(self.max_length,), dtype=np.int32)
        self.action_space = spaces.Box(low=-1, high=1, shape=(embedding_dim,), dtype=np.float32) # assuming your model produces embeddings of size 768
        self.index = 0

        self.setup_batch_and_databse()
        self.reset()


    def setup_batch_and_databse(self):
        # Dummy methods
        batch = ["Hello world", "How are you doing today?", "I am doing well", "I am doing terrible", "I am doing great", "I am a fish", "I am a man", "How now brown cow"]
        dummy_embeddings_database = torch.randn((len(batch), embedding_dim))
        self.batch = batch
        self.embeddings_database = dummy_embeddings_database
            

    def reset(self,*, seed=None, options=None):
        if self.index >= len(self.batch):
            self.index = 0  # reset index if all strings have been used
        curren_string = self.batch[self.index]
        tokenized_string = self.tokenizer(curren_string, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")["input_ids"]
        self.index += 1
        return (tokenized_string.numpy().reshape(-1), {})

    def step(self, action):
        # action is the embedding vector produced by your policy
        top_doc_embedding = self.search_database(action)  # replace with logic to search your database using the action embedding
        reward = self.calculate_reward(action, top_doc_embedding)  # replace with logic to compute reward
        done = truncated = True
        dummy_observation = np.ones(self.max_length, dtype=np.int).reshape(-1)
        return dummy_observation, reward, done, truncated, {}

    def search_database(self, action):
        # Find the closest embedding in the database to the action embedding
        # let's just do a dot product  with the database for now
        action = torch.tensor(action)
        logits = torch.matmul(action, torch.tensor(self.embeddings_database).T)
        top_doc_idx = torch.argmax(logits)
        top_doc_embedding = self.embeddings_database[top_doc_idx]
        return top_doc_embedding


    def calculate_reward(self, action, top_doc_embedding):
        #dummy reward function
        reward = -1 * cosine(action, top_doc_embedding)
        return reward


class TransformerPolicy(TorchModelV2,nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)  # you should also initialize the nn.Module
        self.transformer = AutoModel.from_pretrained(embedding_model_name)
        self.value_layer = nn.Linear(embedding_dim, 1)

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].long()
        print(f">>>>>> Shape of obs is {obs.shape}")
        transformer_outputs = self.transformer(obs)
        action_out = transformer_outputs.pooler_output
        self._value_out = self.value_layer(action_out)

        print(f">>>>>> Shape of action_out is {action_out.shape}")
        print(f">>>>>> Shape of value_out is {self._value_out.shape}")      
        # assert all elements in action_out are in the range [-1,1]
        assert torch.all(action_out <= 1) and torch.all(action_out >= -1)  
        return action_out, []
    
    def value_function(self):
        values = self._value_out.squeeze(1)
        return values
        

if __name__ == "__main__":
    ray.init()

    ModelCatalog.register_custom_model("transformer_policy", TransformerPolicy)

    config = (
        get_trainable_cls("PPO")
        .get_default_config()
        .environment(CustomEnv, env_config={"foo": 5})
        .framework("torch")
        .rollouts(num_rollout_workers=1)
        .training(
            model={
                "custom_model": "transformer_policy",
            }
        )
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
        .rollouts(num_rollout_workers=0)
    )
    algo = config.build()

    ray.shutdown()

Hi @badwolf,

You are using a PPO algorithm on continuous space. If the model outputs a vector of size d, half of it is used as the logits for the mean and the other half is used as log of std of the Normal distribution over your actions. That’s why the distribution dimensions are actually 192 dimensional but your action space is 384 dimensional, and hence errors out on the mismatch of the shape. To fix this, your custom model should output a 2xaction_dim logit.

1 Like

Thanks, that was it!