How does the rollout worker pass the trainbatch to the loss function?

I’m implementing a centralized critic method in my environment. I collect other agents’ obs in postprocessing fn and concatenate them as the “neighbor_obs” in trainbatch with a size (?,144). But when pass it to the loss_fn, the trainbatch[“neighbor”] has a size of (?, ) and cause the following error:

type or paste code here
(pid=3113684) ray::RolloutWorker.__init__() (pid=3113684, ip=10.170.25.1)
(pid=3113684)   File "python/ray/_raylet.pyx", line 490, in ray._raylet.execute_task
(pid=3113684)   File "python/ray/_raylet.pyx", line 497, in ray._raylet.execute_task
(pid=3113684)   File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
(pid=3113684)   File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=3113684)     return method(__ray_actor, *args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 534, in __init__
(pid=3113684)     self._build_policy_map(policy_dict, policy_config)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1193, in _build_policy_map
(pid=3113684)     policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 237, in __init__
(pid=3113684)     DynamicTFPolicy.__init__(
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 351, in __init__
(pid=3113684)     self._initialize_loss_from_dummy_batch(
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 516, in _initialize_loss_from_dummy_batch
(pid=3113684)     loss = self._do_loss_init(train_batch)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 598, in _do_loss_init
(pid=3113684)     loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
(pid=3113684)   File "/home/cy/rllibsumoutils-master/example/centralized_critic_tsc.py", line 169, in loss_with_central_critic
(pid=3113684)     policy._central_value_out = model.value_function()
(pid=3113684)   File "/home/cy/rllibsumoutils-master/example/centralized_critic_tsc.py", line 166, in <lambda>
(pid=3113684)     model.value_function = lambda: policy.model.central_value_function(
(pid=3113684)   File "/home/cy/rllibsumoutils-master/example/centralized_critic_tsc.py", line 80, in central_value_function
(pid=3113684)     self.central_vf([obs, other_obs]), [-1])
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 786, in __call__
(pid=3113684)     outputs = call_fn(cast_inputs, *args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 424, in call
(pid=3113684)     return self._run_internal_graph(
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 560, in _run_internal_graph
(pid=3113684)     outputs = node.layer(*args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 786, in __call__
(pid=3113684)     outputs = call_fn(cast_inputs, *args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/layers/merge.py", line 183, in call
(pid=3113684)     return self._merge_function(inputs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/layers/merge.py", line 522, in _merge_function
(pid=3113684)     return K.concatenate(inputs, axis=self.axis)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
(pid=3113684)     return target(*args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/keras/backend.py", line 2989, in concatenate
(pid=3113684)     return array_ops.concat([to_dense(x) for x in tensors], axis)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
(pid=3113684)     return target(*args, **kwargs)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 1677, in concat
(pid=3113684)     return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1207, in concat_v2
(pid=3113684)     _, _, _op, _outputs = _op_def_library._apply_op_helper(
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 748, in _apply_op_helper
(pid=3113684)     op = g._create_op_internal(op_type_name, inputs, dtypes=None,
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3528, in _create_op_internal
(pid=3113684)     ret = Operation(
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2015, in __init__
(pid=3113684)     self._c_op = _create_c_op(self._graph, node_def, inputs,
(pid=3113684)   File "/home/cy/anaconda3/envs/rl/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1856, in _create_c_op
(pid=3113684)     raise ValueError(str(e))
(pid=3113684) ValueError: Shape must be rank 2 but is rank 1 for '{{node nt1/model_1/concatenate/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](nt1/model_1/Cast, nt1/model_1/Cast_1, nt1/model_1/concatenate/concat/axis)' with input shapes: [?,6], [?], [].

can you help to see this?