How can I synchronization metrics in `ray.train` valid_loop

When evaluate the model in valid_loop, I get different metrics(such as loss, accuracy, and so on) from different ranks in the ray clusters. Do we have a simple solution to do like pytorch all_reduce work, so that I can get the average loss in Rank0.
I tried @ray.remote, but I meet 2 questions:

  1. I don’t know how to wait all the ranks finish the evaluate, then I can do the all_reduce work;
  2. where to reset the remote object(for calculate metrics) before each epoch.

Do I understand correctly that you would like to simply average the results from all ranks? Unfortunately we do not have a built-in way to do that just now, but it is on our backlog.

You should be able to use pytorch communication primitives directly in the training function, however, as that function is spawned in separate Ray processes for each worker, meaning you do not need to create any extra ray.remote workers.

The solution you said works, Thank you very much, Thanks for your wonderful product!

@Yard1 Thanks for your help, I want to ask another question, I can use the all_reduce in TorchTrainer, it works well . but it didn’t work when I use ray.tune, the error log:

dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1284, in all_reduce
    default_pg = _get_default_group()
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 410, in _get_default_group
    raise RuntimeError(
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

How could I sovle this problem?

That’s because Train automatically starts the PyTorch distributed backend, while Tune doesn’t. If you want to use PyTorch distributed methods, you should use them with Train.

What is that you are trying to accomplish here? Are you trying to communicate between Tune trials?

My situation is like this: One trial may need 2 GPUs, in the valid_loop I need to calulate the average of the metrics from 2 GPUs. So I want to know do you have some solutions of all_reduce in tune?

Was this solved in How to make all use of the GPU memory in Ray.tune - #3 by Yard1?

I will try it, yesterday I was stucked by another problem.

Sometimes the trail didn’t realease the GPU memory, I find ray.tune.utils.wait_for_gpu in the documentation, but this function didn’t realease the GPU memory, right? Do you have suggestions how to realease GPU memory in the code? I tried torch.cuda.empty_cache() which didn’t solve it.

Can you post the code for the training function you are running in Tune?

1. Training Code

def train_func_per_worker(args):
    os.environ['TORCH_DISTRIBUTED_DEBUG'] = "DETAIL" 
    global best_ade
    args = dict_to_namespace(args)
    log_writer = LogWriter("./")
    worker_info = f"Rank{session.get_world_rank()}" if not args.use_ray_tune else f"Trial{session.get_trial_name()}"

    # Check if the GPU memory has been released
    # assert all([ray.tune.utils.wait_for_gpu(gpu_id=gpu_id,target_util=0.1) for gpu_id in ray.get_gpu_ids()]) == True

    # Data loading code
    train_dataloader, valid_dataloader, test_dataloader = get_dataloader(args) # get pytorch DataLoader
    if not args.evaluate:
        train_dataloader = train.torch.prepare_data_loader(train_dataloader)
        valid_dataloader = train.torch.prepare_data_loader(valid_dataloader, move_to_device=False)#,
    test_dataloader = train.torch.prepare_data_loader(test_dataloader, move_to_device=False)#

    # Create model
    if args.pretrained: 
        model = VectorNetModel(with_aux_loss=args.with_aux_loss)
        model.load_state_dict(ray.get(args.trained_state_dict_id))
    else:
        print(f"=> [{worker_info}]creating model VectorNetModel")
        model = VectorNetModel(with_aux_loss=args.with_aux_loss)
        if args.use_apollo_weight:
            model.load_state_dict(ray.get(args.apollo_weight_id), strict=False)

    model = train.torch.prepare_model(model)
    loss_fn = VectorLoss(alpha=args.loss_alpha, with_aux_loss = args.with_aux_loss,y_alpha=args.loss_y_alpha)
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=args.decay_lr_every, gamma=args.decay_lr_factor)

    # optionally resume from a checkpoint
    if args.resume_checkpoint: 
        checkpoint = ray.get(args.resume_checkpoint_id)
        args.start_epoch = checkpoint['epoch']+1
        best_ade = checkpoint['best_ade']
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model_state_dict = checkpoint['model']
        if session.get_world_size() >= 2:
            model_state_dict = {'module.'+k : v for k, v in model_state_dict.items()}
        model.load_state_dict(model_state_dict)
        print("=> loaded checkpoint '{}' (epoch {} best_ade {})"
                .format(args.resume_checkpoint, checkpoint['epoch']+1, checkpoint['best_ade']))

    if args.evaluate: 
        start_time = time.time()
        avg_ADE, avg_FDE, avg_MR = evaluate_loop(test_dataloader,  model, loss_fn, epoch=0, args=args, 
                                        is_test_dataloader=True, log_writer=log_writer,
                                        calc_detail_metrics = args.calc_detail_metrics)
        print(f"ADE: {avg_ADE},  FDE: {avg_FDE}, MR: {avg_MR}")
        print(f"Used Time: {time.time() - start_time}s")
        return

    for epoch in range(args.start_epoch,args.epochs):#
        start_time = time.time()
        train_loop(train_dataloader, model, loss_fn, optimizer,epoch, args,log_writer)
        # # session.report({}, checkpoint=TorchCheckpoint.from_model(model))
        scheduler.step()
        end_train_time = time.time()
        torch.cuda.empty_cache() 
        gc.collect()

        avg_ADE, avg_FDE, avg_MR = evaluate_loop(valid_dataloader, model, loss_fn, epoch, args, 
                                        is_test_dataloader = False, log_writer=log_writer)
        end_test_time = time.time()
        torch.cuda.empty_cache() 
        gc.collect()

        is_best = (best_ade > avg_ADE)
        best_ade = min(best_ade, avg_ADE)

        if args.ray_num_workers >= 2:
            trained_model_state_dict = model.module.state_dict()
        else:
            trained_model_state_dict = model.state_dict()
        ckt_dict = {
                'epoch': epoch,
                # 'arch': args.arch,
                'model': trained_model_state_dict,
                'best_ade': best_ade,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }

        if args.use_ray_tune: 
            # args.ray_num_workers==1 or (args.ray_num_workers >= 2 and session.get_world_rank()==0)::
            save_checkpoint(ckt_dict, is_best=is_best, filename=f'vectornet_vehicle_model_epoch{epoch}_ADE{avg_ADE:.4f}.pt')
            if is_best:
                checkpoint = Checkpoint.from_directory("./")
                session.report({"ADE": avg_ADE}, checkpoint=checkpoint)
            else:
                session.report({"ADE": avg_ADE})
            torch.cuda.empty_cache() 
            gc.collect()
            continue

        avg_ADE_test, avg_FDE_test, avg_MR_test = evaluate_loop(test_dataloader,  model, loss_fn, epoch, args, 
                                        is_test_dataloader=True, log_writer=log_writer,
                                        calc_detail_metrics = args.calc_detail_metrics)
        end_metric_time = time.time()
        torch.cuda.empty_cache() 
        gc.collect()

        # write log
        log_writer.write_train_params(epoch, start_time, end_train_time, end_test_time, end_metric_time, 
                                      avg_ADE, avg_FDE, avg_MR, scheduler)

        if session.get_world_rank() == 0:
            save_checkpoint(ckt_dict,is_best=is_best, filename=f'vectornet_vehicle_model_epoch{epoch}_ADE{avg_ADE:.4f}.pt')
        session.report( 
            {"Valid_ADE":avg_ADE, "Valid_FDE":avg_FDE, "Valid_MR":avg_MR,
            "Test_ADE":avg_ADE_test, "Test_FDE":avg_FDE_test, "Test_MR":avg_MR_test})

        filtered_args = {k:v for k,v in vars(args).items() if not isinstance(v, ray._raylet.ObjectRef)}
        log_writer.write_hparams(log_writer, filtered_args, best_ade)
    log_writer.close()
    torch.cuda.empty_cache() 
    gc.collect()
    print(f"[rank {session.get_world_rank()}] Done!") 


def train_loop(dataloader, model, loss_fn, optimizer,epoch,args,log_writer): 
    # size = len(dataloader.dataset) // session.get_world_size()  # Divide by word size
    worker_info = f"rank{session.get_world_rank()}" if not args.use_ray_tune else f"trial{session.get_trial_name()}"
    batch_time = AverageMeter('Batch_Time', ':6.3f')
    data_time = AverageMeter('Data_Load_Time', ':6.3f')
    Losses = AverageMeter('Loss', ':.4e')
    Aux_Loss = AverageMeter('Aux_Loss', ':.4e')
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, data_time, Losses],
        prefix=f"Epoch{epoch + 1}({worker_info}): ")
    
    model.train()
    end = time.time()
    for batch, (X, y) in enumerate(dataloader):
        # measure data loading time
        data_time.update(time.time() - end) 
        # Compute prediction error
        pred, aux_true, aux_pred = model(X['target_obstacle_pos'],
                                         X['target_obstacle_pos_step'],
                                         X['vector_data'],
                                         X['vector_mask'],
                                         X['polyline_mask'],
                                         X['rand_mask'],
                                         X['polyline_id'])
        loss, aux_loss = loss_fn(pred, y['future_traj'], aux_pred, aux_true, true_mask=None )#

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        Losses.update(loss.item(), X['target_obstacle_pos'].shape[0])
        Aux_Loss.update(aux_loss.item(), X['target_obstacle_pos'].shape[0])

        if batch % args.print_freq == 0:
            progress.display(batch + 1)
            # loss, current = loss.item(), batch * len(X)

    print(f"[{worker_info}] Train Avg loss: {Losses.avg:>6e} ")

    if log_writer:
        log_writer.add_scalars("Loss",{"Train_Loss":loss.item()}, epoch)
        log_writer.add_scalars("Loss",{"Train_Aux_Loss":aux_loss.item()}, epoch)
        log_writer.add_scalars("Used time/Batch_Time",{"Train":batch_time.avg}, epoch)
        log_writer.add_scalars("Used time/Data_Load_Time",{"Train":data_time.avg}, epoch)
    
    torch.cuda.empty_cache()
    gc.collect()


def evaluate_loop(dataloader, model, loss_fn, epoch, args, is_test_dataloader = True,log_writer=None,calc_detail_metrics = False):
    def run_test_loop(dataloader, base_progress=0):
        # print(f"evalute loop rank:{session.get_world_rank()} dataset shape:{len(dataloader.dataset)}  batch_size:{args.batch_size}  world_size:{session.get_world_size()}")
        forecasted_trajectories, gt_trajectories,keys = {}, {}, []
        seq_id = 0

        max_n_guesses = 1 
        horizon = 30 
        with torch.no_grad():
            end = time.time()
            for i, (X, y) in enumerate(dataloader):
                # test_loss = 0 #initialization
                batch_size = len(X['target_obstacle_pos'])
                i = base_progress + i

                gt = y["future_traj"].cpu().numpy() 
                for k in input_keys:
                    if k != 'key': # bytes dtype, move to GPU will show a warning
                        X[k] = X[k].cuda(args.device, non_blocking=True)
                y['future_traj'] = y['future_traj'].cuda(args.device, non_blocking=True)
                y['label_mask'] = y['label_mask'].cuda(args.device, non_blocking=True)
                    
                data_time.update(time.time() - end)
                pred = model(X['target_obstacle_pos'],
                             X['target_obstacle_pos_step'],
                             X['vector_data'],
                             X['vector_mask'],
                             X['polyline_mask'],
                             X['rand_mask'],
                             X['polyline_id'])
                
                dim_out = len(pred.shape)
                # pred_y = pred.unsqueeze(dim_out).view((batch_size, max_n_guesses, horizon, 2)).cumsum(axis=2).cpu().numpy()
                pred_y = pred.unsqueeze(dim_out).view((batch_size, max_n_guesses, horizon, 2)).cpu().numpy()

                # record the prediction and ground truth
                for batch_id in range(batch_size):
                    if args.with_metric_gt_traj_eq_30 and sum(~y['label_mask'][batch_id]) < 30:
                        continue 
                    forecasted_trajectories[seq_id] = [pred_y_k for pred_y_k in pred_y[batch_id]]
                    gt_trajectories[seq_id] = gt[batch_id]
                    keys.append(X['key'][batch_id].decode('ascii')) 
                    seq_id += 1
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % args.print_freq == 0:
                    progress.display(i + 1)
                
                # calc metrics
                if i % 100 == 0 or i == math.floor(len(dataloader.dataset)/(args.batch_size//args.ray_num_workers)):#i从0开始
                    metric_results = get_displacement_errors_and_miss_rate( 
                        forecasted_trajectories,
                        gt_trajectories,
                        max_n_guesses,
                        horizon,
                        args.miss_threshold
                    )
                    ADE.update(metric_results["minADE"], seq_id)
                    FDE.update(metric_results["minFDE"], seq_id)
                    MR.update(metric_results["MR"], seq_id)
                    metric_results = get_xy_displacement_errors( 
                        forecasted_trajectories,
                        gt_trajectories,
                        max_n_guesses,
                        horizon,
                    )
                    ADE_x.update(metric_results['ADE_x'], seq_id)
                    ADE_y.update(metric_results['ADE_y'], seq_id)
                    if calc_detail_metrics:
                        get_detail_evaluate(forecasted_trajectories,gt_trajectories,keys, first_call=(i==0))
                    forecasted_trajectories, gt_trajectories,keys = {}, {}, []
                    seq_id = 0

    worker_info = f"rank{session.get_world_rank()}" if not args.use_ray_tune else f"trial{session.get_trial_name()}"
    batch_time = AverageMeter('Batch_Time', ':6.3f', Summary.NONE)
    data_time = AverageMeter('Data_Load_Time', ':6.3f', Summary.NONE)
    ADE = AverageMeter('minADE', ':6.2f', Summary.AVERAGE)
    FDE = AverageMeter('minFDE', ':6.2f', Summary.AVERAGE)
    MR  = AverageMeter('MR', ':2.2f', Summary.AVERAGE)
    ADE_y = AverageMeter('ADE_y', ':6.2f', Summary.AVERAGE)#
    ADE_x = AverageMeter('ADE_x', ':6.2f', Summary.AVERAGE)#
    prefix = f'Test({worker_info}): ' if is_test_dataloader else f'Valid({worker_info}): '
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, data_time],
        prefix=prefix)

    # switch to evaluate mode 
    model.eval()#
    run_test_loop(dataloader)   
    
    if session.get_world_size() >= 2:
        batch_time.all_reduce(train.torch.get_device())
        data_time.all_reduce(train.torch.get_device())
        ADE.all_reduce(train.torch.get_device())
        FDE.all_reduce(train.torch.get_device())
        MR.all_reduce(train.torch.get_device())
        ADE_x.all_reduce(train.torch.get_device())
        ADE_y.all_reduce(train.torch.get_device())
    print(f"[{worker_info}] ADE:{ADE.sum:.4f}/{ADE.count}  FDE:{FDE.sum:.4f}/{FDE.count}  MR:{MR.sum:.2f}/{MR.count}")

    prefix = "Test" if is_test_dataloader else "Valid"

    if log_writer:
        log_writer.add_scalars("Metrics/ADE",{f"{prefix}":ADE.avg},epoch)
        log_writer.add_scalars("Metrics/FDE",{f"{prefix}":FDE.avg},epoch)
        log_writer.add_scalars("Metrics/MR",{f"{prefix}":MR.avg},epoch)
        log_writer.add_scalars("Used time/Batch_Time",{f"{prefix}":batch_time.avg}, epoch)
        log_writer.add_scalars("Used time/Data_Load_Time",{f"{prefix}":data_time.avg}, epoch)
        log_writer.add_scalars("Metrics/ADE_x",{f"{prefix}":ADE_x.avg},epoch)
        log_writer.add_scalars("Metrics/ADE_y",{f"{prefix}":ADE_y.avg},epoch)
        
    torch.cuda.empty_cache()
    gc.collect()
        
    return ADE.avg, FDE.avg, MR.avg

Remarks:
train.torch.prepare_data_loader for valid_dataset, I set move_to_device =False, because I want to silent the annoying warning: INFO train_loop_utils.py:617 -- Data type <class 'bytes'> doesn't support being moved to device. I also test move_to_device =True, GPU out of memory still exists

2. Tuner status:

+--------------------------+------------+---------------------+------------------------+------------------------+------------------------+--------+------------------+----------+--------------+---------------------+
| Trial name               | status     | loc                 |   train_loop_config/lo |   train_loop_config/lo |   train_loop_config/lr |   iter |   total time (s) |      ADE |   _timestamp |   _time_this_iter_s |
|                          |            |                     |               ss_alpha |             ss_y_alpha |                        |        |                  |          |              |                     |
|--------------------------+------------+---------------------+------------------------+------------------------+------------------------+--------+------------------+----------+--------------+---------------------|
| TorchTrainer_7ef95_00000 | TERMINATED | 10.20.84.14:970700  |              0.261408  |               1.41171  |            1.24188e-06 |     16 |         2064.36  |  3.9124  |   1670391475 |             131.881 |
| TorchTrainer_7ef95_00001 | TERMINATED | 10.20.84.14:970899  |              0.735604  |               3.98855  |            5.80859e-06 |     16 |         2077.1   |  2.5384  |   1670391496 |             125.745 |
| TorchTrainer_7ef95_00002 | TERMINATED | 10.20.84.14:997902  |              0.0320946 |               4.9361   |            7.16022e-05 |     12 |         1528.58  |  2.4324  |   1670393013 |             125.455 |
| TorchTrainer_7ef95_00003 | TERMINATED | 10.20.84.14:1001567 |              0.878715  |               1.98428  |            8.33248e-05 |     16 |         1991.18  |  1.8629  |   1670393491 |             121.188 |
| TorchTrainer_7ef95_00004 | TERMINATED | 10.20.84.14:1020884 |              0.559817  |               4.68423  |            1.37275e-06 |      3 |          389.758 | 13.4697  |   1670393405 |             124.802 |
| TorchTrainer_7ef95_00008 | TERMINATED | 10.20.84.14:1040897 |              0.905132  |               0.576755 |            1.93009e-05 |      3 |          376.791 |  2.52287 |   1670393870 |             119.155 |
| TorchTrainer_7ef95_00005 | ERROR      | 10.20.84.14:1028279 |              0.626533  |               2.23601  |            1.26431e-05 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00006 | ERROR      | 10.20.84.14:1032359 |              0.384273  |               0.736505 |            2.80126e-05 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00007 | ERROR      | 10.20.84.14:1036761 |              0.0195909 |               0.113402 |            0.000780299 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00009 | ERROR      | 10.20.84.14:1041261 |              0.908589  |               4.07545  |            0.00308423  |        |                  |          |              |                     |
| TorchTrainer_7ef95_00010 | ERROR      | 10.20.84.14:1049156 |              0.27278   |               2.77148  |            6.73293e-06 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00011 | ERROR      | 10.20.84.14:1053239 |              0.894232  |               1.35123  |            1.7537e-05  |        |                  |          |              |                     |
| TorchTrainer_7ef95_00012 | ERROR      | 10.20.84.14:1057358 |              0.674399  |               4.52086  |            6.76853e-06 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00013 | ERROR      | 10.20.84.14:1061615 |              0.0822584 |               2.55529  |            1.46401e-05 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00014 | ERROR      | 10.20.84.14:1065840 |              0.246589  |               2.65665  |            8.83962e-06 |        |                  |          |              |                     |
| TorchTrainer_7ef95_00015 | ERROR      | 10.20.84.14:1069923 |              0.569313  |               2.88062  |            0.00231894  |        |                  |          |              |                     |
+--------------------------+------------+---------------------+------------------------+------------------------+------------------------+--------+------------------+----------+--------------+---------------------+

3. Error:

Failure # 1 (occurred at 2022-12-07_14-10-40)
[36mray::_Inner.train()[39m (pid=1028279, ip=10.20.84.14, repr=TorchTrainer)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 355, in train
    raise skipped from exception_cause(skipped)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(RuntimeError): [36mray::RayTrainWorker._RayTrainWorker__execute()[39m (pid=1028423, ip=10.20.84.14, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fd644337340>)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/ray/train/_internal/worker_group.py", line 31, in __execute
    raise skipped from exception_cause(skipped)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "vectornet_train_ray.py", line 402, in train_func_per_worker
    train_loop(train_dataloader, model, loss_fn, optimizer,epoch, args,log_writer)
  File "vectornet_train_ray.py", line 173, in train_loop
    pred, aux_true, aux_pred = model(X['target_obstacle_pos'],
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zetlin/Prediction_ML/prediction/model/vectornet_no_clf.py", line 101, in forward
    enc_out, aux_true, aux_pred = self.vector_net_encoder(data, v_mask, p_id, polyline_mask,rand_mask)
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zetlin/Prediction_ML/prediction/model/vectornet.py", line 148, in forward
    data0 = self.sub_graph(input1, mask0).view(batch_size,polyline_num,256)#TODO reshape
  File "/home/zetlin/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zetlin/Prediction_ML/prediction/model/vectornet.py", line 61, in forward
    x = data.masked_fill(mask, -float("inf"))
RuntimeError: CUDA out of memory. Tried to allocate 1.38 GiB (GPU 0; 11.77 GiB total capacity; 5.64 GiB already allocated; 577.31 MiB free; 6.33 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

4. after the tuner finished, type nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:17:00.0 Off |                  N/A |
| 30%   49C    P2   102W / 350W |   2613MiB / 12053MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:CA:00.0 Off |                  N/A |
| 30%   31C    P8    15W / 350W |      5MiB / 12053MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2403      G   /usr/lib/xorg/Xorg                  4MiB |
|    1   N/A  N/A      2403      G   /usr/lib/xorg/Xorg                  4MiB |
+-----------------------------------------------------------------------------+

pytorch all_reduce works well in TorchTrainer, thanks a lot.

1 Like

Hi, @Yard1,The problem GPU memory didn’t realease in some tuner trials still exists, Please help me check my code if you have time. Thanks a lot!

Can you set Tuner(tune_config=tune.TuneConfig(reuse_actors=False))?

I test in a small dataset, your solution works well for the GPU memory release problem. You help me solve many problems in recent these days, words can’t express how grateful I am to you !

1 Like