Error with "column_names" when using Ray with TRL's sft_trainer

1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.

2. Environment:

  • Ray version: 2.50.0
  • Python version: 3.11.10
  • Trl: 0.25.1
  • transformers: 4.57.1

3. What happened vs. what you expected:

  • Expected: I would like to obtain an effective solution to this issue.
  • Actual: In practice, we found that the iterable data generated by ray.get_dataset_shard() and iter_torch_batches() is not fully compatible with TRL’s SFTTrainer. Specifically, SFTTrainer validates the column_name of the dataset, but the required module is missing from the data, causing the program to crash.
```
def train_func(self):
    ……
    train_ds = raytrain.get_dataset_shard("train")
    eval_ds = raytrain.get_dataset_shard("validation")
    train_ds_iterable = train_ds.iter_torch_batches(batch_size = self.trainParameter.batch_size, local_shuffle_buffer_size = raytrain.get_context().get_world_size() * self.trainParameter.batch_size,)
    eval_ds_iterable = eval_ds.iter_torch_batches(batch_size = 2)
    ……
    trainer = SFTTrainer(self.model, train_dataset = train_ds_iterable, eval_dataset = eval_ds_iterable)
    ……
```

with error:

object has no attribute ‘column_names’

and code in trl :

python
def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]:
    return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names

The error occurs because the iterable returned by Ray’s iter_torch_batches() does not have a column_names attribute, which TRL’s SFTTrainer expects. This incompatibility is known: Ray’s iter_torch_batches() yields batches as dictionaries, not as HuggingFace Datasets or IterableDatasets, so attributes like column_names are missing, causing the crash in TRL’s validation logic. This is not directly supported by Ray or TRL as of your version set (source, source).

A common workaround is to convert your Ray batches into a format compatible with HuggingFace Datasets before passing them to SFTTrainer, or to avoid passing Ray Data iterables directly to TRL trainers. Alternatively, you can preprocess your data into a HuggingFace Dataset and use that with TRL, bypassing Ray’s iter_torch_batches() for the trainer input (source). Would you like a step-by-step example of how to implement this workaround?

Sources:

Hint: Mention @RunLLM in the post for followups.