Skip to content

Commit

Permalink
[MRG+2] switch to multinomial composition for mixture sampling (sciki…
Browse files Browse the repository at this point in the history
…t-learn#7702)

* switch to multinomial composition for mixture sampling

* add shape assertions to test

* Use n_components=3 to test actual regression

n_components and n_features were equal and one was used for the other in
some places.
  • Loading branch information
lesteve committed Oct 20, 2016
2 parents cd714b1 + 4e1c101 commit ad6f094
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sklearn/mixture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def sample(self, n_samples=1):

_, n_features = self.means_.shape
rng = check_random_state(self.random_state)
n_samples_comp = np.round(self.weights_ * n_samples).astype(int)
n_samples_comp = rng.multinomial(n_samples, self.weights_)

if self.covariance_type == 'full':
X = np.vstack([
Expand Down
18 changes: 14 additions & 4 deletions sklearn/mixture/tests/test_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def test_property():

def test_sample():
rng = np.random.RandomState(0)
rand_data = RandomData(rng, scale=7)
rand_data = RandomData(rng, scale=7, n_components=3)
n_features, n_components = rand_data.n_features, rand_data.n_components

for covar_type in COVARIANCE_TYPE:
Expand All @@ -935,8 +935,10 @@ def test_sample():
gmm.sample, 0)

# Just to make sure the class samples correctly
X_s, y_s = gmm.sample(20000)
for k in range(n_features):
n_samples = 20000
X_s, y_s = gmm.sample(n_samples)

for k in range(n_components):
if covar_type == 'full':
assert_array_almost_equal(gmm.covariances_[k],
np.cov(X_s[y_s == k].T), decimal=1)
Expand All @@ -953,9 +955,17 @@ def test_sample():
decimal=1)

means_s = np.array([np.mean(X_s[y_s == k], 0)
for k in range(n_features)])
for k in range(n_components)])
assert_array_almost_equal(gmm.means_, means_s, decimal=1)

# Check shapes of sampled data, see
# https://github.com/scikit-learn/scikit-learn/issues/7701
assert_equal(X_s.shape, (n_samples, n_features))

for sample_size in range(1, 100):
X_s, _ = gmm.sample(sample_size)
assert_equal(X_s.shape, (sample_size, n_features))


@ignore_warnings(category=ConvergenceWarning)
def test_init():
Expand Down

0 comments on commit ad6f094

Please sign in to comment.