Model.training never True

Previously, TorchModelV2.training would be set to False when sampling actions for rollouts and True when training the model. With the latest nightly wheel on 10 Aug, this is no longer the case. TorchModelV2.training is always False.

I believe this is a big issue, as torch.nn.module.training affects dropout, batch norm, etc. So these modules will be silently disabled during train time.

This has been verified with a custom model training on cartpole:

def forward(input_dict, state, seq_lens):
  if self.training:
    raise Exception('Training')
  ...

But the model keeps training and never crashes.

Definitely a bug! Can you post this as a Git issue and tag @sven1977 and @michaelzhiluo?

I think the training parameter is set to True in this line of code: ray/torch_policy.py at 3e010c5760c99be5a9940001f33db087c52eb8e7 · ray-project/ray · GitHub

Lmk otherwise

Installed Aug 12 nightly to verify issue, and it seems to have already been fixed. Fix seems to be in [RLlib] Issue 17653: Torch multi-GPU (>1) broken for LSTMs. (#17657) · ray-project/ray@811d71b · GitHub. My test script for posterity:

import gym
import ray
import os
import torch
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class Env:
    def __init__(self, cfg):
        self.observation_space = gym.spaces.Discrete(1)
        self.action_space = gym.spaces.Discrete(1)

    def step(self, action):
        return 0, 0, False, {}

    def reset(self):
        return 0

class Model(TorchModelV2, torch.nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **custom_model_kwargs,
    ):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        torch.nn.Module.__init__(self)
        self.num_outputs = num_outputs
        self.obs_dim = gym.spaces.utils.flatdim(obs_space)
        self.act_space = action_space
        self.act_dim = gym.spaces.utils.flatdim(action_space)

        self.logit_branch = SlimFC(
            in_size=1,
            out_size=self.num_outputs,
            activation_fn=None,
        )
        self.value_branch = SlimFC(
            in_size=1,
            out_size=1,
            activation_fn=None,
        )

    def forward(
        self,
        input_dict,
        state,
        seq_lens,
        ):
        if self.training:
            raise Exception("Shit's wack, yo")

        logits = input_dict["obs_flat"].reshape(-1,1)
        self.values = input_dict["obs_flat"].reshape(-1)
        return logits, []

    def value_function(self):
        assert self.values is not None, "must call forward() first"
        return self.values

cfg = {
        "env_config": {},
        "framework": "torch",
        "num_gpus": 1,
        "env": Env,
        "model": {
            "custom_model": Model,
        },
}
ray.init()
analysis = ray.tune.run(
    PPOTrainer,
    config=cfg,
)

Thanks for raising this @smorad !
The fix for this has been merged here: