Very slow gradient descent on remote workers

I am using ray.tune to run ImpalaTrainer trainables. For some reason, gradient descent becomes unbearably slow using remote workers. However, local workers ray.init(local_mode=true) do not seem to have this problem. The chart below shows two identical runs, one with local_mode=True and one with local_mode=False. I use ray.tune to run them:

image

Looking at the relative wall-clock time, it seems like the grad time starts reasonably but keeps increasing on each run:

image

Config:

base_model = {
    "custom_model": RayObsGraph,
    "custom_model_config": {
    },
    "max_seq_len": 8,
}
rnn_model = {
    "use_lstm": True,
    "max_seq_len": 8,
    "lstm_cell_size": 16
}
    "ray": {
        "env_config": {"dim": 4, "max_items": 4, "max_queries": 4},
        # These are rllib/ray specific
        "framework": "torch",
        "model": grid_search([base_model, rnn_model),
        "num_workers": 16,
        # Total GPU usage: num_gpus (trainer proc) + num_gpus_per_worker (workers)
        "num_cpus_per_worker": 2,
        "num_envs_per_worker": 1,
        # this corresponds to the number of learner GPUs used,
        # not the total used for the environments/rollouts
        "num_gpus": 1,
        # Size of batches (in timesteps) placed in the learner queue
        "rollout_fragment_length": 16,
        # Total number of timesteps to train per batch
        "train_batch_size": 512,
        "lr": 0.0001,
        "env": RecallEnv.__name__,
    },

I use a custom TorchV2Model model that utilizes state, so I’m wondering if this has to do with computing gradients from past states.

I plotted the computation graph and reduced the complexity of my backwards pass by replacing:
out = torch.cat([gnn_out[b, node_idx[b]] for b in range(batch)])
with
gnn_out[torch.arange(B, device=flat.device), node_idx.squeeze()]
reduces the gradient time by a factor of 10. However, running ray with ray_init(local_mode=True) is still much faster:
image

Logit computation graph for posterity:

0

To anyone else who runs into this issue. I seem to have solved the problem by decreasing the number of workers from 10. I have ample memory free, so it isn’t a disk swapping issue.

2 Likes

The issue is not the gradients or computation graph, but the loss computation. I’ve narrowed the problem down to the loss computation in torch_policy.py line 434

loss_out = force_list(
            self._loss(self, self.model, self.dist_class, train_batch))

This occurs both with and without vtrace.

Looking at A3CTorchPolicy the slowdown occurs when during in model.from_batch(train_batch) where it calls model.__call__ during with the sample batch. I will investigate further.

@sven1977 please take a look