Source code for hhpy.plotting

"""
hhpy.plotting.py
~~~~~~~~~~~~~~~~

Contains plotting functions

"""

# standard imports
from copy import deepcopy

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import warnings

# third party imports
from matplotlib import patches
from matplotlib.animation import FuncAnimation
from matplotlib.legend import Legend
from colour import Color
from scipy import stats
from typing import Union, Sequence, Mapping, Callable, List

try:
    from IPython.core.display import HTML
except ImportError:
    HTML = None

# local imports
from hhpy.main import export, concat_cols, is_list_like, floor_signif, ceil_signif, list_intersection, \
    force_list, progressbar, DocstringProcessor
from hhpy.ds import get_df_corr, lfit, kde, df_count, quantile_split, top_n_coding, df_rmsd, df_agg

# --- constants

rcParams = {
    'palette': [
        'xkcd:blue', 'xkcd:red', 'xkcd:green', 'xkcd:cyan', 'xkcd:magenta',
        'xkcd:golden yellow', 'xkcd:dark cyan', 'xkcd:red orange', 'xkcd:dark yellow', 'xkcd:easter green',
        'xkcd:baby blue', 'xkcd:light brown', 'xkcd:strong pink', 'xkcd:light navy blue', 'xkcd:deep blue',
        'xkcd:deep red', 'xkcd:ultramarine blue', 'xkcd:sea green', 'xkcd:plum', 'xkcd:old pink',
        'xkcd:lawn green', 'xkcd:amber', 'xkcd:green blue', 'xkcd:yellow green', 'xkcd:dark mustard',
        'xkcd:bright lime', 'xkcd:aquamarine', 'xkcd:very light blue', 'xkcd:light grey blue', 'xkcd:dark sage',
        'xkcd:dark peach', 'xkcd:shocking pink'
    ],
    'hatches': ['/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*'],
    'figsize_square': (7, 7),
    'fig_width': 7,
    'fig_height': 7,
    'float_format': '.2f',
    'int_format': ',.0f',
    'legend_outside.legend_space': .1,
    'distplot.label_style': 'mu_sigma',
    'distplot.legend_loc': None,
    'max_n': 10000,
    'max_n.random_state': None,
    'max_n.sample_warn': True,
    'return_fig_ax': True,
    'corr_cutoff': 0,
    'animplot.mode': 'jshtml',
}

validations = {
    'valid_distfits': ['kde', 'gauss', 'False', 'None']
}

docstr = DocstringProcessor(
    ax_in='The axes object to plot on, defaults to current axis [optional]',
    ax_out='The axes object with the plot on it',
    fig_ax_out='if return_fig_ax: figure and axis objects as tuple, else None',
    x='Name of the x variable in data or vector data',
    y='Name of the y variable in data or vector data',
    t='Name of the t variable in data or vector data',
    x_novec='Name of the x variable in data',
    y_novec='Name of the y variable in data',
    t_novec='Name of the t variable in data',
    data='Pandas DataFrame containing named data, optional if vector data is used',
    data_novec='Pandas DataFrame containing named data',
    hue='Further split the plot by the levels of this variable [optional]',
    order='Either a string describing how the (hue) levels or to be ordered or an explicit list of levels to be'
    'used for plotting. Accepted strings are:'
    '''
    
        * ``sorted``: following python standard sorting conventions (alphabetical for string, ascending for value)
        
        * ``inv``: following sort of python standard sorting conventions but in inverse order
        
        * ``count``: sorted by value counts
        
        * ``mean``, ``mean_ascending``, ``mean_descending``: sorted by mean value, defaults to descending
        
        * ``median``, ``mean_ascending``, ``median_descending``: sorted by median value, defaults to descending
        
    ''',
    color='color used for plotting, must be known to matplotlib.pyplot [optional]',
    palette='Collection of colors to be used for plotting. Can be a dictionary for with names for each level or '
            'a list of colors or an individual color name. Must be valid colors known to pyplot [optional]',
    cmap='Color map to use [optional]',
    annotations='Whether to display annotations [optional]',
    number_format='The format string used for annotations [optional]',
    float_format='The format string used for displaying floats [optional]',
    int_format='The format string used for displaying floats [optional]',
    corr_target='target variable name, if specified only correlations with the target are shown [optional]',
    corr_cutoff='filter all correlation whose absolute value is below the cutoff [optional]',
    col_wrap='After how many columns to create a new line of subplots [optional]',
    subplot_width='Width of each individual subplot [optional]',
    subplot_height='Height of each individual subplot [optional]',
    trendline='Whether to add a trendline [optional]',
    alpha='Alpha transparency level [optional]',
    max_n='Maximum number of samples to be used for plotting, if this number is exceeded max_n samples are drawn'
          'at random from the data which triggers a warning unless sample_warn is set to False.'
          'Set to False or None to use all samples for plotting. [optional]',
    max_n_random_state='Random state (seed) used for drawing the random samples [optional]',
    max_n_sample_warn='Whether to trigger a warning if the data has more samples than max_n [optional]',
    return_fig_ax='Whether to return the figure and axes objects as tuple to be captured as fig,ax = ..., '
                  'If False pyplot.show() is called and the plot returns None [optional]',
    legend='Whether to show a legend [optional]',
    legend_loc='Location of the legend, one of [bottom, right] or accepted value of pyplot.legend'
               'If in [bottom, right] legend_outside is used, else pyplot.legend [optional]',
    legend_ncol='Number of columns to use in legend [optional]',
    legend_space='Only valid if legend_loc is bottom. The space between the main plot and the legend [optional]',
    kde_steps='Nr of steps the range is split into for kde fitting [optional]',
    linestyle='Linestyle used, must a valid linestyle recognized by pyplot.plot [optional]',
    bins='Nr of bins of the histogram [optional]',
    sharex='Whether to share the x axis [optional]',
    sharey='Whether to share the y axis [optional]',
    row='Row index [optional]',
    col='Column index [optional]',
    legend_out='Whether to draw the legend outside of the axis, can also be a location string [optional]',
    legend_kws='Other keyword arguments passed to pyplot.legend [optional]',
    xlim='X limits for the axis as tuple, passed to ax.set_xlim() [optional]',
    ylim='Y limits for the axis as tuple, passed to ax.set_ylim() [optional]',
    grid='Whether to toggle ax.grid() [optional]',
    vline='A list of x positions to draw vlines at [optional]',
    to_abs='whether to cast the values to absolute before proceeding [optional]',
    label='label to use for the data [optional]',
    x_tick_rotation='Set x tick label rotation to this value [optional]',
    std_cutoff='remove data outside of std_cutoff standard deviations, for a good visual experience try 3 [optional]',
    do_print='whether to print intermediate steps to console [optional]',
    **validations
)


# --- functions
def _get_ordered_levels(data: pd.DataFrame, level: str, order: Union[list, str, None], x: str = None) -> list:
    if order is None or order == 'sorted':
        _hues = data[level].drop_duplicates().sort_values().tolist()
    elif order == 'inv':
        _hues = data[level].drop_duplicates().sort_values().tolist()[::-1]
    elif order == 'count':
        _hues = data[level].value_counts().reset_index().sort_values(by=[level, 'index'])['index'].tolist()
    elif order in ['mean', 'mean_descending']:
        _hues = data.groupby(level)[x].mean().reset_index().sort_values(by=[x, level], ascending=[False, True]
                                                                        )[level].tolist()
    elif order == 'mean_ascending':
        _hues = data.groupby(level)[x].mean().reset_index().sort_values(by=[x, level])[level].tolist()
    elif order in ['median', 'median_descending']:
        _hues = data.groupby(level)[x].median().reset_index().sort_values(by=[x, level], ascending=[False, True]
                                                                          )[level].tolist()
    elif order == 'median_ascending':
        _hues = data.groupby(level)[x].median().reset_index().sort_values(by=[x, level])[level].tolist()
    else:
        _hues = order

    return _hues


[docs]@docstr @export def heatmap(x: str, y: str, z: str, data: pd.DataFrame, ax: plt.Axes = None, cmap: object = None, agg_func: str = 'mean', invert_y: bool = True, **kwargs) -> plt.Axes: """ Wrapper for seaborn heatmap in x-y-z format :param x: Variable name for x axis value :param y: Variable name for y axis value :param z: Variable name for z value, used to color the heatmap :param data: %(data)s :param ax: %(ax_in)s :param cmap: %(cmap)s :param agg_func: If more than one z value per x,y pair exists agg_func is used to aggregate the data. Must be a function name understood by pandas.DataFrame.agg :param invert_y: Whether to call ax.invert_yaxis (orders the heatmap as expected) :param kwargs: Other keyword arguments passed to seaborn heatmap :return: %(ax_out)s """ if cmap is None: cmap = sns.diverging_palette(10, 220, as_cmap=True) _df = data.groupby([x, y]).agg({z: agg_func}).reset_index().pivot(x, y, z) if ax is None: ax = plt.gca() sns.heatmap(_df, ax=ax, cmap=cmap, **kwargs) ax.set_title(z) if invert_y: ax.invert_yaxis() return ax
[docs]@docstr @export def corrplot(data: pd.DataFrame, annotations: bool = True, number_format: str = rcParams['float_format'], ax=None): """ function to create a correlation plot using a seaborn heatmap based on: https://www.linkedin.com/pulse/generating-correlation-heatmaps-seaborn-python-andrew-holt :param number_format: %(number_format)s :param data: %(data_novec)s :param annotations: %(annotations)s :param ax: %(ax_in)s :return: %(ax_out)s """ # Create Correlation df _corr = data.corr() if ax is None: ax = plt.gca() # Generate Color Map _colormap = sns.diverging_palette(220, 10, as_cmap=True) # mask _mask = np.zeros_like(_corr) _mask[np.triu_indices_from(_mask)] = True # Generate Heat Map, allow annotations and place floats in map sns.heatmap(_corr, cmap=_colormap, annot=annotations, fmt=number_format, mask=_mask, ax=ax) # Adjust tick labels ax.set_xticks(ax.get_xticks()[:-1]) _yticklabels = ax.get_yticklabels()[1:] ax.set_yticks(ax.get_yticks()[1:]) ax.set_yticklabels(_yticklabels) return ax
[docs]@docstr @export def corrplot_bar(data: pd.DataFrame, target: str = None, columns: List[str] = None, corr_cutoff: float = rcParams['corr_cutoff'], corr_as_alpha: bool = False, xlim: tuple = (-1, 1), ax: plt.Axes = None): """ correlation plot as barchart :param data: %(data)s :param target: %(corr_target)s :param columns: columns for which to calculate the correlations, defaults to all numeric columns [optional] :param corr_cutoff: %(corr_cutoff)s :param corr_as_alpha: whether to set alpha value of bars to scale with correlation [optional] :param xlim: xlim scale for plot, defaults to (-1, 1) to show the absolute scale of the correlations. set to None if you want the plot x limits to scale to the highest correlation values [optional] :param ax: %(ax_in)s :return: %(ax_out)s """ _df_corr = get_df_corr(data, target=target) _df_corr = _df_corr[_df_corr['corr_abs'] >= corr_cutoff] if target is None: _df_corr['label'] = concat_cols(_df_corr, ['col_0', 'col_1'], sep=' X ') else: _df_corr['label'] = _df_corr['col_1'] # filter columns (if applicable) if columns is not None: _columns = columns + [] if target is not None and target not in _columns: _columns.append(target) _df_corr = _df_corr[(_df_corr['col_0'].isin(_columns)) & (_df_corr['col_1'].isin(_columns))] # get colors _rgba_colors = np.zeros((len(_df_corr), 4)) # for red the first column needs to be one _rgba_colors[:, 0] = np.where(_df_corr['corr'] > 0., 0., 1.) # for blue the third column needs to be one _rgba_colors[:, 2] = np.where(_df_corr['corr'] > 0., 1., 0.) # the fourth column needs to be alphas if corr_as_alpha: _rgba_colors[:, 3] = _df_corr['corr_abs'].where(lambda _: _ > .1, .1) else: _rgba_colors[:, 3] = 1 if ax is None: ax = plt.gca() _rgba_colors = np.round(_rgba_colors, 2) _plot = ax.barh(_df_corr['label'], _df_corr['corr'], color=_rgba_colors) ax.invert_yaxis() if xlim: # noinspection PyTypeChecker ax.set_xlim(xlim) if target is not None: ax.set_title('Correlations with {} by Absolute Value'.format(target)) ax.set_xlabel('corr × {}'.format(target)) else: ax.set_title('Correlations by Absolute Value') return ax
[docs]@docstr @export def pairwise_corrplot(data: pd.DataFrame, corr_cutoff: float = rcParams['corr_cutoff'], col_wrap: int = 4, hue: str = None, hue_order: Union[list, str] = None, width: float = rcParams['fig_width'], height: float = rcParams['fig_height'], trendline: bool = True, alpha: float = .75, ax: plt.Axes = None, target: str = None, palette: Union[Mapping, Sequence, str] = rcParams['palette'], max_n: int = rcParams['max_n'], random_state: int = rcParams['max_n.random_state'], sample_warn: bool = rcParams['max_n.sample_warn'], return_fig_ax: bool = rcParams['return_fig_ax'], **kwargs) -> Union[tuple, None]: """ print a pairwise_corrplot to for all variables in the df, by default only plots those with a correlation coefficient of >= corr_cutoff :param data: %(data_novec)s :param corr_cutoff: %(corr_cutoff)s :param col_wrap: %(col_wrap)s :param hue: %(hue)s :param hue_order: %(order)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param trendline: %(trendline)s :param alpha: %(alpha)s :param ax: %(ax_in)s :param target: %(corr_target)s :param palette: %(palette)s :param max_n: %(max_n)s :param random_state: %(max_n_random_state)s :param sample_warn: %(max_n_sample_warn)s :param return_fig_ax: %(return_fig_ax)s :param kwargs: other keyword arguments passed to pyplot.subplots :return: %(fig_ax_out)s """ # actual plot function def _f_plot(_f_x, _f_y, _f_data, _f_color, _f_color_trendline, _f_label, _f_ax): _data = _f_data.copy() # limit plot points if max_n is not None: if len(_data) > max_n: if sample_warn: warnings.warn( 'Limiting Scatter Plot to {:,} randomly selected points. ' 'Turn this off with max_n=None or suppress this warning with sample_warn=False.'.format( max_n)) _data = _data.sample(max_n, random_state=random_state) _f_ax.scatter(_f_x, _f_y, data=_data, alpha=alpha, color=_f_color, label=_f_label) if trendline: _f_ax.plot(_f_data[_f_x], lfit(_f_data[_f_x], _f_data[_f_y]), color=_f_color_trendline, linestyle=':') return _f_ax # avoid inplace operations _df = data.copy() _df_hues = pd.DataFrame() _df_corrs = pd.DataFrame() _hues = None if hue is not None: _hues = _get_ordered_levels(_df, hue, hue_order) _df_hues = {} _df_corrs = {} for _hue in _hues: _df_hue = _df[_df[hue] == _hue] _df_corr_hue = get_df_corr(_df_hue, target=target) _df_hues[_hue] = _df_hue.copy() _df_corrs[_hue] = _df_corr_hue.copy() # get df corr _df_corr = get_df_corr(_df, target=target) if corr_cutoff is not None: _df_corr = _df_corr[_df_corr['corr_abs'] >= corr_cutoff] # warning for empty df if len(_df_corr) == 0: warnings.warn('Correlation DataFrame is Empty. Do you need a lower corr_cutoff?') return None # edge case for less plots than ncols if len(_df_corr) < col_wrap: _ncols = len(_df_corr) else: _ncols = col_wrap # calculate nrows _nrows = int(np.ceil(len(_df_corr) / _ncols)) _figsize = (width * col_wrap, height * _nrows) if ax is None: fig, ax = plt.subplots(nrows=_nrows, ncols=_ncols, figsize=_figsize, **kwargs) else: fig = plt.gcf() _row = None _col = None for _it in range(len(_df_corr)): _col = _it % _ncols _row = _it // _ncols _x = _df_corr.iloc[_it]['col_1'] _y = _df_corr.iloc[_it]['col_0'] # so that target (if available) becomes y _corr = _df_corr.iloc[_it]['corr'] if _ncols == 1: _rows_prio = True else: _rows_prio = False _ax = get_subax(ax, _row, _col, rows_prio=_rows_prio) _ax.set_xlabel(_x) _ax.set_ylabel(_y) _ax.set_title('corr = {:.3f}'.format(_corr)) # hue if if hue is None: # actual plot _f_plot(_f_x=_x, _f_y=_y, _f_data=_df, _f_color=None, _f_color_trendline='k', _f_label=None, _f_ax=_ax) else: for _hue_it, _hue in enumerate(_hues): if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_hue_it % len(palette)] else: _color = palette _df_hue = _df_hues[_hue] _df_corr_hue = _df_corrs[_hue].copy() # sometimes it can happen that the correlation is not possible to calculate because # one of those values does not change in the hue level # i.e. use try except try: _df_corr_hue = _df_corr_hue[_df_corr_hue['col_1'] == _x] _df_corr_hue = _df_corr_hue[_df_corr_hue['col_0'] == _y] _corr_hue = _df_corr_hue['corr'].iloc[0] except ValueError: _corr_hue = 0 # actual plot _f_plot(_f_x=_x, _f_y=_y, _f_data=_df_hue, _f_color=_color, _f_color_trendline=_color, _f_label='{} corr: {:.3f}'.format(_hue, _corr_hue), _f_ax=_ax) _ax.legend() # hide unused axis for __col in range(_col + 1, _ncols): get_subax(ax, _row, __col, rows_prio=False).set_axis_off() if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def distplot(x: Union[Sequence, str], data: pd.DataFrame = None, hue: str = None, hue_order: Union[Sequence, str] = 'sorted', palette: Union[Mapping, Sequence, str] = None, linecolor: str = 'black', edgecolor: str = 'black', alpha: float = None, bins: Union[Sequence, int] = 40, perc: bool = None, top_nr: int = None, other_name: str = 'other', title: bool = True, title_prefix: str = '', std_cutoff: float = None, hist: bool = None, distfit: Union[str, bool, None] = 'kde', fill: bool = True, legend: bool = True, legend_loc: str = rcParams['distplot.legend_loc'], legend_space: float = rcParams['legend_outside.legend_space'], legend_ncol: int = 1, agg_func: str = 'mean', number_format: str = rcParams['float_format'], kde_steps: int = 1000, max_n: int = 100000, random_state: int = None, sample_warn: bool = True, xlim: Sequence = None, linestyle: str = None, label_style: str = rcParams['distplot.label_style'], x_offset_perc: float = .025, ax: plt.Axes = None, **kwargs) -> plt.Axes: """ Similar to seaborn.distplot but supports hues and some other things. Plots a combination of a histogram and a kernel density estimation. :param x: the name of the variable(s) in data or vector data, if data is provided and x is a list of columns the DataFrame is automatically melted and the newly generated column used as hue. i.e. you plot the distributions of multiple columns on the same axis :param data: %(data)s :param hue: %(hue)s :param hue_order: %(order)s :param palette: %(palette)s :param linecolor: Color of the kde fit line, overwritten with palette by hue level if hue is specified [optional] :param edgecolor: Color of the histogram edges [optional] :param alpha: %(alpha)s :param bins: %(bins)s :param perc: Whether to display the y-axes as percentage, if false count is displayed. Defaults if hue: True, else False [optional] :param top_nr: limit hue to top_nr levels using hhpy.ds.top_n, the rest will be cast to other [optional] :param other_name: name of the other group created by hhpy.ds.top_n [optional] :param title: whether to set the plot title equal to x's name [optional] :param title_prefix: prefix to be used in plot title [optional] :param std_cutoff: automatically cutoff data outside of the std_cutoff standard deviations range, by default this is off but a recommended value for a good visual experience without outliers is 3 [optional] :param hist: whether to show the histogram, default False if hue else True [optional] :param distfit: one of %(valid_distfits)s. If 'kde' fits a kernel density distribution to the data. If gauss fits a gaussian distribution with the observed mean and std to the data. [optional] :param fill: whether to fill the area under the distfit curve, ignored if hist is True [optional] :param legend: %(legend)s :param legend_loc: %(legend_loc)s :param legend_space: %(legend_space)s :param legend_ncol: %(legend_ncol)s :param agg_func: one of ['mean', 'median']. The agg function used to find the center of the distribution [optional] :param number_format: %(number_format)s :param kde_steps: %(kde_steps)s :param max_n: %(max_n)s :param random_state: %(max_n_random_state)s :param sample_warn: %(max_n_sample_warn)s :param xlim: %(xlim)s :param linestyle: %(linestyle)s :param label_style: one of ['mu_sigma', 'plain']. If mu_sigma then the mean (or median) and std value are displayed inside the label [optional] :param x_offset_perc: the amount whitespace to display next to x_min and x_max in percent of x_range [optional] :param ax: %(ax_in)s :param kwargs: additional keyword arguments passed to pyplot.plot :return: %(ax_out)s """ # -- asserts assert (distfit in validations['valid_distfits']), 'distfit must be one of {}'.format(validations['valid_distfits']) # -- defaults if palette is None: palette = rcParams['palette'] if not top_nr: top_nr = None # case: vector data if data is None: if 'name' in dir(x): # noinspection PyUnresolvedReferences _x = x.name _x_name = _x else: _x = 'x' _x_name = None _df = pd.DataFrame.from_dict({_x: x}) # data: DataFrame else: _df = data.copy() # avoid inplace operations del data if is_list_like(x) and len(x) > 1: hue = '__variable' _x = '__value' _x_name = _x hue_order = x title = False _df = pd.melt(_df, value_vars=x, value_name=_x, var_name=hue) else: _x = x _x_name = _x del x # handle hue and default values if hue is None: if perc is None: perc = False if hist is None: hist = True if alpha is None: alpha = .75 else: _df = _df[~_df[hue].isnull()] if perc is None: perc = True if hist is None: hist = False if alpha is None: alpha = .5 # case more than max_n samples: take a random sample for calc speed if max_n and (len(_df) > max_n): if sample_warn: warnings.warn( 'Limiting samples to {:,} for calc speed. Turn this off with max_n=None or suppress this warning ' 'with sample_warn=False.'.format(max_n)) _df = _df.sample(max_n, random_state=random_state) # the actual plot def _f_distplot(_f_x, _f_data, _f_x_label, _f_facecolor, _f_distfit_color, _f_bins, _f_std_cutoff, _f_xlim, _f_distfit_line, _f_ax, _f_ax2, _f_fill): # make a copy to avoid inplace operations _df_i = _f_data.copy() # best fit of data _mu = _df_i.agg({_f_x: agg_func})[0] _sigma = _df_i.agg({_f_x: 'std'})[0] # apply sigma cutoff if (_f_std_cutoff is not None) or (_f_xlim is not None): if _f_xlim is not None: __x_min = _f_xlim[0] __x_max = _f_xlim[1] elif is_list_like(_f_std_cutoff): __x_min = _f_std_cutoff[0] __x_max = _f_std_cutoff[1] else: __x_min = _mu - _f_std_cutoff * _sigma __x_max = _mu + _f_std_cutoff * _sigma _df_i = _df_i[ (_df_i[_f_x] >= __x_min) & (_df_i[_f_x] <= __x_max) ] # for plot trimming _x_mins.append(_df_i[_x_name].min()) _x_maxs.append(_df_i[_x_name].max()) # handle label try: _mu_label = format(_mu, number_format) except ValueError: _mu_label = '0' try: _sigma_label = format(_sigma, number_format) except ValueError: _sigma_label = '0' if agg_func == 'mean': _mu_symbol = r'\ \mu' else: _mu_symbol = r'\ \nu' if label_style == 'mu_sigma': _label = r'{}: $ {}={},\ \sigma={}$'.format(_f_x_label, _mu_symbol, _mu_label, _sigma_label) else: _label = _f_x_label # plot histogam if hist: _hist_n, _hist_bins = _f_ax.hist(_df_i[_f_x], _f_bins, density=perc, facecolor=_f_facecolor, edgecolor=edgecolor, alpha=alpha, label=_label)[:2] _label_2 = '__nolegend___' if _f_distfit_line is None: _f_distfit_line = '--' else: _hist_n = None _hist_bins = None _label_2 = _label + '' if _f_distfit_line is None: _f_distfit_line = '-' # plot distfit if distfit: # if a histogram was plot on the primary axis, the distfit goes on the secondary axis if hist: _ax = _f_ax2 else: _ax = _f_ax if distfit == 'gauss': # add a 'best fit' line __x = _f_bins _y = stats.norm.pdf(_f_bins, _mu, _sigma) # _hist_bins _ax.plot(__x, _y, linestyle=_f_distfit_line, color=_f_distfit_color, alpha=alpha, linewidth=2, label=_label_2, **kwargs) elif distfit == 'kde': _kde = kde(x=_f_x, df=_df_i, x_steps=kde_steps)[0] __x = _kde[_f_x] _y = _kde['value'] _ax.plot(__x, _y, linestyle=_f_distfit_line, color=_f_distfit_color, alpha=alpha, linewidth=2, label=_label_2, **kwargs) if not hist: _ax.set_ylabel('pdf') if _f_fill: # noinspection PyUnboundLocalVariable _ax.fill_between(__x, _y, color=_f_facecolor, alpha=alpha) _f_ax2.get_yaxis().set_visible(False) if perc and hist: _y_max = np.max(_hist_n) / np.sum(_hist_n) * 100 _y_ticklabels = list(_f_ax.get_yticks()) _y_ticklabels = [float(_) for _ in _y_ticklabels] _factor = _y_max / np.nanmax(_y_ticklabels) if np.isnan(_factor): _factor = 1 _y_ticklabels = [format(int(_ * _factor), ',') for _ in _y_ticklabels] _f_ax.set_yticklabels(_y_ticklabels) _f_ax.set_ylabel('%') elif hist: _f_ax.set_ylabel('count') # adjust xlims if necessary _xlim = list(_f_ax.get_xlim()) # here _df is used to access the 'parent' DataFrame with all hue levels if _xlim[0] <= _plot_x_min: _xlim[0] = _plot_x_min if _xlim[1] >= _plot_x_max: _xlim[1] = _plot_x_max _f_ax.set_xlim(_xlim) return _f_ax, _f_ax2 # -- preparing the data frame # drop nan values _df = _df[np.isfinite(_df[_x])] # init plot if ax is None: ax = plt.gca() ax2 = ax.twinx() # for plot trimming _x_mins = [] _x_maxs = [] if hue is None: # handle x limits if xlim is not None: _x_min = xlim[0] _x_max = xlim[1] elif std_cutoff is not None: _x_min = _df[_x].mean() - _df[_x].std() * std_cutoff _x_max = _df[_x].mean() + _df[_x].std() * std_cutoff else: _x_min = _df[_x].min() _x_max = _df[_x].max() # hadle bins if not is_list_like(bins): _x_step = (_x_max - _x_min) / bins _bins = np.arange(_x_min, _x_max + _x_step, _x_step) _plot_x_min = _df[_x].min() - _x_step _plot_x_max = _df[_x].max() + _x_step else: _bins = bins _plot_x_min = np.min(bins) _plot_x_max = np.max(bins) # handle palette / color if isinstance(palette, Mapping): _color = list(palette.values())[0] elif is_list_like(palette): _color = palette[0] else: _color = palette # plot ax, ax2 = _f_distplot(_f_x=_x, _f_data=_df, _f_x_label=_x_name, _f_facecolor=_color, _f_distfit_color=linecolor, _f_bins=_bins, _f_std_cutoff=std_cutoff, _f_xlim=xlim, _f_distfit_line=linestyle, _f_ax=ax, _f_ax2=ax2, _f_fill=fill) else: # no hue # group values outside of top_n to other_name if top_nr is not None: _hues = _df[hue].value_counts().reset_index().sort_values(by=[hue, 'index'])['index'].tolist() if (top_nr + 1) < len(_hues): # the plus 1 is there to avoid the other group having exactly 1 entry _hues = pd.Series(_hues)[0:top_nr] _df[hue] = np.where(_df[hue].isin(_hues), _df[hue], other_name) _df[hue] = _df[hue].astype('str') _hues = list(_hues) + [other_name] # parse hue order else: _hues = _get_ordered_levels(_df, hue, hue_order, _x) # find shared _x_min ; _x_max if xlim is not None: _std_cutoff_hues = None _x_min = xlim[0] _x_max = xlim[1] elif std_cutoff is None: _std_cutoff_hues = None _x_min = _df[_x].min() _x_max = _df[_x].max() else: _df_agg = _df.groupby(hue).agg({_x: ['mean', 'std']}).reset_index() _df_agg.columns = [hue, 'mean', 'std'] _df_agg['x_min'] = _df_agg['mean'] - _df_agg['std'] * std_cutoff _df_agg['x_max'] = _df_agg['mean'] + _df_agg['std'] * std_cutoff _df_agg['x_range'] = _df_agg['x_max'] - _df_agg['x_min'] _x_min = _df_agg['x_min'].min() _x_max = _df_agg['x_max'].max() _std_cutoff_hues = [_x_min, _x_max] # handle bins _x_step = (_x_max - _x_min) / bins _plot_x_min = _df[_x].min() - _x_step _plot_x_max = _df[_x].max() + _x_step _bins = np.arange(_x_min, _x_max + _x_step, _x_step) # loop hues for _it, _hue in enumerate(_hues): if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_it] else: _color = palette if isinstance(linestyle, Mapping): _linestyle = linestyle[_hue] elif is_list_like(linestyle): _linestyle = linestyle[_it] else: _linestyle = linestyle _df_hue = _df[_df[hue] == _hue] # one plot per hue ax, ax2 = _f_distplot(_f_x=_x, _f_data=_df_hue, _f_x_label=_hue, _f_facecolor=_color, _f_distfit_color=_color, _f_bins=_bins, _f_std_cutoff=_std_cutoff_hues, _f_xlim=xlim, _f_distfit_line=_linestyle, _f_ax=ax, _f_ax2=ax2, _f_fill=fill) # -- postprocessing # handle legend if legend: if legend_loc in ['bottom', 'right']: legend_outside(ax, loc=legend_loc, legend_space=legend_space, ncol=legend_ncol) legend_outside(ax2, loc=legend_loc, legend_space=legend_space, ncol=legend_ncol) else: _, _labels = ax.get_legend_handles_labels() if len(_labels) > 0: ax.legend(loc=legend_loc, ncol=legend_ncol) _, _labels = ax2.get_legend_handles_labels() if len(_labels) > 0: ax2.legend(loc=legend_loc, ncol=legend_ncol) # handle title if title: _title = '{}{}'.format(title_prefix, _x_name) if hue is not None: _title += ' by {}'.format(hue) ax.set_title(_title) # handle xlim if xlim is not None and xlim: # noinspection PyTypeChecker ax.set_xlim(xlim) else: _x_min = np.min(_x_mins) _x_max = np.max(_x_maxs) _x_offset = (_x_max - _x_min) * x_offset_perc # noinspection PyTypeChecker ax.set_xlim((_x_min - _x_offset, _x_max + _x_offset)) return ax
[docs]@docstr @export def hist_2d(x: str, y: str, data: pd.DataFrame, bins: int = 100, std_cutoff: int = 3, cutoff_perc: float = .01, cutoff_abs: float = 0, cmap: str = 'rainbow', ax: plt.Axes = None, color_sigma: str = 'xkcd:red', draw_sigma: bool = True, **kwargs) -> plt.Axes: """ generic 2d histogram created by splitting the 2d area into equal sized cells, counting data points in them and drawn using pyplot.pcolormesh :param x: %(x)s :param y: %(y)s :param data: %(data)s :param bins: %(bins)s :param std_cutoff: %(std_cutoff)s :param cutoff_perc: if less than this percentage of data points is in the cell then the data is ignored [optional] :param cutoff_abs: if less than this amount of data points is in the cell then the data is ignored [optional] :param cmap: %(cmap)s :param ax: %(ax_in)s :param color_sigma: color to highlight the sigma range in, must be a valid pyplot.plot color [optional] :param draw_sigma: whether to highlight the sigma range [optional] :param kwargs: other keyword arguments passed to pyplot.pcolormesh [optional] :return: %(ax_out)s """ _df = data.copy() del data if std_cutoff is not None: _x_min = _df[x].mean() - _df[x].std() * std_cutoff _x_max = _df[x].mean() + _df[x].std() * std_cutoff _y_min = _df[y].mean() - _df[y].std() * std_cutoff _y_max = _df[y].mean() + _df[y].std() * std_cutoff # x or y should be in std range _df = _df[ ((_df[x] >= _x_min) & (_df[x] <= _x_max) & (_df[y] >= _y_min) & (_df[y] <= _y_max)) ].reset_index(drop=True) _x = _df[x] _y = _df[y] # Estimate the 2D histogram _hist, _x_edges, _y_edges = np.histogram2d(_x, _y, bins=bins) # hist needs to be rotated and flipped _hist = np.rot90(_hist) _hist = np.flipud(_hist) # Mask too small counts if cutoff_abs is not None: _hist = np.ma.masked_where(_hist <= cutoff_abs, _hist) if cutoff_perc is not None: _hist = np.ma.masked_where(_hist <= _hist.max() * cutoff_perc, _hist) # Plot 2D histogram using pcolor if ax is None: ax = plt.gca() _mappable = ax.pcolormesh(_x_edges, _y_edges, _hist, cmap=cmap, **kwargs) ax.set_xlabel(x) ax.set_ylabel(y) _cbar = plt.colorbar(mappable=_mappable, ax=ax) _cbar.ax.set_ylabel('count') # draw ellipse to mark 1 sigma area if draw_sigma: _ellipse = patches.Ellipse(xy=(_x.median(), _y.median()), width=_x.std(), height=_y.std(), edgecolor=color_sigma, fc='None', lw=2, ls=':') ax.add_patch(_ellipse) return ax
[docs]@docstr @export def paired_plot(data: pd.DataFrame, cols: Sequence, color: str = None, cmap: str = None, alpha: float = 1, **kwargs) -> sns.FacetGrid: """ create a facet grid to analyze various aspects of correlation between two variables using seaborn.PairGrid :param data: %(data)s :param cols: list of exactly two variables to be compared :param color: %(color)s :param cmap: %(cmap)s :param alpha: %(alpha)s :param kwargs: other arguments passed to seaborn.PairGrid :return: seaborn FacetGrid object with the plots on it """ def _f_corr(_f_x, _f_y, _f_s=10, **_f_kwargs): # Calculate the value _coef = np.corrcoef(_f_x, _f_y)[0][1] # Make the label _label = r'$\rho$ = ' + str(round(_coef, 2)) # Add the label to the plot _ax = plt.gca() _ax.annotate(_label, xy=(0.2, 0.95 - (_f_s - 10.) / 10.), size=20, xycoords=_ax.transAxes, **_f_kwargs) # Create an instance of the PairGrid class. _grid = sns.PairGrid(data=data, vars=cols, **kwargs) # Map a scatter plot to the upper triangle _grid = _grid.map_upper(plt.scatter, alpha=alpha, color=color) # Map a corr coef _grid = _grid.map_upper(_f_corr) # density = True might not be supported in older versions of seaborn / matplotlib _grid = _grid.map_diag(plt.hist, bins=30, color=color, alpha=alpha, edgecolor='k', density=True) # Map a density plot to the lower triangle _grid = _grid.map_lower(sns.kdeplot, cmap=cmap, alpha=alpha) # add legend _grid.add_legend() return _grid
[docs]@export def q_plim(s: pd.Series, q_min: float = .1, q_max: float = .9, offset_perc: float = .1, limit_min_max: bool = False, offset=True) -> tuple: """ returns quick x limits for plotting (cut off data not in q_min to q_max quantile) :param s: pandas Series to truncate :param q_min: lower bound quantile [optional] :param q_max: upper bound quantile [optional] :param offset_perc: percentage of offset to the left and right of the quantile boundaries :param limit_min_max: whether to truncate the plot limits at the data limits :param offset: whether to apply the offset :return: a tuple containing the x limits """ _lower_bound = floor_signif(s.quantile(q=q_min)) _upper_bound = ceil_signif(s.quantile(q=q_max)) if _upper_bound == _lower_bound: _upper_bound = s.max() _lower_bound = s.min() if limit_min_max: if _upper_bound > s.max(): _upper_bound = s.max() if _lower_bound < s.min(): _lower_bound = s.min() if offset: _offset = (_upper_bound - _lower_bound) * offset_perc else: _offset = 0 return _lower_bound - _offset, _upper_bound + _offset
[docs]@docstr @export def levelplot(data: pd.DataFrame, level: str, cols: Union[list, str], hue: str = None, order: Union[list, str] = None, hue_order: Union[list, str] = None, func: Callable = distplot, summary_title: bool = True, level_title: bool = True, do_print: bool = False, width: int = rcParams['fig_width'], height: int = rcParams['fig_height'], return_fig_ax: bool = rcParams['return_fig_ax'], kwargs_subplots_adjust: Mapping = None, kwargs_summary: Mapping = None, **kwargs) -> Union[None, tuple]: """ Plots a plot for each specified column for each level of a certain column plus a summary plot :param data: %(data)s :param level: the name of the column to split the plots by, must be in data :param cols: the columns to create plots for, defaults to all numeric columns [optional] :param hue: %(hue)s :param order: %(order)s :param hue_order: %(order)s :param func: function to use for plotting, must support 1 positional argument, data, hue, ax and kwargs [optional] :param summary_title: whether to automatically set the summary plot title [optional] :param level_title: whether to automatically set the level plot title [optional] :param do_print: %(do_print)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param return_fig_ax: %(return_fig_ax)s :param kwargs_subplots_adjust: other keyword arguments passed to pyplot.subplots_adjust [optional] :param kwargs_summary: other keyword arguments passed to summary distplot, if None uses kwargs [optional] :param kwargs: other keyword arguments passed to func [optional] :return: see return_fig_ax """ if kwargs_summary is None: kwargs_summary = kwargs _df = data.copy() del data if cols is None: cols = _df.select_dtypes(include=np.number) _levels = _get_ordered_levels(data=_df, level=level, order=order) if hue is not None: _hues = _get_ordered_levels(data=_df, level=hue, order=hue_order) _hue_str = ' by {}'.format(hue) else: _hue_str = '' _nrows = len(cols) _ncols = len(_levels) + 1 _it_max = _nrows * _ncols fig, ax = plt.subplots(nrows=_nrows, ncols=_ncols, figsize=(_ncols * width, _nrows * height)) _it = -1 for _col_i, _col in enumerate(cols): _ax_summary = get_subax(ax, _col_i, 0, rows_prio=False) # always plot to col 0 of current row # summary plot func(_col, data=_df, hue=level, ax=_ax_summary, **kwargs_summary) if summary_title: _ax_summary.set_title('{} by {}'.format(_col, level)) for _level_i, _level in enumerate(_levels): _it += 1 if do_print: progressbar(_it, _it_max, print_prefix='{}_{}'.format(_col, _level)) _df_level = _df[_df[level] == _level] _ax = get_subax(ax, _col_i, _level_i + 1) # level plot func(_col, data=_df_level, hue=hue, ax=_ax, **kwargs) if level_title: _ax.set_title('{}{} - {}={}'.format(_col, _hue_str, level, _level)) if kwargs_subplots_adjust is not None: plt.subplots_adjust(**kwargs_subplots_adjust) if do_print: progressbar() if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def get_legends(ax: plt.Axes = None) -> list: """ returns all legends on a given axis, useful if you have a secaxis :param ax: %(ax_in)s :return: list of legends """ if ax is None: ax = plt.gca() return [_ for _ in ax.get_children() if isinstance(_, Legend)]
# a plot to compare four components of a DataFrame def four_comp_plot(data, x_1, y_1, x_2, y_2, hue_1=None, hue_2=None, lim=None, return_fig_ax=rcParams['return_fig_ax'], **kwargs): # you can pass the hues to use or if none are given the default ones (std,plus/minus) are used # you can pass xlim and ylim or assume default (4 std) # four components, ie 2 x 2 if lim is None: lim = {'x_1': 'default', 'x_2': 'default', 'y_1': 'default', 'y_2': 'default'} _nrows = 2 _ncols = 2 # init plot fig, ax = plt.subplots(ncols=_ncols, nrows=_nrows) # make a copy yo avoid inplace operations _df_plot = data.copy() _x_std = _df_plot[x_1].std() _y_std = _df_plot[y_1].std() # type 1: split by size in relation to std if hue_1 is None: _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) <= 1 * _x_std) & (np.abs(_df_plot[y_1]) <= 1 * _y_std), '0_std', 'Null') _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 1 * _x_std) | (np.abs(_df_plot[y_1]) > 1 * _y_std), '1_std', _df_plot['std']) _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 2 * _x_std) | (np.abs(_df_plot[y_1]) > 2 * _y_std), '2_std', _df_plot['std']) _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 3 * _x_std) | (np.abs(_df_plot[y_1]) > 3 * _y_std), '3_std', _df_plot['std']) _df_plot['std'] = _df_plot['std'].astype('category') hue_1 = 'std' # type 2: split by plus minus if hue_2 is None: _df_plot['plus_minus'] = np.where((_df_plot[x_1] <= 0) & (_df_plot[y_1] <= 0), '- -', 'Null') _df_plot['plus_minus'] = np.where((_df_plot[x_1] <= 0) & (_df_plot[y_1] > 0), '- +', _df_plot['plus_minus']) _df_plot['plus_minus'] = np.where((_df_plot[x_1] > 0) & (_df_plot[y_1] <= 0), '+ -', _df_plot['plus_minus']) _df_plot['plus_minus'] = np.where((_df_plot[x_1] > 0) & (_df_plot[y_1] > 0), '+ +', _df_plot['plus_minus']) _df_plot['plus_minus'] = _df_plot['plus_minus'].astype('category') hue_2 = 'plus_minus' _xs = [x_1, x_2] _ys = [y_1, y_2] _hues = [hue_1, hue_2] _xlims = [lim['x_1'], lim['x_2']] _ylims = [lim['y_1'], lim['y_2']] for _row in range(_nrows): for _col in range(_ncols): # init _ax = get_subax(ax, _row, _col) _x_name = _xs[_col] _y_name = _ys[_col] _hue = _hues[_row] _x = _df_plot[_x_name] _y = _df_plot[_y_name] # scatterplot _ax = sns.scatterplot(data=_df_plot, x=_x_name, y=_y_name, hue=_hue, marker='.', ax=_ax, **kwargs) # grid 0 line _ax.axvline(0, color='k', alpha=.5, linestyle=':') _ax.axhline(0, color='k', alpha=.5, linestyle=':') # title _ax.set_title('%s vs %s, hue: %s' % (_x_name, _y_name, _hue)) # labels _ax.set_xlabel(_x_name) _ax.set_ylabel(_y_name) # set limits to be 4 std range if _xlims[_col] == 'default': _x_low = -_x.std() * 4 if _x.min() > _x_low: _x_low = _x.min() _x_high = _x.std() * 4 if _x.max() < _x_high: _x_high = _x.max() _ax.set_xlim([_x_low, _x_high]) if _ylims[_col] == 'default': _y_low = -_y.std() * 4 if _y.min() > _y_low: _y_low = _y.min() _y_high = _y.std() * 4 if _y.max() < _y_high: _y_high = _y.max() _ax.set_ylim([_y_low, _y_high]) if return_fig_ax: return fig, ax else: plt.tight_layout() plt.show()
[docs]@docstr @export def facet_wrap(func: Callable, data: pd.DataFrame, facet: Union[list, str], *args, facet_type: str = None, col_wrap: int = 4, width: int = rcParams['fig_width'], height: int = rcParams['fig_height'], catch_error: bool = True, return_fig_ax: bool = rcParams['return_fig_ax'], sharex: bool = False, sharey: bool = False, show_xlabel: bool = True, x_tick_rotation: int = None, y_tick_rotation: int = None, ax_title: str = 'set', order: Union[list, str] = None, subplots_kws: Mapping = None, **kwargs): """ modeled after r's facet_wrap function. Wraps a number of subplots onto a 2d grid of subplots while creating a new line after col_wrap columns. Uses a given plot function and creates a new plot for each facet level. :param func: Any plot function. Must support keyword arguments data and ax :param data: %(data)s :param facet: The column / list of columns to facet over. :param args: passed to func :param facet_type: one of ['group', 'cols', None]. If group facet is treated as the column creating the facet levels and a subplot is created for each level. If cols each facet is in turn passed as the first positional argument to the plot function func. If None then the facet_type is inferred: a single facet value will be treated as group and multiple facet values will be treated as cols. :param col_wrap: %(col_wrap)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param catch_error: whether to keep going in case of an error being encountered in the plot function [optional] :param return_fig_ax: %(return_fig_ax)s :param sharex: %(sharex)s :param sharey: %(sharey)s :param show_xlabel: whether to show the x label for each subplot :param x_tick_rotation: x tick rotation for each subplot :param y_tick_rotation: y tick rotation for each subplot :param ax_title: one of ['set','hide'], if set sets axis title to facet name, if hide forcefully hids axis title :param order: %(order)s :param subplots_kws: other keyword arguments passed to pyplot.subplots :param kwargs: other keyword arguments passed to func :return: %(fig_ax_out)s **examples** Check out the `example notebook <https://colab.research.google.com/drive/1bAEFRoWJgwPzkEqOoPBHVX849qQjxLYC>`_ """ if subplots_kws is None: subplots_kws = {} _df = data.copy() del data _facet = None _row = None _col = None # if it is a list of column names we will melt the df together if facet_type is None: if is_list_like(facet): facet_type = 'cols' else: facet_type = 'group' # process the facets if facet_type == 'cols': _facets = facet else: _df['_facet'] = concat_cols(_df, facet) facet = '_facet' _facets = _get_ordered_levels(_df, facet, order) # init a grid if len(_facets) > col_wrap: _ncols = col_wrap _nrows = int(np.ceil(len(_facets) / _ncols)) else: _ncols = len(_facets) _nrows = 1 fig, ax = plt.subplots(ncols=_ncols, nrows=_nrows, figsize=(width * _ncols, height * _nrows), **subplots_kws) _ax_list = ax_as_list(ax) # loop facets for _it, _facet in enumerate(_facets): _col = _it % _ncols _row = _it // _ncols _ax = _ax_list[_it] # get df facet _facet = _facets[_it] # for list set target to be in line with facet to ensure proper naming if facet_type == 'cols': _df_facet = _df.copy() _args = force_list(_facet) + list(args) else: _df_facet = _df[_df[facet] == _facet] _args = args # apply function on target (try catch) if catch_error: try: func(*_args, data=_df_facet, ax=_ax, **kwargs) except Exception as _exc: warnings.warn('could not plot facet {} with exception {}, skipping. ' 'For details use catch_error=False'.format(_exc, _facet)) _ax.set_axis_off() continue else: func(*_args, data=_df_facet, ax=_ax, **kwargs) # set axis title to facet or hide it or do nothing (depending on preference) if ax_title == 'set': _ax.set_title(_facet) elif ax_title == 'hide': _ax.set_title('') # tick rotation if x_tick_rotation is not None: _ax.xaxis.set_tick_params(rotation=x_tick_rotation) if y_tick_rotation is not None: _ax.yaxis.set_tick_params(rotation=y_tick_rotation) # hide x label (if appropriate) if not show_xlabel: _ax.set_xlabel('') # hide unused axes for __col in range(_col + 1, _ncols): ax[_row, __col].set_axis_off() # share xy if sharex or sharey: share_xy(ax, x=sharex, y=sharey) if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def get_subax(ax: Union[plt.Axes, np.ndarray], row: int = None, col: int = None, rows_prio: bool = True) -> plt.Axes: """ shorthand to get around the fact that ax can be a 1D array or a 2D array (for subplots that can be 1x1,1xn,nx1) :param ax: %(ax_in)s :param row: %(row)s :param col: %(col)s :param rows_prio: decides if to use row or col in case of a 1xn / nx1 shape (False means cols get priority) :return: %(ax_out)s """ if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 if _dims == 0: _ax = ax elif _dims == 1: if rows_prio: _ax = ax[row] else: _ax = ax[col] else: _ax = ax[row, col] return _ax
[docs]@docstr @export def ax_as_list(ax: Union[plt.Axes, np.ndarray]) -> list: """ takes any Axes and turns them into a list :param ax: %(ax_in)s :return: List containing the subaxes """ if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 if _dims == 0: _ax_list = [ax] elif _dims == 1: _ax_list = list(ax) else: _ax_list = list(ax.flatten()) return _ax_list
[docs]@docstr @export def ax_as_array(ax: Union[plt.Axes, np.ndarray]) -> np.ndarray: """ takes any Axes and turns them into a numpy 2D array :param ax: %(ax_in)s :return: Numpy 2D array containing the subaxes """ if isinstance(ax, np.ndarray): if len(ax.shape) == 2: return ax else: return ax.reshape(-1, 1) else: return np.array([ax]).reshape(-1, 1)
# bubble plot def bubbleplot(x, y, hue, s, text=None, text_as_label=False, data=None, s_factor=250, palette=None, hue_order=None, x_range_factor=5, y_range_factor=5, show_std=False, ax=None, legend_loc='right', text_kws=None): if palette is None: palette = rcParams['palette'] if text_kws is None: text_kws = {} if ax is None: ax = plt.gca() _df = data.copy() _df = _df[~((_df[x].isnull()) | (_df[y].isnull()) | (_df[s].isnull()))].reset_index(drop=True) if hue_order is not None: _df['_sort'] = _df[hue].apply(lambda _: hue_order.index(_)) _df = _df.sort_values(by=['_sort']) _df = _df.reset_index(drop=True) _x = _df[x] _y = _df[y] _s = _df[s] * s_factor if text is not None: _text = _df[text] else: _text = pd.Series() if isinstance(palette, Mapping): _df['_color'] = _df[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df['_color'] = palette[:_df.index.max() + 1] else: _df['color'] = palette # draw ellipse to mark 1 sigma area if show_std: _x_min = None _x_max = None _y_min = None _y_max = None for _index, _row in _df.iterrows(): _ellipse = patches.Ellipse(xy=(_row[x], _row[y]), width=_row[x + '_std'] * 2, height=_row[y + '_std'] * 2, edgecolor=_row['_color'], fc='None', lw=2, ls=':') ax.add_patch(_ellipse) _x_min_i = _row[x] - _row[x + '_std'] * 1.05 _x_max_i = _row[x] + _row[x + '_std'] * 1.05 _y_min_i = _row[y] - _row[y + '_std'] * 1.05 _y_max_i = _row[y] + _row[y + '_std'] * 1.05 if _x_min is None: _x_min = _x_min_i elif _x_min_i < _x_min: _x_min = _x_min_i if _x_max is None: _x_max = _x_max_i elif _x_max_i > _x_max: _x_max = _x_max_i if _y_min is None: _y_min = _y_min_i elif _y_min_i < _y_min: _y_min = _y_min_i if _y_max is None: _y_max = _y_max_i elif _y_max_i > _y_max: _y_max = _y_max_i else: # scatter for bubbles ax.scatter(x=_x, y=_y, s=_s, label='__nolegend__', facecolor=_df['_color'], edgecolor='black', alpha=.75) _x_range = _x.max() - _x.min() _x_min = _x.min() - _x_range / x_range_factor _x_max = _x.max() + _x_range / x_range_factor _y_range = _y.max() - _y.min() _y_min = _y.min() - _y_range / y_range_factor _y_max = _y.max() + _y_range / y_range_factor # plot fake data for legend (a little hacky) if text_as_label: _xlim_before = ax.get_xlim() for _it in range(len(_x)): _label = _text[_it] # fake data ax.scatter(x=-9999, y=_y[_it], label=_label, facecolor=_df['_color'].loc[_it], s=200, edgecolor='black', alpha=.75) ax.set_xlim(_xlim_before) if (text is not None) and (not text_as_label): for _it in range(len(_text)): _ = '' if (not np.isnan(_x.iloc[_it])) and (not np.isnan(_y.iloc[_it])): ax.text(x=_x.iloc[_it], y=_y.iloc[_it], s=_text.iloc[_it], horizontalalignment='center', verticalalignment='center', **text_kws) # print(_x_min,_x_max) ax.set_xlim(_x_min, _x_max) ax.set_ylim(_y_min, _y_max) ax.set_xlabel(_x.name) ax.set_ylabel(_y.name) if text_as_label and (legend_loc in ['bottom', 'right']): legend_outside(ax, loc=legend_loc) else: ax.legend(loc=legend_loc) # title ax.set_title(hue) return ax def bubblecountplot(x, y, hue, data, agg_function='median', show_std=True, top_nr=None, n_quantiles=10, other_name='other', dropna=True, float_format='.2f', text_end='', **kwargs): _df = data.copy() if dropna: _df = _df[~_df[hue].isnull()] if hue in _df.select_dtypes(include=np.number): _n = n_quantiles if top_nr is not None: if top_nr < n_quantiles: _n = top_nr _df[hue] = quantile_split(_df[hue], _n) if top_nr is not None: _df[hue] = top_n_coding(_df[hue], n=top_nr, other_name=other_name) # handle na _df[x] = _df[x].fillna(_df[x].dropna().agg(agg_function)) _df[y] = _df[y].fillna(_df[y].dropna().agg(agg_function)) # build agg dict _df['_count'] = 1 _df = _df.groupby([hue]).agg({x: [agg_function, 'std'], y: [agg_function, 'std'], '_count': 'count'}).reset_index() if x != y: _columns = [hue, x, x + '_std', y, y + '_std', '_count'] else: _columns = [hue, x, x + '_std', '_count'] _df.columns = _columns _df['_perc'] = _df['_count'] / _df['_count'].sum() * 100 _df['_count_text'] = _df.apply(lambda _: "{:,}".format(_['_count']), axis=1) _df['_perc_text'] = np.round(_df['_perc'], 2) _df['_perc_text'] = _df['_perc_text'].astype(str) + '%' if show_std: _df['_text'] = _df[hue].astype(str) + '(' + _df['_count_text'] + ')' + '\n' \ + 'x:' + _df[x].apply(lambda _: format(_, float_format)) + r'$\pm$' + _df[x + '_std'].apply( lambda _: format(_, float_format)) + '\n' \ + 'y:' + _df[y].apply(lambda _: format(_, float_format)) + r'$\pm$' + _df[y + '_std'].apply( lambda _: format(_, float_format)) else: _df['_text'] = _df[hue].astype(str) + '\n' + _df['_count_text'] + '\n' + _df['_perc_text'] _df['_text'] += text_end bubbleplot(x=x, y=y, hue=hue, s='_perc', text='_text', data=_df, show_std=show_std, **kwargs)
[docs]@docstr @export def rmsdplot(x: str, data: pd.DataFrame, groups: Union[Sequence, str] = None, hue: str = None, hue_order: Union[Sequence, str] = None, cutoff: float = 0, ax: plt.Axes = None, color_as_balance: bool = False, balance_cutoff: float = None, rmsd_as_alpha: bool = False, sort_by_hue: bool = False, palette=None, barh_kws=None, **kwargs): """ creates a seaborn.barplot showing the rmsd calculating hhpy.ds.df_rmsd :param x: %(x)s :param data: %(data)s :param groups: the columns to calculate the rmsd for, defaults to all columns [optional] :param hue: %(hue)s :param hue_order: %(order)s :param cutoff: drop rmsd values smaller than cutoff [optional] :param ax: %(ax_in)s :param color_as_balance: whether to color the bars based on how balanced the levels are [optional] :param balance_cutoff: if specified the balance coloring red for worse balance than balance cutoff [optional] :param rmsd_as_alpha: whether to use set the alpha values of the columns based on the rmsd value [optional] :param sort_by_hue: passed to hhpy.ds.df_rmsd [optional] :param palette: %(palette)s :param barh_kws: other keyword arguments passed to seaborn.barplot [optional] :param kwargs: other keyword arguments passed to hhpy.ds.rf_rmsd [optional] :return: %(ax_out)s """ if palette is None: palette = rcParams['palette'] if barh_kws is None: barh_kws = {} _data = data.copy() del data if hue is not None and hue_order is not None: _data = _data.query('{} in @hue_order'.format(hue)) _df_rmsd = df_rmsd(x=x, df=_data, groups=groups, hue=hue, sort_by_hue=sort_by_hue, **kwargs) _df_rmsd = _df_rmsd[_df_rmsd['rmsd'] >= cutoff] if hue is not None: _df_rmsd_no_hue = df_rmsd(x=x, df=_data, groups=groups, include_rmsd=False, **kwargs) else: _df_rmsd_no_hue = pd.DataFrame() if isinstance(x, list): if hue is None: _df_rmsd['label'] = concat_cols(_df_rmsd, ['x', 'group'], sep=' X ') else: _df_rmsd['label'] = concat_cols(_df_rmsd, ['x', 'group', hue], sep=' X ') else: if hue is None: _df_rmsd['label'] = _df_rmsd['group'] else: _df_rmsd['label'] = concat_cols(_df_rmsd, ['group', hue], sep=' X ') _df_rmsd['rmsd_scaled'] = _df_rmsd['rmsd'] / _df_rmsd['rmsd'].max() # get colors _rgba_colors = np.zeros((len(_df_rmsd), 4)) _hues = [] if hue is not None: _hues = _get_ordered_levels(data=_df_rmsd, level=hue, order=hue_order, x=x) if isinstance(palette, Mapping): _df_rmsd['_color'] = _df_rmsd[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df_rmsd['_color'] = _df_rmsd[hue].apply(lambda _: palette[list(_hues).index(_)]) else: _df_rmsd['color'] = palette _rgba_colors[:, 0] = _df_rmsd['_color'].apply(lambda _: Color(_).red) _rgba_colors[:, 1] = _df_rmsd['_color'].apply(lambda _: Color(_).green) _rgba_colors[:, 2] = _df_rmsd['_color'].apply(lambda _: Color(_).blue) elif color_as_balance: if balance_cutoff is None: _rgba_colors[:, 0] = _df_rmsd['maxperc'] # for red the first column needs to be one _rgba_colors[:, 2] = 1 - _df_rmsd['maxperc'] # for blue the third column needs to be one else: _rgba_colors[:, 0] = np.where(_df_rmsd['maxperc'] >= balance_cutoff, 1, 0) _rgba_colors[:, 2] = np.where(_df_rmsd['maxperc'] < balance_cutoff, 1, 0) else: _rgba_colors[:, 2] = 1 # for blue the third column needs to be one # the fourth column needs to be alphas if rmsd_as_alpha: _rgba_colors[:, 3] = _df_rmsd['rmsd_scaled'] else: _rgba_colors[:, 3] = 1 if ax is None: ax = plt.gca() # make positions from labels if hue is not None: _pos_factor = .8 else: _pos_factor = 1 _df_rmsd['pos'] = _df_rmsd.index * _pos_factor if (hue is not None) and (not sort_by_hue): # iterate over rows and add to pos if label changes for _row in range(1, len(_df_rmsd)): if _df_rmsd['group'].iloc[_row] != _df_rmsd['group'].iloc[_row - 1]: _df_rmsd['pos'][_row:] = _df_rmsd['pos'][_row:] + _pos_factor # make a df of the average positions for each group _df_ticks = _df_rmsd.groupby('group').agg({'pos': 'mean'}).reset_index() # 'maxperc':'max' _df_ticks = pd.merge(_df_ticks, _df_rmsd_no_hue[['group', 'maxperc']]) # get maxperc from global value else: _df_ticks = pd.DataFrame() ax.barh(_df_rmsd['pos'], _df_rmsd['rmsd'], color=_rgba_colors, **barh_kws) _y_colors = None if (hue is not None) and (not sort_by_hue): _y_pos = _df_ticks['pos'] _y_lab = _df_ticks['group'] # color if balance_cutoff is not None: _y_colors = np.where(_df_ticks['maxperc'] > balance_cutoff, sns.xkcd_rgb['red'], 'k') else: _y_pos = _df_rmsd['pos'] if not is_list_like(x): _y_lab = _df_rmsd['group'] elif not is_list_like(groups): _y_lab = _df_rmsd['x'] else: _y_lab = concat_cols(_df_rmsd, ['x', 'group'], sep=' X ') ax.set_yticks(_y_pos) ax.set_yticklabels(_y_lab) if _y_colors is not None: for _y_tick, _color in zip(ax.get_yticklabels(), _y_colors): _y_tick.set_color(_color) if hue is None: _offset = _pos_factor else: _offset = _pos_factor * len(_hues) # noinspection PyTypeChecker ax.set_ylim([_y_pos.min() - _offset, _y_pos.max() + _offset]) ax.invert_yaxis() # create legend for hues if hue is not None: _patches = [] for _hue, _color, _count in _df_rmsd[[hue, '_color', 'count']].drop_duplicates().values: _patches.append(patches.Patch(color=_color, label='{} (n={:,})'.format(_hue, _count))) ax.legend(handles=_patches) # check if standardized _x_label_suffix = '' if 'standardize' in kwargs.keys(): if kwargs['standardize']: _x_label_suffix += ' [std]' if not is_list_like(x): ax.set_title('Root Mean Square Difference for {}'.format(x)) ax.set_xlabel('RMSD: {}{}'.format(x, _x_label_suffix)) elif not is_list_like(groups): ax.set_title('Root Mean Square Difference for {}'.format(groups)) ax.set_xlabel('RMSD: {}{}'.format(groups, _x_label_suffix)) else: ax.set_title('Root Mean Square Difference') return ax
# plot agg def aggplot(x, data, group, hue=None, hue_order=None, width=16, height=9 / 2, p_1_0=True, palette=None, sort_by_hue=False, return_fig_ax=rcParams['return_fig_ax'], agg=None, p=False, legend_loc='upper right', aggkws=None, subplots_kws=None, subplots_adjust_kws=None, **kwargs): # avoid inplace operations if palette is None: palette = rcParams['palette'] if agg is None: agg = ['mean', 'median', 'std'] if aggkws is None: aggkws = {} if subplots_kws is None: subplots_kws = {} if subplots_adjust_kws is None: subplots_adjust_kws = {'top': .95, 'hspace': .25, 'wspace': .35} _df = data.copy() _len = len(agg) + 1 + p _x = x _group = group # EITHER x OR group can be a list (hue cannot be a lists) if is_list_like(x) and is_list_like(group): warnings.warn('both x and group cannot be a list, setting group = {}'.format(group[0])) _x_is_list = True _group_is_list = False _group = group[0] _ncols = len(x) _nrows = _len elif isinstance(x, list): _x_is_list = True _group_is_list = False _group = group _ncols = len(x) _nrows = _len elif isinstance(group, list): _x_is_list = False _group_is_list = True _ncols = len(group) _nrows = _len else: _x_is_list = False _group_is_list = False _ncols = int(np.floor(_len / 2)) _nrows = int(np.ceil(_len / 2)) fig, ax = plt.subplots(figsize=(width * _ncols, height * _nrows), nrows=_nrows, ncols=_ncols, **subplots_kws) _it = -1 for _col in range(_ncols): if _x_is_list: _x = x[_col] if _group_is_list: _group = group[_col] _df_agg = df_agg(x=_x, group=_group, hue=hue, df=_df, agg=agg, p=p, **aggkws) if hue is not None: if sort_by_hue: _sort_by = [hue, _group] else: _sort_by = [_group, hue] _df_agg = _df_agg.sort_values(by=_sort_by).reset_index(drop=True) _label = '_label' _df_agg[_label] = concat_cols(_df_agg, [_group, hue], sep='_').astype('category') _hues = _get_ordered_levels(data=_df, level=hue, order=hue_order, x=x) if isinstance(palette, Mapping): _df_agg['_color'] = _df_agg[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df_agg['_color'] = _df_agg[hue].apply(lambda _: palette[list(_hues).index(_)]) else: _df_agg['_color'] = palette else: _label = _group for _row in range(_nrows): _it += 1 if _x_is_list or _group_is_list: _index = _row else: _index = _it _ax = get_subax(ax, _row, _col) if _index >= _len: _ax.set_axis_off() continue _agg = list(_df_agg)[1:][_index] # one color per graph (if no hue) if hue is None: _df_agg['_color'] = palette[_index] # handle hue grouping if hue is not None: _pos_factor = .8 else: _pos_factor = 1 _df_agg['pos'] = _df_agg.index if (hue is not None) and (not sort_by_hue): # iterate over rows and add to pos if label changes for _row_2 in range(1, len(_df_agg)): if _df_agg[_group].iloc[_row_2] != _df_agg[_group].iloc[_row_2 - 1]: _df_agg['pos'][_row_2:] = _df_agg['pos'][_row_2:] + _pos_factor # make a df of the average positions for each group _df_ticks = _df_agg.groupby(_group).agg({'pos': 'mean'}).reset_index() else: _df_ticks = pd.DataFrame() _ax.barh('pos', _agg, color='_color', label=_agg, data=_df_agg, **kwargs) if (hue is not None) and (not sort_by_hue): _ax.set_yticks(_df_ticks['pos']) _ax.set_yticklabels(_df_ticks[_group]) else: _ax.set_yticks(_df_agg['pos']) _ax.set_yticklabels(_df_agg[_group]) _ax.invert_yaxis() _ax.set_xlabel(_x + '_' + _agg) _ax.set_ylabel(_group) # create legend for hues if hue is not None: _patches = [] for _hue, _color in _df_agg[[hue, '_color']].drop_duplicates().values: _patches.append(patches.Patch(color=_color, label=_hue)) _ax.legend(handles=_patches) else: _ax.legend(loc=legend_loc) # range of p is between 0 and 1 if _agg == 'p' and p_1_0: # noinspection PyTypeChecker _ax.set_xlim([0, 1]) if _x_is_list: _x_title = ','.join(x) else: _x_title = _x if _group_is_list: _group_title = ','.join(group) else: _group_title = _group _title = _x_title + ' by ' + _group_title if hue is not None: _title = _title + ' per ' + hue plt.suptitle(_title, size=16) plt.subplots_adjust(**subplots_adjust_kws) if return_fig_ax: return fig, ax else: plt.show() def aggplot2d(x, y, data, aggfunc='mean', ax=None, x_int=None, time_int=None, color=rcParams['palette'][0], as_abs=False): # time int should be something like '<M8[D]' # D can be any datetime unit from numpy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.datetime.html _y_agg = '{}_{}'.format(y, aggfunc) _y_std = '{}_std'.format(y) # preprocessing _df = data.copy() if as_abs: _df[y] = np.abs(_df[y]) if x_int is not None: _df[x] = np.round(_df[x] / x_int) * x_int if time_int is not None: _df[x] = _df[x].astype('<M8[{}]'.format(time_int)) # agg _df = _df.groupby([x]).agg({y: [aggfunc, 'std']}).set_axis([_y_agg, _y_std], axis=1, inplace=False).reset_index() if ax is None: ax = plt.gca() ax.plot(_df[x], _df[_y_agg], color=color, label=_y_agg) ax.fill_between(_df[x], _df[_y_agg] + _df[_y_std], _df[_y_agg] - _df[_y_std], color='xkcd:cyan', label=_y_std) ax.set_xlabel(x) ax.set_ylabel(y) ax.legend() return ax
[docs]@export def insert_linebreak(s: str, pos: int = None, frac: float = None, max_breaks: int = None) -> str: """ used to insert linebreaks in strings, useful for formatting axes labels :param s: string to insert linebreaks into :param pos: inserts a linebreak every pos characters [optional] :param frac: inserts a linebreak after frac percent of characters [optional] :param max_breaks: maximum number of linebreaks to insert [optional] :return: string with the linebreaks inserted """ _s = s + '' if pos is not None: _pos = pos _frac = int(np.ceil(len(_s) / _pos)) elif frac is not None: _pos = int(np.ceil(len(_s) / frac)) _frac = frac else: _pos = None _frac = None _pos_i = 0 if max_breaks is not None: _max = np.min([max_breaks, _frac - 1]) else: _max = _frac - 1 for _it in range(_max): _pos_i += _pos if _it > 0: _pos_i += 1 # needed because of from 0 indexing _s = _s[:_pos_i] + '\n' + _s[_pos_i:] # remove trailing newlines if _s[-1:] == '\n': _s = _s[:-1] return _s
[docs]@docstr @export def ax_tick_linebreaks(ax: plt.Axes = None, x: bool = True, y: bool = True, **kwargs) -> None: """ uses insert_linebreaks to insert linebreaks into the axes ticklabels :param ax: %(ax_in)s :param x: whether to insert linebreaks into the x axis label [optional] :param y: whether to insert linebreaks into the y axis label [optional] :param kwargs: other keyword arguments passed to insert_linebreaks :return: None """ if ax is None: ax = plt.gca() if x: ax.set_xticklabels([insert_linebreak(_item.get_text(), **kwargs) for _item in ax.get_xticklabels()]) if y: ax.set_yticklabels([insert_linebreak(_item.get_text(), **kwargs) for _item in ax.get_yticklabels()])
[docs]@docstr @export def annotate_barplot(ax: plt.Axes = None, x: Sequence = None, y: Sequence = None, ci: bool = True, ci_newline: bool = True, adj_ylim: float = .05, nr_format: str = rcParams['float_format'], ha: str = 'center', va: str = 'center', offset: int = plt.rcParams['font.size'], **kwargs) -> plt.Axes: """ automatically annotates a barplot with bar values and error bars (if present). Currently does not work with ticks! :param ax: %(ax_in)s :param x: %(x)s :param y: %(y)s :param ci: whether to annotate error bars [optional] :param ci_newline: whether to add a newline between values and error bar values [optional] :param adj_ylim: whether to automatically adjust the plot y limits to fit the annotations [optional] :param nr_format: %(number_format)s :param ha: horizontal alignment [optional] :param va: vertical alignment [optional] :param offset: offset between bar top and annotation center [optional] :param kwargs: other keyword arguments passed to pyplot.annotate :return: %(ax_out)s """ # catch font warnings logging.getLogger().setLevel(logging.CRITICAL) if ax is None: ax = plt.gca() _adj_plus = False _adj_minus = False if ci_newline: _ci_sep = '\n' _offset = offset + 5 else: _ci_sep = '' _offset = offset for _it, _patch in enumerate(ax.patches): try: if x is None: _x = _patch.get_x() + _patch.get_width() / 2. elif is_list_like(x): _x = x[_it] else: _x = x if y is None: _y = _patch.get_height() elif is_list_like(y): _y = y[_it] else: _y = y _val = _patch.get_height() if _val > 0: _adj_plus = True if _val < 0: _adj_minus = True if np.isnan(_val): continue _val_text = format(_val, nr_format) _annotate = r'${}$'.format(_val_text) # TODO: HANDLE CAPS if ci and ax.lines.__len__() > _it: _line = ax.lines[_it] _line_y = _line.get_xydata()[:, 1] _ci = (_line_y[1] - _line_y[0]) / 2 if not np.isnan(_ci): _ci_text = format(_ci, nr_format) _annotate = r'${}$'.format(_val_text) + _ci_sep + r'$\pm{}$'.format(_ci_text) ax.annotate(_annotate, (_x, _y), ha=ha, va=va, xytext=(0, np.sign(_val) * _offset), textcoords='offset points', **kwargs) except Exception as exc: print(exc) if adj_ylim: _ylim = list(ax.get_ylim()) _y_adj = (_ylim[1] - _ylim[0]) * adj_ylim if _adj_minus: _ylim[0] = _ylim[0] - _y_adj if _adj_plus: _ylim[1] = _ylim[1] + _y_adj # noinspection PyTypeChecker ax.set_ylim(_ylim) logging.getLogger().setLevel(logging.DEBUG) return ax
[docs]@docstr @export def animplot(data: pd.DataFrame = None, x: str = 'x', y: str = 'y', t: str = 't', lines: Mapping = None, max_interval: int = None, time_per_frame: int = 200, mode: str = rcParams['animplot.mode'], title: bool = True, title_prefix: str = '', t_format: str = None, fig: plt.Figure = None, ax: plt.Axes = None, color: str = None, label: str = None, legend: bool = False, legend_out: bool = False, legend_kws: Mapping = None, xlim: tuple = None, ylim: tuple = None, ax_facecolor: Union[str, Mapping] = None, grid: bool = False, vline: Union[Sequence, float] = None, **kwargs) -> Union[HTML, FuncAnimation]: """ wrapper for FuncAnimation to be used with pandas DataFrames. Assumes that you have a DataFrame containing one data point for each x-y-t combination. If mode is set to jshtml the function is optimized for use with Jupyter Notebook and returns an Interactive JavaScript Widget. :param data: %(data)s :param x: %(x_novec)s :param y: %(y_novec)s :param t: %(t_novec)s :param lines: you can also pass lines that you want to animate. Details to follow [optional] :param max_interval: max interval at which to abort the animation [optional] :param time_per_frame: time per frame [optional] :param mode: one of the below [optional] * ``matplotlib``: Return the matplotlib FuncAnimation object * ``html``: Returns an HTML5 movie (You need to install ffmpeg for this to work) * ``jshtml``: Returns an interactive Javascript Widget :param title: whether to set the time as plot title [optional] :param title_prefix: title prefix to be put in front of the time if title is true [optional] :param t_format: format string used to format the time variable in the title [optional] :param fig: figure to plot on [optional] :param ax: axes to plot on [optional] :param color: %(color)s :param label: %(label)s :param legend: %(legend)s :param legend_out: %(legend_out)s :param legend_kws: %(legend_kws)s :param xlim: %(xlim)s :param ylim: %(ylim)s :param ax_facecolor: passed to ax.set_facecolor, can also be a conditional mapping to change the facecolor at specific timepoints t [optional] :param grid: %(grid)s :param vline: %(vline)s :param kwargs: other keyword arguments passed to pyplot.plot :return: see mode **examples** Check out the `example notebook <https://drive.google.com/open?id=1hJRfZn3Zwnc1n4cK7h2-UPSEj4BmsxhY>`_ """ # example for lines (a list of dicts) # lines = [{'line':line,'data':data,'x':'x','y':'y','t':'t'}] if legend_kws is None: legend_kws = {} _args = {'data': data, 'x': x, 'y': y, 't': t} # init fig,ax if fig is None: fig = plt.gcf() if ax is None: ax = plt.gca() _ax_list = ax_as_list(ax) # init lines if lines is None: _ax = _ax_list[0] lines = [] _len = 1 if is_list_like(x): _len = np.max([_len, len(x)]) if is_list_like(y): _len = np.max([_len, len(y)]) for _it in range(_len): if is_list_like(x): _x = x[_it] else: _x = x if is_list_like(y): _y = y[_it] else: _y = y if is_list_like(vline): _vline = vline[_it] else: _vline = vline if isinstance(color, Mapping): if _y in color.keys(): _color = color[_y] else: _color = None elif is_list_like(color): _color = color[_it] else: _color = color _kwargs = deepcopy(kwargs) _kwargs_keys = list(_kwargs.keys()) # defaults if len(list_intersection(['markerfacecolor', 'mfc'], _kwargs_keys)) == 0: _kwargs['markerfacecolor'] = _color if len(list_intersection(['markeredgecolor', 'mec'], _kwargs_keys)) == 0: _kwargs['markeredgecolor'] = _color if len(list_intersection(['markeredgewidth', 'mew'], _kwargs_keys)) == 0: _kwargs['markeredgewidth'] = 1 if label is None: _label = _y elif isinstance(label, Mapping): _label = label[_y] elif is_list_like(label): _label = label[_it] else: _label = label lines += [{ 'line': _ax.plot([], [], label=_label, color=_color, **_kwargs)[0], 'ax': _ax, 'data': data, 'x': _x, 'y': _y, 't': t, 'vline': _vline, 'title': title, 'title_prefix': title_prefix, }] _ts = pd.Series(data[t].unique()).sort_values() else: _ts = pd.Series() for _line in lines: _keys = list(_line.keys()) # default: label = y if 'label' not in _keys: if 'y' in _keys: _line['label'] = _line['y'] elif y is not None: _line['label'] = y # update keys _keys = list(_line.keys()) # get kws _line_kws = {} _line_kw_keys = [_ for _ in _keys if _ not in ['ax', 'line', 'ts', 'data', 'x', 'y', 't']] _kw_keys = [_ for _ in list(kwargs.keys()) if _ not in _line_kw_keys] for _key in _line_kw_keys: _line_kws[_key] = _line[_key] for _kw_key in _kw_keys: _line_kws[_kw_key] = kwargs[_kw_key] if 'ax' not in _keys: _line['ax'] = _ax_list[0] if 'line' not in _keys: _line['line'] = _line['ax'].plot([], [], **_line_kws)[0], if is_list_like(_line['line']): _line['line'] = _line['line'][0] for _arg in list(_args.keys()): if _arg not in _keys: _line[_arg] = _args[_arg] _line['ts'] = _line['data'][_line['t']].drop_duplicates().sort_values().reset_index(drop=True) _ts = _ts.append(_line['ts']).drop_duplicates().sort_values().reset_index(drop=True) # get max interval if max_interval is not None: if max_interval < _ts.shape[0]: _max_interval = max_interval else: _max_interval = _ts.shape[0] else: _max_interval = _ts.shape[0] # unchanging stuff goes here def init(): for __ax in _ax_list: _xylim_set = False _x_min = None _x_max = None _y_min = None _y_max = None _legend = legend for __line in lines: # -- xy lims -- if __ax == __line['ax']: if not _xylim_set: # init with limits of first line _x_min = __line['data'][__line['x']].min() _x_max = __line['data'][__line['x']].max() _y_min = __line['data'][__line['y']].min() _y_max = __line['data'][__line['y']].max() _xylim_set = True else: # compare with x y lims of other lines if __line['data'][__line['x']].min() < _x_min: _x_min = __line['data'][__line['x']].min() if __line['data'][__line['y']].min() < _y_min: _y_min = __line['data'][__line['y']].min() if __line['data'][__line['x']].max() > _x_max: _x_max = __line['data'][__line['x']].max() if __line['data'][__line['y']].max() > _y_max: _y_max = __line['data'][__line['y']].max() # -- legend -- if 'legend' in list(__line.keys()): _legend = __line['legend'] if _legend: if legend_out: legend_outside(__ax, width=.995) else: __ax.legend(**legend_kws) # -- vlines -- if 'vline' in __line.keys(): _vline_i = __line['vline'] if _vline_i is not None: if not is_list_like(_vline_i): _vline_i = [_vline_i] for _vline_j in _vline_i: __ax.axvline(_vline_j, color='k', linestyle=':') # -- lims -- if xlim is not None: if xlim: __ax.set_xlim(xlim) else: __ax.set_xlim([_x_min, _x_max]) if ylim is not None: if ylim: __ax.set_ylim(ylim) else: __ax.set_ylim([_y_min, _y_max]) # -- grid -- if grid: __ax.grid() # -- ax facecolor -- if isinstance(ax_facecolor, str): __ax.set_facecolor(ax_facecolor) return () def animate(_i): _t = _ts[_i] for _line_i in lines: _line_keys_i = list(_line_i.keys()) _data = _line_i['data'].copy() _data = _data[_data[_line_i['t']] == _t] _line_i['line'].set_data(_data[_line_i['x']], _data[_line_i['y']]) if 'ax' in _line_keys_i: _ax_i = _line_i['ax'] else: _ax_i = plt.gca() # -- title -- _title = title _title_prefix = title_prefix if 'title' in list(_line_i.keys()): _title = _line_i['title'] if 'title_prefix' in list(_line_i.keys()): _title_prefix = _line_i['title_prefix'] if t_format is not None: _t_str = pd.to_datetime(_t).strftime(t_format) else: _t_str = _t if _title: _ax_i.set_title('{}{}'.format(_title_prefix, _t_str)) # -- facecolor -- if isinstance(ax_facecolor, Mapping): for _key_i in list(ax_facecolor.keys()): _ax_facecolor = ax_facecolor[_key_i] if (_key_i is None) or (_key_i > _t): _ax_i.set_facecolor(_ax_facecolor) return () for _line in lines: _line_keys = list(_line.keys()) if 'ax' in _line_keys: _ax = _line['ax'] else: _ax = plt.gca() # _ax.set_xlim(_line['data'][_line['x']].min(), _line['data'][_line['x']].max()) # _ax.set_ylim(_line['data'][_line['y']].min(), _line['data'][_line['y']].max()) _anim = FuncAnimation(fig, animate, init_func=init, frames=_max_interval, interval=time_per_frame, blit=True) plt.close('all') if mode == 'html': return HTML(_anim.to_html5_video()) elif mode == 'jshtml': return HTML(_anim.to_jshtml()) else: return _anim
[docs]@docstr @export def legend_outside(ax: plt.Axes = None, width: float = .85, loc: str = 'right', legend_space: float = rcParams['legend_outside.legend_space'], offset_x: float = 0, offset_y: float = 0, **kwargs): """ draws a legend outside of the subplot :param ax: %(ax_in)s :param width: how far to shrink down the subplot if loc=='right' :param loc: one of ['right','bottom'], where to put the legend :param legend_space: how far below the subplot to put the legend if loc=='bottom' :param offset_x: x offset for the legend :param offset_y: y offset for the legend :param kwargs: other keyword arguments passed to pyplot.legend :return: None """ # -- init if loc not in ['bottom', 'right']: warnings.warn('legend_outside: legend loc not recognized') ax.legend(loc=loc, **kwargs) return None _loc = {'bottom': 'upper center', 'right': 'center left'}[loc] _bbox_to_anchor = {'bottom': (0.5 + offset_x, - .15 + offset_y), 'right': (1, 0.5)}[loc] if ax is None: ax = plt.gca() for _ax in ax_as_list(ax): # -- shrink box _box = _ax.get_position() _pos = { 'bottom': [_box.x0, _box.y0, _box.width, _box.height * (1 - legend_space)], # 'bottom':[_box.x0, _box.y0 + _box.height * legend_space,_box.width, _box.height * (1-legend_space)], 'right': [_box.x0, _box.y0, _box.width * width, _box.height] }[loc] _ax.set_position(_pos) # -- legend logging.getLogger().setLevel(logging.CRITICAL) _, _labels = _ax.get_legend_handles_labels() if len(_labels) > 0: _ax.legend(loc=_loc, bbox_to_anchor=_bbox_to_anchor, **kwargs) logging.getLogger().setLevel(logging.DEBUG)
[docs]@docstr @export def set_ax_sym(ax: plt.Axes, x: bool = True, y: bool = True): """ automatically sets the select axes to be symmetrical :param ax: %(ax_in)s :param x: whether to set x axis to be symmetrical :param y: whether to set y axis to be symmetrical :return: None """ if x: _x_max = np.max(np.abs(np.array(ax.get_xlim()))) # noinspection PyTypeChecker ax.set_xlim((-_x_max, _x_max)) if y: _y_max = np.max(np.abs(np.array(ax.get_ylim()))) # noinspection PyTypeChecker ax.set_ylim((-_y_max, _y_max))
[docs]@docstr @export def custom_legend(colors: Union[list, str], labels: Union[list, str], do_show=True) -> Union[list, None]: """ uses patches to create a custom legend with the specified colors :param colors: list of matplotlib colors to use for the legend :param labels: list of labels to use for the legend :param do_show: whether to show the created legend :return: if do_show: None, else handles """ _handles = [] for _color, _label in zip(force_list(colors), force_list(labels)): _handles.append(patches.Patch(color=_color, label=_label)) if do_show: plt.legend(handles=_handles) else: return _handles
def lcurveplot(train, test, labels=None, legend='upper right', ax=None): if labels is None: if 'name' in dir(train): _label_train = train.name else: _label_train = 'train' if 'name' in dir(test): _label_test = test.name else: _label_test = 'test' elif isinstance(labels, Mapping): _label_train = labels['train'] _label_test = labels['test'] elif is_list_like(labels): _label_train = labels[0] _label_test = labels[1] else: _label_train = labels _label_test = labels if ax is None: ax = plt.gca() ax.plot(train, color='xkcd:blue', label=_label_train) ax.plot(test, color='xkcd:red', label=_label_test) ax.plot(lfit(test), color='xkcd:red', ls='--', alpha=.75, label=_label_test + '_lfit') ax.axhline(np.min(test), color='xkcd:red', ls=':', alpha=.5) ax.axvline(np.argmin(test), color='xkcd:red', ls=':', alpha=.5) if legend: if isinstance(legend, str): _loc = legend else: _loc = None ax.legend(loc=_loc) return ax def dic_to_lcurveplot(dic, width=16, height=9 / 2, **kwargs): if 'curves' not in dic.keys(): warnings.warn('key curves not found, stopping') return None _targets = list(dic['curves'].keys()) _nrows = len(_targets) _, ax = plt.subplots(nrows=_nrows, figsize=(width, height * _nrows)) _ax_list = ax_as_list(ax) for _it, _target in enumerate(_targets): _ax = _ax_list[_it] lcurveplot(dic['curves'][_target]['train'], dic['curves'][_target]['test'], labels=['{}_train'.format(_target), '{}_test'.format(_target)], ax=_ax, **kwargs) plt.show()
[docs]@docstr @export def stemplot(x, y, data=None, ax=None, color=rcParams['palette'][0], baseline=0, kwline=None, **kwargs): """ modeled after pyplot.stemplot but more customizeable :param x: %(x)s :param y: %(y)s :param data: %(data)s :param ax: %(ax_in)s :param color: %(color)s :param baseline: where to draw the baseline for the stemplot :param kwline: other keyword arguments passed to pyplot.plot :param kwargs: other keyword arguments passed to pyplot.scatter :return: %(ax_out)s """ if kwline is None: kwline = {} if data is None: if 'name' in dir(x): _x = x.name else: _x = 'x' if 'name' in dir(y): _y = y.name else: _y = 'x' _data = pd.DataFrame({_x: x, _y: y}) else: _x = x _y = y _data = data.copy() if ax is None: ax = plt.gca() # baseline ax.axhline(baseline, color='k', ls='--', alpha=.5) # iterate over data so you can draw the lines for _it, _row in _data.iterrows(): ax.plot([_row[_x], _row[_x]], [baseline, _row[_y]], color=color, label='__nolegend__', **kwline) # scatterplot for markers ax.scatter(x=_x, y=_y, data=_data, facecolor=color, **kwargs) return ax
def from_to_plot(data: pd.DataFrame, x_from='x_from', x_to='x_to', y_from=0, y_to=1, palette=None, label=None, legend=True, legend_loc=None, ax=None, **kwargs): # defaults if ax is None: ax = plt.gca() if palette is None: palette = rcParams['palette'] _labels = [] for _, _row in data.itertuples(): _label = '__nolabel__' _name = None if label is not None: _name = _row[label] if _name not in _labels: _label = _name + '' _labels.append(_label) if isinstance(palette, Mapping): _color = palette[_name] elif is_list_like(palette): _color = palette[_labels.index(_name) % len(palette)] else: _color = palette ax.fill_betweenx([y_from, y_to], _row[x_from], _row[x_to], label=_label, color=_color, **kwargs) if legend and label: ax.legend(loc=legend_loc) return ax def vlineplot(data, palette=None, label=None, legend=True, legend_loc=None, ax=None, **kwargs): # defaults if ax is None: ax = plt.gca() if palette is None: palette = rcParams['palette'] _labels = [] _name = None for _, _row in data.iterrows(): _label = '__nolabel__' if label is not None: _name = _row[label] if _name not in _labels: _label = _name + '' _labels.append(_label) if isinstance(palette, Mapping): _color = palette[_name] elif is_list_like(palette): _color = palette[_labels.index(_name) % len(palette)] else: _color = palette ax.axvline(_row['x'], label=_label, color=_color, **kwargs) if legend and label: ax.legend(loc=legend_loc) return ax def show_ax_ticklabels(ax, x=None, y=None): _ax_list = ax_as_list(ax) for _ax in _ax_list: if x is not None: plt.setp(_ax.get_xticklabels(), visible=x) if y is not None: plt.setp(_ax.get_yticklabels(), visible=y)
[docs]@docstr @export def get_twin(ax: plt.Axes) -> Union[plt.Axes, None]: """ get the twin axis from an Axes object :param ax: %(ax_in)s :return: the twin axis if it exists, else None """ for _other_ax in ax.figure.axes: if _other_ax is ax: continue if _other_ax.bbox.bounds == ax.bbox.bounds: return _other_ax return None
[docs]@docstr @export def get_axlim(ax: plt.Axes, xy: Union[str, None] = None) -> Union[tuple, Mapping]: """ Wrapper function to get x limits, y limits or both with one function call :param ax: %(ax_in)s :param xy: one of ['x', 'y', 'xy', None] :return: if xy is 'xy' or None returns a dictionary else returns x or y lims as tuple """ if xy == 'x': return ax.get_xlim() elif xy == 'y': return ax.get_ylim() else: return {'x': ax.get_xlim(), 'y': ax.get_ylim()}
[docs]@docstr @export def set_axlim(ax: plt.Axes, lim: Union[Sequence, Mapping], xy: Union[str, None] = None): """ Wrapper function to set both x and y limits with one call :param ax: %(ax_in)s :param lim: axes limits as tuple or Mapping :param xy: one of ['x', 'y', 'xy', None] :return: None """ if xy == 'x': # noinspection PyTypeChecker ax.set_xlim(lim) elif xy == 'y': # noinspection PyTypeChecker ax.set_ylim(lim) else: if isinstance(lim, Mapping): ax.set_xlim(lim['x']) ax.set_xlim(lim['y']) else: raise ValueError('Specify xy parameter or pass a dictionary')
[docs]@docstr @export def share_xy(ax: plt.Axes, x: bool = True, y: bool = True, mode: str = 'all', adj_twin_ax: bool = True): """ set the subplots on the Axes to share x and/or y limits WITHOUT sharing x and y legends. If you want that please use pyplot.subplots(share_x=True,share_y=True) when creating the plots. :param ax: %(ax_in)s :param x: whether to share x limits [optional] :param y: whether to share y limits [optional] :param mode: one of ['all', 'row', 'col'], if all shares across all subplots, else just across rows / columns :param adj_twin_ax: whether to also adjust twin axes :return: None """ _xys = [] if x: _xys.append('x') if y: _xys.append('y') if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 # slice for mode row / col (only applicable if shape==2) _ax_lists = [] if (_dims <= 1) or (mode == 'all'): _ax_lists += [ax_as_list(ax)] elif mode == 'row': for _row in range(ax.shape[0]): _ax_lists += [ax_as_list(ax[_row, :])] elif mode == 'col': for _col in range(ax.shape[1]): _ax_lists += [ax_as_list(ax[:, _col])] # we can have different subsets (by row or col) that share x / y min for _ax_list in _ax_lists: # init as None _xy_min = {'x': None, 'y': None} _xy_max = {'x': None, 'y': None} # get min max for _ax in _ax_list: _lims = get_axlim(_ax) for _xy in _xys: _xy_min_i = _lims[_xy][0] _xy_max_i = _lims[_xy][1] if _xy_min[_xy] is None: _xy_min[_xy] = _xy_min_i elif _xy_min[_xy] > _xy_min_i: _xy_min[_xy] = _xy_min_i if _xy_max[_xy] is None: _xy_max[_xy] = _xy_max_i elif _xy_max[_xy] < _xy_max_i: _xy_max[_xy] = _xy_max_i # set min max for _ax in _ax_list: if adj_twin_ax: _ax2 = get_twin(_ax) else: _ax2 = False # collect xy funcs for _xy in _xys: # save old lim _old_lim = list(get_axlim(_ax, xy=_xy)) # set new lim _new_lim = [_xy_min[_xy], _xy_max[_xy]] set_axlim(_ax, lim=_new_lim, xy=_xy) # adjust twin axis if _ax2: _old_lim_2 = list(get_axlim(_ax2, xy=_xy)) _new_lim_2 = [0 if _old == 0 else _new / _old * _old2 for _new, _old, _old2 in zip(_new_lim, _old_lim, _old_lim_2)] set_axlim(_ax2, lim=_new_lim_2, xy=_xy)
[docs]@docstr @export def share_legend(ax: plt.Axes, keep_i: int = None): """ removes all legends except for i from an Axes object :param ax: %(ax_in)s :param keep_i: index of the plot whose legend you want to keep :return: None """ _ax_list = ax_as_list(ax) if keep_i is None: keep_i = len(_ax_list) // 2 for _it, _ax in enumerate(_ax_list): _it += 1 _legend = _ax.get_legend() if _it != keep_i and (_legend is not None): _legend.remove()
def replace_xticklabels(ax, mapping): _new_labels = [] for _it, _label in enumerate(list(ax.get_xticklabels())): _text = _label.get_text() if isinstance(mapping, Mapping): if _text in mapping.keys(): _new_label = mapping[_text] else: _new_label = _text else: _new_label = mapping[_it] _new_labels.append(_new_label) ax.set_xticklabels(_new_labels) def replace_yticklabels(ax, mapping): _new_labels = [] for _it, _label in enumerate(list(ax.get_yticklabels())): _text = _label.get_text() if isinstance(mapping, Mapping): if _text in mapping.keys(): _new_label = mapping[_text] else: _new_label = _text else: _new_label = mapping[_it] _new_labels.append(_new_label) ax.set_yticklabels(_new_labels) def kdeplot(x, data=None, *args, hue=None, hue_order=None, bins=40, adj_x_range=False, baseline=0, highlight_peaks=True, show_kde=True, hist=True, show_area=False, area_center='mean', ha='center', va='center', legend_loc='upper right', palette=None, text_offset=15, nr_format=',.2f', kwline=None, perc=False, facecolor=None, sigma_color='xkcd:blue', sigma_2_color='xkcd:cyan', kde_color='black', edgecolor='black', alpha=.5, ax=None, ax2=None, kwhist=None, **kwargs): # -- init if palette is None: palette = rcParams['palette'] if kwline is None: kwline = {} if kwhist is None: kwhist = {} if data is not None: _df = data.copy() del data _x_name = x else: if 'name' in dir(x): _x_name = x.name else: _x_name = 'x' _df = pd.DataFrame({_x_name: x}) _df = _df.dropna(subset=[_x_name]) if hue is None: hue = '_dummy' _df[hue] = 1 if hue_order is None: hue_order = sorted(_df[hue].unique()) _x = _df[_x_name] if facecolor is None: if show_area: facecolor = 'None' else: facecolor = 'xkcd:cyan' if show_kde and show_area: _label_hist = '__nolabel__' else: _label_hist = _x_name # default if adj_x_range and isinstance(adj_x_range, bool): adj_x_range = 2 # -- get kde _it = -1 _twinx = False for _hue in hue_order: _it += 1 _df_hue = _df.query('{}==@_hue'.format(hue)) _df_kde, _df_kde_ex = kde(x=x, df=_df_hue, *args, **kwargs) if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_it % len(palette)] else: _color = palette if hue == '_dummy': _kde_color = kde_color _edgecolor = edgecolor _facecolor = facecolor else: _kde_color = _color _edgecolor = _color _facecolor = 'None' _df_kde['value'] = _df_kde['value'] / _df_kde['value'].max() _df_kde_ex['value'] = _df_kde_ex['value'] / _df_kde['value'].max() if adj_x_range: _x_min = _df_kde_ex['range_min'].min() _x_max = _df_kde_ex['range_max'].max() _x_step = (_x_max - _x_min) / bins _x_range_min = _x_min - _x_step * adj_x_range * bins _x_range_max = _x_max + _x_step * adj_x_range * bins _df_hue = _df_hue.query('{}>=@_x_range_min & {}<=@_x_range_max'.format(_x_name, _x_name)) _df_kde = _df_kde.query('{}>=@_x_range_min & {}<=@_x_range_max'.format(_x_name, _x_name)) # -- plot if ax is None: ax = plt.gca() # hist if hist: ax.hist(_df_hue[_x_name], bins, density=perc, facecolor=_facecolor, edgecolor=_edgecolor, label=_label_hist, **kwhist) _twinx = True else: _twinx = False if _twinx and (ax2 is None): ax2 = ax.twinx() else: ax2 = ax _kde_label = '{} ; '.format(_x_name) + r'${:,.2f}\pm{:,.2f}$'.format(_df[_x_name].mean(), _df[_x_name].std()) # kde ax2.plot(_df_kde[_x_name], _df_kde['value'], ls='--', label=_kde_label, color=_kde_color, **kwargs) _ylim = list(ax2.get_ylim()) _ylim[0] = 0 _ylim[1] = _ylim[1] * (100 + text_offset) / 100. ax2.set_ylim(_ylim) # area if show_area: # get max if area_center == 'max': _area_center = _df_kde[_df_kde['value'] == _df_kde['value'].max()].index[0] else: if area_center == 'mean': _ref = _df_hue[_x_name].mean() else: _ref = area_center _df_area = _df_kde.copy() _df_area['diff'] = (_df_area[_x_name] - area_center).abs() _df_area = _df_area.sort_values(by='diff', ascending=True) _area_center = _df_area.index[0] _sigma = None _2_sigma = None for _it in range(1, _df_kde.shape[0]): _perc_data = \ _df_kde[np.max([0, _area_center - _it]):np.min([_df_kde.shape[0], _area_center + _it + 1])][ 'value'].sum() / _df_kde['value'].sum() if (_perc_data >= .6826) and (_sigma is None): _sigma = _it + 0 if (_perc_data >= .9544) and (_2_sigma is None): _2_sigma = _it + 0 break if _it == _df_kde.shape[0] - 1: _2_sigma = _it + 0 _df_sigma = _df_kde.loc[ np.max([0, _area_center - _sigma]):np.min([_df_kde.shape[0], _area_center + _sigma])] _df_2_sigma_left = _df_kde.loc[ np.max([0, _area_center - _2_sigma]):np.min([_df_kde.shape[0], _area_center - _sigma])] _df_2_sigma_right = _df_kde.loc[ np.max([0, _area_center + _sigma]):np.min([_df_kde.shape[0], _area_center + _2_sigma])] _2_sigma_min = _df_2_sigma_left[_x_name].min() _2_sigma_max = _df_2_sigma_right[_x_name].max() if np.isnan(_2_sigma_min): _2_sigma_min = _df[_x_name].min() if np.isnan(_2_sigma_max): _2_sigma_max = _df[_x_name].max() _sigma_range = ': {:,.2f} to {:,.2f}'.format(_df_sigma[_x_name].min(), _df_sigma[_x_name].max()) _2_sigma_range = ': {:,.2f} to {:,.2f}'.format(_2_sigma_min, _2_sigma_max) ax2.fill_between(_x_name, 'value', data=_df_sigma, color=sigma_color, label=r'$1\sigma(68\%)$' + _sigma_range, alpha=alpha) ax2.fill_between(_x_name, 'value', data=_df_2_sigma_left, color=sigma_2_color, label=r'$2\sigma(95\%)$' + _2_sigma_range, alpha=alpha) ax2.fill_between(_x_name, 'value', data=_df_2_sigma_right, color=sigma_2_color, label='__nolegend__', alpha=alpha) ax2.legend(loc=legend_loc) # iterate over data so you can draw the lines if highlight_peaks: for _it, _row in _df_kde_ex.iterrows(): _mu = _row[_x_name] _value_std = np.min([_row['value_min'], _row['value_max']]) # stem (max) ax2.plot([_mu, _mu], [baseline, _row['value']], color=kde_color, label='__nolegend__', ls=':', **kwline) # std if highlight_peaks != 'max': ax2.plot([_row['range_min'], _row['range_max']], [_value_std, _value_std], color=kde_color, label='__nolegend__', ls=':', **kwline) # scatterplot for markers ax2.scatter(x=_mu, y=_row['value'], facecolor=kde_color, **kwargs) _mean_str = format(_mu, nr_format) _std_str = format(_row['range'] / 2., nr_format) _annotate = r'${}$'.format(_mean_str) if highlight_peaks != 'max': _annotate += '\n' + r'$\pm{}$'.format(_std_str) ax2.annotate(_annotate, (_mu, _row['value']), ha=ha, va=va, xytext=(0, text_offset), textcoords='offset points') if _twinx: ax2.legend(loc=legend_loc) ax2.set_axis_off() else: ax.legend(loc=legend_loc) return ax def draw_ellipse(ax, *args, **kwargs): _e = patches.Ellipse(*args, **kwargs) ax.add_artist(_e)
[docs]@docstr @export def barplot_err(x: str, y: str, xerr: str = None, yerr: str = None, data: pd.DataFrame = None, **kwargs) -> plt.Axes: """ extension on `seaborn barplot <https://seaborn.pydata.org/generated/seaborn.barplot.html>`_ that allows for plotting errorbars with preprocessed data. The idea is based on this `StackOverflow question <https://datascience.stackexchange.com/questions/31736/unable-to-generate-error-bars-with-seaborn/64128>`_ :param x: %(x_novec)s :param y: %(y_novec)s :param xerr: variable to use as x error bars [optional] :param yerr: variable to use as y error bars [optional] :param data: %(data_novec)s :param kwargs: other keyword arguments passed to `seaborn barplot <https://seaborn.pydata.org/generated/seaborn.barplot.html>`_ :return: %(ax_out)s """ _data = [] for _it in data.index: _data_i = pd.concat([data.loc[_it:_it]] * 3, ignore_index=True, sort=False) _row = data.loc[_it] if xerr is not None: _data_i[x] = [_row[x] - _row[xerr], _row[x], _row[x] + _row[xerr]] if yerr is not None: _data_i[y] = [_row[y] - _row[yerr], _row[y], _row[y] + _row[yerr]] _data.append(_data_i) _data = pd.concat(_data, ignore_index=True, sort=False) _ax = sns.barplot(x=x, y=y, data=_data, ci='sd', **kwargs) return _ax
def q_barplot(pd_series, ax=None, sort=False, percentage=False, **kwargs): _name = pd_series.name if ax is None: ax = plt.gca() _df_plot = pd_series.value_counts().reset_index() if sort: _df_plot = _df_plot.sort_values(['index']) if percentage: _y_name = _name + '_perc' _df_plot[_y_name] = _df_plot[_name] / _df_plot[_name].sum() * 100 _df_plot[_y_name] = _df_plot[_y_name].round(2) else: _y_name = _name sns.barplot(data=_df_plot, x='index', y=_y_name, ax=ax, **kwargs) return ax def histplot(x=None, data=None, hue=None, hue_order=None, ax=None, bins=30, use_q_xlim=False, legend_kws=None, **kwargs): # long or short format if legend_kws is None: legend_kws = {} if data is not None: # avoid inplace operations _df_plot = data.copy() del data _x = x else: # create dummy df _df_plot = pd.DataFrame.from_dict({'x': x}) _x = 'x' _xs = _df_plot[_x] # if applicable: filter data if use_q_xlim: _x_lim = q_plim(_xs) _df_plot = _df_plot[(_df_plot[_x] >= _x_lim[0]) & (_df_plot[_x] <= _x_lim[1])] _xs = _df_plot[_x] # create bins if not isinstance(bins, list): bins = np.linspace(_xs.min(), _xs.max(), bins) # if an axis has not been passed initialize one if ax is None: ax = plt.gca() # if a hue has been passed loop them if hue is not None: # if no hue order has been passed use default sorting if hue_order is None: hue_order = sorted(_df_plot[hue].unique()) for _hue in hue_order: _xs = _df_plot[_df_plot[hue] == _hue][_x] ax.hist(_xs, label=_hue, alpha=.5, bins=bins, **kwargs) ax.legend(**legend_kws) else: ax.hist(_xs, bins=bins, **kwargs) return ax
[docs]@docstr @export def countplot(x: Union[Sequence, str] = None, data: pd.DataFrame = None, hue: str = None, ax: plt.Axes = None, order: Union[Sequence, str] = None, hue_order: Union[Sequence, str] = None, normalize_x: bool = False, normalize_hue: bool = False, palette: Union[Mapping, Sequence, str] = None, x_tick_rotation: int = None, count_twinx: bool = False, hide_legend: bool = False, annotate: bool = True, annotate_format: str = rcParams['int_format'], legend_kws: Mapping = None, barplot_kws: Mapping = None, count_twinx_kws: Mapping = None, **kwargs): """ Based on seaborn barplot but with a few more options :param x: %(x)s :param data: %(data)s :param hue: %(hue)s :param ax: %(ax_in)s :param order: %(order)s :param hue_order: %(order)s :param normalize_x: whether to normalize x, causes the sum of each x group to be 100 percent [optional] :param normalize_hue: whether to normalize hue, causes the sum of each hue group to be 100 percent [optional] :param palette: %(palette)s :param x_tick_rotation: %(x_tick_rotation)s :param count_twinx: whether to plot the count values on the second axis (if using normalize) [optional] :param hide_legend: whether to hide the legend [optional] :param annotate: whether to use annotate_barplot [optional] :param annotate_format: %(number_format)s :param legend_kws: additional keyword arguments passed to pyplot.legend [optional] :param barplot_kws: additional keyword arguments passed to seaborn.barplot [optional] :param count_twinx_kws: additional keyword arguments passed to pyplot.plot [optional] :param kwargs: additional keyword arguments passed to hhpy.ds.df_count [optional] :return: %(ax_out)s """ # -- init # long or short format if legend_kws is None: legend_kws = {} if barplot_kws is None: barplot_kws = {} if count_twinx_kws is None: count_twinx_kws = {} if data is not None: # avoid inplace operations _df = data.copy() if x is None: _x = '_dummy' _df = _df.assign(_dummy=1) else: _x = x else: # create dummy df _df = pd.DataFrame.from_dict({'x': x}) _x = 'x' _count_x = 'count_{}'.format(_x) _count_hue = 'count_{}'.format(hue) # if an axis has not been passed initialize one if ax is None: ax = plt.gca() if normalize_x: _y = 'perc_{}'.format(_x) elif normalize_hue: _y = 'perc_{}'.format(hue) else: _y = 'count' _df_count = df_count(x=_x, df=_df, hue=hue, **kwargs) if order is None or order == 'count': _order = _df_count[[_x, _count_x]].drop_duplicates().sort_values(by=[_count_x], ascending=False)[_x].tolist() elif order == 'sorted': _order = _df_count[_x].drop_duplicates().sort_values().tolist() else: _order = order if hue is not None: _hues = _get_ordered_levels(data=_df, level=hue, order=hue_order, x=_x) if palette is None: palette = rcParams['palette'] * 5 sns.barplot(data=_df_count, x=_x, y=_y, hue=hue, order=_order, hue_order=hue_order, palette=palette, ax=ax, **barplot_kws) ax.set_xlabel('') # cleanup for x=None if x is None: ax.get_xaxis().set_visible(False) if normalize_x: ax.set_ylabel('perc') if hue is None and normalize_hue: ax.set_ylabel('perc') if annotate: # add annotation annotate_barplot(ax, nr_format=annotate_format) # enlarge ylims _ylim = list(ax.get_ylim()) _ylim[1] = _ylim[1] * 1.1 # noinspection PyTypeChecker ax.set_ylim(_ylim) if hide_legend: ax.get_legend().remove() elif hue is not None: ax.legend(**legend_kws) # tick rotation if x_tick_rotation is not None: ax.xaxis.set_tick_params(rotation=x_tick_rotation) # total count on secaxis if count_twinx: _ax = ax.twinx() _count_twinx_kws_keys = list(count_twinx_kws.keys()) if 'marker' not in _count_twinx_kws_keys: count_twinx_kws['marker'] = '_' if 'color' not in _count_twinx_kws_keys: count_twinx_kws['color'] = 'k' if 'alpha' not in _count_twinx_kws_keys: count_twinx_kws['alpha'] = .5 _ax.scatter(_x, _count_x, data=_df_count[[x, _count_x]].drop_duplicates(), **count_twinx_kws) _ax.set_ylabel('count') return ax
[docs]@docstr @export def quantile_plot(x: Union[Sequence, str], data: pd.DataFrame = None, qs: Union[Sequence, float] = None, x2: str = None, hue: str = None, hue_order: Union[Sequence, str] = None, to_abs: bool = False, ax: plt.Axes = None, **kwargs) -> plt.Axes: """ plots the specified quantiles of a Series using seaborn.barplot :param x: %(x)s :param data: %(data)s :param qs: Quantile levels [optional] :param x2: if specified: subtracts x2 from x before calculating quantiles [optional] :param hue: %(hue)s :param hue_order: %(order)s :param to_abs: %(to_abs)s :param ax: %(ax_in)s :param kwargs: other keyword arguments passed to seaborn.barplot :return: %(ax_out)s """ # long or short format if qs is None: qs = [.1, .25, .5, .75, .9] if data is not None: # avoid inplace operations _df = data.copy() if x2 is None: _x = x else: _x = '{} - {}'.format(x, x2) _df[_x] = _df[x] - _df[x2] else: # create dummy df if x2 is None: _df = pd.DataFrame({'x': x}) _x = 'x' else: _df = pd.DataFrame({'x': x, 'x2': x2}).eval('x_delta=x2-x') _x = 'x_delta' if ax is None: ax = plt.gca() _label = _x if to_abs: _df[_x] = _df[_x].abs() _label = '|{}|'.format(_x) if hue is None: _df_q = _df[_x].quantile(qs).reset_index() else: _hues = _get_ordered_levels(data=_df, level=hue, order=hue_order, x=_x) _df_q = [] for _hue in _hues: _df_i = _df[_df[hue] == _hue][_x].quantile(qs).reset_index() _df_i[hue] = _hue _df_q.append(_df_i) _df_q = pd.concat(_df_q, ignore_index=True, sort=False) sns.barplot(x='index', y=_x, data=_df_q, hue=hue, ax=ax, **kwargs) ax.set_xticklabels(['q{}'.format(int(_ * 100)) for _ in qs]) ax.set_xlabel('') ax.set_ylabel(_label) return ax