Very slow gradient descent on remote workers

I am using ray.tune to run ImpalaTrainer trainables. For some reason, gradient descent becomes unbearably slow using remote workers. However, local workers ray.init(local_mode=true) do not seem to have this problem. The chart below shows two identical runs, one with local_mode=True and one with local_mode=False. I use ray.tune to run them:


Looking at the relative wall-clock time, it seems like the grad time starts reasonably but keeps increasing on each run:



base_model = {
    "custom_model": RayObsGraph,
    "custom_model_config": {
    "max_seq_len": 8,
rnn_model = {
    "use_lstm": True,
    "max_seq_len": 8,
    "lstm_cell_size": 16
    "ray": {
        "env_config": {"dim": 4, "max_items": 4, "max_queries": 4},
        # These are rllib/ray specific
        "framework": "torch",
        "model": grid_search([base_model, rnn_model),
        "num_workers": 16,
        # Total GPU usage: num_gpus (trainer proc) + num_gpus_per_worker (workers)
        "num_cpus_per_worker": 2,
        "num_envs_per_worker": 1,
        # this corresponds to the number of learner GPUs used,
        # not the total used for the environments/rollouts
        "num_gpus": 1,
        # Size of batches (in timesteps) placed in the learner queue
        "rollout_fragment_length": 16,
        # Total number of timesteps to train per batch
        "train_batch_size": 512,
        "lr": 0.0001,
        "env": RecallEnv.__name__,

I use a custom TorchV2Model model that utilizes state, so I’m wondering if this has to do with computing gradients from past states.

I plotted the computation graph and reduced the complexity of my backwards pass by replacing:
out =[gnn_out[b, node_idx[b]] for b in range(batch)])
gnn_out[torch.arange(B, device=flat.device), node_idx.squeeze()]
reduces the gradient time by a factor of 10. However, running ray with ray_init(local_mode=True) is still much faster:

Logit computation graph for posterity:


To anyone else who runs into this issue. I seem to have solved the problem by decreasing the number of workers from 10. I have ample memory free, so it isn’t a disk swapping issue.


The issue is not the gradients or computation graph, but the loss computation. I’ve narrowed the problem down to the loss computation in line 434

loss_out = force_list(
            self._loss(self, self.model, self.dist_class, train_batch))

This occurs both with and without vtrace.

Looking at A3CTorchPolicy the slowdown occurs when during in model.from_batch(train_batch) where it calls model.__call__ during with the sample batch. I will investigate further.

@sven1977 please take a look

Just to follow up in case others run into this issue: The issue seems to be with pytorch, probably due to the GPU scheduler. I’ve found at some point, the models will simply experience a 10-100x increase in forward pass time. I’ve dropped into the debugger when the slowdown occurred and fed zeros through the network to verify this. Flushing the torch GPU cache, upgrading to torch-1.8.2, and other various approaches do not appear to fix the issue.

I’ve found this issue only occurs if the ray trainers get more rollouts than they can process. If the rollout queue is quickly emptied and remains empty most of the time, this issue does not seem to occur.

TL;DR If you run into this issue, either decrease the number of workers/envs to reduce the rate at which rollouts are produced, or make your model more efficient.

1 Like

@sven1977 this is the same issue as in this discussion: [RLlib] Ray trains extremely slow when learner queue is full

Very cool, thanks so much for digging into this @smorad and finding the bug on the torch end! And for updating the posts with the links @Bam4d !

@sven1977 I have done some more looking into this and I have come up with a workaround that may be the solution to this issue.

Basically what I think is happening is the pytorch gpu scheduler needs to interrupt a thread to either read/write/execute GPU instructions. And to do this the thread needs to be in an interruptible state. In rllib code, the python actually spinlocks (not interruptible) if the queue is empty for example in ray/ at master · ray-project/ray · GitHub

    def __call__(self, x: Any) -> Any:
        except queue.Full:
            return _NextValueNotReady()
        return x

The trick I found that works is to put a tiny “sleep(0.0001)” which actually sleeps and releases the thread at the OS level (which is interruptible) in the except and the problem just disappears. I think doing a put with a timeout might also work.

I dont know if there any many more places that spinlocks are being used to wait on queues, but this is generally a bad practice unless in real-time-critical code (which this is not).

I’m not entirely sure if this causes the workers to “run slower” though and this would need to be a backpressure mechanism to avoid memory leaks where the environments just keep creating more and more data.

I’ll have more of a play with this today and see if I can get a PR out.

1 Like

time.sleep(0) should yield without adding any additional delays:

I think adding this before the return with a comment explaining why would be an ideal solution. If the queue is full, then the thread should yield its quantum to a thread that can actually make use of the computation.

1 Like

Ah this is great, I did not know this :smiley:

I’ll try this out and see if it works.
If it works, ill link the PR here.

So the sleep(0) does not work but a 1ms sleep works perfectly.

The only issue is that the sample throughput is calculated incorrectly as the timer does not take into account that if queues are full there is practially no wait time for subsequent batches. I’ve treated this as a different bug so have not included it as a fix here.

1 Like

Awesome, this has been merged! :slight_smile:

Thanks @Bam4d and @smorad for your invaluable help here!

Actually, @Bam4d @smorad, this PR caused our TD3 tests (agents/ddpg/tests/ to time out very often now.

Try this one here:

I think the sleep should only be added to learner_thread-using algos (e.g. IMPALA), not in general to the _NextValueNotReady exception.

@smorad @Bam4d , could you let me know, whether you are still seeing the slowdown even with this new PR? This also seems to fix TD3 flakiness/timeouts, which was probably due to the added sleep.