latte.metrics.torch.bundles

Module Contents

Classes

DependencyAwareMutualInformationBundle

Calculate Mutual Information Gap (MIG), Dependency-Aware Mutual Information Gap (DMIG), Dependency-Blind Mutual Information Gap (XMIG), and Dependency-Aware Latent Information Gap (DLIG) between latent vectors (z) and attributes (a).

LiadInterpolatabilityBundle

A wrapper class for converting a Latte metric to TorchMetrics metric.

class DependencyAwareMutualInformationBundle(reg_dim=None, discrete=False)

Bases: latte.metrics.torch.wrapper.TorchMetricWrapper

Calculate Mutual Information Gap (MIG), Dependency-Aware Mutual Information Gap (DMIG), Dependency-Blind Mutual Information Gap (XMIG), and Dependency-Aware Latent Information Gap (DLIG) between latent vectors (z) and attributes (a).

Parameters:
  • reg_dim (Optional[List], optional) – regularized dimensions, by default None Attribute a[:, i] is regularized by z[:, reg_dim[i]]. If None, a[:, i] is assumed to be regularized by z[:, i]. Note that this is the reg_dim behavior of the dependency-aware family but is different from the default reg_dim behavior of the conventional MIG.

  • discrete (bool, optional) – Whether the attributes are discrete, by default False

References

[1]
  1. Chen, X. Li, R. Grosse, and D. Duvenaud, “Isolating sources of disentanglement in variational autoencoders”, in Proceedings of the 32nd International Conference on Neural Information Processing Systems, 2018.

[2]
    1. Watcharasupat and A. Lerch, “Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes”, in Extended Abstracts of the Late-Breaking Demo Session of the 22nd International Society for Music Information Retrieval Conference, 2021.

[3]
    1. Watcharasupat, “Controllable Music: Supervised Learning of Disentangled Representations for Music Generation”, 2021.

update(z, a)

Update metric states. This function converts the tensors to numpy arrays then append the latent vectors and attributes to the internal state lists.

Parameters:
  • z (torch.Tensor, (n_samples, n_features)) – a batch of latent vectors

  • a (torch.Tensor, (n_samples, n_attributes) or (n_samples,)) – a batch of attribute(s)

compute()

Compute metric values from the current state. The latent vectors and attributes in the internal states are concatenated along the sample dimension and passed to the metric function to obtain the metric values.

Returns:

A dictionary of mutual information metrics with keys [‘MIG’, ‘DMIG’, ‘XMIG’, ‘DLIG’] each mapping to a corresponding metric torch.Tensor of shape (n_attributes,).

Return type:

Dict[str, torch.Tensor]

class LiadInterpolatabilityBundle(reg_dim=None, liad_mode='forward', max_mode='lehmer', ptp_mode='naive', reduce_mode='attribute', liad_thresh=0.001, degenerate_val=np.nan, nanmean=True, clamp=False, p=2.0)

Bases: latte.metrics.torch.wrapper.TorchMetricWrapper

A wrapper class for converting a Latte metric to TorchMetrics metric.

Parameters:
  • metric (Callable[..., LatteMetric]) – Class handle of the Latte metric to be converted.

  • name (Optional[str], optional) – Name of the Keras metric object, by default None. If None, the name of the Latte metric is used.

  • **kwargs – Keyword arguments to be passed to the Latte metric.

  • compute_on_step – Forward only calls update() and returns None if this is set to False.

  • dist_sync_on_step – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

  • reg_dim (Optional[List[int]]) –

  • liad_mode (str) –

  • max_mode (str) –

  • ptp_mode (Union[float, str]) –

  • reduce_mode (str) –

  • liad_thresh (float) –

  • degenerate_val (float) –

  • nanmean (bool) –

  • clamp (bool) –

  • p (float) –

See also

torchmetrics.Metric

TorchMetrics base metric class

update(z, a)

Update metric states. This function append the latent vectors and attributes to the internal state lists.

Parameters:
  • z (torch.Tensor, (n_samples, n_interp) or (n_samples, n_features or n_attributes, n_interp)) – a batch of latent vectors

  • a (torch.Tensor, (n_samples, n_interp) or (n_samples, n_attributes, n_interp)) – a batch of attribute(s)

compute()

Compute metric values from the current state. The latent vectors and attributes in the internal states are concatenated along the sample dimension and passed to the metric function to obtain the metric values.

Returns:

A dictionary of LIAD-based interpolatability metrics with keys [‘smoothness’, ‘monotonicity’] each mapping to a corresponding metric torch.Tensor. See reduce_mode for details on the shape of the return arrays.

Return type:

Dict[str, torch.Tensor]