Usage of torchmetrics for non-additive metrics

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

EDIT: After digging a bit further myself, I figured out that my understanding of several points were not good. My question now concerns more torchmetrics itself, not ray, so I’m closing this question. What I was looking for is explained in Implementing a Metric — PyTorch-Metrics 0.8.2 documentation.

Hello there !
I would like to understand more how torchmetrics is integrated with ray and how to use it properly. I’ve read the doc page Monitoring and Logging Metrics — Ray 2.37.0 explaining how to use torchmetrics, but the example does not fully cover my use case.

The torchmetric that I’ve implemented is a likelihood ratio. The .update method only increments some additive quantities (loss, counter, etc), and the .compute returns a ratio, that is, a non-additive metrics. To make it more clear, my metric looks like this:

from torchmetrics.metric import Metric
import torch

class LossRatio(Metric):
    def __init__(self):
        self.add_state("loss1", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("loss2", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, loss1, loss2):
        self.loss1 += loss1
        self.loss2 += loss2

    def compute(self):
        return self.loss1/self.loss2

From what I understand from the doc example, each worker creates a torchmetric instance and calls it’s .compute method, which is then reported to ray via train.report. How the metrics are aggregated across multiple processes ? I would have designed this differently: instantiate the torchmetrics on the head and then shared it on each worker, and called the .compute only on the head after each worker has called .update.

Could you help with this ?
Thanks a lot.

Closing since it’s resolved