Tune.choice for list of lists

As of v1.3 the following fails with the below error. Can anyone suggest a workaround? Thank you.

In [1]: from ray import tune
In [2]: c = tune.choice([[1,2],[3,4]])
In [3]: c.sample()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-b421c6722185> in <module>
----> 1 c.sample()

~/miniconda3/envs/fctk/lib/python3.7/site-packages/ray/tune/sample.py in sample(self, spec, size)                                                                             
     45     def sample(self, spec=None, size=1):
     46         sampler = self.get_sampler()
---> 47         return sampler.sample(self, spec=spec, size=size)
     48 
     49     def is_grid(self):

~/miniconda3/envs/fctk/lib/python3.7/site-packages/ray/tune/sample.py in sample(self, domain, spec, size)                                                                     
    295                    size: int = 1):
    296 
--> 297             items = np.random.choice(domain.categories, size=size).tolist()
    298             return items if len(items) > 1 else domain.cast(items[0])
    299 

mtrand.pyx in numpy.random.mtrand.RandomState.choice()

ValueError: a must be 1-dimensional

One workaround is just to use an index instead:

mapping = {"a": [1, 2], "b": [3, 4]}
tune.choice(["a", "b"])

@Stuart_Siegel we should try to fix this though! would you be willing to help open an issue on Github for me?

Thanks @rliaw .

I was just about to post that I found a workaround by defining my own sampler using the older random.choice method that was used in ray<v1.3.

For reference for others who might be interested:

class LegacyCategoricalSampler(Uniform):
    def sample(self, domain: "Categorical", spec, size: int = 1):
        items = random.choice(domain.categories)
        return items if len(items) > 1 else domain.cast(items[0])

c = tune.choice([[1,2],[3,4]])
c.set_sampler(class LegacyCategoricalSampler(), allow_override=True)
c.sample() # works as before (returns either [1,2] or [3.4]

1 Like