Skip to content

Commit

Permalink
Plot sorting and lcc bug fix (graspologic-org#141)
Browse files Browse the repository at this point in the history
* trying heatmap color norm, intersect bug

* line tweak

* working but ugly plot code

* fixing hier plot

* tweaks

* stash

* test for sorting inds

* test for recursive case

* trying heatmap color norm, intersect bug

* line tweak

* working but ugly plot code

* fixing hier plot

* tweaks

* stash

* test for sorting inds

* test for recursive case

* comments

* bug fix and berlin figure

* gridplot bug

* convert to numpy array always

* convert to numpy for heatmap

* heatmap fontsize scaling

* gridplot size fix

* update fontsize name and description
  • Loading branch information
bdpedigo authored Apr 23, 2019
1 parent 80b81f3 commit df833a0
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 57 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,6 @@ notebooks/bpedigo/emmons/*.csv
notebooks/bpedigo/diag_procrust.py
/graspy-env
notebooks/bpedigo/new_pairs.py
*.code-workspace
*.code-workspace
*.png

153 changes: 117 additions & 36 deletions graspy/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
# Email: [email protected]
# Copyright (c) 2018. All rights reserved.

from operator import itemgetter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from ..utils import import_graph, pass_to_ranks
from ..embed import selectSVD
from sklearn.utils import check_array, check_consistent_length
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sklearn.utils import check_array, check_consistent_length

from ..embed import selectSVD
from ..utils import import_graph, pass_to_ranks


def _check_common_inputs(
Expand Down Expand Up @@ -99,6 +103,7 @@ def heatmap(
cbar=True,
inner_hier_labels=None,
outer_hier_labels=None,
hier_label_fontsize=30,
):
r"""
Plots a graph as a heatmap.
Expand Down Expand Up @@ -146,6 +151,9 @@ def heatmap(
Categorical labeling of the nodes, ignored without `inner_hier_labels`
If not None, will plot these labels as the second level of a hierarchy on the
marginals
hier_label_fontsize : int
size (in points) of the text labels for the `inner_hier_labels` and
`outer_hier_labels`.
"""
_check_common_inputs(
figsize=figsize, title=title, context=context, font_scale=font_scale
Expand Down Expand Up @@ -189,13 +197,15 @@ def heatmap(
arr = import_graph(X)
arr = _transform(arr, transform)
if inner_hier_labels is not None:
inner_hier_labels = np.array(inner_hier_labels)
if outer_hier_labels is None:
arr = _sort_graph(arr, inner_hier_labels, np.ones_like(inner_hier_labels))
else:
outer_hier_labels = np.array(outer_hier_labels)
arr = _sort_graph(arr, inner_hier_labels, outer_hier_labels)

# Global plotting settings
CBAR_KWS = dict(shrink=0.7)
CBAR_KWS = dict(shrink=0.7, norm=colors.NoNorm)

with sns.plotting_context(context, font_scale=font_scale):
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -217,10 +227,14 @@ def heatmap(
plot.set_yticklabels([])
plot.set_xticklabels([])
_plot_groups(
plot, arr[0].shape[0], inner_hier_labels, outer_hier_labels
plot,
arr,
inner_hier_labels,
outer_hier_labels,
fontsize=hier_label_fontsize,
)
else:
_plot_groups(plot, arr[0].shape[0], inner_hier_labels)
_plot_groups(plot, arr, inner_hier_labels, fontsize=hier_label_fontsize)
return plot


Expand All @@ -238,6 +252,7 @@ def gridplot(
legend_name="Type",
inner_hier_labels=None,
outer_hier_labels=None,
hier_label_fontsize=30,
):
r"""
Plots multiple graphs as a grid, with intensity denoted by the size
Expand Down Expand Up @@ -290,6 +305,9 @@ def gridplot(
Categorical labeling of the nodes, ignored without `inner_hier_labels`
If not None, will plot these labels as the second level of a hierarchy on the
marginals
hier_label_fontsize : int
size (in points) of the text labels for the `inner_hier_labels` and
`outer_hier_labels`.
"""
_check_common_inputs(
height=height, title=title, context=context, font_scale=font_scale
Expand All @@ -301,17 +319,24 @@ def gridplot(
msg = "X must be a list, not {}.".format(type(X))
raise TypeError(msg)

check_consistent_length(X, labels, inner_hier_labels, outer_hier_labels)
if labels is None:
labels = np.arange(len(X))

check_consistent_length(X, labels)
for g in X:
check_consistent_length(g, inner_hier_labels, outer_hier_labels)

graphs = [_transform(arr, transform) for arr in graphs]

if inner_hier_labels is not None:
inner_hier_labels = np.array(inner_hier_labels)
if outer_hier_labels is None:
graphs = [
_sort_graph(arr, inner_hier_labels, np.ones_like(inner_hier_labels))
for arr in graphs
]
else:
outer_hier_labels = np.array(outer_hier_labels)
graphs = [
_sort_graph(arr, inner_hier_labels, outer_hier_labels) for arr in graphs
]
Expand Down Expand Up @@ -358,14 +383,19 @@ def gridplot(
if inner_hier_labels is not None:
if outer_hier_labels is not None:
_plot_groups(
plot.ax, graphs[0].shape[0], inner_hier_labels, outer_hier_labels
plot.ax,
graphs[0],
inner_hier_labels,
outer_hier_labels,
fontsize=hier_label_fontsize,
)
else:
_plot_groups(plot.ax, graphs[0].shape[0], inner_hier_labels)
_plot_groups(
plot.ax, graphs[0], inner_hier_labels, fontsize=hier_label_fontsize
)
return plot


# TODO would it be cool if pairplot reduced to single plot
def pairplot(
X,
labels=None,
Expand All @@ -380,6 +410,7 @@ def pairplot(
alpha=0.7,
size=50,
marker=".",
diag_kind="auto",
):
r"""
Plot pairwise relationships in a dataset.
Expand Down Expand Up @@ -467,7 +498,6 @@ def pairplot(
else:
variables = col_names

diag_kind = "auto"
df = pd.DataFrame(X, columns=col_names)
if labels is not None:
if legend_name is None:
Expand Down Expand Up @@ -748,67 +778,107 @@ def screeplot(
return ax


def _sort_inds(inner_labels, outer_labels):
def _sort_inds(graph, inner_labels, outer_labels):
sort_df = pd.DataFrame(columns=("inner_labels", "outer_labels"))
sort_df["inner_labels"] = inner_labels
if outer_labels is not None:
sort_df["outer_labels"] = outer_labels
sort_df.sort_values(
by=["outer_labels", "inner_labels"], kind="mergesort", inplace=True
)
outer_labels = sort_df["outer_labels"]
inner_labels = sort_df["inner_labels"]
sort_df["outer_labels"] = outer_labels

# get frequencies of the different labels so we can sort by them
inner_label_counts = _get_freq_vec(inner_labels)
outer_label_counts = _get_freq_vec(outer_labels)

# inverse counts so we can sort largest to smallest
# would rather do it this way so can still sort alphabetical for ties
sort_df["inner_counts"] = len(inner_labels) - inner_label_counts
sort_df["outer_counts"] = len(outer_labels) - outer_label_counts

# get node edge sums (not exactly degrees if weighted)
node_edgesums = graph.sum(axis=1) + graph.sum(axis=0)
sort_df["node_edgesums"] = node_edgesums.max() - node_edgesums

sort_df.sort_values(
by=[
"outer_counts",
"outer_labels",
"inner_counts",
"inner_labels",
"node_edgesums",
],
kind="mergesort",
inplace=True,
)

sorted_inds = sort_df.index.values
return sorted_inds


def _sort_graph(graph, inner_labels, outer_labels):
inds = _sort_inds(inner_labels, outer_labels)
inds = _sort_inds(graph, inner_labels, outer_labels)
graph = graph[inds, :][:, inds]
return graph


def _get_freqs(inner_labels, outer_labels=None):
_, outer_freq = np.unique(outer_labels, return_counts=True)
# use this because unique would give alphabetical
_, outer_freq = _unique_like(outer_labels)
outer_freq_cumsum = np.hstack((0, outer_freq.cumsum()))

# for each group of outer labels, calculate the boundaries of the inner labels
inner_freq = np.array([])
for i in range(outer_freq.size):
start_ind = outer_freq_cumsum[i]
stop_ind = outer_freq_cumsum[i + 1]
_, temp_freq = np.unique(inner_labels[start_ind:stop_ind], return_counts=True)
_, temp_freq = _unique_like(inner_labels[start_ind:stop_ind])
inner_freq = np.hstack([inner_freq, temp_freq])
inner_freq_cumsum = np.hstack((0, inner_freq.cumsum()))

return inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum


def _get_freq_vec(vals):
# give each set of labels a vector corresponding to its frequency
_, inv, counts = np.unique(vals, return_counts=True, return_inverse=True)
count_vec = counts[inv]
return count_vec


def _unique_like(vals):
# gives output like
uniques, inds, counts = np.unique(vals, return_index=True, return_counts=True)
inds_sort = np.argsort(inds)
uniques = uniques[inds_sort]
counts = counts[inds_sort]
return uniques, counts


# assume that the graph has already been plotted in sorted form
def _plot_groups(ax, n_verts, inner_labels, outer_labels=None):
def _plot_groups(ax, graph, inner_labels, outer_labels=None, fontsize=30):
plot_outer = True
if outer_labels is None:
outer_labels = np.ones_like(inner_labels)
plot_outer = False
sorted_inds = _sort_inds(inner_labels, outer_labels)

sorted_inds = _sort_inds(graph, inner_labels, outer_labels)

inner_labels = inner_labels[sorted_inds]
outer_labels = outer_labels[sorted_inds]

inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum = _get_freqs(
inner_labels, outer_labels
)
inner_unique, _ = _unique_like(inner_labels)
outer_unique, _ = _unique_like(outer_labels)

inner_unique = np.unique(inner_labels)
outer_unique = np.unique(outer_labels)

n_verts = graph.shape[0]
# draw lines
for x in inner_freq_cumsum:
if x != inner_freq_cumsum[0]:
x -= 0.2
ax.vlines(x, 0, n_verts, linestyle="dashed", lw=0.9, alpha=0.25, zorder=3)
if x == inner_freq_cumsum[-1]:
x -= 1
ax.hlines(x, 0, n_verts, linestyle="dashed", lw=0.9, alpha=0.25, zorder=3)

# generic curve that we will use for everything
lx = np.linspace(-np.pi / 2.0 + 0.05, np.pi / 2.0 - 0.05, 50)
lx = np.linspace(-np.pi / 2.0 + 0.05, np.pi / 2.0 - 0.05, 500)
tan = np.tan(lx)
curve = np.hstack((tan[::-1], tan))

Expand All @@ -835,6 +905,7 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None):
"inner",
"x",
n_verts,
fontsize,
)
# side inner curves
# ax_y = divider.new_horizontal(
Expand All @@ -850,11 +921,13 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None):
"inner",
"y",
n_verts,
fontsize,
)

if plot_outer:
# top outer curves
ax_x2 = divider.new_vertical(size="5%", pad=0.25, pack_start=False)
pad_scalar = 0.35 / 30 * fontsize
ax_x2 = divider.new_vertical(size="5%", pad=pad_scalar, pack_start=False)
ax.figure.add_axes(ax_x2)
_plot_brackets(
ax_x2,
Expand All @@ -865,9 +938,10 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None):
"outer",
"x",
n_verts,
fontsize,
)
# side outer curves
ax_y2 = divider.new_horizontal(size="5%", pad=0.25, pack_start=True)
ax_y2 = divider.new_horizontal(size="5%", pad=pad_scalar, pack_start=True)
ax.figure.add_axes(ax_y2)
_plot_brackets(
ax_y2,
Expand All @@ -878,31 +952,38 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None):
"outer",
"y",
n_verts,
fontsize,
)
return ax


def _plot_brackets(ax, group_names, tick_loc, tick_width, curve, level, axis, max_size):
def _plot_brackets(
ax, group_names, tick_loc, tick_width, curve, level, axis, max_size, fontsize
):
for x0, width in zip(tick_loc, tick_width):
x = np.linspace(x0 - width, x0 + width, 100)
x = np.linspace(x0 - width, x0 + width, 1000)
if axis == "x":
ax.plot(x, -curve, c="k")
ax.patch.set_alpha(0)
elif axis == "y":
ax.plot(curve, x, c="k")
ax.patch.set_alpha(0)
ax.set_yticks([])
ax.set_xticks([])
ax.tick_params(axis=axis, which=u"both", length=0, pad=7)
for direction in ["left", "right", "bottom", "top"]:
ax.spines[direction].set_visible(False)
if axis == "x":
ax.set_xticks(tick_loc)
ax.set_xticklabels(group_names, fontsize=15, verticalalignment="center")
ax.set_xticklabels(group_names, fontsize=fontsize, verticalalignment="center")
ax.xaxis.set_label_position("top")
ax.xaxis.tick_top()
ax.xaxis.labelpad = 30
ax.set_xlim(0, max_size)
ax.tick_params(axis="x", which="major", pad=5 + fontsize / 4)
elif axis == "y":
ax.set_yticks(tick_loc)
ax.set_yticklabels(group_names, fontsize=15, verticalalignment="center")
ax.set_yticklabels(group_names, fontsize=fontsize, verticalalignment="center")
# ax.yaxis.set_label_position('top')
# ax.yaxis.tick_top()
ax.set_ylim(0, max_size)
Expand Down
8 changes: 7 additions & 1 deletion graspy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,15 @@ def get_multigraph_intersect_lcc(graphs, return_inds=False):
recurse = True
break
if recurse:
new_graphs, inds_intersection = get_multigraph_intersect_lcc(
new_graphs, new_inds_intersection = get_multigraph_intersect_lcc(
new_graphs, return_inds=True
)
# new inds intersection are the indices of new_graph that were kept on recurse
# need to do this because indices could have shifted during recursion
if type(graphs[0]) is np.ndarray:
inds_intersection = inds_intersection[new_inds_intersection]
else:
inds_intersection = new_inds_intersection
if type(graphs) != list:
new_graphs = np.stack(new_graphs)
if return_inds:
Expand Down
Loading

0 comments on commit df833a0

Please sign in to comment.