KMeans.predict

KMeans.predict(data)[source]

Predicts the closest cluster for each item provided.

Parameters:

data (Union[ndarray[Any, dtype[float32]], Tensor]) -- The datapoints for which to predict the clusters.

Return type:

Tensor

Returns:

Tensor of shape [num_datapoints] with the index of the closest cluster for each datapoint.

Attention

When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.