latte.metrics.torch.wrapper
Module Contents
Classes
A wrapper class for converting a Latte metric to TorchMetrics metric. |
- class TorchMetricWrapper(metric, name=None, compute_on_step=False, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, **kwargs)
Bases:
torchmetrics.MetricA wrapper class for converting a Latte metric to TorchMetrics metric.
- Parameters:
metric (Callable[..., LatteMetric]) – Class handle of the Latte metric to be converted.
name (Optional[str], optional) – Name of the Keras metric object, by default None. If None, the name of the Latte metric is used.
**kwargs – Keyword arguments to be passed to the Latte metric.
compute_on_step (bool) – Forward only calls
update()and returns None if this is set to False.dist_sync_on_step (bool) – Synchronize metric state across processes at each
forward()before returning the value at the step.process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn (Callable) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.
See also
torchmetrics.MetricTorchMetrics base metric class
- update(*args, **kwargs)
Convert inputs to np.ndarray and call the functional update_state method.
- compute()
Calculate the metric values and convert them to tf.Tensor or a collection of them.
- Returns:
Metric values
- Return type:
Union[tf.Tensor, Collection[tf.Tensor]]