Segfault in torchtrainer for num_workers > 0 in dataloader

My pytorch ray training program is crashing with segfault in dataloader workers after 2 - 3 minutes. It works when num_workers = 0. I have stripped away all code and now there is a really bare bones Torchtrainer with a dummy iterabledataset, it still crashes when num_wokers > 0 in dataloader. It works otherwise for num_workers = 0 . I am using ray 2.3.1 and torch ‘1.13.1+cu117’. Has something changed? I have attached the code for reproducing the issue. I am using EC2 VM based setup. If i just move the dummy dataloader loop code to and submit it as a ray job, it still works. Ray 2.3.0 seems to work fine.
Code for reproduction:

import argparse
import os
import yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
#import torch.backends.cudnn as cudnn
#import torch.distributed as dist


from ray import train
import ray
from ray.air import session, RunConfig, Checkpoint, CheckpointConfig
from ray.train.torch import TorchCheckpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.tune.logger import TBXLoggerCallback
#import torchmetrics
import time
import os
import json
from os import listdir
from os.path import isfile, join
import boto3
from ray.tune.syncer import SyncConfig
import errno
import requests
from itertools import chain, cycle, islice, takewhile, repeat


def worker_init_fn(worker_id):

class MyIterableDataset(
def init(self):

 def __iter__(self):
     return  repeat([1, 2, 3])

def train_loop_per_worker(config: dict):
#train_dataset = load_data()
#dataset_length_in_batches = len(train_dataset)

train_dataset = MyIterableDataset()

train_loader =, num_workers=2, batch_size=None, persistent_workers=True,
                                     prefetch_factor=2, worker_init_fn=worker_init_fn, multiprocessing_context='fork')
for i, (query_image, catalog_image, text) in enumerate(train_loader):
    # TODO

if name == ‘main’:
config = {}
# NOTE refer to Distributed Deep Learning with Ray Train User Guide — Ray 2.3.1
# Ray Train API — Ray 2.3.1
trainer = TorchTrainer(

    scaling_config=ScalingConfig(num_workers=1, use_gpu=True, resources_per_worker={"CPU": 2, "GPU": 1}),

# NOTE: interpret training results
result =

@matthewdeng is following up this one on Github.