I’m training a custom policy using PPO and rllib. Within the context of my policy torch.cuda.is_available() keeps returning False despite my gpus definitely being available (verified using nvidia-smi ). This prevents my model from running on cuda and makes my whole pipeline untenable if I have to rely soley on CPU. This is my main method.
if __name__ == "__main__":
ray.init()
ModelCatalog.register_custom_model("transformer_policy", PromptModelPolicy)
config = (
get_trainable_cls("PPO")
.get_default_config()
.environment(SearchEnv, env_config=env_config)
.framework("torch")
.training(
model={
"custom_model": "transformer_policy",
"custom_model_config": model_kwargs,
},
num_sgd_iter=1,
)
.resources(num_gpus=8)
.rollouts(num_rollout_workers=8)
)
config.lr = 1e-3
algo = config.build()
and for reference here’s my policy where the issue is occurring:
class PromptModelPolicy(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name,**kwargs):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
model_family = kwargs["model_family"]
prompt_model_path = kwargs["prompt_model_path"]
prompt_model = PromptModel.load(model_family=model_family, checkpoint_path=prompt_model_path)
self.prompt_model = prompt_model
self.value_layer = nn.Linear(PROMPT_MODEL_EMBEDDING_DIM, 1)
self.log_std = nn.Parameter(torch.zeros(PROMPT_MODEL_EMBEDDING_DIM))
def buildphase_bypass(self, prompts, device):
prompt_embedding = torch.zeros((prompts.shape[0], PROMPT_MODEL_EMBEDDING_DIM)).to(device)
self.value_out = torch.zeros((prompts.shape[0], 1)).to(device)
log_std = self.log_std.expand_as(prompt_embedding).to(device)
return torch.cat([prompt_embedding, log_std], dim=1), []
def forward(self, input_dict, state, seq_lens):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}, cuda available: {torch.cuda.is_available()}")
obs = input_dict["obs"]
prompts, prompt_lengths = obs
prompts, prompt_lengths = prompts.long().to(device), prompt_lengths.long().to(device)
# If all the prompt lengths are 1, then we are in the build phase, just return dummy values
if all([int(x) == 0 for x in prompt_lengths]):
return self.buildphase_bypass(prompts, device)
# Do forward pass and get the embeddings and the value function
print(f"Recieved prompts of shape {prompts.shape} and prompt_lengths of length {len(prompt_lengths)}")
print(f"Devices for prompts: {prompts.device}")
prompt_embedding = self.prompt_model(prompts, prompt_lengths).squeeze(1).to(device)
self.value_out = self.value_layer(prompt_embedding).to(device)
log_std = self.log_std.expand_as(prompt_embedding).to(device)
# Return the output and the state
print(f"Prompt embedding shape: {prompt_embedding.shape}")
return torch.cat([prompt_embedding, log_std], dim=1), []
def value_function(self):
values = self.value_out.squeeze(1)
return values
Note this happens even when I set local_mode to false in ray.init()
How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.