Neuron device busy with ray 2.4.0 after program quit successfully

How severe does this issue affect your experience of using Ray?

  • Low: It annoys or frustrates me for a moment.

I have a simple torch xla based dataparallel trainer code. It works for ray 2.3.1 but not for ray 2.4.0. This process does not quit while using ray 2.4.0 : python3 -m torch_neuronx.distributed._xrt_run_server --port 48805 --pid_to_track 52718 and subsequent runs of the program fails with neuron device busy error. What changed in ray 2.4.0 and how do i address this issue in the latest release? If I manually kill this process I can rerun the code again successfully. I have attached the code snippet here:

import ray
from ray.air import session
from ray.train.data_parallel_trainer import DataParallelTrainer
from dataclasses import dataclass
from ray.train.backend import Backend, BackendConfig
from ray.train._internal.utils import get_address_and_port
import os
import uuid
import time





# refer to https://docs.ray.io/en/latest/train/api/doc/ray.train.data_parallel_trainer.DataParallelTrainer.html#ray.train.data_parallel_trainer.DataParallelTrainer
# https://docs.ray.io/en/latest/train/api/doc/ray.train.trainer.BaseTrainer.html#ray.train.trainer.BaseTrainer

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.
    # session.report(...)

    dataset_shard_for_this_worker = session.get_dataset_shard("train")


    os.environ["LOCAL_RANK"] = str(session.get_local_rank())
    os.environ["RANK"] = str(session.get_world_rank())
    os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size())
    os.environ["WORLD_SIZE"] = str(session.get_world_size())
    # os.environ["NODE_RANK"] = str(session.get_node_rank())
    os.environ["GROUP_RANK"] = str(session.get_node_rank())

    import torch
    import torch.distributed as dist
    import torch.nn as nn
    from torch.nn.parallel import DistributedDataParallel as DDP
    import torch.optim as optim
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_backend
    import torch_xla.distributed.xla_multiprocessing as xmp
    torch.distributed.init_process_group('xla')

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.net1 = nn.Linear(10, 10)
            self.relu = nn.ReLU()
            self.net2 = nn.Linear(10, 5)

        def forward(self, x):
            return self.net2(self.relu(self.net1(x)))

    def cleanup():
        torch.distributed.destroy_process_group()

    def train_fn():
        device = xm.xla_device()
        rank = xm.get_ordinal()

        # Create the model and move to device
        model = Model().to(device)
        ddp_model = DDP(model, gradient_as_bucket_view=True)

        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
        num_iterations = 100
        for step in range(num_iterations):
            optimizer.zero_grad()
            outputs = ddp_model(torch.randn(20, 10).to(device))
            labels = torch.randn(20, 5).to(device)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            xm.mark_step()
            if rank == 0:
                print(f"Loss after step {step}: {loss.cpu()}")

    time.sleep(5)
    train_fn()
    print("XLA ", xm.get_ordinal())
    print("XLA ", xm.xrt_world_size())

    print(os.environ)
    cleanup()


# refer to https://docs.ray.io/en/master/train/api/doc/ray.train.backend.Backend.html?highlight=on_training_start
# https://docs.ray.io/en/master/_modules/ray/train/torch/config.html
class MyBackend(Backend):

    def on_start(self, worker_group, backend_config):
        def set_env_var(env_var_value, addr, port):
            import os
            os.environ["TORCHELASTIC_RUN_ID"] = env_var_value
            os.environ["MASTER_ADDR"] = addr
            os.environ["MASTER_PORT"] = str(port)

        master_addr, master_port = worker_group.execute_single(
            0, get_address_and_port
        )
        worker_group.execute(set_env_var, backend_config.env_var, addr=master_addr, port=master_port)




@dataclass
class MyBackendConfig(BackendConfig):
    env_var: str = str(uuid.uuid4())

    @property
    def backend_cls(self):
        return MyBackend


class MyTrainer(DataParallelTrainer):
    def __init__(self, train_loop_per_worker, my_backend_config:
    MyBackendConfig, **kwargs):
        super().__init__(
            train_loop_per_worker,
            backend_config=my_backend_config, **kwargs)


train_dataset = ray.data.from_items([1, 2, 3])
print(train_dataset.count())
assert train_dataset.count() == 3
trainer = MyTrainer(
    train_loop_per_worker,
    MyBackendConfig(),
    scaling_config=ray.air.config.ScalingConfig(num_workers=2),
    datasets={"train": train_dataset},
)
result = trainer.fit()
ray.shutdown()