When ray returns two outputs it appears: cannot pickle 'weakref' object

When ray returns two outputs it appears: cannot pickle ‘weakref’ object

epochs = 2
# Define training function
@ray.remote(num_returns=2)
def train_mnist_binary_classification():
    # Load and preprocess MNIST dataset
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    #y_train = (y_train == 0).astype(int)
    #y_test = (y_test == 0).astype(int)
    x_train, y_train = filter_36(x_train, y_train)
    x_test, y_test = filter_36(x_test, y_test)
    n_samples = 1500
    test_samples = 100
    x_train, x_test, y_train, y_test = x_train[:n_samples], x_test[:test_samples], y_train[:n_samples], y_test[:test_samples]
    y_train, y_test = one_hot(y_train), one_hot(y_test)

    # Create model
    model = create_model()

    # Compile model
    opt = keras.optimizers.SGD(learning_rate = 0.09)
    model.compile(opt, loss='binary_crossentropy', metrics =['binary_accuracy'])

    # Train the model
    model.fit(x_train, 
                   y_train,
                   epochs = epochs,
                   batch_size = 64,
                   shuffle = True, 
                   validation_data = (x_test, y_test))
    model.summary()
    return model.history, model.layers[1].get_weights()[0]

# Execute the training function remotely
t1, t2 = train_mnist_binary_classification.remote()
# Execute the training function remotely
a = ray.get(t1)
b = ray.get(t2)
print(a)
print(b)
---------------------------------------------------------------------------
RayTaskError(TypeError)                   Traceback (most recent call last)
Cell In[3], line 2
      1 # Execute the training function remotely
----> 2 a = ray.get(t1)
      3 b = ray.get(t2)
      4 print(a)

File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/auto_init_hook.py:22, in wrap_auto_init.<locals>.auto_init_wrapper(*args, **kwargs)
     19 @wraps(fn)
     20 def auto_init_wrapper(*args, **kwargs):
     21     auto_init_ray()
---> 22     return fn(*args, **kwargs)

File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/client_mode_hook.py:103, in client_mode_hook.<locals>.wrapper(*args, **kwargs)
    101     if func.__name__ != "init" or is_client_mode_enabled_by_default:
    102         return getattr(ray, func.__name__)(*args, **kwargs)
--> 103 return func(*args, **kwargs)

File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/worker.py:2624, in get(object_refs, timeout)
   2622     worker.core_worker.dump_object_store_memory_usage()
   2623 if isinstance(value, RayTaskError):
-> 2624     raise value.as_instanceof_cause()
   2625 else:
   2626     raise value

RayTaskError(TypeError): ray::train_mnist_binary_classification() (pid=685065, ip=172.17.0.3)
  File "/opt/conda/envs/tf/lib/python3.9/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 88, in dumps
    cp.dump(obj)
  File "/opt/conda/envs/tf/lib/python3.9/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 733, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'weakref' object