[Dataset] function add_column inserts repeats of sub-column instead of whole column

How severe does this issue affect your experience of using Ray?
Low-Medium: It is anoying, but I can work around it.

Hi,
I am using Ray datasets to preprocess my data as a last-mile preprocessing step as discussed in the documentation.

While trying to add an ID column to my tensors dataset (I need this column for further steps), I noticed that the ray.data.Dataset.add_column function does not act as it should. Instead of mapping the entire pandas DataFrame or Series given in parameter, adds only a portion of the column to add over and over. I am guessing that this portion is something like length of the whole column / parallelism = batch length to extract the sub-column then it is copied over and over (probably for each batch).

I have been able to work around this by repartitionning the dataset to 1 before adding the column.

Here is a code/output exemple using a small dataset:

print(df)
Dataset(num_blocks=29, num_rows=29, schema={__value__: TensorDtype(shape=(512,), dtype=float64), id: object, test: object})

print(ids)
                   id
0      PMSQ01000054.1
1   CAJCBY010000033.1
2   JABXKY010000147.1
3      DNMQ01000225.1
4   JAAYQI010000217.1
5      PBSX01000072.1
6   JABBOX010000109.1
7   JAAZAC010000025.1
8   CAIXRL010000197.1
9      DHMB01000127.1
10  JAAYXU010000041.1
11  JACMKV010000045.1
12  NZ_LCZE01000023.1
13     URSE01000035.1
14  CACNVV010000042.1
15     PMOP01000016.1
16     UQEY01000009.1
17     VBOR01000009.1
18     WLHF01000026.1
19  JACNFQ010000081.1
20     DHUT01000069.1
21     WBXD01000017.1
22  JAABRC010000419.1
23     URUG01000300.1
24     QMMC01000579.1
25  CAAFRK010000216.1
26  NZ_LRTR01000260.1
27     DKBA01000026.1
28  CAJCHR010000269.1

# Adding column with n_partitions == n_rows

print(df.add_column('id', lambda ds : ids).to_pandas())

Read->Map_Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 499.50it/s]
                                            __value__              id
0   [60.0, 68.0, 67.0, 58.0, 46.0, 74.0, 76.0, 64....  PMSQ01000054.1
1   [16.0, 0.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0....  PMSQ01000054.1
2   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
3   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
4   [108.0, 61.0, 56.0, 66.0, 36.0, 29.0, 19.0, 26...  PMSQ01000054.1
5   [78.0, 48.0, 45.0, 76.0, 44.0, 23.0, 28.0, 25....  PMSQ01000054.1
6   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
7   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
8   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
9   [29.0, 14.0, 14.0, 19.0, 0.0, 0.0, 0.0, 0.0, 1...  PMSQ01000054.1
10  [0.0, 12.0, 12.0, 0.0, 0.0, 0.0, 14.0, 0.0, 0....  PMSQ01000054.1
11  [15.0, 14.0, 11.0, 15.0, 11.0, 17.0, 14.0, 0.0...  PMSQ01000054.1
12  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
13  [55.0, 31.0, 46.0, 39.0, 17.0, 21.0, 23.0, 21....  PMSQ01000054.1
14  [42.0, 15.0, 0.0, 30.0, 18.0, 0.0, 0.0, 11.0, ...  PMSQ01000054.1
15  [10.0, 0.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11....  PMSQ01000054.1
16  [387.0, 210.0, 251.0, 311.0, 219.0, 94.0, 152....  PMSQ01000054.1
17  [0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0,...  PMSQ01000054.1
18  [41.0, 38.0, 44.0, 51.0, 28.0, 37.0, 23.0, 29....  PMSQ01000054.1
19  [13.0, 10.0, 27.0, 10.0, 13.0, 31.0, 20.0, 16....  PMSQ01000054.1
20  [69.0, 29.0, 45.0, 47.0, 23.0, 13.0, 0.0, 10.0...  PMSQ01000054.1
21  [82.0, 62.0, 37.0, 67.0, 33.0, 74.0, 42.0, 47....  PMSQ01000054.1
22  [0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0, 0.0,...  PMSQ01000054.1
23  [29.0, 14.0, 21.0, 23.0, 15.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
24  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0,...  PMSQ01000054.1
25  [13.0, 13.0, 0.0, 20.0, 10.0, 0.0, 0.0, 0.0, 1...  PMSQ01000054.1
26  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
27  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  PMSQ01000054.1
28  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0,...  PMSQ01000054.1

# Repartition to 3 partitions before adding column

print(df.repartition(3).add_column('id', lambda ds : ids).to_pandas())

Read: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 2268.42it/s]
Repartition: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 483.47it/s]
Map_Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 264.70it/s]
                                            __value__                 id
0   [60.0, 68.0, 67.0, 58.0, 46.0, 74.0, 76.0, 64....     PMSQ01000054.1
1   [16.0, 0.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0....  CAJCBY010000033.1
2   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JABXKY010000147.1
3   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...     DNMQ01000225.1
4   [108.0, 61.0, 56.0, 66.0, 36.0, 29.0, 19.0, 26...  JAAYQI010000217.1
5   [78.0, 48.0, 45.0, 76.0, 44.0, 23.0, 28.0, 25....     PBSX01000072.1
6   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JABBOX010000109.1
7   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JAAZAC010000025.1
8   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  CAIXRL010000197.1
9   [29.0, 14.0, 14.0, 19.0, 0.0, 0.0, 0.0, 0.0, 1...     PMSQ01000054.1
10  [0.0, 12.0, 12.0, 0.0, 0.0, 0.0, 14.0, 0.0, 0....  CAJCBY010000033.1
11  [15.0, 14.0, 11.0, 15.0, 11.0, 17.0, 14.0, 0.0...  JABXKY010000147.1
12  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...     DNMQ01000225.1
13  [55.0, 31.0, 46.0, 39.0, 17.0, 21.0, 23.0, 21....  JAAYQI010000217.1
14  [42.0, 15.0, 0.0, 30.0, 18.0, 0.0, 0.0, 11.0, ...     PBSX01000072.1
15  [10.0, 0.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11....  JABBOX010000109.1
16  [387.0, 210.0, 251.0, 311.0, 219.0, 94.0, 152....  JAAZAC010000025.1
17  [0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0,...  CAIXRL010000197.1
18  [41.0, 38.0, 44.0, 51.0, 28.0, 37.0, 23.0, 29....     DHMB01000127.1
19  [13.0, 10.0, 27.0, 10.0, 13.0, 31.0, 20.0, 16....     PMSQ01000054.1
20  [69.0, 29.0, 45.0, 47.0, 23.0, 13.0, 0.0, 10.0...  CAJCBY010000033.1
21  [82.0, 62.0, 37.0, 67.0, 33.0, 74.0, 42.0, 47....  JABXKY010000147.1
22  [0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0, 0.0,...     DNMQ01000225.1
23  [29.0, 14.0, 21.0, 23.0, 15.0, 0.0, 0.0, 0.0, ...  JAAYQI010000217.1
24  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0,...     PBSX01000072.1
25  [13.0, 13.0, 0.0, 20.0, 10.0, 0.0, 0.0, 0.0, 1...  JABBOX010000109.1
26  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JAAZAC010000025.1
27  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  CAIXRL010000197.1
28  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0,...     DHMB01000127.1

# Repartition to 1 partition before adding column

print(df.repartition(1).add_column('id', lambda ds : ids).to_pandas()

Read: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 2122.81it/s]
Repartition: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 79.16it/s]
Map_Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 93.25it/s]
                                            __value__                 id
0   [60.0, 68.0, 67.0, 58.0, 46.0, 74.0, 76.0, 64....     PMSQ01000054.1
1   [16.0, 0.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0....  CAJCBY010000033.1
2   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JABXKY010000147.1
3   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...     DNMQ01000225.1
4   [108.0, 61.0, 56.0, 66.0, 36.0, 29.0, 19.0, 26...  JAAYQI010000217.1
5   [78.0, 48.0, 45.0, 76.0, 44.0, 23.0, 28.0, 25....     PBSX01000072.1
6   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JABBOX010000109.1
7   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  JAAZAC010000025.1
8   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  CAIXRL010000197.1
9   [29.0, 14.0, 14.0, 19.0, 0.0, 0.0, 0.0, 0.0, 1...     DHMB01000127.1
10  [0.0, 12.0, 12.0, 0.0, 0.0, 0.0, 14.0, 0.0, 0....  JAAYXU010000041.1
11  [15.0, 14.0, 11.0, 15.0, 11.0, 17.0, 14.0, 0.0...  JACMKV010000045.1
12  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  NZ_LCZE01000023.1
13  [55.0, 31.0, 46.0, 39.0, 17.0, 21.0, 23.0, 21....     URSE01000035.1
14  [42.0, 15.0, 0.0, 30.0, 18.0, 0.0, 0.0, 11.0, ...  CACNVV010000042.1
15  [10.0, 0.0, 13.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11....     PMOP01000016.1
16  [387.0, 210.0, 251.0, 311.0, 219.0, 94.0, 152....     UQEY01000009.1
17  [0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0,...     VBOR01000009.1
18  [41.0, 38.0, 44.0, 51.0, 28.0, 37.0, 23.0, 29....     WLHF01000026.1
19  [13.0, 10.0, 27.0, 10.0, 13.0, 31.0, 20.0, 16....  JACNFQ010000081.1
20  [69.0, 29.0, 45.0, 47.0, 23.0, 13.0, 0.0, 10.0...     DHUT01000069.1
21  [82.0, 62.0, 37.0, 67.0, 33.0, 74.0, 42.0, 47....     WBXD01000017.1
22  [0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0, 0.0,...  JAABRC010000419.1
23  [29.0, 14.0, 21.0, 23.0, 15.0, 0.0, 0.0, 0.0, ...     URUG01000300.1
24  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0,...     QMMC01000579.1
25  [13.0, 13.0, 0.0, 20.0, 10.0, 0.0, 0.0, 0.0, 1...  CAAFRK010000216.1
26  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  NZ_LRTR01000260.1
27  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...     DKBA01000026.1
28  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0,...  CAJCHR010000269.1

Should I report this on Ray’s github?

Hi @nicdemon, the behavior of add_column looks working as expected. Note the UDF in .add_column is going to apply on each batch. So your lambda lambda ds : ids will apply the entire ids to each batch, which may not be what you want.

In your case, you may create another dataset for ids and then zip the two together:

import ray

ds1 = ray.data.range(29)

ds2 = ray.data.from_items([
"   PMSQ01000054.1 ",
" CAJCBY010000033.1",
" JABXKY010000147.1",
"    DNMQ01000225.1",
" JAAYQI010000217.1",
"    PBSX01000072.1",
" JABBOX010000109.1",
" JAAZAC010000025.1",
" CAIXRL010000197.1",
"    DHMB01000127.1",
" JAAYXU010000041.1",
" JACMKV010000045.1",
" NZ_LCZE01000023.1",
"    URSE01000035.1",
" CACNVV010000042.1",
"    PMOP01000016.1",
"    UQEY01000009.1",
"    VBOR01000009.1",
"    WLHF01000026.1",
" JACNFQ010000081.1",
"    DHUT01000069.1",
"    WBXD01000017.1",
" JAABRC010000419.1",
"    URUG01000300.1",
"    QMMC01000579.1",
" CAAFRK010000216.1",
" NZ_LRTR01000260.1",
"    DKBA01000026.1",
" CAJCHR010000269.1",
])

ds3 = ds1.zip(ds2)
ds3.show()

The resulted ds3 will have:

(10, '   PMSQ01000054.1 ')
(11, ' CAJCBY010000033.1')
(12, ' JABXKY010000147.1')
(13, '    DNMQ01000225.1')
(14, ' JAAYQI010000217.1')
(15, '    PBSX01000072.1')
(16, ' JABBOX010000109.1')
(17, ' JAAZAC010000025.1')
(18, ' CAIXRL010000197.1')
(19, '    DHMB01000127.1')
(20, ' JAAYXU010000041.1')
(21, ' JACMKV010000045.1')
(22, ' NZ_LCZE01000023.1')
(23, '    URSE01000035.1')
(24, ' CACNVV010000042.1')
(25, '    PMOP01000016.1')
(26, '    UQEY01000009.1')
(27, '    VBOR01000009.1')
(28, '    WLHF01000026.1')
(29, ' JACNFQ010000081.1')

Great! Thanks for your help @jianxiao. I guess I hadn’t understood well the function .add_column even if I have read the API documentation for this function a few times…