Installed Aug 12 nightly to verify issue, and it seems to have already been fixed. Fix seems to be in [RLlib] Issue 17653: Torch multi-GPU (>1) broken for LSTMs. (#17657) · ray-project/ray@811d71b · GitHub. My test script for posterity:
import gym
import ray
import os
import torch
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class Env:
def __init__(self, cfg):
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(1)
def step(self, action):
return 0, 0, False, {}
def reset(self):
return 0
class Model(TorchModelV2, torch.nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**custom_model_kwargs,
):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
torch.nn.Module.__init__(self)
self.num_outputs = num_outputs
self.obs_dim = gym.spaces.utils.flatdim(obs_space)
self.act_space = action_space
self.act_dim = gym.spaces.utils.flatdim(action_space)
self.logit_branch = SlimFC(
in_size=1,
out_size=self.num_outputs,
activation_fn=None,
)
self.value_branch = SlimFC(
in_size=1,
out_size=1,
activation_fn=None,
)
def forward(
self,
input_dict,
state,
seq_lens,
):
if self.training:
raise Exception("Shit's wack, yo")
logits = input_dict["obs_flat"].reshape(-1,1)
self.values = input_dict["obs_flat"].reshape(-1)
return logits, []
def value_function(self):
assert self.values is not None, "must call forward() first"
return self.values
cfg = {
"env_config": {},
"framework": "torch",
"num_gpus": 1,
"env": Env,
"model": {
"custom_model": Model,
},
}
ray.init()
analysis = ray.tune.run(
PPOTrainer,
config=cfg,
)