Running into AttributeError: 'TorchCategorical' object has no attribute 'log_prob' when training MAPPO in a unity scene

I am using rllib to train MAPPO in a scene adapted from dodgeball as I added waypoints into the scene and let the agents move from one point to another and my action space is therefore set to be discrete. I used the file and file as the basic code framework and I modified the action spec as well as the observation spec to suit my task. As the number of the discrete action branches is 1 and the branch size is 9, I set the action spec to be MultiDiscrete([9]). But when I run the code, it gives me this bug: AttributeError: ‘TorchCategorical’ object has no attribute ‘log_prob’.

If I use the training API provided by Unity and only use Unity to train, it works.

How can I solve the problem?

Okay, I solved it partially. In short: Use Tensorflow as the framework.

During solving this bug, I think Rllib’s support for torch, especially when dealing with MultiDiscrete input, seems to have certain problems.

As I am modifying from this
example script from rllib’s repo, I set the default framework of training to be torch as 1) the algorithm that I am going to train is implemented in torch, and 2)judging from the default parameter setting, torch seems to be the default choice.

Then it gives me the previously mentioned error:

AttributeError: ‘TorchCategorical’ object has no attribute ‘log_prob’.

However, from the torch repo, ‘TorchCategorical’ has this attribute.

After going through the scripts, I couldn’t solve the problem. But then when I looked back at the original example script, I found the actual default framework is Tensorflow ‘tf’. And I tested using ‘tf’ as my framework, the training process was successfully executed.


I also tested other example scenes released by Unity’s ml-agent, setting the framework to be ‘torch’ instead of ‘tf’, and the same error appeared again if the action space is MultiDiscrete.

The tutorial for this test can be found here:Unity-and-rllib

As my algorithm implementation is based on torch, I might still need to find a solution to this problem…But if you just want to play around with rllib’s included algorithms, just modify from the example script and choose Tensorflow! :slight_smile:

And this is the parameter setting, looks like the default framework is ‘torch’ which led me to the wrong direction…

So after digging into the codes, I have the solution for using torch as the framework for training.

The “TorchCategorical” appeared in the error is not implemented by torch but a wrapper implemented by rllib. You can find the definition of TorchCategorical here at line 67 of

The way to solve this problem is by going to the TorchMultiCategorical class defined in the same file and modify the following line

logps = torch.stack([cat.log_prob(act) for cat, act in zip(self._cats, value)])

into this:

logps = torch.stack([cat._get_torch_distribution(logits=cat.logits).log_prob(act) for cat, act in zip(self._cats, value)])

And also change the TorchCategorical._get_torch_distribution function which is also in this file into this:

 def _get_torch_distribution(
        probs: torch.Tensor = None,
        logits: torch.Tensor = None
    ) -> "torch.distributions.Distribution":
        assert (probs is None) != (
                logits is None
        ), "Exactly one out of `probs` and `logits` must be set!"
        if probs is not None:
            return torch.distributions.categorical.Categorical(probs)
            return torch.distributions.categorical.Categorical(logits=logits)

Also if experiencing error like this:

Traceback (most recent call last):
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/air/execution/_internal/", line 110, in resolve_future
    result = ray.get(future)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/", line 2547, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::PPO.train() (pid=702896, ip=, actor_id=4759bace2b689d18ce33dc4101000000, repr=PPO)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/tune/trainable/", line 400, in train
    raise skipped from exception_cause(skipped)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/tune/trainable/", line 397, in train
    result = self.step()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/", line 853, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/", line 2838, in _run_one_training_iteration
    results = self.training_step()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/ppo/", line 448, in training_step
    train_results = self.learner_group.update(
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/", line 184, in update
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/", line 1304, in update
    ) = self._update(nested_tensor_minibatch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/torch/", line 365, in _update
    return self._possibly_compiled_update(batch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/torch/", line 123, in _uncompiled_update
    loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/", line 1024, in compute_loss
    loss = self.compute_loss_for_module(
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/ppo/torch/", line 87, in compute_loss_for_module
    action_kl = prev_action_dist.kl(curr_action_dist)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/models/torch/", line 327, in kl
    for cat, oth_cat in zip(self._cats, other.cats)
AttributeError: '<class 'ray.rllib.models.torch.torch_distributions' object has no attribute 'cats'

You can checkout this github issue