Hi @arturn,
No worries about the delay! I’m really thankful you’re helping me with this! (and also full of joy because it is working now )
I have a question about your last line of code: shouldn’t that be
config['model']['custom_model_config']['theta_dummy_handle'].set_theta.remote(new_theta)
?
So that when build_vtrace_loss
changes its local copy of the policy.config dictionary, I still have the reference to the original RemoteThetaDummy
class (actor) and update it’s value without making any copy.
It is working for me now, so thank you very much for your valuable responses @arturn!
Just for the record, here is my particular implementation that maybe is useful to other people with similar needs:
At train.py
, outside any other function or class:
@ray.remote
class RemoteThetaDummy:
def __init__(self):
self.thetas = torch.zeros(2) # A 2-d tensor in my case
def set_thetas(self, new_thetas):
print("Setting new thetas...")
print("Old thetas:", self.thetas) # To test if previous update really worked or keep being zeros
self.thetas = new_thetas
print("New thetas updated to:", self.thetas)
def get_thetas(self):
return self.thetas
Before creating the config
dictionary to pass to tune.run
, I instantiate that class:
thetas = RemoteThetaDummy.remote()
And as in my second comment on this post, I pass that reference to the NN model and environment config:
def train_rl():
thetas = RemoteThetaDummy.remote()
config = { ...
"env_config": {"theta_params": thetas}
"model": { "custom_model": "SomeCustomModel",
"custom_model_config": {'theta_params': thetas}
}
}
results = tune.run(ImpalaCustomTrainer,
config=config,
...
)
Now, to read those parameters from within the Neural Network model:
At init:
self.thetas = theta_params
At forward_rnn()
thetas = ray.get(self.thetas.get_thetas.remote())
print("NN thetas:", thetas) # to test if up to date theta parameters
And finally, inside build_vtrace_loss()
at vtrace_torch_policy.py
def build_vtrace_loss(policy, model, dist_class, train_batch):
...
# Get old theta values
thetas = ray.get(policy.config['model']['custom_model_config']['theta_params'].get_thetas.remote())
# Do some computing to get updated theta parameters
thetas = thetas + 42
# Update the original theta class instance (actor)
policy.config['model']['custom_model_config']['theta_params'].set_thetas.remote(thetas)
# To test if updated parameters correspond to local ones
print("Local/global thetas:", thetas, ray.get(policy.config['model']['custom_model_config']['theta_params'].get_thetas.remote()))