GaussianMixture.predict

GaussianMixture.predict(data)[source]

Computes the most likely components for each of the provided datapoints.

Parameters:

data (Union[ndarray[Any, dtype[float32]], Tensor]) -- The datapoints for which to obtain the most likely components.

Return type:

Tensor

Returns:

A tensor of shape [num_datapoints] with the indices of the most likely components.

Note

Use predict_proba() to obtain probabilities for each component instead of the most likely component only.

Attention

When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.