latte.metrics.torch.wrapper

Module Contents

Classes

TorchMetricWrapper

A wrapper class for converting a Latte metric to TorchMetrics metric.

class TorchMetricWrapper(metric, name=None, compute_on_step=False, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, **kwargs)

Bases: torchmetrics.Metric

A wrapper class for converting a Latte metric to TorchMetrics metric.

Parameters:
  • metric (Callable[..., LatteMetric]) – Class handle of the Latte metric to be converted.

  • name (Optional[str], optional) – Name of the Keras metric object, by default None. If None, the name of the Latte metric is used.

  • **kwargs – Keyword arguments to be passed to the Latte metric.

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Callable) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

See also

torchmetrics.Metric

TorchMetrics base metric class

update(*args, **kwargs)

Convert inputs to np.ndarray and call the functional update_state method.

compute()

Calculate the metric values and convert them to tf.Tensor or a collection of them.

Returns:

Metric values

Return type:

Union[tf.Tensor, Collection[tf.Tensor]]