When ray returns two outputs it appears: cannot pickle ‘weakref’ object
epochs = 2
# Define training function
@ray.remote(num_returns=2)
def train_mnist_binary_classification():
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#y_train = (y_train == 0).astype(int)
#y_test = (y_test == 0).astype(int)
x_train, y_train = filter_36(x_train, y_train)
x_test, y_test = filter_36(x_test, y_test)
n_samples = 1500
test_samples = 100
x_train, x_test, y_train, y_test = x_train[:n_samples], x_test[:test_samples], y_train[:n_samples], y_test[:test_samples]
y_train, y_test = one_hot(y_train), one_hot(y_test)
# Create model
model = create_model()
# Compile model
opt = keras.optimizers.SGD(learning_rate = 0.09)
model.compile(opt, loss='binary_crossentropy', metrics =['binary_accuracy'])
# Train the model
model.fit(x_train,
y_train,
epochs = epochs,
batch_size = 64,
shuffle = True,
validation_data = (x_test, y_test))
model.summary()
return model.history, model.layers[1].get_weights()[0]
# Execute the training function remotely
t1, t2 = train_mnist_binary_classification.remote()
# Execute the training function remotely
a = ray.get(t1)
b = ray.get(t2)
print(a)
print(b)
---------------------------------------------------------------------------
RayTaskError(TypeError) Traceback (most recent call last)
Cell In[3], line 2
1 # Execute the training function remotely
----> 2 a = ray.get(t1)
3 b = ray.get(t2)
4 print(a)
File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/auto_init_hook.py:22, in wrap_auto_init.<locals>.auto_init_wrapper(*args, **kwargs)
19 @wraps(fn)
20 def auto_init_wrapper(*args, **kwargs):
21 auto_init_ray()
---> 22 return fn(*args, **kwargs)
File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/client_mode_hook.py:103, in client_mode_hook.<locals>.wrapper(*args, **kwargs)
101 if func.__name__ != "init" or is_client_mode_enabled_by_default:
102 return getattr(ray, func.__name__)(*args, **kwargs)
--> 103 return func(*args, **kwargs)
File /opt/conda/envs/tf/lib/python3.9/site-packages/ray/_private/worker.py:2624, in get(object_refs, timeout)
2622 worker.core_worker.dump_object_store_memory_usage()
2623 if isinstance(value, RayTaskError):
-> 2624 raise value.as_instanceof_cause()
2625 else:
2626 raise value
RayTaskError(TypeError): ray::train_mnist_binary_classification() (pid=685065, ip=172.17.0.3)
File "/opt/conda/envs/tf/lib/python3.9/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 88, in dumps
cp.dump(obj)
File "/opt/conda/envs/tf/lib/python3.9/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 733, in dump
return Pickler.dump(self, obj)
TypeError: cannot pickle 'weakref' object