Running into AttributeError: 'TorchCategorical' object has no attribute 'log_prob' when training MAPPO in a unity scene

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I am using rllib to train MAPPO in a scene adapted from dodgeball as I added waypoints into the scene and let the agents move from one point to another and my action space is therefore set to be discrete. I used the unity_3d_local.py file and unity_3d_env.py file as the basic code framework and I modified the action spec as well as the observation spec to suit my task. As the number of the discrete action branches is 1 and the branch size is 9, I set the action spec to be MultiDiscrete([9]). But when I run the code, it gives me this bug: AttributeError: ‘TorchCategorical’ object has no attribute ‘log_prob’.

If I use the training API provided by Unity and only use Unity to train, it works.

How can I solve the problem?

Okay, I solved it partially. In short: Use Tensorflow as the framework.

During solving this bug, I think Rllib’s support for torch, especially when dealing with MultiDiscrete input, seems to have certain problems.

As I am modifying from this
example script from rllib’s repo, I set the default framework of training to be torch as 1) the algorithm that I am going to train is implemented in torch, and 2)judging from the default parameter setting, torch seems to be the default choice.

Then it gives me the previously mentioned error:

AttributeError: ‘TorchCategorical’ object has no attribute ‘log_prob’.

However, from the torch repo, ‘TorchCategorical’ has this attribute.

After going through the scripts, I couldn’t solve the problem. But then when I looked back at the original example script, I found the actual default framework is Tensorflow ‘tf’. And I tested using ‘tf’ as my framework, the training process was successfully executed.

image

I also tested other example scenes released by Unity’s ml-agent, setting the framework to be ‘torch’ instead of ‘tf’, and the same error appeared again if the action space is MultiDiscrete.

The tutorial for this test can be found here:Unity-and-rllib

As my algorithm implementation is based on torch, I might still need to find a solution to this problem…But if you just want to play around with rllib’s included algorithms, just modify from the example script and choose Tensorflow! :slight_smile:

image
And this is the parameter setting, looks like the default framework is ‘torch’ which led me to the wrong direction…

So after digging into the codes, I have the solution for using torch as the framework for training.

The “TorchCategorical” appeared in the error is not implemented by torch but a wrapper implemented by rllib. You can find the definition of TorchCategorical here at line 67 of rllib.models.torch.torch_distributions.py.

The way to solve this problem is by going to the TorchMultiCategorical class defined in the same file and modify the following line

logps = torch.stack([cat.log_prob(act) for cat, act in zip(self._cats, value)])

into this:

logps = torch.stack([cat._get_torch_distribution(logits=cat.logits).log_prob(act) for cat, act in zip(self._cats, value)])

And also change the TorchCategorical._get_torch_distribution function which is also in this torch.torch_distributions.py file into this:

 def _get_torch_distribution(
        self,
        probs: torch.Tensor = None,
        logits: torch.Tensor = None
    ) -> "torch.distributions.Distribution":
        assert (probs is None) != (
                logits is None
        ), "Exactly one out of `probs` and `logits` must be set!"
        if probs is not None:
            return torch.distributions.categorical.Categorical(probs)
        else:
            return torch.distributions.categorical.Categorical(logits=logits)

Also if experiencing error like this:

Traceback (most recent call last):
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/_private/worker.py", line 2547, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::PPO.train() (pid=702896, ip=192.168.0.38, actor_id=4759bace2b689d18ce33dc4101000000, repr=PPO)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 400, in train
    raise skipped from exception_cause(skipped)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 397, in train
    result = self.step()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 853, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 2838, in _run_one_training_iteration
    results = self.training_step()
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 448, in training_step
    train_results = self.learner_group.update(
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/learner_group.py", line 184, in update
    self._learner.update(
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/learner.py", line 1304, in update
    ) = self._update(nested_tensor_minibatch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/torch/torch_learner.py", line 365, in _update
    return self._possibly_compiled_update(batch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/torch/torch_learner.py", line 123, in _uncompiled_update
    loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/core/learner/learner.py", line 1024, in compute_loss
    loss = self.compute_loss_for_module(
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_learner.py", line 87, in compute_loss_for_module
    action_kl = prev_action_dist.kl(curr_action_dist)
  File "/home/alexpalms/miniconda3/envs/new-ray/lib/python3.8/site-packages/ray/rllib/models/torch/torch_distributions.py", line 327, in kl
    for cat, oth_cat in zip(self._cats, other.cats)
AttributeError: '<class 'ray.rllib.models.torch.torch_distributions' object has no attribute 'cats'

You can checkout this github issue