ValueError: Raw data must be data for a single arm for non batched trials

Hi,

I am trying to tune hyper parameter with Ray tune and AxSearch “AX Example — Ray 2.4.0”. But I am encountering this error:

(train pid=13418) Validation loss has not improved for 0 epochs, stopping training.
2023-05-24 16:21:07,084	ERROR trial_runner.py:671 -- Trial train_a6d842ff: Error stopping trial.
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ax/utils/common/typeutils.py", line 86, in checked_cast_complex
    check_type("val", val, typ)
  File "/usr/local/lib/python3.10/dist-packages/typeguard/__init__.py", line 757, in check_type
    checker_func(argname, value, expected_type, memo)
  File "/usr/local/lib/python3.10/dist-packages/typeguard/__init__.py", line 558, in check_union
    raise TypeError('type of {} must be one of ({}); got {} instead'.
TypeError: type of val must be one of (Dict[str, Union[float, numpy.floating, numpy.integer, Tuple[Union[float, numpy.floating, numpy.integer], Union[float, numpy.floating, numpy.integer, NoneType]]]], float, numpy.floating, numpy.integer, Tuple[Union[float, numpy.floating, numpy.integer], Union[float, numpy.floating, numpy.integer, NoneType]], List[Tuple[Dict[str, Union[str, bool, float, int, NoneType]], Dict[str, Union[float, numpy.floating, numpy.integer, Tuple[Union[float, numpy.floating, numpy.integer], Union[float, numpy.floating, numpy.integer, NoneType]]]]]], List[Tuple[Dict[str, Hashable], Dict[str, Union[float, numpy.floating, numpy.integer, Tuple[Union[float, numpy.floating, numpy.integer], Union[float, numpy.floating, numpy.integer, NoneType]]]]]]); got dict instead

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/execution/trial_runner.py", line 662, in stop_trial
    self._search_alg.on_trial_complete(
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/search/search_generator.py", line 142, in on_trial_complete
    self.searcher.on_trial_complete(trial_id=trial_id, result=result, error=error)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/search/concurrency_limiter.py", line 138, in on_trial_complete
    self.searcher.on_trial_complete(trial_id, result=result, error=error)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/search/ax/ax_search.py", line 323, in on_trial_complete
    self._process_result(trial_id, result)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/search/ax/ax_search.py", line 343, in _process_result
    self._ax.complete_trial(trial_index=ax_trial_index, raw_data=metric_dict)
  File "/usr/local/lib/python3.10/dist-packages/ax/service/ax_client.py", line 759, in complete_trial
    data_update_repr = self._update_trial_with_raw_data(
  File "/usr/local/lib/python3.10/dist-packages/ax/service/ax_client.py", line 1556, in _update_trial_with_raw_data
    update_info = trial.update_trial_data(
  File "/usr/local/lib/python3.10/dist-packages/ax/core/trial.py", line 298, in update_trial_data
    arm_name: checked_cast_complex(
  File "/usr/local/lib/python3.10/dist-packages/ax/utils/common/typeutils.py", line 89, in checked_cast_complex
    raise ValueError(message or f"Value was not of type {typ}: {val}")
ValueError: Raw data must be data for a single arm for non batched trials.

What is the issue here? Any help is appreciated…