from dataclasses import dataclass
from typing import Tuple
import torch
from lightkit.nn import Configurable
from torch import jit, nn
class KMeansModelConfig:
Configuration class for a K-Means model.
See also:
#: The number of clusters.
num_clusters: int
#: The number of features of each cluster.
num_features: int
[docs]class KMeansModel(Configurable[KMeansModelConfig], nn.Module):
PyTorch module for the K-Means model.
The centroids managed by this model are non-trainable parameters.
def __init__(self, config: KMeansModelConfig):
config: The configuration to use for initializing the module's buffers.
#: The centers of all clusters, buffer of shape ``[num_clusters, num_features].``
self.centroids: torch.Tensor
self.register_buffer("centroids", torch.empty(config.num_clusters, config.num_features))
[docs] @jit.unused
def reset_parameters(self) -> None:
Resets the parameters of the KMeans model.
It samples all cluster centers from a standard Normal.
[docs] def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Computes the distance of each datapoint to each centroid as well as the "inertia", the
squared distance of each datapoint to its closest centroid.
data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the
distances and inertia.
- A tensor of shape ``[num_datapoints, num_centroids]`` with the distance from each
datapoint to each centroid.
- A tensor of shape ``[num_datapoints]`` with the assignments, i.e. the indices of
each datapoint's closest centroid.
- A tensor of shape ``[num_datapoints]`` with the inertia (squared distance to the
closest centroid) of each datapoint.
distances = torch.cdist(data, self.centroids)
assignments = distances.min(1, keepdim=True).indices
inertias = distances.gather(1, assignments).square()
return distances, assignments.squeeze(1), inertias.squeeze(1)