forked from graspologic-org/graspologic
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Plot sorting and lcc bug fix (graspologic-org#141)
* trying heatmap color norm, intersect bug * line tweak * working but ugly plot code * fixing hier plot * tweaks * stash * test for sorting inds * test for recursive case * trying heatmap color norm, intersect bug * line tweak * working but ugly plot code * fixing hier plot * tweaks * stash * test for sorting inds * test for recursive case * comments * bug fix and berlin figure * gridplot bug * convert to numpy array always * convert to numpy for heatmap * heatmap fontsize scaling * gridplot size fix * update fontsize name and description
- Loading branch information
Showing
8 changed files
with
549 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,14 +3,18 @@ | |
# Email: [email protected] | ||
# Copyright (c) 2018. All rights reserved. | ||
|
||
from operator import itemgetter | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
from ..utils import import_graph, pass_to_ranks | ||
from ..embed import selectSVD | ||
from sklearn.utils import check_array, check_consistent_length | ||
from matplotlib import colors | ||
from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
from sklearn.utils import check_array, check_consistent_length | ||
|
||
from ..embed import selectSVD | ||
from ..utils import import_graph, pass_to_ranks | ||
|
||
|
||
def _check_common_inputs( | ||
|
@@ -99,6 +103,7 @@ def heatmap( | |
cbar=True, | ||
inner_hier_labels=None, | ||
outer_hier_labels=None, | ||
hier_label_fontsize=30, | ||
): | ||
r""" | ||
Plots a graph as a heatmap. | ||
|
@@ -146,6 +151,9 @@ def heatmap( | |
Categorical labeling of the nodes, ignored without `inner_hier_labels` | ||
If not None, will plot these labels as the second level of a hierarchy on the | ||
marginals | ||
hier_label_fontsize : int | ||
size (in points) of the text labels for the `inner_hier_labels` and | ||
`outer_hier_labels`. | ||
""" | ||
_check_common_inputs( | ||
figsize=figsize, title=title, context=context, font_scale=font_scale | ||
|
@@ -189,13 +197,15 @@ def heatmap( | |
arr = import_graph(X) | ||
arr = _transform(arr, transform) | ||
if inner_hier_labels is not None: | ||
inner_hier_labels = np.array(inner_hier_labels) | ||
if outer_hier_labels is None: | ||
arr = _sort_graph(arr, inner_hier_labels, np.ones_like(inner_hier_labels)) | ||
else: | ||
outer_hier_labels = np.array(outer_hier_labels) | ||
arr = _sort_graph(arr, inner_hier_labels, outer_hier_labels) | ||
|
||
# Global plotting settings | ||
CBAR_KWS = dict(shrink=0.7) | ||
CBAR_KWS = dict(shrink=0.7, norm=colors.NoNorm) | ||
|
||
with sns.plotting_context(context, font_scale=font_scale): | ||
fig, ax = plt.subplots(figsize=figsize) | ||
|
@@ -217,10 +227,14 @@ def heatmap( | |
plot.set_yticklabels([]) | ||
plot.set_xticklabels([]) | ||
_plot_groups( | ||
plot, arr[0].shape[0], inner_hier_labels, outer_hier_labels | ||
plot, | ||
arr, | ||
inner_hier_labels, | ||
outer_hier_labels, | ||
fontsize=hier_label_fontsize, | ||
) | ||
else: | ||
_plot_groups(plot, arr[0].shape[0], inner_hier_labels) | ||
_plot_groups(plot, arr, inner_hier_labels, fontsize=hier_label_fontsize) | ||
return plot | ||
|
||
|
||
|
@@ -238,6 +252,7 @@ def gridplot( | |
legend_name="Type", | ||
inner_hier_labels=None, | ||
outer_hier_labels=None, | ||
hier_label_fontsize=30, | ||
): | ||
r""" | ||
Plots multiple graphs as a grid, with intensity denoted by the size | ||
|
@@ -290,6 +305,9 @@ def gridplot( | |
Categorical labeling of the nodes, ignored without `inner_hier_labels` | ||
If not None, will plot these labels as the second level of a hierarchy on the | ||
marginals | ||
hier_label_fontsize : int | ||
size (in points) of the text labels for the `inner_hier_labels` and | ||
`outer_hier_labels`. | ||
""" | ||
_check_common_inputs( | ||
height=height, title=title, context=context, font_scale=font_scale | ||
|
@@ -301,17 +319,24 @@ def gridplot( | |
msg = "X must be a list, not {}.".format(type(X)) | ||
raise TypeError(msg) | ||
|
||
check_consistent_length(X, labels, inner_hier_labels, outer_hier_labels) | ||
if labels is None: | ||
labels = np.arange(len(X)) | ||
|
||
check_consistent_length(X, labels) | ||
for g in X: | ||
check_consistent_length(g, inner_hier_labels, outer_hier_labels) | ||
|
||
graphs = [_transform(arr, transform) for arr in graphs] | ||
|
||
if inner_hier_labels is not None: | ||
inner_hier_labels = np.array(inner_hier_labels) | ||
if outer_hier_labels is None: | ||
graphs = [ | ||
_sort_graph(arr, inner_hier_labels, np.ones_like(inner_hier_labels)) | ||
for arr in graphs | ||
] | ||
else: | ||
outer_hier_labels = np.array(outer_hier_labels) | ||
graphs = [ | ||
_sort_graph(arr, inner_hier_labels, outer_hier_labels) for arr in graphs | ||
] | ||
|
@@ -358,14 +383,19 @@ def gridplot( | |
if inner_hier_labels is not None: | ||
if outer_hier_labels is not None: | ||
_plot_groups( | ||
plot.ax, graphs[0].shape[0], inner_hier_labels, outer_hier_labels | ||
plot.ax, | ||
graphs[0], | ||
inner_hier_labels, | ||
outer_hier_labels, | ||
fontsize=hier_label_fontsize, | ||
) | ||
else: | ||
_plot_groups(plot.ax, graphs[0].shape[0], inner_hier_labels) | ||
_plot_groups( | ||
plot.ax, graphs[0], inner_hier_labels, fontsize=hier_label_fontsize | ||
) | ||
return plot | ||
|
||
|
||
# TODO would it be cool if pairplot reduced to single plot | ||
def pairplot( | ||
X, | ||
labels=None, | ||
|
@@ -380,6 +410,7 @@ def pairplot( | |
alpha=0.7, | ||
size=50, | ||
marker=".", | ||
diag_kind="auto", | ||
): | ||
r""" | ||
Plot pairwise relationships in a dataset. | ||
|
@@ -467,7 +498,6 @@ def pairplot( | |
else: | ||
variables = col_names | ||
|
||
diag_kind = "auto" | ||
df = pd.DataFrame(X, columns=col_names) | ||
if labels is not None: | ||
if legend_name is None: | ||
|
@@ -748,67 +778,107 @@ def screeplot( | |
return ax | ||
|
||
|
||
def _sort_inds(inner_labels, outer_labels): | ||
def _sort_inds(graph, inner_labels, outer_labels): | ||
sort_df = pd.DataFrame(columns=("inner_labels", "outer_labels")) | ||
sort_df["inner_labels"] = inner_labels | ||
if outer_labels is not None: | ||
sort_df["outer_labels"] = outer_labels | ||
sort_df.sort_values( | ||
by=["outer_labels", "inner_labels"], kind="mergesort", inplace=True | ||
) | ||
outer_labels = sort_df["outer_labels"] | ||
inner_labels = sort_df["inner_labels"] | ||
sort_df["outer_labels"] = outer_labels | ||
|
||
# get frequencies of the different labels so we can sort by them | ||
inner_label_counts = _get_freq_vec(inner_labels) | ||
outer_label_counts = _get_freq_vec(outer_labels) | ||
|
||
# inverse counts so we can sort largest to smallest | ||
# would rather do it this way so can still sort alphabetical for ties | ||
sort_df["inner_counts"] = len(inner_labels) - inner_label_counts | ||
sort_df["outer_counts"] = len(outer_labels) - outer_label_counts | ||
|
||
# get node edge sums (not exactly degrees if weighted) | ||
node_edgesums = graph.sum(axis=1) + graph.sum(axis=0) | ||
sort_df["node_edgesums"] = node_edgesums.max() - node_edgesums | ||
|
||
sort_df.sort_values( | ||
by=[ | ||
"outer_counts", | ||
"outer_labels", | ||
"inner_counts", | ||
"inner_labels", | ||
"node_edgesums", | ||
], | ||
kind="mergesort", | ||
inplace=True, | ||
) | ||
|
||
sorted_inds = sort_df.index.values | ||
return sorted_inds | ||
|
||
|
||
def _sort_graph(graph, inner_labels, outer_labels): | ||
inds = _sort_inds(inner_labels, outer_labels) | ||
inds = _sort_inds(graph, inner_labels, outer_labels) | ||
graph = graph[inds, :][:, inds] | ||
return graph | ||
|
||
|
||
def _get_freqs(inner_labels, outer_labels=None): | ||
_, outer_freq = np.unique(outer_labels, return_counts=True) | ||
# use this because unique would give alphabetical | ||
_, outer_freq = _unique_like(outer_labels) | ||
outer_freq_cumsum = np.hstack((0, outer_freq.cumsum())) | ||
|
||
# for each group of outer labels, calculate the boundaries of the inner labels | ||
inner_freq = np.array([]) | ||
for i in range(outer_freq.size): | ||
start_ind = outer_freq_cumsum[i] | ||
stop_ind = outer_freq_cumsum[i + 1] | ||
_, temp_freq = np.unique(inner_labels[start_ind:stop_ind], return_counts=True) | ||
_, temp_freq = _unique_like(inner_labels[start_ind:stop_ind]) | ||
inner_freq = np.hstack([inner_freq, temp_freq]) | ||
inner_freq_cumsum = np.hstack((0, inner_freq.cumsum())) | ||
|
||
return inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum | ||
|
||
|
||
def _get_freq_vec(vals): | ||
# give each set of labels a vector corresponding to its frequency | ||
_, inv, counts = np.unique(vals, return_counts=True, return_inverse=True) | ||
count_vec = counts[inv] | ||
return count_vec | ||
|
||
|
||
def _unique_like(vals): | ||
# gives output like | ||
uniques, inds, counts = np.unique(vals, return_index=True, return_counts=True) | ||
inds_sort = np.argsort(inds) | ||
uniques = uniques[inds_sort] | ||
counts = counts[inds_sort] | ||
return uniques, counts | ||
|
||
|
||
# assume that the graph has already been plotted in sorted form | ||
def _plot_groups(ax, n_verts, inner_labels, outer_labels=None): | ||
def _plot_groups(ax, graph, inner_labels, outer_labels=None, fontsize=30): | ||
plot_outer = True | ||
if outer_labels is None: | ||
outer_labels = np.ones_like(inner_labels) | ||
plot_outer = False | ||
sorted_inds = _sort_inds(inner_labels, outer_labels) | ||
|
||
sorted_inds = _sort_inds(graph, inner_labels, outer_labels) | ||
|
||
inner_labels = inner_labels[sorted_inds] | ||
outer_labels = outer_labels[sorted_inds] | ||
|
||
inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum = _get_freqs( | ||
inner_labels, outer_labels | ||
) | ||
inner_unique, _ = _unique_like(inner_labels) | ||
outer_unique, _ = _unique_like(outer_labels) | ||
|
||
inner_unique = np.unique(inner_labels) | ||
outer_unique = np.unique(outer_labels) | ||
|
||
n_verts = graph.shape[0] | ||
# draw lines | ||
for x in inner_freq_cumsum: | ||
if x != inner_freq_cumsum[0]: | ||
x -= 0.2 | ||
ax.vlines(x, 0, n_verts, linestyle="dashed", lw=0.9, alpha=0.25, zorder=3) | ||
if x == inner_freq_cumsum[-1]: | ||
x -= 1 | ||
ax.hlines(x, 0, n_verts, linestyle="dashed", lw=0.9, alpha=0.25, zorder=3) | ||
|
||
# generic curve that we will use for everything | ||
lx = np.linspace(-np.pi / 2.0 + 0.05, np.pi / 2.0 - 0.05, 50) | ||
lx = np.linspace(-np.pi / 2.0 + 0.05, np.pi / 2.0 - 0.05, 500) | ||
tan = np.tan(lx) | ||
curve = np.hstack((tan[::-1], tan)) | ||
|
||
|
@@ -835,6 +905,7 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None): | |
"inner", | ||
"x", | ||
n_verts, | ||
fontsize, | ||
) | ||
# side inner curves | ||
# ax_y = divider.new_horizontal( | ||
|
@@ -850,11 +921,13 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None): | |
"inner", | ||
"y", | ||
n_verts, | ||
fontsize, | ||
) | ||
|
||
if plot_outer: | ||
# top outer curves | ||
ax_x2 = divider.new_vertical(size="5%", pad=0.25, pack_start=False) | ||
pad_scalar = 0.35 / 30 * fontsize | ||
ax_x2 = divider.new_vertical(size="5%", pad=pad_scalar, pack_start=False) | ||
ax.figure.add_axes(ax_x2) | ||
_plot_brackets( | ||
ax_x2, | ||
|
@@ -865,9 +938,10 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None): | |
"outer", | ||
"x", | ||
n_verts, | ||
fontsize, | ||
) | ||
# side outer curves | ||
ax_y2 = divider.new_horizontal(size="5%", pad=0.25, pack_start=True) | ||
ax_y2 = divider.new_horizontal(size="5%", pad=pad_scalar, pack_start=True) | ||
ax.figure.add_axes(ax_y2) | ||
_plot_brackets( | ||
ax_y2, | ||
|
@@ -878,31 +952,38 @@ def _plot_groups(ax, n_verts, inner_labels, outer_labels=None): | |
"outer", | ||
"y", | ||
n_verts, | ||
fontsize, | ||
) | ||
return ax | ||
|
||
|
||
def _plot_brackets(ax, group_names, tick_loc, tick_width, curve, level, axis, max_size): | ||
def _plot_brackets( | ||
ax, group_names, tick_loc, tick_width, curve, level, axis, max_size, fontsize | ||
): | ||
for x0, width in zip(tick_loc, tick_width): | ||
x = np.linspace(x0 - width, x0 + width, 100) | ||
x = np.linspace(x0 - width, x0 + width, 1000) | ||
if axis == "x": | ||
ax.plot(x, -curve, c="k") | ||
ax.patch.set_alpha(0) | ||
elif axis == "y": | ||
ax.plot(curve, x, c="k") | ||
ax.patch.set_alpha(0) | ||
ax.set_yticks([]) | ||
ax.set_xticks([]) | ||
ax.tick_params(axis=axis, which=u"both", length=0, pad=7) | ||
for direction in ["left", "right", "bottom", "top"]: | ||
ax.spines[direction].set_visible(False) | ||
if axis == "x": | ||
ax.set_xticks(tick_loc) | ||
ax.set_xticklabels(group_names, fontsize=15, verticalalignment="center") | ||
ax.set_xticklabels(group_names, fontsize=fontsize, verticalalignment="center") | ||
ax.xaxis.set_label_position("top") | ||
ax.xaxis.tick_top() | ||
ax.xaxis.labelpad = 30 | ||
ax.set_xlim(0, max_size) | ||
ax.tick_params(axis="x", which="major", pad=5 + fontsize / 4) | ||
elif axis == "y": | ||
ax.set_yticks(tick_loc) | ||
ax.set_yticklabels(group_names, fontsize=15, verticalalignment="center") | ||
ax.set_yticklabels(group_names, fontsize=fontsize, verticalalignment="center") | ||
# ax.yaxis.set_label_position('top') | ||
# ax.yaxis.tick_top() | ||
ax.set_ylim(0, max_size) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.