Skip to content

Commit

Permalink
make sure to access dataframe differently
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilianFranz committed Feb 24, 2020
1 parent bfe2e6f commit 3aac88a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/justcause/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def generate_data(
), "Covariate function should return a dataframe with `n_samples` rows"
else:
indices = random_state.choice(covariates.shape[0], n_samples, replace=False)
covariates = covariates[indices, :]
if isinstance(covariates, pd.DataFrame):
covariates = covariates.iloc[indices, :] # ensure proper access
else:
covariates = covariates[indices, :]

if covariate_names is None:
if isinstance(covariates, pd.DataFrame):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_data_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def covariates(_, *, random_state, **kwargs):
assert len(df) == len(ihdp_cov)
assert len(set(df.columns).intersection(set(cov_names))) == 25

# Works with DataFrame and np.array when n_samples is given
gen = generate_data(
covariates=ihdp_cov, treatment=treatment, outcomes=outcomes, n_samples=500,
)
df = list(gen)[0]
assert len(df) == 500
assert len(set(df.columns).intersection({f"x_{i}" for i in range(25)})) == 25


def test_ihdp_generator():
gen = multi_expo_on_ihdp(setting="multi-modal", n_replications=10)
Expand Down

0 comments on commit 3aac88a

Please sign in to comment.