Size mismatch when training PPO in pettingzoo environment

As title said, I’m getting size mismatch error occasionally. The error does not occur everytime.

I’m using a customized model for parameteristic action_space"

class TorchMaskedActions(DQNTorchModel):
    """PyTorch version of above ParametricActionsModel."""

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 **kw):
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kw)

        obs_len = obs_space.shape[0]-action_space.n

        orig_obs_space = Box(shape=(obs_len,), low=obs_space.low[:obs_len], high=obs_space.high[:obs_len])
        self.action_embed_model = TorchFC(orig_obs_space, action_space, action_space.n, model_config, name + "_action_embed")

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the predicted action embedding
        action_logits, _ = self.action_embed_model({
            "obs": input_dict["obs"]['observation']
        })
        # turns probit action mask into logit action mask
        inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)

        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()

The configration setting is:

alg_name = "PPO"
config = deepcopy(get_agent_class(alg_name)._default_config)
config['vf_share_layers'] = True
config['vf_share_layers'] = True
config["multiagent"] = {
    "policies": {
        "attacker_0": (None, obs_space, act_space, {}),
        "defender_0": (None, obs_space, act_space, {}),
    },
    "policy_mapping_fn": lambda agent_id: agent_id,
    "count_steps_by": "env_steps",
}

config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["log_level"] = "INFO"
config["num_workers"] = 1
config["rollout_fragment_length"] = 30
config["train_batch_size"] = 200
config["horizon"] = 200
config["no_done_at_end"] = False
config["framework"] = "torch"
config["model"] = {
    "custom_model": "pa_model",
    'vf_share_layers': True
}
config['env'] = "cyberEnv-v0"

The tune call is:

analysis = tune.run(
    alg_name,
    name="PPO",
    stop={"training_iteration": 1000000},
    checkpoint_freq=10,
    config=config,
    checkpoint_at_end=True
)

The error message I got is:

  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 934, in _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches,
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 244, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "<ipython-input-3-ccce39a25015>", line 24, in forward
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 244, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 124, in forward
    self._features = self._hidden_layers(self._last_flat_in)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/ray/rllib/models/torch/misc.py", line 160, in forward
    return self._model(x)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/yinuod/anaconda3/envs/rllib/lib/python3.8/site-packages/torch/nn/functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [32 x 32], m2: [21092 x 256] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41