How severe does this issue affect your experience of using Ray?
- Low: It annoys or frustrates me for a moment.
I have a simple torch xla based dataparallel trainer code. It works for ray 2.3.1 but not for ray 2.4.0. This process does not quit while using ray 2.4.0 : python3 -m torch_neuronx.distributed._xrt_run_server --port 48805 --pid_to_track 52718 and subsequent runs of the program fails with neuron device busy error. What changed in ray 2.4.0 and how do i address this issue in the latest release? If I manually kill this process I can rerun the code again successfully. I have attached the code snippet here:
import ray
from ray.air import session
from ray.train.data_parallel_trainer import DataParallelTrainer
from dataclasses import dataclass
from ray.train.backend import Backend, BackendConfig
from ray.train._internal.utils import get_address_and_port
import os
import uuid
import time
# refer to https://docs.ray.io/en/latest/train/api/doc/ray.train.data_parallel_trainer.DataParallelTrainer.html#ray.train.data_parallel_trainer.DataParallelTrainer
# https://docs.ray.io/en/latest/train/api/doc/ray.train.trainer.BaseTrainer.html#ray.train.trainer.BaseTrainer
def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
# session.report(...)
dataset_shard_for_this_worker = session.get_dataset_shard("train")
os.environ["LOCAL_RANK"] = str(session.get_local_rank())
os.environ["RANK"] = str(session.get_world_rank())
os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size())
os.environ["WORLD_SIZE"] = str(session.get_world_size())
# os.environ["NODE_RANK"] = str(session.get_node_rank())
os.environ["GROUP_RANK"] = str(session.get_node_rank())
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
torch.distributed.init_process_group('xla')
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def cleanup():
torch.distributed.destroy_process_group()
def train_fn():
device = xm.xla_device()
rank = xm.get_ordinal()
# Create the model and move to device
model = Model().to(device)
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
num_iterations = 100
for step in range(num_iterations):
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
xm.mark_step()
if rank == 0:
print(f"Loss after step {step}: {loss.cpu()}")
time.sleep(5)
train_fn()
print("XLA ", xm.get_ordinal())
print("XLA ", xm.xrt_world_size())
print(os.environ)
cleanup()
# refer to https://docs.ray.io/en/master/train/api/doc/ray.train.backend.Backend.html?highlight=on_training_start
# https://docs.ray.io/en/master/_modules/ray/train/torch/config.html
class MyBackend(Backend):
def on_start(self, worker_group, backend_config):
def set_env_var(env_var_value, addr, port):
import os
os.environ["TORCHELASTIC_RUN_ID"] = env_var_value
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
master_addr, master_port = worker_group.execute_single(
0, get_address_and_port
)
worker_group.execute(set_env_var, backend_config.env_var, addr=master_addr, port=master_port)
@dataclass
class MyBackendConfig(BackendConfig):
env_var: str = str(uuid.uuid4())
@property
def backend_cls(self):
return MyBackend
class MyTrainer(DataParallelTrainer):
def __init__(self, train_loop_per_worker, my_backend_config:
MyBackendConfig, **kwargs):
super().__init__(
train_loop_per_worker,
backend_config=my_backend_config, **kwargs)
train_dataset = ray.data.from_items([1, 2, 3])
print(train_dataset.count())
assert train_dataset.count() == 3
trainer = MyTrainer(
train_loop_per_worker,
MyBackendConfig(),
scaling_config=ray.air.config.ScalingConfig(num_workers=2),
datasets={"train": train_dataset},
)
result = trainer.fit()
ray.shutdown()