1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.
2. Environment:
- Ray version: 2.50.0
- Python version: 3.11.10
- Trl: 0.25.1
- transformers: 4.57.1
3. What happened vs. what you expected:
- Expected: I would like to obtain an effective solution to this issue.
- Actual: In practice, we found that the iterable data generated by
ray.get_dataset_shard()anditer_torch_batches()is not fully compatible with TRL’sSFTTrainer. Specifically,SFTTrainervalidates thecolumn_nameof the dataset, but the required module is missing from the data, causing the program to crash.
```
def train_func(self):
……
train_ds = raytrain.get_dataset_shard("train")
eval_ds = raytrain.get_dataset_shard("validation")
train_ds_iterable = train_ds.iter_torch_batches(batch_size = self.trainParameter.batch_size, local_shuffle_buffer_size = raytrain.get_context().get_world_size() * self.trainParameter.batch_size,)
eval_ds_iterable = eval_ds.iter_torch_batches(batch_size = 2)
……
trainer = SFTTrainer(self.model, train_dataset = train_ds_iterable, eval_dataset = eval_ds_iterable)
……
```
with error:
object has no attribute ‘column_names’
and code in trl :
python
def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]:
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names