-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sample size calculation * refactor * multiple test correction
- Loading branch information
1 parent
dd45bc7
commit 5cf6139
Showing
10 changed files
with
435 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import logging | ||
from fastapi import APIRouter, Depends, HTTPException | ||
from statsd import StatsClient | ||
import asyncio | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from ..toolkit.statistics import Statistics | ||
|
||
from .req import SampleSizeCalculationData | ||
from .res import SampleSizeCalculationResult | ||
|
||
|
||
_logger = logging.getLogger("epstats") | ||
|
||
|
||
def get_sample_size_calculation_router(get_executor_pool, get_statsd) -> APIRouter: | ||
def _sample_size_calculation(data: SampleSizeCalculationData, statsd: StatsClient): | ||
try: | ||
|
||
if data.std is None: | ||
f = Statistics.required_sample_size_per_variant_bernoulli | ||
else: | ||
f = Statistics.required_sample_size_per_variant | ||
|
||
sample_size_per_variant = f(**data.dict()) | ||
|
||
_logger.info((f"Calculation finished, sample_size_per_variant = {sample_size_per_variant}.")) | ||
return SampleSizeCalculationResult(sample_size_per_variant=sample_size_per_variant) | ||
except Exception as e: | ||
_logger.error(f"Cannot calculate the sample size because of: '{e}'") | ||
_logger.exception(e) | ||
statsd.incr("errors.sample_size_calculation") | ||
raise HTTPException( | ||
status_code=500, | ||
detail=f"Cannot calculate the sample size because of: '{e}'", | ||
) | ||
|
||
router = APIRouter() | ||
|
||
@router.post("/sample-size-calculation", response_model=SampleSizeCalculationResult) | ||
async def sample_size_calculation( | ||
data: SampleSizeCalculationData, | ||
evaluation_pool: ThreadPoolExecutor = Depends(get_executor_pool), | ||
statsd: StatsClient = Depends(get_statsd), | ||
): | ||
""" | ||
Calculates sample size based on `data`. | ||
""" | ||
_logger.info(f"Calling the sample size calculation with {data.json()}") | ||
statsd.incr("requests.sample_size_calculation") | ||
loop = asyncio.get_event_loop() | ||
return await loop.run_in_executor(evaluation_pool, _sample_size_calculation, data, statsd) | ||
|
||
return router |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
from fastapi.testclient import TestClient | ||
|
||
from src.epstats.main import api | ||
from src.epstats.main import get_statsd, get_executor_pool | ||
|
||
from .depend import get_test_executor_pool, get_test_statsd | ||
|
||
|
||
client = TestClient(api) | ||
api.dependency_overrides[get_statsd] = get_test_statsd | ||
api.dependency_overrides[get_executor_pool] = get_test_executor_pool | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_variants, minimum_effect, mean, std, expected", | ||
[(2, 0.10, 0.2, 1.2, 56512), (2, 0.05, 0.4, None, 9489), (3, 0.05, 0.4, None, 11492)], | ||
) | ||
def test_sample_size_calculation(n_variants, minimum_effect, mean, std, expected): | ||
json_blob = { | ||
"minimum_effect": minimum_effect, | ||
"mean": mean, | ||
"std": std, | ||
"n_variants": n_variants, | ||
} | ||
|
||
resp = client.post("/sample-size-calculation", json=json_blob) | ||
assert resp.status_code == 200 | ||
assert resp.json()["sample_size_per_variant"] == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_variants, minimum_effect, mean, expected_message", | ||
[ | ||
(2, -0.4, 0.2, "minimum_effect must be greater than zero"), | ||
(2, 0.05, 1.4, "mean must be between zero and one"), | ||
(1, 0.05, 0.2, "must be at least two variants"), | ||
], | ||
) | ||
def test_sample_size_calculation_error(n_variants, minimum_effect, mean, expected_message): | ||
json_blob = { | ||
"minimum_effect": minimum_effect, | ||
"mean": mean, | ||
"n_variants": n_variants, | ||
} | ||
|
||
resp = client.post("/sample-size-calculation", json=json_blob) | ||
assert resp.status_code == 500 | ||
assert expected_message in resp.content.decode() |
Oops, something went wrong.