Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a regression in show_percentages / total calculation #248

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix a regression in show_percentages / total calculation
Fix #226

Fix #223
  • Loading branch information
jnothman committed Dec 28, 2023
commit 235e4d631283a31182806d8f29ebf43cc7670952
11 changes: 10 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
What's new in version 0.9
-------------------------

- Fixes a bug where ``show_percentages`` used the incorrect denominator if
filtering (e.g. ``min_subset_size``) was applied. This bug was a regression
introduced in version 0.7. (:issue:`248`)
- Ability to disable totals plot with `totals_plot_elements=0`. (:issue:`246`)
- Ability to set totals y axis label (:issue:`243`)

What's new in version 0.8
-------------------------

Expand All @@ -10,7 +19,7 @@ What's new in version 0.8
- Added `subsets` attribute to QueryResult. (:issue:`198`)
- Fixed a bug where more than 64 categories could result in an error. (:issue:`193`)

Patch release 0.8.1 handles deprecations in dependencies.
Patch release 0.8.2 handles deprecations in dependencies.

What's new in version 0.7
-------------------------
Expand Down
6 changes: 1 addition & 5 deletions upsetplot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def _process_data(

df = results.data
agg = results.subset_sizes
totals = results.category_totals
total = agg.sum()

# add '_bin' to df indicating index in agg
# XXX: ugly!
Expand All @@ -80,7 +78,7 @@ def _pack_binary(X):
if reverse:
agg = agg[::-1]

return total, df, agg, totals
return results.total, df, agg, results.category_totals


def _multiply_alpha(c, mult):
Expand Down Expand Up @@ -678,8 +676,6 @@ def make_grid(self, fig=None):
fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)

text_nelems = int(np.ceil(figw / colw - non_text_nelems))
# print('textw', textw, 'figw', figw, 'colw', colw,
# 'ncols', figw/colw, 'text_nelems', text_nelems)

GS = self._reorient(matplotlib.gridspec.GridSpec)
gridspec = GS(
Expand Down
30 changes: 18 additions & 12 deletions upsetplot/reformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,20 @@ class QueryResult:
for `data`.
category_totals : Series
Total size of each category, regardless of selection.
total : number
Total number of samples / sum of value
"""

def __init__(self, data, subset_sizes, category_totals):
def __init__(self, data, subset_sizes, category_totals, total):
self.data = data
self.subset_sizes = subset_sizes
self.category_totals = category_totals
self.total = total

def __repr__(self):
return (
"QueryResult(data={data}, subset_sizes={subset_sizes}, "
"category_totals={category_totals}".format(**vars(self))
"category_totals={category_totals}, total={total}".format(**vars(self))
)

@property
Expand Down Expand Up @@ -270,7 +273,7 @@ def query(
-------
QueryResult
Including filtered ``data``, filtered and sorted ``subset_sizes`` and
overall ``category_totals``.
overall ``category_totals`` and ``total``.

Examples
--------
Expand Down Expand Up @@ -325,11 +328,12 @@ def query(

data, agg = _aggregate_data(data, subset_size, sum_over)
data = _check_index(data)
totals = [
grand_total = agg.sum()
category_totals = [
agg[agg.index.get_level_values(name).values.astype(bool)].sum()
for name in agg.index.names
]
totals = pd.Series(totals, index=agg.index.names)
category_totals = pd.Series(category_totals, index=agg.index.names)

if include_empty_subsets:
nlevels = len(agg.index.levels)
Expand Down Expand Up @@ -361,15 +365,17 @@ def query(

# sort:
if sort_categories_by in ("cardinality", "-cardinality"):
totals.sort_values(ascending=sort_categories_by[:1] == "-", inplace=True)
category_totals.sort_values(
ascending=sort_categories_by[:1] == "-", inplace=True
)
elif sort_categories_by == "-input":
totals = totals[::-1]
category_totals = category_totals[::-1]
elif sort_categories_by in (None, "input"):
pass
else:
raise ValueError("Unknown sort_categories_by: %r" % sort_categories_by)
data = data.reorder_levels(totals.index.values)
agg = agg.reorder_levels(totals.index.values)
data = data.reorder_levels(category_totals.index.values)
agg = agg.reorder_levels(category_totals.index.values)

if sort_by in ("cardinality", "-cardinality"):
agg = agg.sort_values(ascending=sort_by[:1] == "-")
Expand All @@ -383,12 +389,12 @@ def query(
pd.MultiIndex.from_tuples(index_tuples, names=agg.index.names)
)
elif sort_by == "-input":
print("<", agg)
agg = agg[::-1]
print(">", agg)
elif sort_by in (None, "input"):
pass
else:
raise ValueError("Unknown sort_by: %r" % sort_by)

return QueryResult(data=data, subset_sizes=agg, category_totals=totals)
return QueryResult(
data=data, subset_sizes=agg, category_totals=category_totals, total=grand_total
)
1 change: 1 addition & 0 deletions upsetplot/tests/test_upsetplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def test_filter_subsets(filter_params, expected, sort_by):
)
# category totals should not be affected
assert_series_equal(upset_full.totals, upset_filtered.totals)
assert upset_full.total == pytest.approx(upset_filtered.total)


@pytest.mark.parametrize(
Expand Down