GaussianMixtureModel.forward

GaussianMixtureModel.forward(data)[source]

Computes the log-probability of observing each of the provided datapoints for each of the GMM's components.

Parameters:

data (Tensor) -- A tensor of shape [num_datapoints, num_features] for which to compute the log-probabilities.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • A tensor of shape [num_datapoints, num_components] with the log-responsibilities for each datapoint and components. These are the logits of the Categorical distribution over the parameters.

  • A tensor of shape [num_datapoints] with the log-likelihood of each datapoint.