Source code for hhpy.plotting

"""
hhpy.plotting.py
~~~~~~~~~~~~~~~~

Contains plotting functions using matplotlib.pyplot

"""

# -- imports
# - standard imports
from copy import deepcopy

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import warnings
import itertools

# - third party imports
from matplotlib import patches, colors as mpl_colors
from matplotlib.animation import FuncAnimation
from matplotlib.legend import Legend
from colour import Color
from scipy import stats
from typing import Union, Sequence, Mapping, Callable, List

# local imports
from hhpy.main import export, concat_cols, is_list_like, floor_signif, ceil_signif, list_intersection, \
    assert_list, progressbar, DocstringProcessor, Scalar, SequenceOrScalar
from hhpy.ds import get_df_corr, lfit, kde, df_count, quantile_split, top_n_coding, df_rmsd, df_agg

# - optional imports
logger = logging.getLogger('hhpy.plotting')
try:
    from IPython.core.display import HTML
except ImportError:
    # noinspection PyPep8Naming
    def HTML(obj):
        logger.warning('Missing optional dependency IPython.core.display.HTML')
        return obj

try:
    # noinspection PyPackageRequirements
    from plotly import graph_objects as go
except ImportError:
    logger.warning('Missing optional dependency plotly')
    go = None

# --- constants

rcParams = {
    'palette': [
        'xkcd:blue', 'xkcd:red', 'xkcd:green', 'xkcd:cyan', 'xkcd:magenta',
        'xkcd:golden yellow', 'xkcd:dark cyan', 'xkcd:red orange', 'xkcd:dark yellow', 'xkcd:easter green',
        'xkcd:baby blue', 'xkcd:light brown', 'xkcd:strong pink', 'xkcd:light navy blue', 'xkcd:deep blue',
        'xkcd:deep red', 'xkcd:ultramarine blue', 'xkcd:sea green', 'xkcd:plum', 'xkcd:old pink',
        'xkcd:lawn green', 'xkcd:amber', 'xkcd:green blue', 'xkcd:yellow green', 'xkcd:dark mustard',
        'xkcd:bright lime', 'xkcd:aquamarine', 'xkcd:very light blue', 'xkcd:light grey blue', 'xkcd:dark sage',
        'xkcd:dark peach', 'xkcd:shocking pink'
    ],
    'hatches': ['/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*'],
    'figsize_square': (7, 7),
    'fig_width': 7,
    'fig_height': 7,
    'float_format': '.2f',
    'int_format': ',.0f',
    'legend_outside.legend_space': .1,
    'distplot.label_style': 'mu_sigma',
    'distplot.legend_loc': None,
    'max_n': 10000,
    'max_n.random_state': None,
    'max_n.sample_warn': True,
    'return_fig_ax': True,
    'corr_cutoff': 0,
    'animplot.mode': 'jshtml',
}

validations = {
    'distplot__distfit': ['kde', 'gauss', 'False', 'None'],
    'cat_to_color__out_type': [None, 'hex', 'rgb', 'rgba', 'rgba_array']
}

docstr = DocstringProcessor(
    ax_in='The matplotlib.pyplot.Axes object to plot on, defaults to current axis [optional]',
    ax_out='The matplotlib.pyplot.Axes object with the plot on it',
    fig_ax_out='if return_fig_ax: figure and axes objects as tuple, else None',
    x='Name of the x variable in data or vector data',
    y='Name of the y variable in data or vector data',
    t='Name of the t variable in data or vector data',
    x_novec='Name of the x variable in data',
    y_novec='Name of the y variable in data',
    t_novec='Name of the t variable in data',
    data='Pandas DataFrame containing named data, optional if vector data is used',
    data_novec='Pandas DataFrame containing named data',
    hue='Further split the plot by the levels of this variable [optional]',
    order='Either a string describing how the (hue) levels or to be ordered or an explicit list of levels to be ' 
    'used for plotting. Accepted strings are: '
    '''
    
        * ``sorted``: following python standard sorting conventions (alphabetical for string, ascending for value)
        
        * ``inv``: following python standard sorting conventions but in inverse order
        
        * ``count``: sorted by value counts
        
        * ``mean``, ``mean_ascending``, ``mean_descending``: sorted by mean value, defaults to descending
        
        * ``median``, ``mean_ascending``, ``median_descending``: sorted by median value, defaults to descending
        
    ''',
    color='Color used for plotting, must be known to matplotlib.pyplot [optional]',
    palette='Collection of colors to be used for plotting. Can be a dictionary for with names for each level or '
            'a list of colors or an individual color name. Must be valid colors known to pyplot [optional]',
    cmap='Color map to use [optional]',
    annotations='Whether to display annotations [optional]',
    number_format='The format string used for annotations [optional]',
    float_format='The format string used for displaying floats [optional]',
    int_format='The format string used for displaying floats [optional]',
    corr_target='Target variable name, if specified only correlations with the target are shown [optional]',
    corr_cutoff='Filter all correlation whose absolute value is below the cutoff [optional]',
    col_wrap='After how many columns to create a new line of subplots [optional]',
    subplot_width='Width of each individual subplot [optional]',
    subplot_height='Height of each individual subplot [optional]',
    trendline='Whether to add a trendline [optional]',
    alpha='Alpha transparency level [optional]',
    max_n='''Maximum number of samples to be used for plotting, if this number is exceeded max_n samples are drawn '
          'at random from the data which triggers a warning unless sample_warn is set to False. '
          'Set to False or None to use all samples for plotting. [optional]''',
    max_n_random_state='Random state (seed) used for drawing the random samples [optional]',
    max_n_sample_warn='Whether to trigger a warning if the data has more samples than max_n [optional]',
    return_fig_ax='Whether to return the figure and axes objects as tuple to be captured as fig,ax = ..., '
                  'If False pyplot.show() is called and the plot returns None [optional]',
    legend='Whether to show a legend [optional]',
    legend_loc='Location of the legend, one of [bottom, right] or accepted value of pyplot.legend'
               'If in [bottom, right] legend_outside is used, else pyplot.legend [optional]',
    legend_ncol='Number of columns to use in legend [optional]',
    legend_space='Only valid if legend_loc is bottom. The space between the main plot and the legend [optional]',
    kde_steps='Nr of steps the range is split into for kde fitting [optional]',
    linestyle='Linestyle used, must a valid linestyle recognized by pyplot.plot [optional]',
    bins='Nr of bins of the histogram [optional]',
    sharex='Whether to share the x axis [optional]',
    sharey='Whether to share the y axis [optional]',
    row='Row index [optional]',
    col='Column index [optional]',
    legend_out='Whether to draw the legend outside of the axis, can also be a location string [optional]',
    legend_kws='Other keyword arguments passed to pyplot.legend [optional]',
    xlim='X limits for the axis as tuple, passed to ax.set_xlim() [optional]',
    ylim='Y limits for the axis as tuple, passed to ax.set_ylim() [optional]',
    grid='Whether to toggle ax.grid() [optional]',
    vline='A list of x positions to draw vlines at [optional]',
    to_abs='Whether to cast the values to absolute before proceeding [optional]',
    label='Label to use for the data [optional]',
    x_tick_rotation='Set x tick label rotation to this value [optional]',
    std_cutoff='Remove data outside of std_cutoff standard deviations, for a good visual experience try 3 [optional]',
    do_print='Whether to print intermediate steps to console [optional]',
    x_min='Lower limit for the x axis [optional]',
    x_max='Upper limit for the x axis [optional]',
    y_min='Lower limit for the y axis [optional]',
    y_max='Upper limit for the y axis [optional]',
    title_plotly='Figure title, passed to plotly.Figure.update_layout [optional]',
    xaxis_title='x axis title, passed to plotly.Figure.update_layout [optional]',
    yaxis_title='y axis title, passed to plotly.Figure.update_layout [optional]',
    fig_plotly='The plotly.Figure object to draw the plot on [optional]',
    **validations
)


# --- functions
def _get_ordered_levels(data: pd.DataFrame, level: str, order: Union[list, str, None], x: str = None) -> list:
    """
    internal function for getting the ordered levels of a categorical like column in a pandas DataFrame

    :param data: pandas DataFrame
    :param level: name of the column
    :param order: how to order it, details see below
    :param x: secondary column name, used to aggregate before sorting
    :return: list of ordered levels
    """
    if order is None or order == 'sorted':
        _hues = data[level].drop_duplicates().sort_values().tolist()
    elif order == 'inv':
        _hues = data[level].drop_duplicates().sort_values().tolist()[::-1]
    elif order == 'count':
        _hues = data[level].value_counts().reset_index().sort_values(by=[level, 'index'])['index'].tolist()
    elif order in ['mean', 'mean_descending']:
        _hues = data.groupby(level)[x].mean().reset_index().sort_values(by=[x, level], ascending=[False, True]
                                                                        )[level].tolist()
    elif order == 'mean_ascending':
        _hues = data.groupby(level)[x].mean().reset_index().sort_values(by=[x, level])[level].tolist()
    elif order in ['median', 'median_descending']:
        _hues = data.groupby(level)[x].median().reset_index().sort_values(by=[x, level], ascending=[False, True]
                                                                          )[level].tolist()
    elif order == 'median_ascending':
        _hues = data.groupby(level)[x].median().reset_index().sort_values(by=[x, level])[level].tolist()
    else:
        _hues = order

    return _hues


[docs]@docstr @export def heatmap(x: str, y: str, z: str, data: pd.DataFrame, ax: plt.Axes = None, cmap: object = None, agg_func: str = 'mean', invert_y: bool = True, **kwargs) -> plt.Axes: """ Wrapper for seaborn heatmap in x-y-z format :param x: Variable name for x axis value :param y: Variable name for y axis value :param z: Variable name for z value, used to color the heatmap :param data: %(data)s :param ax: %(ax_in)s :param cmap: %(cmap)s :param agg_func: If more than one z value per x,y pair exists agg_func is used to aggregate the data. Must be a function name understood by pandas.DataFrame.agg :param invert_y: Whether to call ax.invert_yaxis (orders the heatmap as expected) :param kwargs: Other keyword arguments passed to seaborn heatmap :return: %(ax_out)s """ if cmap is None: cmap = sns.diverging_palette(10, 220, as_cmap=True) _df = data.groupby([x, y]).agg({z: agg_func}).reset_index().pivot(x, y, z) if ax is None: ax = plt.gca() sns.heatmap(_df, ax=ax, cmap=cmap, **kwargs) ax.set_title(z) if invert_y: ax.invert_yaxis() return ax
[docs]@docstr @export def corrplot(data: pd.DataFrame, annotations: bool = True, number_format: str = rcParams['float_format'], ax=None): """ function to create a correlation plot using a seaborn heatmap based on: https://www.linkedin.com/pulse/generating-correlation-heatmaps-seaborn-python-andrew-holt :param number_format: %(number_format)s :param data: %(data_novec)s :param annotations: %(annotations)s :param ax: %(ax_in)s :return: %(ax_out)s """ # Create Correlation df _corr = data.corr() if ax is None: ax = plt.gca() # Generate Color Map _colormap = sns.diverging_palette(220, 10, as_cmap=True) # mask _mask = np.zeros_like(_corr) _mask[np.triu_indices_from(_mask)] = True # Generate Heat Map, allow annotations and place floats in map sns.heatmap(_corr, cmap=_colormap, annot=annotations, fmt=number_format, mask=_mask, ax=ax) # Adjust tick labels ax.set_xticks(ax.get_xticks()[:-1]) _yticklabels = ax.get_yticklabels()[1:] ax.set_yticks(ax.get_yticks()[1:]) ax.set_yticklabels(_yticklabels) return ax
[docs]@docstr @export def corrplot_bar(data: pd.DataFrame, target: str = None, columns: List[str] = None, corr_cutoff: float = rcParams['corr_cutoff'], corr_as_alpha: bool = False, xlim: tuple = (-1, 1), ax: plt.Axes = None): """ Correlation plot as barchart based on :func:`~hhpy.ds.get_df_corr` :param data: %(data)s :param target: %(corr_target)s :param columns: Columns for which to calculate the correlations, defaults to all numeric columns [optional] :param corr_cutoff: %(corr_cutoff)s :param corr_as_alpha: Whether to set alpha value of bars to scale with correlation [optional] :param xlim: xlim scale for plot, defaults to (-1, 1) to show the absolute scale of the correlations. set to None if you want the plot x limits to scale to the highest correlation values [optional] :param ax: %(ax_in)s :return: %(ax_out)s """ _df_corr = get_df_corr(data, target=target) _df_corr = _df_corr[_df_corr['corr_abs'] >= corr_cutoff] if target is None: _df_corr['label'] = concat_cols(_df_corr, ['col_0', 'col_1'], sep=' X ') else: _df_corr['label'] = _df_corr['col_1'] # filter columns (if applicable) if columns is not None: _columns = columns + [] if target is not None and target not in _columns: _columns.append(target) _df_corr = _df_corr[(_df_corr['col_0'].isin(_columns)) & (_df_corr['col_1'].isin(_columns))] # get colors _rgba_colors = np.zeros((len(_df_corr), 4)) # for red the first column needs to be one _rgba_colors[:, 0] = np.where(_df_corr['corr'] > 0., 0., 1.) # for blue the third column needs to be one _rgba_colors[:, 2] = np.where(_df_corr['corr'] > 0., 1., 0.) # the fourth column needs to be alphas if corr_as_alpha: _rgba_colors[:, 3] = _df_corr['corr_abs'].where(lambda _: _ > .1, .1) else: _rgba_colors[:, 3] = 1 if ax is None: ax = plt.gca() _rgba_colors = np.round(_rgba_colors, 2) _plot = ax.barh(_df_corr['label'], _df_corr['corr'], color=_rgba_colors) ax.invert_yaxis() if xlim: # noinspection PyTypeChecker ax.set_xlim(xlim) if target is not None: ax.set_title('Correlations with {} by Absolute Value'.format(target)) ax.set_xlabel('corr × {}'.format(target)) else: ax.set_title('Correlations by Absolute Value') return ax
[docs]@docstr @export def pairwise_corrplot(data: pd.DataFrame, corr_cutoff: float = rcParams['corr_cutoff'], col_wrap: int = 4, hue: str = None, hue_order: Union[list, str] = None, width: float = rcParams['fig_width'], height: float = rcParams['fig_height'], trendline: bool = True, alpha: float = .75, ax: plt.Axes = None, target: str = None, palette: Union[Mapping, Sequence, str] = rcParams['palette'], max_n: int = rcParams['max_n'], random_state: int = rcParams['max_n.random_state'], sample_warn: bool = rcParams['max_n.sample_warn'], return_fig_ax: bool = rcParams['return_fig_ax'], **kwargs) -> Union[tuple, None]: """ print a pairwise_corrplot to for all variables in the df, by default only plots those with a correlation coefficient of >= corr_cutoff :param data: %(data_novec)s :param corr_cutoff: %(corr_cutoff)s :param col_wrap: %(col_wrap)s :param hue: %(hue)s :param hue_order: %(order)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param trendline: %(trendline)s :param alpha: %(alpha)s :param ax: %(ax_in)s :param target: %(corr_target)s :param palette: %(palette)s :param max_n: %(max_n)s :param random_state: %(max_n_random_state)s :param sample_warn: %(max_n_sample_warn)s :param return_fig_ax: %(return_fig_ax)s :param kwargs: other keyword arguments passed to pyplot.subplots :return: %(fig_ax_out)s """ # actual plot function def _f_plot(_f_x, _f_y, _f_data, _f_color, _f_color_trendline, _f_label, _f_ax): _data = _f_data.copy() # limit plot points if max_n is not None: if len(_data) > max_n: if sample_warn: warnings.warn( 'Limiting Scatter Plot to {:,} randomly selected points. ' 'Turn this off with max_n=None or suppress this warning with sample_warn=False.'.format( max_n)) _data = _data.sample(max_n, random_state=random_state) _f_ax.scatter(_f_x, _f_y, data=_data, alpha=alpha, color=_f_color, label=_f_label) if trendline: _f_ax.plot(_f_data[_f_x], lfit(_f_data[_f_x], _f_data[_f_y]), color=_f_color_trendline, linestyle=':') return _f_ax # avoid inplace operations _df = data.copy() _df_hues = pd.DataFrame() _df_corrs = pd.DataFrame() _hues = None if hue is not None: _hues = _get_ordered_levels(_df, hue, hue_order) _df_hues = {} _df_corrs = {} for _hue in _hues: _df_hue = _df[_df[hue] == _hue] _df_corr_hue = get_df_corr(_df_hue, target=target) _df_hues[_hue] = _df_hue.copy() _df_corrs[_hue] = _df_corr_hue.copy() # get df corr _df_corr = get_df_corr(_df, target=target) if corr_cutoff is not None: _df_corr = _df_corr[_df_corr['corr_abs'] >= corr_cutoff] # warning for empty df if len(_df_corr) == 0: warnings.warn('Correlation DataFrame is Empty. Do you need a lower corr_cutoff?') return None # edge case for less plots than ncols if len(_df_corr) < col_wrap: _ncols = len(_df_corr) else: _ncols = col_wrap # calculate nrows _nrows = int(np.ceil(len(_df_corr) / _ncols)) _figsize = (width * col_wrap, height * _nrows) if ax is None: fig, ax = plt.subplots(nrows=_nrows, ncols=_ncols, figsize=_figsize, **kwargs) else: fig = plt.gcf() _row = None _col = None for _it in range(len(_df_corr)): _col = _it % _ncols _row = _it // _ncols _x = _df_corr.iloc[_it]['col_1'] _y = _df_corr.iloc[_it]['col_0'] # so that target (if available) becomes y _corr = _df_corr.iloc[_it]['corr'] if _ncols == 1: _rows_prio = True else: _rows_prio = False _ax = get_subax(ax, _row, _col, rows_prio=_rows_prio) _ax.set_xlabel(_x) _ax.set_ylabel(_y) _ax.set_title('corr = {:.3f}'.format(_corr)) # hue if if hue is None: # actual plot _f_plot(_f_x=_x, _f_y=_y, _f_data=_df, _f_color=None, _f_color_trendline='k', _f_label=None, _f_ax=_ax) else: for _hue_it, _hue in enumerate(_hues): if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_hue_it % len(palette)] else: _color = palette _df_hue = _df_hues[_hue] _df_corr_hue = _df_corrs[_hue].copy() # sometimes it can happen that the correlation is not possible to calculate because # one of those values does not change in the hue level # i.e. use try except try: _df_corr_hue = _df_corr_hue[_df_corr_hue['col_1'] == _x] _df_corr_hue = _df_corr_hue[_df_corr_hue['col_0'] == _y] _corr_hue = _df_corr_hue['corr'].iloc[0] except ValueError: _corr_hue = 0 # actual plot _f_plot(_f_x=_x, _f_y=_y, _f_data=_df_hue, _f_color=_color, _f_color_trendline=_color, _f_label='{} corr: {:.3f}'.format(_hue, _corr_hue), _f_ax=_ax) _ax.legend() # hide unused axis for __col in range(_col + 1, _ncols): get_subax(ax, _row, __col, rows_prio=False).set_axis_off() if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def distplot(x: Union[Sequence, str], data: pd.DataFrame = None, hue: str = None, hue_order: Union[Sequence, str] = 'sorted', palette: Union[Mapping, Sequence, str] = None, linecolor: str = 'black', edgecolor: str = 'black', alpha: float = None, bins: Union[Sequence, int] = 40, perc: bool = None, top_nr: int = None, other_name: str = 'other', title: bool = True, title_prefix: str = '', std_cutoff: float = None, hist: bool = None, distfit: Union[str, bool, None] = 'kde', fill: bool = True, legend: bool = True, legend_loc: str = rcParams['distplot.legend_loc'], legend_space: float = rcParams['legend_outside.legend_space'], legend_ncol: int = 1, agg_func: str = 'mean', number_format: str = rcParams['float_format'], kde_steps: int = 1000, max_n: int = 100000, random_state: int = None, sample_warn: bool = True, xlim: Sequence = None, linestyle: str = None, label_style: str = rcParams['distplot.label_style'], x_offset_perc: float = .025, ax: plt.Axes = None, **kwargs) -> plt.Axes: """ Similar to seaborn.distplot but supports hues and some other things. Plots a combination of a histogram and a kernel density estimation. :param x: the name of the variable(s) in data or vector data, if data is provided and x is a list of columns the DataFrame is automatically melted and the newly generated column used as hue. i.e. you plot the distributions of multiple columns on the same axis :param data: %(data)s :param hue: %(hue)s :param hue_order: %(order)s :param palette: %(palette)s :param linecolor: Color of the kde fit line, overwritten with palette by hue level if hue is specified [optional] :param edgecolor: Color of the histogram edges [optional] :param alpha: %(alpha)s :param bins: %(bins)s :param perc: Whether to display the y-axes as percentage, if false count is displayed. Defaults if hue: True, else False [optional] :param top_nr: limit hue to top_nr levels using hhpy.ds.top_n, the rest will be cast to other [optional] :param other_name: name of the other group created by hhpy.ds.top_n [optional] :param title: whether to set the plot title equal to x's name [optional] :param title_prefix: prefix to be used in plot title [optional] :param std_cutoff: automatically cutoff data outside of the std_cutoff standard deviations range, by default this is off but a recommended value for a good visual experience without outliers is 3 [optional] :param hist: whether to show the histogram, default False if hue else True [optional] :param distfit: one of %(distplot__distfit)s. If 'kde' fits a kernel density distribution to the data. If gauss fits a gaussian distribution with the observed mean and std to the data. [optional] :param fill: whether to fill the area under the distfit curve, ignored if hist is True [optional] :param legend: %(legend)s :param legend_loc: %(legend_loc)s :param legend_space: %(legend_space)s :param legend_ncol: %(legend_ncol)s :param agg_func: one of ['mean', 'median']. The agg function used to find the center of the distribution [optional] :param number_format: %(number_format)s :param kde_steps: %(kde_steps)s :param max_n: %(max_n)s :param random_state: %(max_n_random_state)s :param sample_warn: %(max_n_sample_warn)s :param xlim: %(xlim)s :param linestyle: %(linestyle)s :param label_style: one of ['mu_sigma', 'plain']. If mu_sigma then the mean (or median) and std value are displayed inside the label [optional] :param x_offset_perc: the amount whitespace to display next to x_min and x_max in percent of x_range [optional] :param ax: %(ax_in)s :param kwargs: additional keyword arguments passed to pyplot.plot :return: %(ax_out)s """ # -- asserts if distfit not in validations['distplot__distfit']: raise ValueError(f"distfit must be one of {validations['distplot__distfit']}") # -- defaults if palette is None: palette = rcParams['palette'] if not top_nr: top_nr = None # case: vector data if data is None: if hasattr(x, 'name'): _x_name = x.name else: _x_name = 'x' _df = pd.DataFrame({_x_name: x}) x = _x_name # data: DataFrame else: _df = data.copy() # avoid inplace operations del data if is_list_like(x) and len(x) > 1: hue_order = assert_list(x) title = False hue = '__variable__' x = '__value__' _df = pd.melt(_df, value_vars=x, value_name=x, var_name=hue) # handle hue and default values if hue is None: if perc is None: perc = False if hist is None: hist = True if alpha is None: alpha = .75 else: _df = _df[~_df[hue].isnull()] if perc is None: perc = True if hist is None: hist = False if alpha is None: alpha = .5 # case more than max_n samples: take a random sample for calc speed if max_n and (len(_df) > max_n): if sample_warn: warnings.warn( f"Limiting samples to {max_n:,} for calc speed. Turn this off with max_n=None or suppress this warning " "with sample_warn=False.") _df = _df.sample(max_n, random_state=random_state) # the actual plot def _f_distplot(_f_x, _f_data, _f_x_label, _f_facecolor, _f_distfit_color, _f_bins, _f_std_cutoff, _f_xlim, _f_distfit_line, _f_ax, _f_ax2, _f_fill): # make a copy to avoid inplace operations _df_i = _f_data.copy() # best fit of data _mu = _df_i.agg({_f_x: agg_func})[0] _sigma = _df_i.agg({_f_x: 'std'})[0] # apply sigma cutoff if (_f_std_cutoff is not None) or (_f_xlim is not None): if _f_xlim is not None: __x_min = _f_xlim[0] __x_max = _f_xlim[1] elif is_list_like(_f_std_cutoff): __x_min = _f_std_cutoff[0] __x_max = _f_std_cutoff[1] else: __x_min = _mu - _f_std_cutoff * _sigma __x_max = _mu + _f_std_cutoff * _sigma _df_i = _df_i[ (_df_i[_f_x] >= __x_min) & (_df_i[_f_x] <= __x_max) ] # for plot trimming _x_mins.append(_df_i[x].min()) _x_maxs.append(_df_i[x].max()) # handle label try: _mu_label = format(_mu, number_format) except ValueError: _mu_label = '0' try: _sigma_label = format(_sigma, number_format) except ValueError: _sigma_label = '0' if agg_func == 'mean': _mu_symbol = r'\ \mu' else: _mu_symbol = r'\ \nu' if label_style == 'mu_sigma': _label = r'{}: $ {}={},\ \sigma={}$'.format(_f_x_label, _mu_symbol, _mu_label, _sigma_label) else: _label = _f_x_label # plot histogram if hist: _hist_n, _hist_bins = _f_ax.hist(_df_i[_f_x], _f_bins, density=perc, facecolor=_f_facecolor, edgecolor=edgecolor, alpha=alpha, label=_label)[:2] _label_2 = '__nolegend___' if _f_distfit_line is None: _f_distfit_line = '--' else: _hist_n = None _hist_bins = None _label_2 = _label + '' if _f_distfit_line is None: _f_distfit_line = '-' # plot distfit if distfit: # if a histogram was plot on the primary axis, the distfit goes on the secondary axis if hist: _ax = _f_ax2 else: _ax = _f_ax if distfit == 'gauss': # add a 'best fit' line __x = _f_bins _y = stats.norm.pdf(_f_bins, _mu, _sigma) # _hist_bins _ax.plot(__x, _y, linestyle=_f_distfit_line, color=_f_distfit_color, alpha=alpha, linewidth=2, label=_label_2, **kwargs) elif distfit == 'kde': _kde = kde(x=_f_x, df=_df_i, x_steps=kde_steps)[0] __x = _kde[_f_x] _y = _kde['value'] _ax.plot(__x, _y, linestyle=_f_distfit_line, color=_f_distfit_color, alpha=alpha, linewidth=2, label=_label_2, **kwargs) if not hist: _ax.set_ylabel('pdf') if _f_fill: # noinspection PyUnboundLocalVariable _ax.fill_between(__x, _y, color=_f_facecolor, alpha=alpha) _f_ax2.get_yaxis().set_visible(False) if perc and hist: _y_max = np.max(_hist_n) / np.sum(_hist_n) * 100 _y_ticklabels = list(_f_ax.get_yticks()) _y_ticklabels = [float(_) for _ in _y_ticklabels] _factor = _y_max / np.nanmax(_y_ticklabels) if np.isnan(_factor): _factor = 1 _y_ticklabels = [format(int(_ * _factor), ',') for _ in _y_ticklabels] _f_ax.set_yticklabels(_y_ticklabels) _f_ax.set_ylabel('%') elif hist: _f_ax.set_ylabel('count') # adjust xlims if necessary _xlim = list(_f_ax.get_xlim()) # here _df is used to access the 'parent' DataFrame with all hue levels if _xlim[0] <= _plot_x_min: _xlim[0] = _plot_x_min if _xlim[1] >= _plot_x_max: _xlim[1] = _plot_x_max _f_ax.set_xlim(_xlim) return _f_ax, _f_ax2 # -- preparing the data frame # drop nan values _df = _df[np.isfinite(_df[x])] # init plot if ax is None: ax = plt.gca() ax2 = ax.twinx() # for plot trimming _x_mins = [] _x_maxs = [] if hue is None: # handle x limits if xlim is not None: _x_min = xlim[0] _x_max = xlim[1] elif std_cutoff is not None: _x_min = _df[x].mean() - _df[x].std() * std_cutoff _x_max = _df[x].mean() + _df[x].std() * std_cutoff else: _x_min = _df[x].min() _x_max = _df[x].max() # edge case if _x_min == _x_max: warnings.warn('Distribution min and max are equal') _x_min -= 1 _x_max += 1 # handle bins if not is_list_like(bins): _x_step = (_x_max - _x_min) / bins _bins = np.arange(_x_min, _x_max + _x_step, _x_step) _plot_x_min = _df[x].min() - _x_step _plot_x_max = _df[x].max() + _x_step else: _bins = bins _plot_x_min = np.min(bins) _plot_x_max = np.max(bins) # handle palette / color if isinstance(palette, Mapping): _color = list(palette.values())[0] elif is_list_like(palette): _color = palette[0] else: _color = palette # plot ax, ax2 = _f_distplot(_f_x=x, _f_data=_df, _f_x_label=x, _f_facecolor=_color, _f_distfit_color=linecolor, _f_bins=_bins, _f_std_cutoff=std_cutoff, _f_xlim=xlim, _f_distfit_line=linestyle, _f_ax=ax, _f_ax2=ax2, _f_fill=fill) else: # no hue # group values outside of top_n to other_name if top_nr is not None: _hues = _df[hue].value_counts().reset_index().sort_values(by=[hue, 'index'])['index'].tolist() if (top_nr + 1) < len(_hues): # the plus 1 is there to avoid the other group having exactly 1 entry _hues = pd.Series(_hues)[0:top_nr] _df[hue] = np.where(_df[hue].isin(_hues), _df[hue], other_name) _df[hue] = _df[hue].astype('str') _hues = list(_hues) + [other_name] # parse hue order else: _hues = _get_ordered_levels(_df, hue, hue_order, x) # find shared _x_min ; _x_max if xlim is not None: _std_cutoff_hues = None _x_min = xlim[0] _x_max = xlim[1] elif std_cutoff is None: _std_cutoff_hues = None _x_min = _df[x].min() _x_max = _df[x].max() else: _df_agg = _df.groupby(hue).agg({x: ['mean', 'std']}).reset_index() _df_agg.columns = [hue, 'mean', 'std'] _df_agg['x_min'] = _df_agg['mean'] - _df_agg['std'] * std_cutoff _df_agg['x_max'] = _df_agg['mean'] + _df_agg['std'] * std_cutoff _df_agg['x_range'] = _df_agg['x_max'] - _df_agg['x_min'] _x_min = _df_agg['x_min'].min() _x_max = _df_agg['x_max'].max() _std_cutoff_hues = [_x_min, _x_max] # handle bins _x_step = (_x_max - _x_min) / bins _plot_x_min = _df[x].min() - _x_step _plot_x_max = _df[x].max() + _x_step _bins = np.arange(_x_min, _x_max + _x_step, _x_step) # loop hues for _it, _hue in enumerate(_hues): if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_it] else: _color = palette if isinstance(linestyle, Mapping): _linestyle = linestyle[_hue] elif is_list_like(linestyle): _linestyle = linestyle[_it] else: _linestyle = linestyle _df_hue = _df[_df[hue] == _hue] # one plot per hue ax, ax2 = _f_distplot(_f_x=x, _f_data=_df_hue, _f_x_label=_hue, _f_facecolor=_color, _f_distfit_color=_color, _f_bins=_bins, _f_std_cutoff=_std_cutoff_hues, _f_xlim=xlim, _f_distfit_line=_linestyle, _f_ax=ax, _f_ax2=ax2, _f_fill=fill) # -- postprocessing # handle legend if legend: if legend_loc in ['bottom', 'right']: legend_outside(ax, loc=legend_loc, legend_space=legend_space, ncol=legend_ncol) legend_outside(ax2, loc=legend_loc, legend_space=legend_space, ncol=legend_ncol) else: _, _labels = ax.get_legend_handles_labels() if len(_labels) > 0: ax.legend(loc=legend_loc, ncol=legend_ncol) _, _labels = ax2.get_legend_handles_labels() if len(_labels) > 0: ax2.legend(loc=legend_loc, ncol=legend_ncol) # handle title if title: _title = f"{title_prefix}{x}" if hue is not None: _title += f" by {hue}" ax.set_title(_title) # handle xlim if xlim is not None and xlim: # noinspection PyTypeChecker ax.set_xlim(xlim) else: _x_min = np.min(_x_mins) _x_max = np.max(_x_maxs) _x_offset = (_x_max - _x_min) * x_offset_perc # noinspection PyTypeChecker ax.set_xlim((_x_min - _x_offset, _x_max + _x_offset)) return ax
[docs]@docstr @export def hist_2d(x: str, y: str, data: pd.DataFrame, bins: int = 100, std_cutoff: int = 3, cutoff_perc: float = .01, cutoff_abs: float = 0, cmap: str = 'rainbow', ax: plt.Axes = None, color_sigma: str = 'xkcd:red', draw_sigma: bool = True, **kwargs) -> plt.Axes: """ generic 2d histogram created by splitting the 2d area into equal sized cells, counting data points in them and drawn using pyplot.pcolormesh :param x: %(x)s :param y: %(y)s :param data: %(data)s :param bins: %(bins)s :param std_cutoff: %(std_cutoff)s :param cutoff_perc: if less than this percentage of data points is in the cell then the data is ignored [optional] :param cutoff_abs: if less than this amount of data points is in the cell then the data is ignored [optional] :param cmap: %(cmap)s :param ax: %(ax_in)s :param color_sigma: color to highlight the sigma range in, must be a valid pyplot.plot color [optional] :param draw_sigma: whether to highlight the sigma range [optional] :param kwargs: other keyword arguments passed to pyplot.pcolormesh [optional] :return: %(ax_out)s """ _df = data.copy() del data if std_cutoff is not None: _x_min = _df[x].mean() - _df[x].std() * std_cutoff _x_max = _df[x].mean() + _df[x].std() * std_cutoff _y_min = _df[y].mean() - _df[y].std() * std_cutoff _y_max = _df[y].mean() + _df[y].std() * std_cutoff # x or y should be in std range _df = _df[ ((_df[x] >= _x_min) & (_df[x] <= _x_max) & (_df[y] >= _y_min) & (_df[y] <= _y_max)) ].reset_index(drop=True) _x = _df[x] _y = _df[y] # Estimate the 2D histogram _hist, _x_edges, _y_edges = np.histogram2d(_x, _y, bins=bins) # hist needs to be rotated and flipped _hist = np.rot90(_hist) _hist = np.flipud(_hist) # Mask too small counts if cutoff_abs is not None: _hist = np.ma.masked_where(_hist <= cutoff_abs, _hist) if cutoff_perc is not None: _hist = np.ma.masked_where(_hist <= _hist.max() * cutoff_perc, _hist) # Plot 2D histogram using pcolor if ax is None: ax = plt.gca() _mappable = ax.pcolormesh(_x_edges, _y_edges, _hist, cmap=cmap, **kwargs) ax.set_xlabel(x) ax.set_ylabel(y) _cbar = plt.colorbar(mappable=_mappable, ax=ax) _cbar.ax.set_ylabel('count') # draw ellipse to mark 1 sigma area if draw_sigma: _ellipse = patches.Ellipse(xy=(_x.median(), _y.median()), width=_x.std(), height=_y.std(), edgecolor=color_sigma, fc='None', lw=2, ls=':') ax.add_patch(_ellipse) return ax
[docs]@docstr @export def paired_plot(data: pd.DataFrame, cols: Sequence, color: str = None, cmap: str = None, alpha: float = 1, **kwargs) -> sns.FacetGrid: """ create a facet grid to analyze various aspects of correlation between two variables using seaborn.PairGrid :param data: %(data)s :param cols: list of exactly two variables to be compared :param color: %(color)s :param cmap: %(cmap)s :param alpha: %(alpha)s :param kwargs: other arguments passed to seaborn.PairGrid :return: seaborn FacetGrid object with the plots on it """ def _f_corr(_f_x, _f_y, _f_s=10, **_f_kwargs): # Calculate the value _coef = np.corrcoef(_f_x, _f_y)[0][1] # Make the label _label = r'$\rho$ = ' + str(round(_coef, 2)) # Add the label to the plot _ax = plt.gca() _ax.annotate(_label, xy=(0.2, 0.95 - (_f_s - 10.) / 10.), size=20, xycoords=_ax.transAxes, **_f_kwargs) # Create an instance of the PairGrid class. _grid = sns.PairGrid(data=data, vars=cols, **kwargs) # Map a scatter plot to the upper triangle _grid = _grid.map_upper(plt.scatter, alpha=alpha, color=color) # Map a corr coef _grid = _grid.map_upper(_f_corr) # density = True might not be supported in older versions of seaborn / matplotlib _grid = _grid.map_diag(plt.hist, bins=30, color=color, alpha=alpha, edgecolor='k', density=True) # Map a density plot to the lower triangle _grid = _grid.map_lower(sns.kdeplot, cmap=cmap, alpha=alpha) # add legend _grid.add_legend() return _grid
[docs]@export def q_plim(s: pd.Series, q_min: float = .1, q_max: float = .9, offset_perc: float = .1, limit_min_max: bool = False, offset=True) -> tuple: """ returns quick x limits for plotting (cut off data not in q_min to q_max quantile) :param s: pandas Series to truncate :param q_min: lower bound quantile [optional] :param q_max: upper bound quantile [optional] :param offset_perc: percentage of offset to the left and right of the quantile boundaries :param limit_min_max: whether to truncate the plot limits at the data limits :param offset: whether to apply the offset :return: a tuple containing the x limits """ _lower_bound = floor_signif(s.quantile(q=q_min)) _upper_bound = ceil_signif(s.quantile(q=q_max)) if _upper_bound == _lower_bound: _upper_bound = s.max() _lower_bound = s.min() if limit_min_max: if _upper_bound > s.max(): _upper_bound = s.max() if _lower_bound < s.min(): _lower_bound = s.min() if offset: _offset = (_upper_bound - _lower_bound) * offset_perc else: _offset = 0 return _lower_bound - _offset, _upper_bound + _offset
[docs]@docstr @export def levelplot(data: pd.DataFrame, level: str, cols: Union[list, str], hue: str = None, order: Union[list, str] = None, hue_order: Union[list, str] = None, func: Callable = distplot, summary_title: bool = True, level_title: bool = True, do_print: bool = False, width: int = None, height: int = None, return_fig_ax: bool = None, kwargs_subplots_adjust: Mapping = None, kwargs_summary: Mapping = None, **kwargs) -> Union[None, tuple]: """ Plots a plot for each specified column for each level of a certain column plus a summary plot :param data: %(data)s :param level: the name of the column to split the plots by, must be in data :param cols: the columns to create plots for, defaults to all numeric columns [optional] :param hue: %(hue)s :param order: %(order)s :param hue_order: %(order)s :param func: function to use for plotting, must support 1 positional argument, data, hue, ax and kwargs [optional] :param summary_title: whether to automatically set the summary plot title [optional] :param level_title: whether to automatically set the level plot title [optional] :param do_print: %(do_print)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param return_fig_ax: %(return_fig_ax)s :param kwargs_subplots_adjust: other keyword arguments passed to pyplot.subplots_adjust [optional] :param kwargs_summary: other keyword arguments passed to summary distplot, if None uses kwargs [optional] :param kwargs: other keyword arguments passed to func [optional] :return: see return_fig_ax """ # -- init # - defaults if kwargs_summary is None: kwargs_summary = kwargs if width is None: width = rcParams['fig_width'] if height is None: height = rcParams['fig_height'] if return_fig_ax is None: return_fig_ax = rcParams['return_fig_ax'] # handle no inplace data = pd.DataFrame(data).copy() if cols is None: cols = data.select_dtypes(include=np.number) _levels = _get_ordered_levels(data=data, level=level, order=order) if hue is not None: _hues = _get_ordered_levels(data=data, level=hue, order=hue_order) _hue_str = ' by {}'.format(hue) else: _hue_str = '' _nrows = len(cols) _ncols = len(_levels) + 1 _it_max = _nrows * _ncols fig, ax = plt.subplots(nrows=_nrows, ncols=_ncols, figsize=(_ncols * width, _nrows * height)) _it = -1 for _col_i, _col in enumerate(cols): _ax_summary = get_subax(ax, _col_i, 0, rows_prio=False) # always plot to col 0 of current row # summary plot func(_col, data=data, hue=level, ax=_ax_summary, **kwargs_summary) if summary_title: _ax_summary.set_title('{} by {}'.format(_col, level)) for _level_i, _level in enumerate(_levels): _it += 1 if do_print: progressbar(_it, _it_max, print_prefix='{}_{}'.format(_col, _level)) _df_level = data[data[level] == _level] _ax = get_subax(ax, _col_i, _level_i + 1) # level plot func(_col, data=_df_level, hue=hue, ax=_ax, **kwargs) if level_title: _ax.set_title('{}{} - {}={}'.format(_col, _hue_str, level, _level)) if kwargs_subplots_adjust is not None: plt.subplots_adjust(**kwargs_subplots_adjust) if do_print: progressbar() if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def get_legends(ax: plt.Axes = None) -> list: """ returns all legends on a given axis, useful if you have a secaxis :param ax: %(ax_in)s :return: list of legends """ if ax is None: ax = plt.gca() return [_ for _ in ax.get_children() if isinstance(_, Legend)]
# a plot to compare four components of a DataFrame def four_comp_plot(data, x_1, y_1, x_2, y_2, hue_1=None, hue_2=None, lim=None, return_fig_ax=None, **kwargs): # you can pass the hues to use or if none are given the default ones (std,plus/minus) are used # you can pass xlim and ylim or assume default (4 std) # four components, ie 2 x 2 if return_fig_ax is None: return_fig_ax = rcParams['return_fig_ax'] if lim is None: lim = {'x_1': 'default', 'x_2': 'default', 'y_1': 'default', 'y_2': 'default'} _nrows = 2 _ncols = 2 # init plot fig, ax = plt.subplots(ncols=_ncols, nrows=_nrows) # make a copy yo avoid inplace operations _df_plot = data.copy() _x_std = _df_plot[x_1].std() _y_std = _df_plot[y_1].std() # type 1: split by size in relation to std if hue_1 is None: _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) <= 1 * _x_std) & (np.abs(_df_plot[y_1]) <= 1 * _y_std), '0_std', 'Null') _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 1 * _x_std) | (np.abs(_df_plot[y_1]) > 1 * _y_std), '1_std', _df_plot['std']) _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 2 * _x_std) | (np.abs(_df_plot[y_1]) > 2 * _y_std), '2_std', _df_plot['std']) _df_plot['std'] = np.where((np.abs(_df_plot[x_1]) > 3 * _x_std) | (np.abs(_df_plot[y_1]) > 3 * _y_std), '3_std', _df_plot['std']) _df_plot['std'] = _df_plot['std'].astype('category') hue_1 = 'std' # type 2: split by plus minus if hue_2 is None: _df_plot['plus_minus'] = np.where((_df_plot[x_1] <= 0) & (_df_plot[y_1] <= 0), '- -', 'Null') _df_plot['plus_minus'] = np.where((_df_plot[x_1] <= 0) & (_df_plot[y_1] > 0), '- +', _df_plot['plus_minus']) _df_plot['plus_minus'] = np.where((_df_plot[x_1] > 0) & (_df_plot[y_1] <= 0), '+ -', _df_plot['plus_minus']) _df_plot['plus_minus'] = np.where((_df_plot[x_1] > 0) & (_df_plot[y_1] > 0), '+ +', _df_plot['plus_minus']) _df_plot['plus_minus'] = _df_plot['plus_minus'].astype('category') hue_2 = 'plus_minus' _xs = [x_1, x_2] _ys = [y_1, y_2] _hues = [hue_1, hue_2] _xlims = [lim['x_1'], lim['x_2']] _ylims = [lim['y_1'], lim['y_2']] for _row in range(_nrows): for _col in range(_ncols): # init _ax = get_subax(ax, _row, _col) _x_name = _xs[_col] _y_name = _ys[_col] _hue = _hues[_row] _x = _df_plot[_x_name] _y = _df_plot[_y_name] # scatterplot _ax = sns.scatterplot(data=_df_plot, x=_x_name, y=_y_name, hue=_hue, marker='.', ax=_ax, **kwargs) # grid 0 line _ax.axvline(0, color='k', alpha=.5, linestyle=':') _ax.axhline(0, color='k', alpha=.5, linestyle=':') # title _ax.set_title('%s vs %s, hue: %s' % (_x_name, _y_name, _hue)) # labels _ax.set_xlabel(_x_name) _ax.set_ylabel(_y_name) # set limits to be 4 std range if _xlims[_col] == 'default': _x_low = -_x.std() * 4 if _x.min() > _x_low: _x_low = _x.min() _x_high = _x.std() * 4 if _x.max() < _x_high: _x_high = _x.max() _ax.set_xlim([_x_low, _x_high]) if _ylims[_col] == 'default': _y_low = -_y.std() * 4 if _y.min() > _y_low: _y_low = _y.min() _y_high = _y.std() * 4 if _y.max() < _y_high: _y_high = _y.max() _ax.set_ylim([_y_low, _y_high]) if return_fig_ax: return fig, ax else: plt.tight_layout() plt.show()
[docs]@docstr @export def facet_wrap(func: Callable, data: pd.DataFrame, facet: Union[list, str], *args, facet_type: str = None, col_wrap: int = 4, width: int = None, height: int = None, catch_error: bool = True, return_fig_ax: bool = None, sharex: bool = False, sharey: bool = False, show_xlabel: bool = True, x_tick_rotation: int = None, y_tick_rotation: int = None, ax_title: str = 'set', order: Union[list, str] = None, subplots_kws: Mapping = None, **kwargs): """ modeled after r's facet_wrap function. Wraps a number of subplots onto a 2d grid of subplots while creating a new line after col_wrap columns. Uses a given plot function and creates a new plot for each facet level. :param func: Any plot function. Must support keyword arguments data and ax :param data: %(data)s :param facet: The column / list of columns to facet over. :param args: passed to func :param facet_type: one of ['group', 'cols', None]. If group facet is treated as the column creating the facet levels and a subplot is created for each level. If cols each facet is in turn passed as the first positional argument to the plot function func. If None then the facet_type is inferred: a single facet value will be treated as group and multiple facet values will be treated as cols. :param col_wrap: %(col_wrap)s :param width: %(subplot_width)s :param height: %(subplot_height)s :param catch_error: whether to keep going in case of an error being encountered in the plot function [optional] :param return_fig_ax: %(return_fig_ax)s :param sharex: %(sharex)s :param sharey: %(sharey)s :param show_xlabel: whether to show the x label for each subplot :param x_tick_rotation: x tick rotation for each subplot :param y_tick_rotation: y tick rotation for each subplot :param ax_title: one of ['set','hide'], if set sets axis title to facet name, if hide forcefully hids axis title :param order: %(order)s :param subplots_kws: other keyword arguments passed to pyplot.subplots :param kwargs: other keyword arguments passed to func :return: %(fig_ax_out)s **Examples** Check out the `example notebook <https://colab.research.google.com/drive/1bAEFRoWJgwPzkEqOoPBHVX849qQjxLYC>`_ """ # -- init # - defaults if width is None: width = rcParams['fig_width'] if height is None: height = rcParams['fig_height'] if return_fig_ax is None: return_fig_ax = rcParams['return_fig_ax'] if subplots_kws is None: subplots_kws = {} # - handle no inplace _df = data.copy() del data _facet = None _row = None _col = None # if it is a list of column names we will melt the df together if facet_type is None: if is_list_like(facet): facet_type = 'cols' else: facet_type = 'group' # process the facets if facet_type == 'cols': _facets = facet else: _df['_facet'] = concat_cols(_df, facet) facet = '_facet' _facets = _get_ordered_levels(_df, facet, order) # init a grid if len(_facets) > col_wrap: _ncols = col_wrap _nrows = int(np.ceil(len(_facets) / _ncols)) else: _ncols = len(_facets) _nrows = 1 fig, ax = plt.subplots(ncols=_ncols, nrows=_nrows, figsize=(width * _ncols, height * _nrows), **subplots_kws) _ax_list = ax_as_list(ax) # loop facets for _it, _facet in enumerate(_facets): _col = _it % _ncols _row = _it // _ncols _ax = _ax_list[_it] # get df facet _facet = _facets[_it] # for list set target to be in line with facet to ensure proper naming if facet_type == 'cols': _df_facet = _df.copy() _args = assert_list(_facet) + list(args) else: _df_facet = _df[_df[facet] == _facet] _args = args # apply function on target (try catch) if catch_error: try: func(*_args, data=_df_facet, ax=_ax, **kwargs) except Exception as _exc: warnings.warn('could not plot facet {} with exception {}, skipping. ' 'For details use catch_error=False'.format(_exc, _facet)) _ax.set_axis_off() continue else: func(*_args, data=_df_facet, ax=_ax, **kwargs) # set axis title to facet or hide it or do nothing (depending on preference) if ax_title == 'set': _ax.set_title(_facet) elif ax_title == 'hide': _ax.set_title('') # tick rotation if x_tick_rotation is not None: _ax.xaxis.set_tick_params(rotation=x_tick_rotation) if y_tick_rotation is not None: _ax.yaxis.set_tick_params(rotation=y_tick_rotation) # hide x label (if appropriate) if not show_xlabel: _ax.set_xlabel('') # hide unused axes for __col in range(_col + 1, _ncols): ax[_row, __col].set_axis_off() # share xy if sharex or sharey: share_xy(ax, x=sharex, y=sharey) if return_fig_ax: return fig, ax else: plt.show()
[docs]@docstr @export def get_subax(ax: Union[plt.Axes, np.ndarray], row: int = None, col: int = None, rows_prio: bool = True) -> plt.Axes: """ shorthand to get around the fact that ax can be a 1D array or a 2D array (for subplots that can be 1x1,1xn,nx1) :param ax: %(ax_in)s :param row: %(row)s :param col: %(col)s :param rows_prio: decides if to use row or col in case of a 1xn / nx1 shape (False means cols get priority) :return: %(ax_out)s """ if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 if _dims == 0: _ax = ax elif _dims == 1: if rows_prio: _ax = ax[row] else: _ax = ax[col] else: _ax = ax[row, col] return _ax
[docs]@docstr @export def ax_as_list(ax: Union[plt.Axes, np.ndarray]) -> list: """ takes any Axes and turns them into a list :param ax: %(ax_in)s :return: List containing the subaxes """ if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 if _dims == 0: _ax_list = [ax] elif _dims == 1: _ax_list = list(ax) else: _ax_list = list(ax.flatten()) return _ax_list
[docs]@docstr @export def ax_as_array(ax: Union[plt.Axes, np.ndarray]) -> np.ndarray: """ takes any Axes and turns them into a numpy 2D array :param ax: %(ax_in)s :return: Numpy 2D array containing the subaxes """ if isinstance(ax, np.ndarray): if len(ax.shape) == 2: return ax else: return ax.reshape(-1, 1) else: return np.array([ax]).reshape(-1, 1)
# bubble plot def bubbleplot(x, y, hue, s, text=None, text_as_label=False, data=None, s_factor=250, palette=None, hue_order=None, x_range_factor=5, y_range_factor=5, show_std=False, ax=None, legend_loc='right', text_kws=None): if palette is None: palette = rcParams['palette'] if text_kws is None: text_kws = {} if ax is None: ax = plt.gca() _df = data.copy() _df = _df[~((_df[x].isnull()) | (_df[y].isnull()) | (_df[s].isnull()))].reset_index(drop=True) if hue_order is not None: _df['_sort'] = _df[hue].apply(lambda _: hue_order.index(_)) _df = _df.sort_values(by=['_sort']) _df = _df.reset_index(drop=True) _x = _df[x] _y = _df[y] _s = _df[s] * s_factor if text is not None: _text = _df[text] else: _text = pd.Series() if isinstance(palette, Mapping): _df['_color'] = _df[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df['_color'] = palette[:_df.index.max() + 1] else: _df['color'] = palette # draw ellipse to mark 1 sigma area if show_std: _x_min = None _x_max = None _y_min = None _y_max = None for _index, _row in _df.iterrows(): _ellipse = patches.Ellipse(xy=(_row[x], _row[y]), width=_row[x + '_std'] * 2, height=_row[y + '_std'] * 2, edgecolor=_row['_color'], fc='None', lw=2, ls=':') ax.add_patch(_ellipse) _x_min_i = _row[x] - _row[x + '_std'] * 1.05 _x_max_i = _row[x] + _row[x + '_std'] * 1.05 _y_min_i = _row[y] - _row[y + '_std'] * 1.05 _y_max_i = _row[y] + _row[y + '_std'] * 1.05 if _x_min is None: _x_min = _x_min_i elif _x_min_i < _x_min: _x_min = _x_min_i if _x_max is None: _x_max = _x_max_i elif _x_max_i > _x_max: _x_max = _x_max_i if _y_min is None: _y_min = _y_min_i elif _y_min_i < _y_min: _y_min = _y_min_i if _y_max is None: _y_max = _y_max_i elif _y_max_i > _y_max: _y_max = _y_max_i else: # scatter for bubbles ax.scatter(x=_x, y=_y, s=_s, label='__nolegend__', facecolor=_df['_color'], edgecolor='black', alpha=.75) _x_range = _x.max() - _x.min() _x_min = _x.min() - _x_range / x_range_factor _x_max = _x.max() + _x_range / x_range_factor _y_range = _y.max() - _y.min() _y_min = _y.min() - _y_range / y_range_factor _y_max = _y.max() + _y_range / y_range_factor # plot fake data for legend (a little hacky) if text_as_label: _xlim_before = ax.get_xlim() for _it in range(len(_x)): _label = _text[_it] # fake data ax.scatter(x=-9999, y=_y[_it], label=_label, facecolor=_df['_color'].loc[_it], s=200, edgecolor='black', alpha=.75) ax.set_xlim(_xlim_before) if (text is not None) and (not text_as_label): for _it in range(len(_text)): _ = '' if (not np.isnan(_x.iloc[_it])) and (not np.isnan(_y.iloc[_it])): ax.text(x=_x.iloc[_it], y=_y.iloc[_it], s=_text.iloc[_it], horizontalalignment='center', verticalalignment='center', **text_kws) # print(_x_min,_x_max) ax.set_xlim(_x_min, _x_max) ax.set_ylim(_y_min, _y_max) ax.set_xlabel(_x.name) ax.set_ylabel(_y.name) if text_as_label and (legend_loc in ['bottom', 'right']): legend_outside(ax, loc=legend_loc) else: ax.legend(loc=legend_loc) # title ax.set_title(hue) return ax def bubblecountplot(x, y, hue, data, agg_function='median', show_std=True, top_nr=None, n_quantiles=10, other_name='other', dropna=True, float_format='.2f', text_end='', **kwargs): _df = data.copy() if dropna: _df = _df[~_df[hue].isnull()] if hue in _df.select_dtypes(include=np.number): _n = n_quantiles if top_nr is not None: if top_nr < n_quantiles: _n = top_nr _df[hue] = quantile_split(_df[hue], _n) if top_nr is not None: _df[hue] = top_n_coding(_df[hue], n=top_nr, other_name=other_name) # handle na _df[x] = _df[x].fillna(_df[x].dropna().agg(agg_function)) _df[y] = _df[y].fillna(_df[y].dropna().agg(agg_function)) # build agg dict _df['_count'] = 1 _df = _df.groupby([hue]).agg({x: [agg_function, 'std'], y: [agg_function, 'std'], '_count': 'count'}).reset_index() if x != y: _columns = [hue, x, x + '_std', y, y + '_std', '_count'] else: _columns = [hue, x, x + '_std', '_count'] _df.columns = _columns _df['_perc'] = _df['_count'] / _df['_count'].sum() * 100 _df['_count_text'] = _df.apply(lambda _: "{:,}".format(_['_count']), axis=1) _df['_perc_text'] = np.round(_df['_perc'], 2) _df['_perc_text'] = _df['_perc_text'].astype(str) + '%' if show_std: _df['_text'] = _df[hue].astype(str) + '(' + _df['_count_text'] + ')' + '\n' \ + 'x:' + _df[x].apply(lambda _: format(_, float_format)) + r'$\pm$' + _df[x + '_std'].apply( lambda _: format(_, float_format)) + '\n' \ + 'y:' + _df[y].apply(lambda _: format(_, float_format)) + r'$\pm$' + _df[y + '_std'].apply( lambda _: format(_, float_format)) else: _df['_text'] = _df[hue].astype(str) + '\n' + _df['_count_text'] + '\n' + _df['_perc_text'] _df['_text'] += text_end bubbleplot(x=x, y=y, hue=hue, s='_perc', text='_text', data=_df, show_std=show_std, **kwargs)
[docs]@docstr @export def rmsdplot(x: str, data: pd.DataFrame, groups: Union[Sequence, str] = None, hue: str = None, hue_order: Union[Sequence, str] = None, cutoff: float = 0, ax: plt.Axes = None, color_as_balance: bool = False, balance_cutoff: float = None, rmsd_as_alpha: bool = False, sort_by_hue: bool = False, palette=None, barh_kws=None, **kwargs): """ creates a seaborn.barplot showing the rmsd calculating :func:`~hhpy.ds.df_rmsd` :param x: %(x)s :param data: %(data)s :param groups: the columns to calculate the rmsd for, defaults to all columns [optional] :param hue: %(hue)s :param hue_order: %(order)s :param cutoff: drop rmsd values smaller than cutoff [optional] :param ax: %(ax_in)s :param color_as_balance: Whether to color the bars based on how balanced (based on maxperc values) the levels are [optional] :param balance_cutoff: If specified: all bars with worse balance (based on maxperc values) than cutoff are shown in red [optional] :param rmsd_as_alpha: Whether to use set the alpha values of the columns based on the rmsd value [optional] :param sort_by_hue: Whether to sort the plot by hue value [optional] :param palette: %(palette)s :param barh_kws: other keyword arguments passed to seaborn.barplot [optional] :param kwargs: other keyword arguments passed to :func:`hhpy.ds.rf_rmsd` [optional] :return: %(ax_out)s **Examples** Check out the `example notebook <https://colab.research.google.com/drive/1wvkYK80if0okXJGf1j2Kl-SxXZdl-97k>`_ """ if palette is None: palette = rcParams['palette'] if barh_kws is None: barh_kws = {} _data = data.copy() del data if hue is not None and hue_order is not None: _data = _data.query('{} in @hue_order'.format(hue)) _df_rmsd = df_rmsd(x=x, df=_data, groups=groups, hue=hue, sort_by_hue=sort_by_hue, **kwargs) _df_rmsd = _df_rmsd[_df_rmsd['rmsd'] >= cutoff] if hue is not None: _df_rmsd_no_hue = df_rmsd(x=x, df=_data, groups=groups, include_rmsd=False, **kwargs) else: _df_rmsd_no_hue = pd.DataFrame() if isinstance(x, list): if hue is None: _df_rmsd['label'] = concat_cols(_df_rmsd, ['x', 'group'], sep=' X ') else: _df_rmsd['label'] = concat_cols(_df_rmsd, ['x', 'group', hue], sep=' X ') else: if hue is None: _df_rmsd['label'] = _df_rmsd['group'] else: _df_rmsd['label'] = concat_cols(_df_rmsd, ['group', hue], sep=' X ') _df_rmsd['rmsd_scaled'] = _df_rmsd['rmsd'] / _df_rmsd['rmsd'].max() # get colors _rgba_colors = np.zeros((len(_df_rmsd), 4)) _hues = [] if hue is not None: _hues = _get_ordered_levels(data=_df_rmsd, level=hue, order=hue_order, x=x) if isinstance(palette, Mapping): _df_rmsd['_color'] = _df_rmsd[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df_rmsd['_color'] = _df_rmsd[hue].apply(lambda _: palette[list(_hues).index(_)]) else: _df_rmsd['_color'] = palette _rgba_colors[:, 0] = _df_rmsd['_color'].apply(lambda _: Color(_).red) _rgba_colors[:, 1] = _df_rmsd['_color'].apply(lambda _: Color(_).green) _rgba_colors[:, 2] = _df_rmsd['_color'].apply(lambda _: Color(_).blue) elif color_as_balance: if balance_cutoff is None: _rgba_colors[:, 0] = _df_rmsd['maxperc'] # for red the first column needs to be one _rgba_colors[:, 2] = 1 - _df_rmsd['maxperc'] # for blue the third column needs to be one else: _rgba_colors[:, 0] = np.where(_df_rmsd['maxperc'] >= balance_cutoff, 1, 0) _rgba_colors[:, 2] = np.where(_df_rmsd['maxperc'] < balance_cutoff, 1, 0) else: _rgba_colors[:, 2] = 1 # for blue the third column needs to be one # the fourth column needs to be alphas if rmsd_as_alpha: _rgba_colors[:, 3] = _df_rmsd['rmsd_scaled'] else: _rgba_colors[:, 3] = 1 if ax is None: ax = plt.gca() # make positions from labels if hue is not None: _pos_factor = .8 else: _pos_factor = 1 _df_rmsd['pos'] = _df_rmsd.index * _pos_factor if (hue is not None) and (not sort_by_hue): # iterate over rows and add to pos if label changes for _row in range(1, len(_df_rmsd)): if _df_rmsd['group'].iloc[_row] != _df_rmsd['group'].iloc[_row - 1]: _df_rmsd['pos'][_row:] = _df_rmsd['pos'][_row:] + _pos_factor # make a df of the average positions for each group _df_ticks = _df_rmsd.groupby('group').agg({'pos': 'mean'}).reset_index() # 'maxperc':'max' _df_ticks = pd.merge(_df_ticks, _df_rmsd_no_hue[['group', 'maxperc']]) # get maxperc from global value else: _df_ticks = pd.DataFrame() ax.barh(_df_rmsd['pos'], _df_rmsd['rmsd'], color=_rgba_colors, **barh_kws) _y_colors = None if (hue is not None) and (not sort_by_hue): _y_pos = _df_ticks['pos'] _y_lab = _df_ticks['group'] # color if balance_cutoff is not None: _y_colors = np.where(_df_ticks['maxperc'] > balance_cutoff, sns.xkcd_rgb['red'], 'k') else: _y_pos = _df_rmsd['pos'] if not is_list_like(x): _y_lab = _df_rmsd['group'] elif not is_list_like(groups): _y_lab = _df_rmsd['x'] else: _y_lab = concat_cols(_df_rmsd, ['x', 'group'], sep=' X ') ax.set_yticks(_y_pos) ax.set_yticklabels(_y_lab) if _y_colors is not None: for _y_tick, _color in zip(ax.get_yticklabels(), _y_colors): _y_tick.set_color(_color) if hue is None: _offset = _pos_factor else: _offset = _pos_factor * len(_hues) # noinspection PyTypeChecker ax.set_ylim([_y_pos.min() - _offset, _y_pos.max() + _offset]) ax.invert_yaxis() # create legend for hues if hue is not None: _patches = [] for _hue, _color, _count in _df_rmsd[[hue, '_color', 'count']].drop_duplicates().values: _patches.append(patches.Patch(color=_color, label='{} (n={:,})'.format(_hue, _count))) ax.legend(handles=_patches) # check if standardized _x_label_suffix = '' if 'standardize' in kwargs.keys(): if kwargs['standardize']: _x_label_suffix += ' [std]' if not is_list_like(x): ax.set_title('Root Mean Square Difference for {}'.format(x)) ax.set_xlabel('RMSD: {}{}'.format(x, _x_label_suffix)) elif not is_list_like(groups): ax.set_title('Root Mean Square Difference for {}'.format(groups)) ax.set_xlabel('RMSD: {}{}'.format(groups, _x_label_suffix)) else: ax.set_title('Root Mean Square Difference') return ax
# plot agg def aggplot(x, data, group, hue=None, hue_order=None, width=16, height=9 / 2, p_1_0=True, palette=None, sort_by_hue=False, return_fig_ax=None, agg=None, p=False, legend_loc='upper right', aggkws=None, subplots_kws=None, subplots_adjust_kws=None, **kwargs): if return_fig_ax is None: return_fig_ax = rcParams['return_fig_ax'] if palette is None: palette = rcParams['palette'] if agg is None: agg = ['mean', 'median', 'std'] if aggkws is None: aggkws = {} if subplots_kws is None: subplots_kws = {} if subplots_adjust_kws is None: subplots_adjust_kws = {'top': .95, 'hspace': .25, 'wspace': .35} # avoid inplace operations data = pd.DataFrame(data).copy() _len = len(agg) + 1 + p _x = x _group = group # EITHER x OR group can be a list (hue cannot be a lists) if is_list_like(x) and is_list_like(group): warnings.warn('both x and group cannot be a list, setting group = {}'.format(group[0])) _x_is_list = True _group_is_list = False _group = group[0] _ncols = len(x) _nrows = _len elif isinstance(x, list): _x_is_list = True _group_is_list = False _group = group _ncols = len(x) _nrows = _len elif isinstance(group, list): _x_is_list = False _group_is_list = True _ncols = len(group) _nrows = _len else: _x_is_list = False _group_is_list = False _ncols = int(np.floor(_len / 2)) _nrows = int(np.ceil(_len / 2)) fig, ax = plt.subplots(figsize=(width * _ncols, height * _nrows), nrows=_nrows, ncols=_ncols, **subplots_kws) _it = -1 for _col in range(_ncols): if _x_is_list: _x = x[_col] if _group_is_list: _group = group[_col] _df_agg = df_agg(x=_x, group=_group, hue=hue, df=data, agg=agg, p=p, **aggkws) if hue is not None: if sort_by_hue: _sort_by = [hue, _group] else: _sort_by = [_group, hue] _df_agg = _df_agg.sort_values(by=_sort_by).reset_index(drop=True) _label = '_label' _df_agg[_label] = concat_cols(_df_agg, [_group, hue], sep='_').astype('category') _hues = _get_ordered_levels(data=data, level=hue, order=hue_order, x=x) if isinstance(palette, Mapping): _df_agg['_color'] = _df_agg[hue].apply(lambda _: palette[_]) elif is_list_like(palette): _df_agg['_color'] = _df_agg[hue].apply(lambda _: palette[list(_hues).index(_)]) else: _df_agg['_color'] = palette else: _label = _group for _row in range(_nrows): _it += 1 if _x_is_list or _group_is_list: _index = _row else: _index = _it _ax = get_subax(ax, _row, _col) if _index >= _len: _ax.set_axis_off() continue _agg = list(_df_agg)[1:][_index] # one color per graph (if no hue) if hue is None: _df_agg['_color'] = palette[_index] # handle hue grouping if hue is not None: _pos_factor = .8 else: _pos_factor = 1 _df_agg['pos'] = _df_agg.index if (hue is not None) and (not sort_by_hue): # iterate over rows and add to pos if label changes for _row_2 in range(1, len(_df_agg)): if _df_agg[_group].iloc[_row_2] != _df_agg[_group].iloc[_row_2 - 1]: _df_agg['pos'][_row_2:] = _df_agg['pos'][_row_2:] + _pos_factor # make a df of the average positions for each group _df_ticks = _df_agg.groupby(_group).agg({'pos': 'mean'}).reset_index() else: _df_ticks = pd.DataFrame() _ax.barh('pos', _agg, color='_color', label=_agg, data=_df_agg, **kwargs) if (hue is not None) and (not sort_by_hue): _ax.set_yticks(_df_ticks['pos']) _ax.set_yticklabels(_df_ticks[_group]) else: _ax.set_yticks(_df_agg['pos']) _ax.set_yticklabels(_df_agg[_group]) _ax.invert_yaxis() _ax.set_xlabel(_x + '_' + _agg) _ax.set_ylabel(_group) # create legend for hues if hue is not None: _patches = [] for _hue, _color in _df_agg[[hue, '_color']].drop_duplicates().values: _patches.append(patches.Patch(color=_color, label=_hue)) _ax.legend(handles=_patches) else: _ax.legend(loc=legend_loc) # range of p is between 0 and 1 if _agg == 'p' and p_1_0: # noinspection PyTypeChecker _ax.set_xlim([0, 1]) if _x_is_list: _x_title = ','.join(x) else: _x_title = _x if _group_is_list: _group_title = ','.join(group) else: _group_title = _group _title = _x_title + ' by ' + _group_title if hue is not None: _title = _title + ' per ' + hue plt.suptitle(_title, size=16) plt.subplots_adjust(**subplots_adjust_kws) if return_fig_ax: return fig, ax else: plt.show() def aggplot2d(x, y, data, aggfunc='mean', ax=None, x_int=None, time_int=None, color=None, as_abs=False): if color is None: color = rcParams['palette'][0] # time int should be something like '<M8[D]' # D can be any datetime unit from numpy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.datetime.html _y_agg = '{}_{}'.format(y, aggfunc) _y_std = '{}_std'.format(y) # preprocessing data = pd.DataFrame(data).copy() if as_abs: data[y] = np.abs(data[y]) if x_int is not None: data[x] = np.round(data[x] / x_int) * x_int if time_int is not None: data[x] = data[x].astype('<M8[{}]'.format(time_int)) # agg data = data.groupby([x]).agg({y: [aggfunc, 'std']}).set_axis([_y_agg, _y_std], axis=1, inplace=False).reset_index() if ax is None: ax = plt.gca() ax.plot(data[x], data[_y_agg], color=color, label=_y_agg) ax.fill_between(data[x], data[_y_agg] + data[_y_std], data[_y_agg] - data[_y_std], color='xkcd:cyan', label=_y_std) ax.set_xlabel(x) ax.set_ylabel(y) ax.legend() return ax
[docs]@export def insert_linebreak(s: str, pos: int = None, frac: float = None, max_breaks: int = None) -> str: """ used to insert linebreaks in strings, useful for formatting axes labels :param s: string to insert linebreaks into :param pos: inserts a linebreak every pos characters [optional] :param frac: inserts a linebreak after frac percent of characters [optional] :param max_breaks: maximum number of linebreaks to insert [optional] :return: string with the linebreaks inserted """ _s = s + '' if pos is not None: _pos = pos _frac = int(np.ceil(len(_s) / _pos)) elif frac is not None: _pos = int(np.ceil(len(_s) / frac)) _frac = frac else: _pos = None _frac = None _pos_i = 0 if max_breaks is not None: _max = np.min([max_breaks, _frac - 1]) else: _max = _frac - 1 for _it in range(_max): _pos_i += _pos if _it > 0: _pos_i += 1 # needed because of from 0 indexing _s = _s[:_pos_i] + '\n' + _s[_pos_i:] # remove trailing newlines if _s[-1:] == '\n': _s = _s[:-1] return _s
[docs]@docstr @export def ax_tick_linebreaks(ax: plt.Axes = None, x: bool = True, y: bool = True, **kwargs) -> None: """ uses insert_linebreaks to insert linebreaks into the axes ticklabels :param ax: %(ax_in)s :param x: whether to insert linebreaks into the x axis label [optional] :param y: whether to insert linebreaks into the y axis label [optional] :param kwargs: other keyword arguments passed to insert_linebreaks :return: None """ if ax is None: ax = plt.gca() if x: ax.set_xticklabels([insert_linebreak(_item.get_text(), **kwargs) for _item in ax.get_xticklabels()]) if y: ax.set_yticklabels([insert_linebreak(_item.get_text(), **kwargs) for _item in ax.get_yticklabels()])
[docs]@docstr @export def annotate_barplot(ax: plt.Axes = None, x: Sequence = None, y: Sequence = None, ci: bool = True, ci_newline: bool = True, adj_ylim: float = .05, nr_format: str = None, ha: str = 'center', va: str = 'center', offset: int = None, **kwargs) -> plt.Axes: """ automatically annotates a barplot with bar values and error bars (if present). Currently does not work with ticks! :param ax: %(ax_in)s :param x: %(x)s :param y: %(y)s :param ci: whether to annotate error bars [optional] :param ci_newline: whether to add a newline between values and error bar values [optional] :param adj_ylim: whether to automatically adjust the plot y limits to fit the annotations [optional] :param nr_format: %(number_format)s :param ha: horizontal alignment [optional] :param va: vertical alignment [optional] :param offset: offset between bar top and annotation center, defaults to rcParams[font.size] [optional] :param kwargs: other keyword arguments passed to pyplot.annotate :return: %(ax_out)s """ # -- init # - defaults if nr_format is None: nr_format = rcParams['float_format'] if offset is None: offset = plt.rcParams['font.size'] if ax is None: ax = plt.gca() # catch font warnings logging.getLogger().setLevel(logging.CRITICAL) _adj_plus = False _adj_minus = False if ci_newline: _ci_sep = '\n' _offset = offset + 5 else: _ci_sep = '' _offset = offset for _it, _patch in enumerate(ax.patches): try: if x is None: _x = _patch.get_x() + _patch.get_width() / 2. elif is_list_like(x): _x = x[_it] else: _x = x if y is None: _y = _patch.get_height() elif is_list_like(y): _y = y[_it] else: _y = y _val = _patch.get_height() if _val > 0: _adj_plus = True if _val < 0: _adj_minus = True if np.isnan(_val): continue _val_text = format(_val, nr_format) _annotate = r'${}$'.format(_val_text) # TODO: HANDLE CAPS if ci and ax.lines.__len__() > _it: _line = ax.lines[_it] _line_y = _line.get_xydata()[:, 1] _ci = (_line_y[1] - _line_y[0]) / 2 if not np.isnan(_ci): _ci_text = format(_ci, nr_format) _annotate = r'${}$'.format(_val_text) + _ci_sep + r'$\pm{}$'.format(_ci_text) ax.annotate(_annotate, (_x, _y), ha=ha, va=va, xytext=(0, np.sign(_val) * _offset), textcoords='offset points', **kwargs) except Exception as exc: print(exc) if adj_ylim: _ylim = list(ax.get_ylim()) _y_adj = (_ylim[1] - _ylim[0]) * adj_ylim if _adj_minus: _ylim[0] = _ylim[0] - _y_adj if _adj_plus: _ylim[1] = _ylim[1] + _y_adj # noinspection PyTypeChecker ax.set_ylim(_ylim) logging.getLogger().setLevel(logging.DEBUG) return ax
[docs]@docstr @export def animplot(data: pd.DataFrame = None, x: str = 'x', y: str = 'y', t: str = 't', lines: Mapping = None, max_interval: int = None, time_per_frame: int = 200, mode: str = None, title: bool = True, title_prefix: str = '', t_format: str = None, fig: plt.Figure = None, ax: plt.Axes = None, color: str = None, label: str = None, legend: bool = False, legend_out: bool = False, legend_kws: Mapping = None, xlim: tuple = None, ylim: tuple = None, ax_facecolor: Union[str, Mapping] = None, grid: bool = False, vline: Union[Sequence, float] = None, **kwargs) -> Union[HTML, FuncAnimation]: """ wrapper for FuncAnimation to be used with pandas DataFrames. Assumes that you have a DataFrame containing one data point for each x-y-t combination. If mode is set to jshtml the function is optimized for use with Jupyter Notebook and returns an Interactive JavaScript Widget. :param data: %(data)s :param x: %(x_novec)s :param y: %(y_novec)s :param t: %(t_novec)s :param lines: you can also pass lines that you want to animate. Details to follow [optional] :param max_interval: max interval at which to abort the animation [optional] :param time_per_frame: time per frame [optional] :param mode: one of the below [optional] * ``matplotlib``: Return the matplotlib FuncAnimation object * ``html``: Returns an HTML5 movie (You need to install ffmpeg for this to work) * ``jshtml``: Returns an interactive Javascript Widget :param title: whether to set the time as plot title [optional] :param title_prefix: title prefix to be put in front of the time if title is true [optional] :param t_format: format string used to format the time variable in the title [optional] :param fig: figure to plot on [optional] :param ax: axes to plot on [optional] :param color: %(color)s :param label: %(label)s :param legend: %(legend)s :param legend_out: %(legend_out)s :param legend_kws: %(legend_kws)s :param xlim: %(xlim)s :param ylim: %(ylim)s :param ax_facecolor: passed to ax.set_facecolor, can also be a conditional mapping to change the facecolor at specific timepoints t [optional] :param grid: %(grid)s :param vline: %(vline)s :param kwargs: other keyword arguments passed to pyplot.plot :return: see mode **Examples** Check out the `example notebook <https://drive.google.com/open?id=1hJRfZn3Zwnc1n4cK7h2-UPSEj4BmsxhY>`_ """ # example for lines (a list of dicts) # lines = [{'line':line,'data':data,'x':'x','y':'y','t':'t'}] # -- init # - defaults if mode is None: mode = rcParams['animplot.mode'] if fig is None: fig = plt.gcf() if ax is None: ax = plt.gca() if legend_kws is None: legend_kws = {} # - handle no inplace data = pd.DataFrame(data).copy() # - preprocessing # if t is the index: save to regular column if (t == 'index') and (t not in data.columns): data[t] = data.index _args = {'data': data, 'x': x, 'y': y, 't': t} _ax_list = ax_as_list(ax) # init lines if lines is None: _ax = _ax_list[0] lines = [] _len = 1 if is_list_like(x): _len = np.max([_len, len(x)]) if is_list_like(y): _len = np.max([_len, len(y)]) for _it in range(_len): if is_list_like(x): _x = x[_it] else: _x = x if is_list_like(y): _y = y[_it] else: _y = y if is_list_like(vline): _vline = vline[_it] else: _vline = vline if isinstance(color, Mapping): if _y in color.keys(): _color = color[_y] else: _color = None elif is_list_like(color): _color = color[_it] else: _color = color _kwargs = deepcopy(kwargs) _kwargs_keys = list(_kwargs.keys()) # defaults if len(list_intersection(['markerfacecolor', 'mfc'], _kwargs_keys)) == 0: _kwargs['markerfacecolor'] = _color if len(list_intersection(['markeredgecolor', 'mec'], _kwargs_keys)) == 0: _kwargs['markeredgecolor'] = _color if len(list_intersection(['markeredgewidth', 'mew'], _kwargs_keys)) == 0: _kwargs['markeredgewidth'] = 1 if label is None: _label = _y elif isinstance(label, Mapping): _label = label[_y] elif is_list_like(label): _label = label[_it] else: _label = label lines += [{ 'line': _ax.plot([], [], label=_label, color=_color, **_kwargs)[0], 'ax': _ax, 'data': data, 'x': _x, 'y': _y, 't': t, 'vline': _vline, 'title': title, 'title_prefix': title_prefix, }] _ts = pd.Series(data[t].unique()).sort_values() else: _ts = pd.Series() for _line in lines: _keys = list(_line.keys()) # default: label = y if 'label' not in _keys: if 'y' in _keys: _line['label'] = _line['y'] elif y is not None: _line['label'] = y # update keys _keys = list(_line.keys()) # get kws _line_kws = {} _line_kw_keys = [_ for _ in _keys if _ not in ['ax', 'line', 'ts', 'data', 'x', 'y', 't']] _kw_keys = [_ for _ in list(kwargs.keys()) if _ not in _line_kw_keys] for _key in _line_kw_keys: _line_kws[_key] = _line[_key] for _kw_key in _kw_keys: _line_kws[_kw_key] = kwargs[_kw_key] if 'ax' not in _keys: _line['ax'] = _ax_list[0] if 'line' not in _keys: _line['line'] = _line['ax'].plot([], [], **_line_kws)[0], if is_list_like(_line['line']): _line['line'] = _line['line'][0] for _arg in list(_args.keys()): if _arg not in _keys: _line[_arg] = _args[_arg] _line['ts'] = _line['data'][_line['t']].drop_duplicates().sort_values().reset_index(drop=True) _ts = _ts.append(_line['ts']).drop_duplicates().sort_values().reset_index(drop=True) # get max interval if max_interval is not None: if max_interval < _ts.shape[0]: _max_interval = max_interval else: _max_interval = _ts.shape[0] else: _max_interval = _ts.shape[0] # unchanging stuff goes here def init(): for __ax in _ax_list: _xylim_set = False _x_min = None _x_max = None _y_min = None _y_max = None _legend = legend for __line in lines: # -- xy lims -- if __ax == __line['ax']: if not _xylim_set: # init with limits of first line _x_min = __line['data'][__line['x']].min() _x_max = __line['data'][__line['x']].max() _y_min = __line['data'][__line['y']].min() _y_max = __line['data'][__line['y']].max() _xylim_set = True else: # compare with x y lims of other lines if __line['data'][__line['x']].min() < _x_min: _x_min = __line['data'][__line['x']].min() if __line['data'][__line['y']].min() < _y_min: _y_min = __line['data'][__line['y']].min() if __line['data'][__line['x']].max() > _x_max: _x_max = __line['data'][__line['x']].max() if __line['data'][__line['y']].max() > _y_max: _y_max = __line['data'][__line['y']].max() # -- legend -- if 'legend' in list(__line.keys()): _legend = __line['legend'] if _legend: if legend_out: legend_outside(__ax, width=.995) else: __ax.legend(**legend_kws) # -- vlines -- if 'vline' in __line.keys(): _vline_i = __line['vline'] if _vline_i is not None: if not is_list_like(_vline_i): _vline_i = [_vline_i] for _vline_j in _vline_i: __ax.axvline(_vline_j, color='k', linestyle=':') # -- lims -- if xlim is not None: if xlim: __ax.set_xlim(xlim) else: __ax.set_xlim([_x_min, _x_max]) if ylim is not None: if ylim: __ax.set_ylim(ylim) else: __ax.set_ylim([_y_min, _y_max]) # -- grid -- if grid: __ax.grid() # -- ax facecolor -- if isinstance(ax_facecolor, str): __ax.set_facecolor(ax_facecolor) return () def animate(_i): _t = _ts[_i] for _line_i in lines: _line_keys_i = list(_line_i.keys()) _data = _line_i['data'].copy() _data = _data[_data[_line_i['t']] == _t] _line_i['line'].set_data(_data[_line_i['x']], _data[_line_i['y']]) if 'ax' in _line_keys_i: _ax_i = _line_i['ax'] else: _ax_i = plt.gca() # -- title -- _title = title _title_prefix = title_prefix if 'title' in list(_line_i.keys()): _title = _line_i['title'] if 'title_prefix' in list(_line_i.keys()): _title_prefix = _line_i['title_prefix'] if t_format is not None: _t_str = pd.to_datetime(_t).strftime(t_format) else: _t_str = _t if _title: _ax_i.set_title('{}{}'.format(_title_prefix, _t_str)) # -- facecolor -- if isinstance(ax_facecolor, Mapping): for _key_i in list(ax_facecolor.keys()): _ax_facecolor = ax_facecolor[_key_i] if (_key_i is None) or (_key_i > _t): _ax_i.set_facecolor(_ax_facecolor) return () # - get correct ax for each line for _line in lines: if 'ax' in list(_line.keys()): _ax = _line['ax'] else: _ax = plt.gca() # - create main FuncAnimation object _anim = FuncAnimation(fig, animate, init_func=init, frames=_max_interval, interval=time_per_frame, blit=True) # - close plots plt.close('all') # -- return # -handle return mode if mode == 'html': return HTML(_anim.to_html5_video()) elif mode == 'jshtml': return HTML(_anim.to_jshtml()) else: return _anim
[docs]@docstr @export def legend_outside(ax: plt.Axes = None, width: float = .85, loc: str = 'right', legend_space: float = None, offset_x: float = 0, offset_y: float = 0, loc_warn: bool = True, **kwargs): """ draws a legend outside of the subplot :param ax: %(ax_in)s :param width: how far to shrink down the subplot if loc=='right' :param loc: one of ['right','bottom'], where to put the legend :param legend_space: how far below the subplot to put the legend if loc=='bottom' :param offset_x: x offset for the legend :param offset_y: y offset for the legend :param loc_warn: Whether to trigger a warning if legend loc is not recognized :param kwargs: other keyword arguments passed to pyplot.legend :return: None """ # -- init # - defaults if legend_space is None: legend_space = rcParams['legend_outside.legend_space'] if ax is None: ax = plt.gca() # - check if loc is legend_outside specific, if not treat as inside loc and call regular ax.legend if loc not in ['bottom', 'right']: if loc_warn: warnings.warn('legend_outside: legend loc not recognized, defaulting to plt.legend') ax.legend(loc=loc, **kwargs) return None # -- main # - get loc and bbox _loc = {'bottom': 'upper center', 'right': 'center left'}[loc] _bbox_to_anchor = {'bottom': (0.5 + offset_x, - .15 + offset_y), 'right': (1, 0.5)}[loc] # - loop axes for _ax in ax_as_list(ax): # -- shrink box _box = _ax.get_position() _pos = { 'bottom': [_box.x0, _box.y0, _box.width, _box.height * (1 - legend_space)], # 'bottom':[_box.x0, _box.y0 + _box.height * legend_space,_box.width, _box.height * (1-legend_space)], 'right': [_box.x0, _box.y0, _box.width * width, _box.height] }[loc] _ax.set_position(_pos) # -- legend logging.getLogger().setLevel(logging.CRITICAL) _, _labels = _ax.get_legend_handles_labels() if len(_labels) > 0: _ax.legend(loc=_loc, bbox_to_anchor=_bbox_to_anchor, **kwargs) logging.getLogger().setLevel(logging.DEBUG)
[docs]@docstr @export def set_ax_sym(ax: plt.Axes, x: bool = True, y: bool = True): """ automatically sets the select axes to be symmetrical :param ax: %(ax_in)s :param x: whether to set x axis to be symmetrical :param y: whether to set y axis to be symmetrical :return: None """ if x: _x_max = np.max(np.abs(np.array(ax.get_xlim()))) # noinspection PyTypeChecker ax.set_xlim((-_x_max, _x_max)) if y: _y_max = np.max(np.abs(np.array(ax.get_ylim()))) # noinspection PyTypeChecker ax.set_ylim((-_y_max, _y_max))
[docs]@docstr @export def custom_legend(colors: Union[list, str], labels: Union[list, str], do_show=True) -> Union[list, None]: """ uses patches to create a custom legend with the specified colors :param colors: list of matplotlib colors to use for the legend :param labels: list of labels to use for the legend :param do_show: whether to show the created legend :return: if do_show: None, else handles """ _handles = [] for _color, _label in zip(assert_list(colors), assert_list(labels)): _handles.append(patches.Patch(color=_color, label=_label)) if do_show: plt.legend(handles=_handles) else: return _handles
def lcurveplot(train, test, labels=None, legend='upper right', ax=None): if labels is None: if 'name' in dir(train): _label_train = train.name else: _label_train = 'train' if 'name' in dir(test): _label_test = test.name else: _label_test = 'test' elif isinstance(labels, Mapping): _label_train = labels['train'] _label_test = labels['test'] elif is_list_like(labels): _label_train = labels[0] _label_test = labels[1] else: _label_train = labels _label_test = labels if ax is None: ax = plt.gca() ax.plot(train, color='xkcd:blue', label=_label_train) ax.plot(test, color='xkcd:red', label=_label_test) ax.plot(lfit(test), color='xkcd:red', ls='--', alpha=.75, label=_label_test + '_lfit') ax.axhline(np.min(test), color='xkcd:red', ls=':', alpha=.5) ax.axvline(np.argmin(test), color='xkcd:red', ls=':', alpha=.5) if legend: if isinstance(legend, str): _loc = legend else: _loc = None ax.legend(loc=_loc) return ax def dic_to_lcurveplot(dic, width=16, height=9 / 2, **kwargs): if 'curves' not in dic.keys(): warnings.warn('key curves not found, stopping') return None _targets = list(dic['curves'].keys()) _nrows = len(_targets) _, ax = plt.subplots(nrows=_nrows, figsize=(width, height * _nrows)) _ax_list = ax_as_list(ax) for _it, _target in enumerate(_targets): _ax = _ax_list[_it] lcurveplot(dic['curves'][_target]['train'], dic['curves'][_target]['test'], labels=['{}_train'.format(_target), '{}_test'.format(_target)], ax=_ax, **kwargs) plt.show()
[docs]@docstr @export def stemplot(x, y, data=None, ax=None, color=rcParams['palette'][0], baseline=0, kwline=None, **kwargs): """ modeled after pyplot.stemplot but more customizeable :param x: %(x)s :param y: %(y)s :param data: %(data)s :param ax: %(ax_in)s :param color: %(color)s :param baseline: where to draw the baseline for the stemplot :param kwline: other keyword arguments passed to pyplot.plot :param kwargs: other keyword arguments passed to pyplot.scatter :return: %(ax_out)s """ if kwline is None: kwline = {} if data is None: if 'name' in dir(x): _x = x.name else: _x = 'x' if 'name' in dir(y): _y = y.name else: _y = 'x' _data = pd.DataFrame({_x: x, _y: y}) else: _x = x _y = y _data = data.copy() if ax is None: ax = plt.gca() # baseline ax.axhline(baseline, color='k', ls='--', alpha=.5) # iterate over data so you can draw the lines for _it, _row in _data.iterrows(): ax.plot([_row[_x], _row[_x]], [baseline, _row[_y]], color=color, label='__nolegend__', **kwline) # scatterplot for markers ax.scatter(x=_x, y=_y, data=_data, facecolor=color, **kwargs) return ax
def from_to_plot(data: pd.DataFrame, x_from='x_from', x_to='x_to', y_from=0, y_to=1, palette=None, label=None, legend=True, legend_loc=None, ax=None, **kwargs): # defaults if ax is None: ax = plt.gca() if palette is None: palette = rcParams['palette'] _labels = [] for _, _row in data.itertuples(): _label = '__nolabel__' _name = None if label is not None: _name = _row[label] if _name not in _labels: _label = _name + '' _labels.append(_label) if isinstance(palette, Mapping): _color = palette[_name] elif is_list_like(palette): _color = palette[_labels.index(_name) % len(palette)] else: _color = palette ax.fill_betweenx([y_from, y_to], _row[x_from], _row[x_to], label=_label, color=_color, **kwargs) if legend and label: ax.legend(loc=legend_loc) return ax def vlineplot(data, palette=None, label=None, legend=True, legend_loc=None, ax=None, **kwargs): # defaults if ax is None: ax = plt.gca() if palette is None: palette = rcParams['palette'] _labels = [] _name = None for _, _row in data.iterrows(): _label = '__nolabel__' if label is not None: _name = _row[label] if _name not in _labels: _label = _name + '' _labels.append(_label) if isinstance(palette, Mapping): _color = palette[_name] elif is_list_like(palette): _color = palette[_labels.index(_name) % len(palette)] else: _color = palette ax.axvline(_row['x'], label=_label, color=_color, **kwargs) if legend and label: ax.legend(loc=legend_loc) return ax def show_ax_ticklabels(ax, x=None, y=None): _ax_list = ax_as_list(ax) for _ax in _ax_list: if x is not None: plt.setp(_ax.get_xticklabels(), visible=x) if y is not None: plt.setp(_ax.get_yticklabels(), visible=y)
[docs]@docstr @export def get_twin(ax: plt.Axes) -> Union[plt.Axes, None]: """ get the twin axis from an Axes object :param ax: %(ax_in)s :return: the twin axis if it exists, else None """ for _other_ax in ax.figure.axes: if _other_ax is ax: continue if _other_ax.bbox.bounds == ax.bbox.bounds: return _other_ax return None
[docs]@docstr @export def get_axlim(ax: plt.Axes, xy: Union[str, None] = None) -> Union[tuple, Mapping]: """ Wrapper function to get x limits, y limits or both with one function call :param ax: %(ax_in)s :param xy: one of ['x', 'y', 'xy', None] :return: if xy is 'xy' or None returns a dictionary else returns x or y lims as tuple """ if xy == 'x': return ax.get_xlim() elif xy == 'y': return ax.get_ylim() else: return {'x': ax.get_xlim(), 'y': ax.get_ylim()}
[docs]@docstr @export def set_axlim(ax: plt.Axes, lim: Union[Sequence, Mapping], xy: Union[str, None] = None): """ Wrapper function to set both x and y limits with one call :param ax: %(ax_in)s :param lim: axes limits as tuple or Mapping :param xy: one of ['x', 'y', 'xy', None] :return: None """ if xy == 'x': # noinspection PyTypeChecker ax.set_xlim(lim) elif xy == 'y': # noinspection PyTypeChecker ax.set_ylim(lim) else: if isinstance(lim, Mapping): ax.set_xlim(lim['x']) ax.set_xlim(lim['y']) else: raise ValueError('Specify xy parameter or pass a dictionary')
[docs]@docstr @export def share_xy(ax: plt.Axes, x: bool = True, y: bool = True, mode: str = 'all', adj_twin_ax: bool = True): """ set the subplots on the Axes to share x and/or y limits WITHOUT sharing x and y legends. If you want that please use pyplot.subplots(share_x=True,share_y=True) when creating the plots. :param ax: %(ax_in)s :param x: whether to share x limits [optional] :param y: whether to share y limits [optional] :param mode: one of ['all', 'row', 'col'], if all shares across all subplots, else just across rows / columns :param adj_twin_ax: whether to also adjust twin axes :return: None """ _xys = [] if x: _xys.append('x') if y: _xys.append('y') if isinstance(ax, np.ndarray): _dims = len(ax.shape) else: _dims = 0 # slice for mode row / col (only applicable if shape==2) _ax_lists = [] if (_dims <= 1) or (mode == 'all'): _ax_lists += [ax_as_list(ax)] elif mode == 'row': for _row in range(ax.shape[0]): _ax_lists += [ax_as_list(ax[_row, :])] elif mode == 'col': for _col in range(ax.shape[1]): _ax_lists += [ax_as_list(ax[:, _col])] # we can have different subsets (by row or col) that share x / y min for _ax_list in _ax_lists: # init as None _xy_min = {'x': None, 'y': None} _xy_max = {'x': None, 'y': None} # get min max for _ax in _ax_list: _lims = get_axlim(_ax) for _xy in _xys: _xy_min_i = _lims[_xy][0] _xy_max_i = _lims[_xy][1] if _xy_min[_xy] is None: _xy_min[_xy] = _xy_min_i elif _xy_min[_xy] > _xy_min_i: _xy_min[_xy] = _xy_min_i if _xy_max[_xy] is None: _xy_max[_xy] = _xy_max_i elif _xy_max[_xy] < _xy_max_i: _xy_max[_xy] = _xy_max_i # set min max for _ax in _ax_list: if adj_twin_ax: _ax2 = get_twin(_ax) else: _ax2 = False # collect xy funcs for _xy in _xys: # save old lim _old_lim = list(get_axlim(_ax, xy=_xy)) # set new lim _new_lim = [_xy_min[_xy], _xy_max[_xy]] set_axlim(_ax, lim=_new_lim, xy=_xy) # adjust twin axis if _ax2: _old_lim_2 = list(get_axlim(_ax2, xy=_xy)) _new_lim_2 = [0 if _old == 0 else _new / _old * _old2 for _new, _old, _old2 in zip(_new_lim, _old_lim, _old_lim_2)] set_axlim(_ax2, lim=_new_lim_2, xy=_xy)
[docs]@docstr @export def share_legend(ax: plt.Axes, keep_i: int = None): """ removes all legends except for i from an Axes object :param ax: %(ax_in)s :param keep_i: index of the plot whose legend you want to keep :return: None """ _ax_list = ax_as_list(ax) if keep_i is None: keep_i = len(_ax_list) // 2 for _it, _ax in enumerate(_ax_list): _it += 1 _legend = _ax.get_legend() if _it != keep_i and (_legend is not None): _legend.remove()
def replace_xticklabels(ax, mapping): _new_labels = [] for _it, _label in enumerate(list(ax.get_xticklabels())): _text = _label.get_text() if isinstance(mapping, Mapping): if _text in mapping.keys(): _new_label = mapping[_text] else: _new_label = _text else: _new_label = mapping[_it] _new_labels.append(_new_label) ax.set_xticklabels(_new_labels) def replace_yticklabels(ax, mapping): _new_labels = [] for _it, _label in enumerate(list(ax.get_yticklabels())): _text = _label.get_text() if isinstance(mapping, Mapping): if _text in mapping.keys(): _new_label = mapping[_text] else: _new_label = _text else: _new_label = mapping[_it] _new_labels.append(_new_label) ax.set_yticklabels(_new_labels) def kdeplot(x, data=None, *args, hue=None, hue_order=None, bins=40, adj_x_range=False, baseline=0, highlight_peaks=True, show_kde=True, hist=True, show_area=False, area_center='mean', ha='center', va='center', legend_loc='upper right', palette=None, text_offset=15, nr_format=',.2f', kwline=None, perc=False, facecolor=None, sigma_color='xkcd:blue', sigma_2_color='xkcd:cyan', kde_color='black', edgecolor='black', alpha=.5, ax=None, ax2=None, kwhist=None, **kwargs): # -- init if palette is None: palette = rcParams['palette'] if kwline is None: kwline = {} if kwhist is None: kwhist = {} if data is not None: _df = data.copy() del data _x_name = x else: if 'name' in dir(x): _x_name = x.name else: _x_name = 'x' _df = pd.DataFrame({_x_name: x}) _df = _df.dropna(subset=[_x_name]) if hue is None: hue = '_dummy' _df[hue] = 1 if hue_order is None: hue_order = sorted(_df[hue].unique()) _x = _df[_x_name] if facecolor is None: if show_area: facecolor = 'None' else: facecolor = 'xkcd:cyan' if show_kde and show_area: _label_hist = '__nolabel__' else: _label_hist = _x_name # default if adj_x_range and isinstance(adj_x_range, bool): adj_x_range = 2 # -- get kde _it = -1 _twinx = False for _hue in hue_order: _it += 1 _df_hue = _df.query('{}==@_hue'.format(hue)) _df_kde, _df_kde_ex = kde(x=x, df=_df_hue, *args, **kwargs) if isinstance(palette, Mapping): _color = palette[_hue] elif is_list_like(palette): _color = palette[_it % len(palette)] else: _color = palette if hue == '_dummy': _kde_color = kde_color _edgecolor = edgecolor _facecolor = facecolor else: _kde_color = _color _edgecolor = _color _facecolor = 'None' _df_kde['value'] = _df_kde['value'] / _df_kde['value'].max() _df_kde_ex['value'] = _df_kde_ex['value'] / _df_kde['value'].max() if adj_x_range: _x_min = _df_kde_ex['range_min'].min() _x_max = _df_kde_ex['range_max'].max() _x_step = (_x_max - _x_min) / bins _x_range_min = _x_min - _x_step * adj_x_range * bins _x_range_max = _x_max + _x_step * adj_x_range * bins _df_hue = _df_hue.query('{}>=@_x_range_min & {}<=@_x_range_max'.format(_x_name, _x_name)) _df_kde = _df_kde.query('{}>=@_x_range_min & {}<=@_x_range_max'.format(_x_name, _x_name)) # -- plot if ax is None: ax = plt.gca() # hist if hist: ax.hist(_df_hue[_x_name], bins, density=perc, facecolor=_facecolor, edgecolor=_edgecolor, label=_label_hist, **kwhist) _twinx = True else: _twinx = False if _twinx and (ax2 is None): ax2 = ax.twinx() else: ax2 = ax _kde_label = '{} ; '.format(_x_name) + r'${:,.2f}\pm{:,.2f}$'.format(_df[_x_name].mean(), _df[_x_name].std()) # kde ax2.plot(_df_kde[_x_name], _df_kde['value'], ls='--', label=_kde_label, color=_kde_color, **kwargs) _ylim = list(ax2.get_ylim()) _ylim[0] = 0 _ylim[1] = _ylim[1] * (100 + text_offset) / 100. ax2.set_ylim(_ylim) # area if show_area: # get max if area_center == 'max': _area_center = _df_kde[_df_kde['value'] == _df_kde['value'].max()].index[0] else: if area_center == 'mean': _ref = _df_hue[_x_name].mean() else: _ref = area_center _df_area = _df_kde.copy() _df_area['diff'] = (_df_area[_x_name] - area_center).abs() _df_area = _df_area.sort_values(by='diff', ascending=True) _area_center = _df_area.index[0] _sigma = None _2_sigma = None for _it in range(1, _df_kde.shape[0]): _perc_data = \ _df_kde[np.max([0, _area_center - _it]):np.min([_df_kde.shape[0], _area_center + _it + 1])][ 'value'].sum() / _df_kde['value'].sum() if (_perc_data >= .6826) and (_sigma is None): _sigma = _it + 0 if (_perc_data >= .9544) and (_2_sigma is None): _2_sigma = _it + 0 break if _it == _df_kde.shape[0] - 1: _2_sigma = _it + 0 _df_sigma = _df_kde.loc[ np.max([0, _area_center - _sigma]):np.min([_df_kde.shape[0], _area_center + _sigma])] _df_2_sigma_left = _df_kde.loc[ np.max([0, _area_center - _2_sigma]):np.min([_df_kde.shape[0], _area_center - _sigma])] _df_2_sigma_right = _df_kde.loc[ np.max([0, _area_center + _sigma]):np.min([_df_kde.shape[0], _area_center + _2_sigma])] _2_sigma_min = _df_2_sigma_left[_x_name].min() _2_sigma_max = _df_2_sigma_right[_x_name].max() if np.isnan(_2_sigma_min): _2_sigma_min = _df[_x_name].min() if np.isnan(_2_sigma_max): _2_sigma_max = _df[_x_name].max() _sigma_range = ': {:,.2f} to {:,.2f}'.format(_df_sigma[_x_name].min(), _df_sigma[_x_name].max()) _2_sigma_range = ': {:,.2f} to {:,.2f}'.format(_2_sigma_min, _2_sigma_max) ax2.fill_between(_x_name, 'value', data=_df_sigma, color=sigma_color, label=r'$1\sigma(68\%)$' + _sigma_range, alpha=alpha) ax2.fill_between(_x_name, 'value', data=_df_2_sigma_left, color=sigma_2_color, label=r'$2\sigma(95\%)$' + _2_sigma_range, alpha=alpha) ax2.fill_between(_x_name, 'value', data=_df_2_sigma_right, color=sigma_2_color, label='__nolegend__', alpha=alpha) ax2.legend(loc=legend_loc) # iterate over data so you can draw the lines if highlight_peaks: for _it, _row in _df_kde_ex.iterrows(): _mu = _row[_x_name] _value_std = np.min([_row['value_min'], _row['value_max']]) # stem (max) ax2.plot([_mu, _mu], [baseline, _row['value']], color=kde_color, label='__nolegend__', ls=':', **kwline) # std if highlight_peaks != 'max': ax2.plot([_row['range_min'], _row['range_max']], [_value_std, _value_std], color=kde_color, label='__nolegend__', ls=':', **kwline) # scatterplot for markers ax2.scatter(x=_mu, y=_row['value'], facecolor=kde_color, **kwargs) _mean_str = format(_mu, nr_format) _std_str = format(_row['range'] / 2., nr_format) _annotate = r'${}$'.format(_mean_str) if highlight_peaks != 'max': _annotate += '\n' + r'$\pm{}$'.format(_std_str) ax2.annotate(_annotate, (_mu, _row['value']), ha=ha, va=va, xytext=(0, text_offset), textcoords='offset points') if _twinx: ax2.legend(loc=legend_loc) ax2.set_axis_off() else: ax.legend(loc=legend_loc) return ax def draw_ellipse(ax, *args, **kwargs): _e = patches.Ellipse(*args, **kwargs) ax.add_artist(_e)
[docs]@docstr @export def barplot_err(x: str, y: str, xerr: str = None, yerr: str = None, data: pd.DataFrame = None, **kwargs) -> plt.Axes: """ extension on `seaborn barplot <https://seaborn.pydata.org/generated/seaborn.barplot.html>`_ that allows for plotting errorbars with preprocessed data. The idea is based on this `StackOverflow question <https://datascience.stackexchange.com/questions/31736/unable-to-generate-error-bars-with-seaborn/64128>`_ :param x: %(x_novec)s :param y: %(y_novec)s :param xerr: variable to use as x error bars [optional] :param yerr: variable to use as y error bars [optional] :param data: %(data_novec)s :param kwargs: other keyword arguments passed to `seaborn barplot <https://seaborn.pydata.org/generated/seaborn.barplot.html>`_ :return: %(ax_out)s """ _data = [] for _it in data.index: _data_i = pd.concat([data.loc[_it:_it]] * 3, ignore_index=True, sort=False) _row = data.loc[_it] if xerr is not None: _data_i[x] = [_row[x] - _row[xerr], _row[x], _row[x] + _row[xerr]] if yerr is not None: _data_i[y] = [_row[y] - _row[yerr], _row[y], _row[y] + _row[yerr]] _data.append(_data_i) _data = pd.concat(_data, ignore_index=True, sort=False) _ax = sns.barplot(x=x, y=y, data=_data, ci='sd', **kwargs) return _ax
def q_barplot(pd_series, ax=None, sort=False, percentage=False, **kwargs): _name = pd_series.name if ax is None: ax = plt.gca() _df_plot = pd_series.value_counts().reset_index() if sort: _df_plot = _df_plot.sort_values(['index']) if percentage: _y_name = _name + '_perc' _df_plot[_y_name] = _df_plot[_name] / _df_plot[_name].sum() * 100 _df_plot[_y_name] = _df_plot[_y_name].round(2) else: _y_name = _name sns.barplot(data=_df_plot, x='index', y=_y_name, ax=ax, **kwargs) return ax def histplot(x=None, data=None, hue=None, hue_order=None, ax=None, bins=30, use_q_xlim=False, legend_kws=None, **kwargs): # long or short format if legend_kws is None: legend_kws = {} if data is not None: # avoid inplace operations _df_plot = data.copy() del data _x = x else: # create dummy df _df_plot = pd.DataFrame.from_dict({'x': x}) _x = 'x' _xs = _df_plot[_x] # if applicable: filter data if use_q_xlim: _x_lim = q_plim(_xs) _df_plot = _df_plot[(_df_plot[_x] >= _x_lim[0]) & (_df_plot[_x] <= _x_lim[1])] _xs = _df_plot[_x] # create bins if not isinstance(bins, list): bins = np.linspace(_xs.min(), _xs.max(), bins) # if an axis has not been passed initialize one if ax is None: ax = plt.gca() # if a hue has been passed loop them if hue is not None: # if no hue order has been passed use default sorting if hue_order is None: hue_order = sorted(_df_plot[hue].unique()) for _hue in hue_order: _xs = _df_plot[_df_plot[hue] == _hue][_x] ax.hist(_xs, label=_hue, alpha=.5, bins=bins, **kwargs) ax.legend(**legend_kws) else: ax.hist(_xs, bins=bins, **kwargs) return ax
[docs]@docstr @export def countplot(x: Union[Sequence, str] = None, data: pd.DataFrame = None, hue: str = None, ax: plt.Axes = None, order: Union[Sequence, str] = None, hue_order: Union[Sequence, str] = None, normalize_x: bool = False, normalize_hue: bool = False, palette: Union[Mapping, Sequence, str] = None, x_tick_rotation: int = None, count_twinx: bool = False, hide_legend: bool = False, annotate: bool = True, annotate_format: str = rcParams['int_format'], legend_loc: str = 'upper right', barplot_kws: Mapping = None, count_twinx_kws: Mapping = None, **kwargs): """ Based on seaborn barplot but with a few more options, uses :func:`~hhpy.ds.df_count` :param x: %(x)s :param data: %(data)s :param hue: %(hue)s :param ax: %(ax_in)s :param order: %(order)s :param hue_order: %(order)s :param normalize_x: Whether to normalize x, causes the sum of each x group to be 100 percent [optional] :param normalize_hue: Whether to normalize hue, causes the sum of each hue group to be 100 percent [optional] :param palette: %(palette)s :param x_tick_rotation: %(x_tick_rotation)s :param count_twinx: Whether to plot the count values on the second axis (if using normalize) [optional] :param hide_legend: Whether to hide the legend [optional] :param annotate: Whether to use annotate_barplot [optional] :param annotate_format: %(number_format)s :param legend_loc: %(legend_loc)s :param barplot_kws: Additional keyword arguments passed to seaborn.barplot [optional] :param count_twinx_kws: Additional keyword arguments passed to pyplot.plot [optional] :param kwargs: Additional keyword arguments passed to :func:`~hhpy.ds.df_count` [optional] :return: %(ax_out)s """ # -- init # defaults if barplot_kws is None: barplot_kws = {} if count_twinx_kws is None: count_twinx_kws = {} # long or short format if data is not None: # avoid inplace operations data = data.copy() # if x is not specified count each row if x is None: x = '_dummy' data = data.assign(_dummy=1) else: if isinstance(x, pd.DataFrame): # case: only a DataFrame is passed as first argument (count rows) data = x.copy().assign(_dummy=1) else: # assume passed object is a Sequence and create dummy df data = pd.DataFrame({'_dummy': x}) x = '_dummy' _count_x = 'count_{}'.format(x) _count_hue = 'count_{}'.format(hue) # if an axis has not been passed initialize one if ax is None: ax = plt.gca() if normalize_x: _y = 'perc_{}'.format(x) elif normalize_hue: _y = 'perc_{}'.format(hue) else: _y = 'count' _df_count = df_count(x=x, df=data, hue=hue, **kwargs) if order is None or order == 'count': _order = _df_count[[x, _count_x]].drop_duplicates().sort_values(by=[_count_x], ascending=False)[x].tolist() elif order == 'sorted': _order = _df_count[x].drop_duplicates().sort_values().tolist() else: _order = order if hue is not None: _hues = _get_ordered_levels(data=data, level=hue, order=hue_order, x=x) if palette is None: palette = rcParams['palette'] * 5 sns.barplot(data=_df_count, x=x, y=_y, hue=hue, order=_order, hue_order=hue_order, palette=palette, ax=ax, **barplot_kws) ax.set_xlabel('') # cleanup for x=None if x is None: ax.get_xaxis().set_visible(False) if normalize_x: ax.set_ylabel('perc') if hue is None and normalize_hue: ax.set_ylabel('perc') if annotate: # add annotation annotate_barplot(ax, nr_format=annotate_format) # enlarge ylims _ylim = list(ax.get_ylim()) _ylim[1] = _ylim[1] * 1.1 # noinspection PyTypeChecker ax.set_ylim(_ylim) # legend if hide_legend: ax.get_legend().remove() elif hue is not None: legend_outside(ax, loc=legend_loc, loc_warn=False) # tick rotation if x_tick_rotation is not None: ax.xaxis.set_tick_params(rotation=x_tick_rotation) # total count on secaxis if count_twinx: _ax = ax.twinx() _count_twinx_kws_keys = list(count_twinx_kws.keys()) if 'marker' not in _count_twinx_kws_keys: count_twinx_kws['marker'] = '_' if 'color' not in _count_twinx_kws_keys: count_twinx_kws['color'] = 'k' if 'alpha' not in _count_twinx_kws_keys: count_twinx_kws['alpha'] = .5 _ax.scatter(x, _count_x, data=_df_count[[x, _count_x]].drop_duplicates(), **count_twinx_kws) _ax.set_ylabel('count') return ax
[docs]@docstr @export def quantile_plot(x: Union[Sequence, str], data: pd.DataFrame = None, qs: Union[Sequence, float] = None, x2: str = None, hue: str = None, hue_order: Union[Sequence, str] = None, to_abs: bool = False, ax: plt.Axes = None, **kwargs) -> plt.Axes: """ plots the specified quantiles of a Series using seaborn.barplot :param x: %(x)s :param data: %(data)s :param qs: Quantile levels [optional] :param x2: if specified: subtracts x2 from x before calculating quantiles [optional] :param hue: %(hue)s :param hue_order: %(order)s :param to_abs: %(to_abs)s :param ax: %(ax_in)s :param kwargs: other keyword arguments passed to seaborn.barplot :return: %(ax_out)s """ # long or short format if qs is None: qs = [.1, .25, .5, .75, .9] if data is not None: # avoid inplace operations _df = data.copy() if x2 is None: _x = x else: _x = '{} - {}'.format(x, x2) _df[_x] = _df[x] - _df[x2] else: # create dummy df if x2 is None: _df = pd.DataFrame({'x': x}) _x = 'x' else: _df = pd.DataFrame({'x': x, 'x2': x2}).eval('x_delta=x2-x') _x = 'x_delta' if ax is None: ax = plt.gca() _label = _x if to_abs: _df[_x] = _df[_x].abs() _label = '|{}|'.format(_x) if hue is None: _df_q = _df[_x].quantile(qs).reset_index() else: _hues = _get_ordered_levels(data=_df, level=hue, order=hue_order, x=_x) _df_q = [] for _hue in _hues: _df_i = _df[_df[hue] == _hue][_x].quantile(qs).reset_index() _df_i[hue] = _hue _df_q.append(_df_i) _df_q = pd.concat(_df_q, ignore_index=True, sort=False) sns.barplot(x='index', y=_x, data=_df_q, hue=hue, ax=ax, **kwargs) ax.set_xticklabels(['q{}'.format(int(_ * 100)) for _ in qs]) ax.set_xlabel('') ax.set_ylabel(_label) return ax
[docs]@docstr @export def plotly_aggplot(data: pd.DataFrame, x: Scalar, y: Scalar, hue: Scalar = None, groupby: SequenceOrScalar = None, sep: str = ';', agg: str = 'sum', hue_order: Union[list, str] = None, x_min: Scalar = None, x_max: Scalar = None, y_min: Scalar = None, y_max: Scalar = None, mode: str = 'lines+markers', title: str = None, xaxis_title: str = None, yaxis_title: str = None, label_maxchar: int = 15, direction: str = 'up', showactive: bool = True, dropdown_x: float = 0, dropdown_y: float = -.1, fig: go.Figure = None, do_print: bool = True, kws_dropdown: Mapping = None, kws_fig: Mapping = None, **kwargs) -> go.Figure: """ create a (grouped) plotly aggplot that let's you select the groupby categories :param data: %(data)s :param x: %(x_novec)s :param y: %(y_novec)s :param hue: %(hue)s :param groupby: Column name(s) to split the plot by [optional] :param sep: Separator used for groupby columns [optional] :param agg: Aggregate function to use [optional] :param hue_order: %(order)s :param x_min: %(x_min)s :param x_max: %(x_max)s :param y_min: %(y_min)s :param y_max: %(y_max)s :param mode: plotly mode [optional] :param title: %(title_plotly)s :param xaxis_title: %(xaxis_title)s :param yaxis_title: %(yaxis_title)s :param label_maxchar: Maximum allowed number of characters of the labels [optional] :param direction: One of ['up', 'down'] , direction of the dropdown [optional] :param showactive: Whether to show the active selection in the dropdown [optional] :param dropdown_x: x position of the first dropdown [optional] :param dropdown_y: y position of the first dropdown [optional] :param fig: %(fig_plotly)s :param do_print: %(do_print)s :param kws_dropdown: Other keyword arguments passed to the dropdown updatemenu [optional] :param kws_fig: other keyword arguments passed to plotly.graph_objects.Figure [optional] :param kwargs: other keyword arguments passed to plotly.graph_objects.scatter [optional] :return: plotly Figure with the plot on it """ # -- assert if (y_min is not None and y_max is None) or (y_min is None and y_max is not None): raise ValueError('If you supply y_min or y_max you must also supply the other') # -- functions def _get_xy(fltr: tuple = None, hue_i: Scalar = None) -> tuple: _df = data.copy() if hue != '__dummy__': _df = _df[_df[hue] == hue_i] if fltr is not None: for __it, _value in enumerate(fltr): _key = groupby[__it] if _value != '<ALL>': _df = _df[_df[_key] == _value] _df_agg = _df.groupby(x).agg({y: agg}).reset_index() return _df_agg[x], _df_agg[y] # -- init # - no inplace data = pd.DataFrame(data).copy() # - defaults if kws_dropdown is None: kws_dropdown = {} if kws_fig is None: kws_fig = {} if title is None: title = f"{agg} of '{y}' over '{x}'" if groupby is not None: title += f", filtered by '{groupby}'" if groupby is not None: title += f", split by '{hue}'" if xaxis_title is None: xaxis_title = x elif xaxis_title in [False, 'None']: xaxis_title = None if yaxis_title is None: yaxis_title = y elif yaxis_title in [False, 'None']: yaxis_title = None if hue is None: hue = '__dummy__' data[hue] = 1 _hues = [1] else: _hues = _get_ordered_levels(data, hue, hue_order) if fig is None: fig = go.Figure(**kws_fig) # - force_list groupby = assert_list(groupby) # - x_min / x_max if x_min is not None: data = data[data[x] >= x_min] if x_max is not None: data = data[data[x] <= x_max] # -- main # - scatter for _hue in _hues: _x, _y = _get_xy(hue_i=_hue) fig.add_trace(go.Scatter(x=_x, y=_y, mode=mode, name=_hue, **kwargs)) # - concat groupbys _groupby_dict = {} for _groupby in groupby: _groupby_dict[_groupby] = ['<ALL>'] + data[_groupby].drop_duplicates().sort_values().tolist() _groupby_values = list(itertools.product(*list(_groupby_dict.values()))) _len_groupby_values = len(_groupby_values) # - updatemenus _updatemenus = [] _buttons = [] for _it_group, _category in enumerate(_groupby_values): # show progressbar if do_print: progressbar(_it_group, _len_groupby_values) # get x, y by hue _xs = [] _ys = [] for _hue in _hues: _x, _y = _get_xy(fltr=_category, hue_i=_hue) _xs.append(_x) _ys.append(_y) # get label _label = '' for _it_cat, _category_i in enumerate(assert_list(_category)): if _it_cat > 0: _label += sep _label_i = str(_category_i) if len(_label_i) > label_maxchar: _label_i = _label_i[:label_maxchar] + '...' _label += _label_i # create button _buttons.append({ 'method': 'restyle', 'label': _label, 'args': [{'x': _xs, 'y': _ys}] }) # print(_buttons) _updatemenus.append({ 'buttons': _buttons, 'direction': direction, 'showactive': showactive, 'x': dropdown_x, 'y': dropdown_y, **kws_dropdown }) # - fig # noinspection PyUnboundLocalVariable fig.update_layout(updatemenus=_updatemenus) # # - annotation (not properly aligned, therefore dropped for now) # _annotation = sep.join([str(_) for _ in force_list(groupby)]) # _fig.update_layout(annotations=[ # go.layout.Annotation(text=_annotation, showarrow=False, x=dropdown_x, y=dropdown_y+.1, xref="paper", # yref="paper", align="left") # ]) # - title / axis titles fig.update_layout(title=title, xaxis_title=xaxis_title, yaxis_title=yaxis_title) # - y_min / y_max if y_min is not None: fig.update_yaxes(range=[y_min, y_max]) # - final progressbar if do_print: progressbar() # -- return return fig
def cat_to_color(s: pd.Series, palette: SequenceOrScalar = None, out_type: str = None) -> pd.Series: """ Encodes a categorical column as colors of a specified palette :param s: pandas Series :param palette: %(palette)s :param out_type: Color output type, one of %(cat_to_color__out_type); defaults to None (no conversion) [optional] :return: pandas Series of color names """ # -- functions def _to_color(color_index: int): _color = palette[color_index % len(palette)] if out_type == 'hex': _color = mpl_colors.to_hex(_color) elif out_type == 'rgb': _color = mpl_colors.to_rgb(_color) elif out_type == 'rgba': _color = mpl_colors.to_rgba(_color) elif out_type == 'rgba_array': _color = mpl_colors.to_rgba_array(_color) return _color # -- assert # - no inplace s = pd.Series(s).copy() # - out_type if out_type not in validations['cat_to_color__out_type']: raise ValueError(f"out_type must be one of {validations['cat_to_color__out_type']}") # -- init # - defaults if palette is None: palette = rcParams['palette'] palette = assert_list(palette) s = s.astype('category') if len(s.cat.categories) > len(palette): warnings.warn('Not enough colors in palette, colors will be reused') return s.cat.codes.apply(_to_color).astype('category')