TFRecordDataset -> for TensorflowTrainer

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I would like to perform training on exisiting TF2 pipeline using Ray Train 3.0.0-dev0. The original pipeline uses TFRecords. The model class creates the dataset by calling methods on TFRecordDataset object.

dataset =
dataset =,
dataset = dataset.batch(self.batch_size)
dataset = dataset.prefetch(

When I pass this object into TensorflowTrainer(datasets = {“train”: dataset}), I get this error:

ValueError: At least one value in the `datasets` dict is not a ``: {'train': <PrefetchDataset shapes: ((None, 96, 96, 3), (None,), (None,)), types: (tf.float32, tf.float32, tf.string)>, 'test': <PrefetchDataset shapes: ((None, 96, 96, 3), (None,), (None,)), types: (tf.float32, tf.float32, tf.string)>}

Do I need to convert TFRecordDataset into something else or is there an option to read and pass TFRecords efficiently to workers with


1 Like

Hi @vladislav, it’s required to feed TensorflowTrainer(datasets={…}) with Ray Dataset. Currently there is not support in Ray Dataset to read TFRecord format directly (reading with as you mentioned will not work), so you will need to convert it into the format readable by Ray Dataset (all supported formats see in Ray Datasets: Distributed Data Loading and Compute — Ray 3.0.0.dev0), e.g. maybe NumPy format.

Hi @vladislav
Thanks for posting this question. As @jianxiao said, TFRecord is not natively supported yet by Ray Dataset.
But there are a few options here:

  1. Use our SimpleTensorFlowDatasource datasource to convert a instance to a Ray Dataset. Note that this will be leveraging the TFRecordDataset for the I/O and the like, so this won’t be a parallel/distributed read and the data will not be distributed across a Ray cluster, so this is best for small data.

  2. Use our API to read the TFRecord files (in parallel) as raw binary data, then use .map() to decode the records. We have an example for doing this on PNGs in our Creating Datasets user guide.

  3. Use a TFRecordDatasource datasource that does both the parallel reading and decoding. Unfortunately, we don’t currently have a built-in datasource for TFRecords; however, it should be fairly easy to create one with our custom datasource API. In particular, we’ve had a few users subclass our FileBasedDatasource API, which then only requires the subclasser to implement this _read_file API, or the _read_stream API if you’re wanting to implement streaming reads of the TFRecord files. Our JSONDatasource is one subclassing example here.

I’d probably recommend (2) in the short-term to try out Ray Datasets’ parallel I/O and downstream operations, and look at adding a TFRecordDatasource down the road. We’d love to see the latter happen and could definitely help shepherd the implementation!

@Clark_Zinzow FYI, I found your great answer somewhere and we have another user asking for this :slight_smile:

1 Like

Thanks @jianxiao @xwjiang2010 for your answers and ideas!

We can’t use #1 from the list of options, our dataset size is large >100GB.

I try to better understand how to implement #2. For example, I’ve created 2 parse functions to map to bytes data. But TF doesn’t allow my to directly parse bytes, so I see errors in both cases. It feels like I need to write a low level decoder from bytes to Tensors?

features = {
  'foo':[], tf.string),
  'bar':[], tf.float32)

def parse_single(serialized_input):
  return, features=features)

def parse_sequence(serialized_input):
    ctx, seq =, context_features=features)
    print(f'>>> {ctx}, {seq}\n')
    return serialized_input

bin_dataset =
dataset =
# dataset =

@bveeramani @jianxiao Can you provide some pointers?

@vladislav Does this sort of approach work for you?

class TFRecordsDatasource(FileBasedDatasource):

    _FILE_EXTENSION = "tfrecords"

    def _read_file(
        self, f: "pyarrow.NativeFile", path: str, features: Dict[str,], **reader_args
    ) -> Block:
        dataset =[path])
        dataset = serialized:, features))

        foo = [record["foo"].numpy().decode("utf-8") for record in dataset]
        bar = [float(record["bar"]) for record in dataset]
        return pd.DataFrame({"foo": foo, "bar": bar})

dataset =

Here’s the full reproduction.

1 Like

@bveeramani Thank you for your help!

With your approach the data flow becomes something like this: TFRecords → pandas. DataFrame → And, in the train_loop function do I need convert it to a tensorflow object with ray.train.tensorflow.prepare_dataset_shard(dataset.to_tf(..)) again before passing it to tf

I also tried to run a test job with a simple train_loop that doesn’t have anything other than dataset = session.get_dataset_shard("train") and print(type(dataset)) statements on a smaller dataset ~5GB. This job fails with following errors 50% of the time.

⚠️ The blocks of this dataset are estimated to be 1.8x larger than the target block size of 512 MiB. This may lead to out-of-memory errors during processing. Consider reducing the size of input files or using `.repartition(n)` to increase the number of dataset blocks.

  File ".../lib/python3.7/site-packages/ray/cloudpickle/", line 73, in dumps
  File ".../lib/python3.7/site-packages/ray/cloudpickle/", line 620, in dump
    return Pickler.dump(self, obj)

And, in the train_loop function do I need convert it to a tensorflow object with ray.train.tensorflow.prepare_dataset_shard(dataset.to_tf(..)) again before passing it to tf

Yeah, that sounds right. get_dataset_shard returns a Ray Dataset, and you need to convert that dataset to a format that understands.

This job fails with following errors 50% of the time.

What happens when you repartition the dataset? For the way I implemented TFRecordsDatasource, each block corresponds to a file. So, if your files are large, the blocks will be large.

1 Like