Custom EncoderDecoder Model yelds AssertionError in policy initialisation

Hello,
Im triyng to implement Encoder Decoder with MHA in my custom model within forward() and value() function. Now, when I try to test it and do sth like trainer = algo.build(), I receive an error from postprocess function

“ray\rllib\evaluation\postprocessing.py”, line 195, in compute_gae_for_sample_batch [repeated 9x across cluster]
(PPO pid=3512) assert vf_preds.shape == rewards.shape [repeated 9x across cluster]
(PPO pid=3512) AssertionError [repeated 9x across cluster]"

this is strange, because in the first run, when initializing the model, all tensor dimensions match and I do not put any reward to the model manually, so the error occurs only after the forward() call, at the very end during the policy initialisation. What does it check and why the error?
Below is the code, and thank you in advance!

 ray.init()
    env = GraphEnvironment()
    register_env("env_graph", lambda config: GraphEnvironment())
    ModelCatalog.register_custom_model("EncoderDecoder", my_model)
 algo = (
        PPOConfig()
        .rollouts(num_rollout_workers=0,batch_mode="complete_episodes",num_envs_per_worker=1)
        .framework("torch")
        .resources(num_gpus=0)
        .environment(env="env_graph")
        .training(train_batch_size=10,sgd_minibatch_size =1,
            model={
                "custom_model": "EncoderDecoder",
                "_disable_preprocessor_api": False,
                "_disable_action_flattening":True,
    }
        )
        .experimental(
            _enable_new_api_stack=False,
        )
    )
    trainer = algo.build()

And the model looks like this:

class my_model(TorchModelV2, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name):
   
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        self.conf = model_config["custom_model_config"]
        conf = self.conf

        self.state_buffer = []
        self.num_of_nodes = 11
        node_feature_count = 3
        hidden_size = 128
        context_size = 129
        self.batch_size = 10
        self.init_state = self.get_initial_state()

        self.lin_in = nn.Linear(node_feature_count, hidden_size)  # .to(self.device)
        self.graph_emb = embed_netsEmbedding(hidden_size) 
        self.graph_enc = attent_nets.GraphAttentionEncoder()
        self.graph_dec = decoder_.Decoder()
        self.value_out = MLP(context_size, 1)

       
    @override(TorchModelV2)
    def forward(self,
                input_dict,
                state,
                seq_lens):
        
        embeddings = self.graph_emb(nodes)

        #select the embedings based on previous node indexes has a shape of [32,1,128] for each batch selects one node embedding

        last_node_ = last_node.clone().detach().requires_grad_(False).type(dtype=torch.int64)
        last_node_expanded = last_node_.unsqueeze(-1).expand(-1, -1, 128)
        selected_embedings = torch.gather(embeddings, 1,last_node_expanded) #concantenate on last dimension
        state_tensor = state[-1][:,None,None]
        current_context = torch.cat((selected_embedings, state_tensor), dim=-1)
        encoded_in, _  = self.graph_enc(embeddings)
        cached_embeddings = self.graph_dec.precompute(encoded_in)
        log_p, mask = self.graph_dec.decode(cached_embeddings, last_node, node_mask)
           
        # compute value. 
        self._value = self.value_out(current_context)

        return log_p, state # previous state



    @override(TorchModelV2)
    def value_function(self):
        assert self._value is not None, "Must call `forward()` first!"
        value = self._value.squeeze(-1)  # B x T x h_dim --mlp--> B x T X 1
        return value

after doing torch.reshape(self._value, [-1]) in the value_function(), based on this example:
fast_model_example

the previous error could be mitigated,

but now another RuntimeError appears:

ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)

ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()

\ray\rllib\policy\policy.py", line 1518, in _initialize_loss_from_dummy_batch
    self.loss(self.model, self.dist_class, train_batch)

ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
    curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])

RuntimeError: The size of tensor a (10) must match the size of tensor b (11) at non-singleton dimension 1

Apparently, the action shape (tensor a (10) and tensor b (11)) during the policy initialisation doesn’t match in the train_batch here
logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP] )
`

I assume, that the ‘actions’ and ‘prev_actions’ in the input_dict to the custom model cause the error in the policy, although the output logits from the forward() produce correct shapes, namely

def forward(self, input_dict, state, seq_lens):
...... my code for the model here ......
       return logits, state

logits shape:  torch.Size([10, 1, 11])
action mask shap:   torch.Size([10, 11])
prev_actions: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) and shape: torch.Size([10])

My question is, why and what does rllib put into the inpu_dict to the forward() and why this dict difers in each iteration. Sometimes I have input_dict containing
input_dict. SampleBatch(8 (seqs=4): [‘obs’, ‘new_obs’, ‘actions’, ‘prev_actions’, ‘rewards’, ‘prev_rewards’, ‘terminateds’, ‘truncateds’, ‘infos’, ‘eps_id’, ‘unroll_id’, ‘agent_index’, ‘t’, ‘state_in_0’, ‘state_out_0’, ‘vf_preds’, ‘action_dist_inputs’, ‘action_prob’, ‘action_logp’, ‘values_bootstrapped’, ‘advantages’, ‘value_targets’, ‘obs_flat’])

and sometimes it yelds only
input_dict: SampleBatch(1 (seqs=1): [‘obs’, ‘prev_actions’, ‘state_in_0’, ‘obs_flat’])
where previous action is of the shape 10…But my environmet action space is
self.action_space = Discrete(11)

UPDATE:

after disabling manual batch_size

config = {
 "train_batch_size": 20,#
       "sgd_minibatch_size" : 1

the error could be fixed…But it is still somehow unclear, why the manual definition of the batch_size is so tricky…