TFRecordDataset -> ray.data.Dataset for TensorflowTrainer

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I would like to perform training on exisiting TF2 pipeline using Ray Train 3.0.0-dev0. The original pipeline uses TFRecords. The model class creates the dataset by calling methods on TFRecordDataset object.

dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(self.parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(self.batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

When I pass this object into TensorflowTrainer(datasets = {“train”: dataset}), I get this error:

ValueError: At least one value in the `datasets` dict is not a `ray.data.Dataset`: {'train': <PrefetchDataset shapes: ((None, 96, 96, 3), (None,), (None,)), types: (tf.float32, tf.float32, tf.string)>, 'test': <PrefetchDataset shapes: ((None, 96, 96, 3), (None,), (None,)), types: (tf.float32, tf.float32, tf.string)>}

Do I need to convert TFRecordDataset into something else or is there an option to read and pass TFRecords efficiently to workers with ray.data.read_binary_files/read_datasource?

Thanks!

1 Like

Hi @vladislav, it’s required to feed TensorflowTrainer(datasets={…}) with Ray Dataset. Currently there is not support in Ray Dataset to read TFRecord format directly (reading with ray.data.read_binary_files as you mentioned will not work), so you will need to convert it into the format readable by Ray Dataset (all supported formats see in Ray Datasets: Distributed Data Loading and Compute — Ray 3.0.0.dev0), e.g. maybe NumPy format.

Hi @vladislav
Thanks for posting this question. As @jianxiao said, TFRecord is not natively supported yet by Ray Dataset.
But there are a few options here:

  1. Use our SimpleTensorFlowDatasource datasource to convert a tf.data.TFRecordDataset() instance to a Ray Dataset. Note that this will be leveraging the TFRecordDataset for the I/O and the like, so this won’t be a parallel/distributed read and the data will not be distributed across a Ray cluster, so this is best for small data.

  2. Use our ray.data.read_binary_files() API to read the TFRecord files (in parallel) as raw binary data, then use .map() to decode the records. We have an example for doing this on PNGs in our Creating Datasets user guide.

  3. Use a TFRecordDatasource datasource that does both the parallel reading and decoding. Unfortunately, we don’t currently have a built-in datasource for TFRecords; however, it should be fairly easy to create one with our custom datasource API. In particular, we’ve had a few users subclass our FileBasedDatasource API, which then only requires the subclasser to implement this _read_file API, or the _read_stream API if you’re wanting to implement streaming reads of the TFRecord files. Our JSONDatasource is one subclassing example here.

I’d probably recommend (2) in the short-term to try out Ray Datasets’ parallel I/O and downstream operations, and look at adding a TFRecordDatasource down the road. We’d love to see the latter happen and could definitely help shepherd the implementation!

@Clark_Zinzow FYI, I found your great answer somewhere and we have another user asking for this :slight_smile:

1 Like

Thanks @jianxiao @xwjiang2010 for your answers and ideas!

We can’t use #1 from the list of options, our dataset size is large >100GB.

I try to better understand how to implement #2. For example, I’ve created 2 parse functions to map to bytes data. But TF doesn’t allow my to directly parse bytes, so I see errors in both cases. It feels like I need to write a low level decoder from bytes to Tensors?

features = {
  'foo': tf.io.FixedLenFeature([], tf.string),
  'bar': tf.io.FixedLenFeature([], tf.float32)
}

def parse_single(serialized_input):
  return tf.io.parse_single_example(serialized=serialized_input, features=features)

def parse_sequence(serialized_input):
    ctx, seq = tf.io.parse_sequence_example(serialized=serialized_input, context_features=features)
    print(f'>>> {ctx}, {seq}\n')
    return serialized_input

bin_dataset = ray.data.read_binary_files(filenames)
dataset = bin_dataset.map(parse_single)
# dataset = bin_dataset.map(parse_sequence)

@bveeramani @jianxiao Can you provide some pointers?