Source code for src.aequitas.plotting

import logging
import math
import numpy as np
import collections
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm


from aequitas import squarify_flipped as sf

logging.getLogger(__name__)

__author__ = "Pedro Saleiro <saleiro@uchicago.edu>, Loren Hinkson"
__copyright__ = "Copyright \xa9 2018. The University of Chicago. All Rights Reserved."



# module-level function
[docs]def assemble_ref_groups(disparities_table, ref_group_flag='_ref_group_value', specific_measures=None, label_score_ref=None): """ Creates a dictionary of reference groups for each metric in a data_table. :param disparities_table: a disparity table. Output of bias.get_disparity or fairness.get_fairness functions :param ref_group_flag: string indicating column indicates reference group flag value. Default is '_ref_group_value'. :param specific_measures: Limits reference dictionary to only specified metrics in a data table. Default is None. :param label_score_ref: Defines a metric, ex: 'fpr' (false positive rate) from which to mimic reference group for label_value and score. Used for statistical significance calculations in Bias() class. Default is None. :return: A dictionary """ ref_groups = {} ref_group_cols = \ list(disparities_table.columns[disparities_table.columns.str.contains( ref_group_flag)]) if specific_measures: ref_group_cols = \ [measure + ref_group_flag for measure in specific_measures if measure + ref_group_flag in ref_group_cols] attributes = list(disparities_table.attribute_name.unique()) for attribute in attributes: attr_table = \ disparities_table.loc[disparities_table['attribute_name'] == attribute] attr_refs = {} for col in ref_group_cols: metric_key = "".join(col.split(ref_group_flag)) attr_refs[metric_key] = \ attr_table.loc[attr_table['attribute_name'] == attribute, col].min() if label_score_ref: if label_score_ref + ref_group_flag in ref_group_cols: attr_refs['label_value'] = attr_refs[label_score_ref] attr_refs['score'] = attr_refs[label_score_ref] else: raise ValueError("The specified reference measure for label" " value and score is not included in the " "data frame.") ref_groups[attribute] = attr_refs return ref_groups
# Plot() class
[docs]class Plot(object): """ Plotting object allows for visualization of absolute group bias metrics and relative disparities calculated by Aequitas Group(), Bias(), and Fairness() class instances. """ default_absolute_metrics = ('pprev', 'ppr', 'fdr', 'for', 'fpr', 'fnr') default_disparities = ('pprev_disparity', 'ppr_disparity', 'fdr_disparity', 'for_disparity', 'fpr_disparity', 'fnr_disparity') # Define mapping for conditional coloring based on fairness # determinations _metric_parity_mapping = { 'ppr_disparity': 'Statistical Parity', 'pprev_disparity': 'Impact Parity', 'precision_disparity': 'Precision Parity', 'fdr_disparity': 'FDR Parity', 'for_disparity': 'FOR Parity', 'fpr_disparity': 'FPR Parity', 'fnr_disparity': 'FNR Parity', 'tpr_disparity': 'TPR Parity', 'tnr_disparity': 'TNR Parity', 'npv_disparity': 'NPV Parity', 'ppr': 'Statistical Parity', 'pprev': 'Impact Parity', 'precision': 'Precision Parity', 'fdr': 'FDR Parity', 'for': 'FOR Parity', 'fpr': 'FPR Parity', 'fnr': 'FNR Parity', 'tpr': 'TPR Parity', 'tnr': 'TNR Parity', 'npv': 'NPV Parity' } _significance_disparity_mapping = { 'ppr_disparity': 'ppr_significance', 'pprev_disparity': 'pprev_significance', 'precision_disparity': 'precision_significance', 'fdr_disparity': 'fdr_significance', 'for_disparity': 'fnr_significance', 'fpr_disparity': 'fpr_significance', 'fnr_disparity': 'fnr_significance', 'tpr_disparity': 'tpr_significance', 'tnr_disparity': 'tnr_significance', 'npv_disparity': 'npv_significance' } def __init__(self, key_metrics=default_absolute_metrics, key_disparities=default_disparities): """ :param key_metrics: Set default absolute group metrics for all subplots :param key_disparities: Set default disparity metrics for all subplots """ self.key_metrics = key_metrics self.key_disparities = key_disparities @staticmethod def _nearest_quartile(x): ''' Return nearest quartile for given value x. ''' rounded = round(x * 4) / 4 if rounded > x: return rounded else: return rounded + 1 / 4 @staticmethod def _check_brightness(rgb_tuple): ''' Determine the brightness of background color in a plot. Adapted from https://trendct.org/2016/01/22/how-to-choose-a-label-color-to-contrast-with-background/ ''' r, g, b = rgb_tuple return (r * 299 + g * 587 + b * 114) / 1000 @classmethod def _brightness_threshold(cls, rgb_tuple, min_brightness, light_color, dark_color='black'): ''' Determine ideal plot label color (light or dark) based on brightness of background color based on a given brightness threshold. Adapted from https://trendct.org/2016/01/22/how-to-choose-a-label-color-to-contrast-with-background/ ''' if cls._check_brightness(rgb_tuple) > min_brightness: return dark_color return light_color @staticmethod def _truncate_colormap(orig_cmap, min_value=0.0, max_value=1.0, num_colors=100): ''' Use only part of a colormap (min_value to max_value) across a given number of partitions. :param orig_cmap: an existing Matplotlib colormap. :param min_value: desired minimum value (0.0 to 1.0) for truncated colormap. Default is 0.0. :param max_value: desired maximum value (0.0 to 1.0) for truncated colormap. Default is 1.0. :param num_colors: number of colors to spread colormap gradient across before truncating. Default is 100. :return: Truncated color map Attribution: Adapted from: https://stackoverflow.com/questions/ 18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib ''' cmap = plt.get_cmap(orig_cmap) new_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( 'trunc({n},{a:.2f},{b: .2f})'.format(n=cmap.name, a=min_value, b=max_value), cmap(np.linspace(min_value, max_value, num_colors))) return new_cmap @classmethod def _locate_ref_group_indices(cls, disparities_table, attribute_name, group_metric, ref_group_flag='_ref_group_value'): """ Finds relative index (row) of reference group value for a given metric. :param disparities_table: a disparity table. Output of bias.get_disparity or fairness.get_fairness functions. :param attribute_name: the attribute to plot metric against. Must be a column in the disparities_table. :param group_metric: the metric to plot. Must be a column in the disparities_table. :param ref_group_flag: string indicating column indicates reference group flag value. Default is '_ref_group_value'. :return: Integer indicating relative index of reference group value row. """ df_models = disparities_table.model_id.unique() if len(df_models) == 1: model_id = df_models[0] else: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check disparities_table.model_id.unique() should be just one element list.') # get absolute metric name from passed group metric (vs. a disparity name) abs_metric = "".join(group_metric.split('_disparity')) all_ref_groups = assemble_ref_groups(disparities_table, ref_group_flag) ref_group_name = all_ref_groups[attribute_name][abs_metric] # get index for row associated with reference group for that model ind = list(disparities_table.loc[(disparities_table['attribute_name'] == attribute_name) & (disparities_table['attribute_value'] == ref_group_name) & (disparities_table['model_id'] == model_id)].index) # there should only ever be one item in list, but JIC, select first if len(ind) == 1: idx = ind[0] else: raise ValueError("""failed to find only one index for the reference group for attribute_name = {attribute_name} and attribute_value of reference = {ref_group_name} and model_id={model_id}""".format()) relative_ind = disparities_table.index.get_loc(idx) return relative_ind, ref_group_name
[docs] def plot_group_metric(self, group_table, group_metric, ax=None, ax_lim=None, title=True, label_dict=None, min_group_size = None): """ Plot a single group metric across all attribute groups. :param group_table: group table. Output of of group.get_crosstabs() or bias.get_disparity functions. :param group_metric: the metric to plot. Must be a column in the group_table. :param ax: a matplotlib Axis. If not passed, a new figure will be created. :param title: whether to include a title in visualizations. Default is True. :param label_dict: optional, dictionary of replacement labels for data. Default is None. :param min_group_size: minimum size for groups to include in visualization (as a proportion of total sample) :return: A Matplotlib axis """ df_models = group_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check group_table.model_id.unique() should be just one element list.') if group_metric not in group_table.columns: raise ValueError(f"Specified disparity metric '{group_metric}' not " f"in 'group_table'.") if group_table[group_metric].isnull().any(): raise ValueError(f"Cannot plot {group_metric}, has NaN values.") if ax is None: (_fig, ax) = plt.subplots(figsize=(10, 5)) height_of_bar = 1 attribute_names = group_table.attribute_name.unique() tick_indices = [] next_bar_height = 0 if min_group_size: if min_group_size > (group_table.group_size.max() / group_table.group_size.sum()): raise ValueError(f"'min_group_size' proportion specified: '{min_group_size}' " f"is larger than all groups in sample.") min_size = min_group_size * group_table.group_size.sum() group_table = group_table.loc[group_table['group_size'] >= min_size] label_position_values = collections.deque(group_table[group_metric].values) lighter_coppers = self._truncate_colormap('copper_r', min_value=0, max_value=0.65) norm = matplotlib.colors.Normalize(vmin=group_table['group_size'].min(), vmax=group_table['group_size'].max()) mapping = matplotlib.cm.ScalarMappable(norm=norm, cmap=lighter_coppers) # Lock absolute value metric plot x-axis to (0, 1) if not ax_lim: ax_lim = 1 ax.set_xlim(0, ax_lim) for attribute_name in attribute_names: attribute_data = group_table.loc[ (group_table['attribute_name'] == attribute_name)] values = attribute_data[group_metric].values grp_sizes = attribute_data['group_size'].values attribute_indices = np.arange(next_bar_height, next_bar_height + attribute_data.shape[0], step=height_of_bar) attribute_tick_location = float((min(attribute_indices) + max(attribute_indices) + height_of_bar)) / 2 h_attribute = ax.barh(attribute_indices, width=values, # label=list(attribute_data['attribute_value'].values), align='edge', edgecolor='grey') label_colors = [] min_brightness = 0.55 for bar, g_size in zip(h_attribute, grp_sizes): my_col = mapping.to_rgba(g_size) bar.set_color(my_col) label_colors.append(self._brightness_threshold( rgb_tuple=my_col[:3], min_brightness=min_brightness, light_color=(1, 1, 1, 1))) if label_dict: labels = [label_dict.get(label, label) for label in attribute_data['attribute_value'].values] else: labels = attribute_data['attribute_value'].values for y, label, value, text_color, g_size in zip(attribute_indices, labels, values, label_colors, grp_sizes): next_position = label_position_values.popleft() group_label = f"{label} ({g_size:,})" if ax_lim < 3: CHAR_PLACEHOLDER = 0.03 else: CHAR_PLACEHOLDER = 0.25 label_length = len(label) * CHAR_PLACEHOLDER max_val_length = 7 * CHAR_PLACEHOLDER indent_length = ax_lim * 0.025 # bar long enough for label, enough space after bar for value if ((indent_length + label_length) < (next_position - indent_length)) and ( (next_position + indent_length + max_val_length) < ( ax_lim - indent_length)): ax.text(next_position + indent_length, y + float(height_of_bar) / 2, f"{value:.2f}", fontsize=12, verticalalignment='top') ax.text(indent_length, y + float(height_of_bar) / 2, group_label, fontsize=11, verticalalignment='top', color=text_color) # case when bar too long for labels after bar, print all text in bar elif (next_position + indent_length + max_val_length) > ( ax_lim - indent_length): ax.text(indent_length, y + float(height_of_bar) / 2, f"{group_label}, {value:.2f}", fontsize=11, verticalalignment='top', color=text_color) # case when bar too small for labels inside bar, print after bar else: ax.text(next_position + indent_length, y + float( height_of_bar) / 2, f"{group_label}, {value:.2f}", fontsize=12, verticalalignment='top') tick_indices.append((attribute_name, attribute_tick_location)) next_bar_height = max(attribute_indices) + 2 * height_of_bar ax.yaxis.set_ticks(list(map(lambda x: x[1], tick_indices))) ax.yaxis.set_ticklabels(list(map(lambda x: x[0], tick_indices)), fontsize=14) ax.set_axisbelow(True) ax.xaxis.grid(color='lightgray', which='major',linestyle='dashed') ax.set_xlabel("Absolute Metric Magnitude") if title: ax.set_title(f"{group_metric.upper()}", fontsize=20) return ax
[docs] def plot_disparity(self, disparity_table, group_metric, attribute_name, color_mapping=None, ax=None, fig=None, label_dict=None, title=True, highlight_fairness=False, min_group_size=None, significance_alpha=0.05): """ Create treemap based on a single bias disparity metric across attribute groups. Adapted from https://plot.ly/python/treemaps/, https://gist.github.com/gVallverdu/0b446d0061a785c808dbe79262a37eea, and https://fcpython.com/visualisation/python-treemaps-squarify-matplotlib :param disparity_table: a disparity table. Output of bias.get_disparity or fairness.get_fairness function. :param group_metric: the metric to plot. Must be a column in the disparity_table. :param attribute_name: which attribute to plot group_metric across. :param color_mapping: matplotlib colormapping for treemap value boxes. :param ax: a matplotlib Axis. If not passed, a new figure will be created. :param fig: a matplotlib Figure. If not passed, a new figure will be created. :param label_dict: optional, dictionary of replacement labels for data. Default is None. :param title: whether to include a title in visualizations. Default is True. :param highlight_fairness: whether to highlight treemaps by disparity magnitude, or by related fairness determination. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in bias metric visualization :param significance_alpha: statistical significance level. Used to determine visual representation of significance (number of asterisks on treemap). :return: A Matplotlib axis """ # Use matplotlib to truncate colormap, scale metric values # between the min and max, then assign colors to individual values df_models = disparity_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check disparities_table.model_id.unique() should be just one element list.') table_columns = set(disparity_table.columns) if group_metric not in table_columns: raise ValueError(f"Specified disparity metric {group_metric} not in 'disparity_table'.") attribute_table = \ disparity_table.loc[disparity_table['attribute_name'] == attribute_name] # sort by group size, as box size is indicative of group size sorted_df = attribute_table.sort_values('group_size', ascending=False) x = 0. y = 0. width = 100. height = 100. ref_group_rel_idx, ref_group_name = \ self._locate_ref_group_indices(disparities_table=sorted_df, attribute_name=attribute_name, group_metric=group_metric) if min_group_size: if min_group_size > (disparity_table.group_size.max() / disparity_table.group_size.sum()): raise ValueError(f"'min_group_size' proportion specified: '{min_group_size}' " f"is larger than all groups in sample.") min_size = min_group_size * disparity_table.group_size.sum() # raise warning if minimum group size specified would exclude # reference group if any(sorted_df.loc[(sorted_df['attribute_value']==ref_group_name), ['group_size']].values < min_size): logging.warning( f"Reference group size is smaller than 'min_group_size' proportion " f"specified: '{min_group_size}'. Reference group '{ref_group_name}' " f"was not excluded.") sorted_df = \ sorted_df.loc[(sorted_df['group_size'] >= min_size) | (sorted_df['attribute_value'] == ref_group_name)] # select group size as values for size of boxes values = sorted_df.loc[:, 'group_size'] # get new index for ref group ref_group_rel_idx, _ = \ self._locate_ref_group_indices(disparities_table=sorted_df, attribute_name=attribute_name, group_metric=group_metric) # labels for squares in tree map: # label should always be disparity value (but boxes visualized should be # always be the metric absolute value capped between 0.1x ref group and # 10x ref group) if group_metric + '_disparity' not in attribute_table.columns: related_disparity = group_metric else: related_disparity = group_metric + '_disparity' if highlight_fairness: if not len(table_columns.intersection(self._metric_parity_mapping.values())) > 1: raise ValueError("Data table must include at least one fairness " "determination to visualize metric parity.") # apply red for "False" fairness determinations and green for "True" # determinations cb_green = '#1b7837' cb_red = '#a50026' parity = self._metric_parity_mapping[group_metric] if (parity not in table_columns): raise ValueError( f"Related fairness determination for {group_metric} must be " f"included in data table to color visualization based on " f"metric fairness.") clrs = [cb_green if val else cb_red for val in sorted_df[parity]] else: aq_palette = sns.diverging_palette(225, 35, sep=10, as_cmap=True) if not color_mapping: norm = matplotlib.colors.Normalize(vmin=0, vmax=2) color_mapping = matplotlib.cm.ScalarMappable(norm=norm, cmap=aq_palette) clrs = \ [color_mapping.to_rgba(val) for val in sorted_df[related_disparity]] # color reference group grey clrs[ref_group_rel_idx] = '#D3D3D3' compare_value = values.iloc[ref_group_rel_idx] scaled_values = [(0.1 * compare_value) if val < (0.1 * compare_value) else (10 * compare_value) if val >= (10 * compare_value) else val for val in values] label_values = \ ["(Ref)" if attr_val == ref_group_name else f"{disp:.2f}" for attr_val, disp in zip(sorted_df['attribute_value'], sorted_df[related_disparity]) ] if label_dict: labels = \ [label_dict.get(label, label) for label in sorted_df['attribute_value']] else: labels = sorted_df['attribute_value'].values # if df includes significance columns, add stars to indicate significance if sorted_df.columns[ sorted_df.columns.str.contains('_significance')].value_counts().sum() > 0: # unmasked significance # find indices where related significance have smaller value than significance_alpha if np.issubdtype( sorted_df[ self._significance_disparity_mapping[related_disparity]].dtype, np.number): to_star = sorted_df.loc[ sorted_df[ self._significance_disparity_mapping[related_disparity]] < significance_alpha].index.tolist() # masked significance # find indices where attr values have True value for each of those two columns, else: to_star = sorted_df.loc[ sorted_df[ self._significance_disparity_mapping[related_disparity]] > 0].index.tolist() # add stars to label value where significant for idx in to_star: # convert idx location to relative index in sorted df and label_values list idx_adj = sorted_df.index.get_loc(idx) # star significant disparities in visualizations based on significance level if 0.10 >= significance_alpha > 0.05: significance_stars = '*' elif 0.05 >= significance_alpha > 0.01: significance_stars = '**' elif significance_alpha <= 0.01: significance_stars = '***' else: significance_stars = '' label_values[idx_adj] = label_values[idx_adj] + significance_stars normed = sf.normalize_sizes(scaled_values, width, height) padded_rects = sf.padded_squarify(normed, x, y, width, height) # make plot if not ax or not fig: fig, ax = plt.subplots(figsize=(5, 4)) ax = sf.squarify_plot_rects(padded_rects, color=clrs, labels=labels, values=label_values, ax=ax, alpha=0.8, acronyms=False) # TO DO: build out in next phase (model comparison) # if model_id: # ax.set_title(f"MODEL {model_id}, {(' ').join(group_metric.split('_')).upper()} ({attribute_name.upper()})", # fontsize=23, fontweight="bold") if title: ax.set_title(f"{(' ').join(related_disparity.split('_')).upper()} ({attribute_name.upper()})", fontsize=23, fontweight="bold") if not highlight_fairness: # create dummy invisible image with a color map to leverage for color bar img = plt.imshow([[0, 2]], cmap=aq_palette, alpha=0.8) img.set_visible(False) fig.colorbar(img, orientation="vertical", shrink=.96, ax=ax) # Remove axes and display the plot ax.axis('off')
[docs] def plot_fairness_group(self, fairness_table, group_metric, ax=None, ax_lim=None, title=False, label_dict=None, min_group_size=None): ''' This function plots absolute group metrics as indicated by the config file, colored based on calculated parity. :param fairness_table: a fairness table. Output of fairness.get_fairness function. :param group_metric: the fairness metric to plot. Must be a column in the fairness_table. :param ax: a matplotlib Axis. If not passed a new figure will be created. :param ax_lim: maximum value on x-axis, used to match axes across subplots when plotting multiple metrics. Default is None. :param title: whether to include a title in visualizations. Default is True. :param label_dict: optional dictionary of replacement values for data. Default is None. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in fairness visualization :return: A Matplotlib axis ''' df_models = fairness_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check fairness_table.model_id.unique() should be just one element list.') if group_metric not in fairness_table.columns: raise ValueError(f"Specified disparity metric {group_metric} not " f"in 'fairness_table'.") if fairness_table[group_metric].isnull().any(): raise ValueError(f"Cannot plot {group_metric}, has NaN values.") if ax is None: fig, ax = plt.subplots(figsize=(10, 5)) height_of_bar = 1 attributes = fairness_table.attribute_name.unique() tick_indices = [] next_bar_height = 0 if min_group_size: if min_group_size > (fairness_table.group_size.max() / fairness_table.group_size.sum()): raise ValueError(f"'min_group_size' proportion specified: '{min_group_size}' " f"is larger than all groups in sample.") min_size = min_group_size * fairness_table.group_size.sum() fairness_table = fairness_table.loc[fairness_table['group_size'] >= min_size] label_position_values = collections.deque(fairness_table[group_metric].values) # Lock absolute value metric plot x-axis to (0, 1) if not ax_lim: ax_lim = 1 ax.set_xlim(0, ax_lim) for attribute in attributes: attribute_data = fairness_table.loc[ fairness_table['attribute_name'] == attribute] values = attribute_data[group_metric].values grp_sizes = attribute_data['group_size'].values # apply red for "False" fairness determinations and green for "True" # determinations cb_green = '#1b7837' cb_red = '#a50026' parity = self._metric_parity_mapping[group_metric] parity_colors = [cb_green if val else cb_red for val in attribute_data[parity]] # Set white text for red bars and black text for green bars label_colors = [(0, 0, 0, 1) if val == True else (1, 1, 1, 1) for val in attribute_data[parity]] attribute_indices = \ np.arange(next_bar_height, next_bar_height + attribute_data.shape[0], step=height_of_bar) attribute_tick_location = \ float((min(attribute_indices) + max(attribute_indices) + height_of_bar)) / 2 h_attribute = ax.barh(attribute_indices, width=values, color=parity_colors, align='edge', edgecolor='grey', alpha=0.8) if label_dict: labels = [label_dict.get(label, label) for label in attribute_data['attribute_value'].values] else: labels = attribute_data['attribute_value'].values for y, label, value, text_color, g_size in zip( attribute_indices, labels, values, label_colors, grp_sizes): next_position = label_position_values.popleft() if ax_lim < 3: CHAR_PLACEHOLDER = 0.03 else: CHAR_PLACEHOLDER = 0.25 label_length = len(label) * CHAR_PLACEHOLDER max_val_length = 7 * CHAR_PLACEHOLDER indent_length = ax_lim * 0.025 # bar long enough for label, enough space after bar for value if ((indent_length + label_length) < (next_position - indent_length)) and ( (next_position + indent_length + max_val_length) < ( ax_lim - indent_length)): ax.text(next_position + indent_length, y + float(height_of_bar) / 2, f"{value:.2f}", fontsize=12, verticalalignment='top') ax.text(indent_length, y + float(height_of_bar) / 2, label, fontsize=11, verticalalignment='top', color=text_color) # case when bar too long for labels after bar, print all text in bar elif (next_position + indent_length + max_val_length) > ( ax_lim - indent_length): ax.text(indent_length, y + float(height_of_bar) / 2, f"{label}, {value:.2f}", fontsize=11, verticalalignment='top', color=text_color) # case when bar too small for labels inside bar, print all text # after bar else: ax.text(next_position + indent_length, y + float(height_of_bar) / 2, f"{label}, {value:.2f}", fontsize=12, verticalalignment='top') tick_indices.append((attribute, attribute_tick_location)) next_bar_height = max(attribute_indices) + 2 * height_of_bar ax.yaxis.set_ticks(list(map(lambda x: x[1], tick_indices))) ax.yaxis.set_ticklabels(list(map(lambda x: x[0], tick_indices)), fontsize=14) ax.set_axisbelow(True) ax.xaxis.grid(color='lightgray', which='major', linestyle='dashed') ax.set_xlabel('Absolute Metric Magnitude') if title: ax.set_title(f"{group_metric.upper()}", fontsize=20) return ax
[docs] def plot_fairness_disparity(self, fairness_table, group_metric, attribute_name, ax=None, fig=None, title=True, min_group_size=None, significance_alpha=0.05): """ Plot disparity metrics colored based on calculated disparity. :param group_metric: the metric to plot. Must be a column in the disparity_table. :param attribute_name: which attribute to plot group_metric across. :param ax: a matplotlib Axis. If not passed, a new figure will be created. :param fig: a matplotlib Figure. If not passed, a new figure will be created. :param title: whether to include a title in visualizations. Default is True. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in bias metric visualization :param significance_alpha: statistical significance level. Used to determine visual representation of significance (number of asterisks on treemap). :return: A Matplotlib axis """ df_models = fairness_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check fairness_table.model_id.unique() should be just one element list.') return self.plot_disparity(disparity_table=fairness_table, group_metric=group_metric, attribute_name=attribute_name, color_mapping=None, ax=ax, fig=fig, highlight_fairness=True, min_group_size=min_group_size, title=title, significance_alpha=significance_alpha)
def _plot_multiple(self, data_table, plot_fcn, metrics=None, fillzeros=True, title=True, ncols=3, label_dict=None, show_figure=True, min_group_size=None): """ This function plots bar charts of absolute metrics indicated by config file :param data_table: output of group.get_crosstabs, bias.get_disparity, or fairness.get_fairness functions :param plot_fcn: the single-metric plotting function to use for subplots :param metrics: which metric(s) to plot, or 'all.' If this value is null, will plot the following absolute metrics (or related disparity measures): - Predicted Prevalence (pprev), - Predicted Positive Rate (ppr), - False Discovery Rate (fdr), - False Omission Rate (for), - False Positive Rate (fpr), - False Negative Rate (fnr) :param fillzeros: Should null values be filled with zeros. Default is True. :param title: Whether to display a title on each plot. Default is True. :param ncols: The number of subplots to plot per row visualization figure. Default is 3. :param label_dict: Optional dictionary of label replacements. Default is None. :param show_figure: Whether to show figure (plt.show()). Default is True. :param min_group_size: Minimum proportion of total group size (all data) a population group must meet in order to be included in visualization :return: Returns a figure """ df_models = data_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check data_table.model_id.unique() should be just one element list.') if fillzeros: data_table = data_table.fillna(0) if plot_fcn in [self.plot_fairness_group, self.plot_group_metric]: if not metrics: metrics = \ [met for met in self.key_metrics if met in data_table.columns] elif metrics == 'all': all_abs_metrics = ('pprev', 'ppr', 'fdr', 'for', 'fpr', 'fnr', 'tpr', 'tnr', 'npv', 'precision') metrics = \ [met for met in all_abs_metrics if met in data_table.columns] ax_lim = 1 # elif plot_fcn in [self.plot_fairness_disparity, self.plot_disparity]: else: if not metrics: metrics = \ [disp for disp in self.key_disparities if disp in data_table.columns] elif metrics == 'all': metrics = \ list(data_table.columns[data_table.columns.str.contains('_disparity')]) ax_lim = min(10, self._nearest_quartile(max(data_table[metrics].max())) + 0.1) num_metrics = len(metrics) rows = math.ceil(num_metrics / ncols) if ncols == 1 or (num_metrics % ncols == 0): axes_to_remove = 0 else: axes_to_remove = ncols - (num_metrics % ncols) if not (0 < rows <= num_metrics): raise ValueError ( "Plot must have at least one row. Please update number of columns" " ('ncols') or check that at least one metric is specified in " "'metrics'.") if not (0 < ncols <= num_metrics): raise ValueError( "Plot must have at least one column, and no more columns than " "subplots. Please update number of columns ('ncols') or check " "that at least one metric is specified in 'metrics'.") total_plot_width = 25 fig, axs = plt.subplots(nrows=rows, ncols=ncols, figsize=(total_plot_width, 6 * rows), sharey=True, gridspec_kw={'wspace': 0.075, 'hspace': 0.25}) # set a different metric to be plotted in each subplot ax_col = 0 ax_row = 0 for group_metric in metrics: if (ax_col >= ncols) and ((ax_col + 1) % ncols) == 1: ax_row += 1 ax_col = 0 if rows == 1: current_subplot = axs[ax_col] elif ncols == 1: current_subplot = axs[ax_row] ax_row += 1 else: current_subplot = axs[ax_row, ax_col] plot_fcn(data_table, group_metric=group_metric, ax=current_subplot, ax_lim=ax_lim, title=title, label_dict=label_dict, min_group_size=min_group_size) ax_col += 1 # disable axes not being used if axes_to_remove > 0: for i in np.arange(axes_to_remove): axs[-1, -(i + 1)].axis('off') if show_figure: plt.show() return fig def _plot_multiple_treemaps(self, data_table, plot_fcn, attributes=None, metrics=None, fillzeros=True, title=True, label_dict=None, highlight_fairness=False, show_figure=True, min_group_size=None, significance_alpha=0.05): """ This function plots treemaps of disparities indicated by config file :param data_table: output of bias.get_disparity, or fairness.get_fairness functions :param plot_fcn: Plotting function to use to plot individual disparity or fairness treemaps in grid :param attributes: which attributes to plot against. Must be specified if no metrics specified. :param metrics: which metric(s) to plot, or 'all.' MUST be specified if no attributes specified. If this value is null, the following absolute metrics/ related disparity measures will be plotted against specified attributes: - Predicted Prevalence (pprev), - Predicted Positive Rate (ppr), - False Discovery Rate (fdr), - False Omission Rate (for), - False Positive Rate (fpr), - False Negative Rate (fnr) :param fillzeros: Whether null values should be filled with zeros. Default is True. :param title: Whether to display a title on each plot. Default is True. :param label_dict: Optional dictionary of label replacements. Default is None. :param highlight_fairness: Whether to highlight treemaps by disparity magnitude, or by related fairness determination. :param show_figure: Whether to show figure (plt.show()). Default is True. :param min_group_size: Minimum proportion of total group size (all data) a population group must meet in order to be included in visualization :param significance_alpha: statistical significance level. Used to determine visual representation of significance (number of asterisks on treemap). :return: Returns a figure """ df_models = data_table.model_id.unique() if len(df_models) != 1: raise ValueError('This method requires one and only one model_id in the disparities table. ' 'Tip: check disparities_table.model_id.unique() should be just one element list.') if fillzeros: data_table = data_table.fillna(0) if all(v is None for v in [attributes, metrics]): raise ValueError("One of the following parameters must be specified: " \ "'attribute', 'metrics'.") if attributes: if not metrics: metrics = [abs_m for abs_m in self.key_metrics if abs_m in data_table.columns] # metrics = list(set(self.input_group_metrics) & # set(data_table.columns)) elif metrics == 'all': all_abs_metrics = ['tpr_disparity', 'tnr_disparity', 'for_disparity', 'fdr_disparity', 'fpr_disparity', 'fnr_disparity', 'npv_disparity', 'precision_disparity', 'ppr_disparity', 'pprev_disparity'] metrics = \ [abs_m for abs_m in all_abs_metrics if abs_m in data_table.columns] viz_title = \ f"DISPARITY METRICS by {(', ').join(list(map(lambda x:x.upper(), attributes)))}" elif not attributes: attributes = list(data_table.attribute_name.unique()) if metrics == 'all': all_disparities = ['tpr_disparity', 'tnr_disparity', 'for_disparity', 'fdr_disparity', 'fpr_disparity', 'fnr_disparity', 'npv_disparity', 'precision_disparity', 'ppr_disparity', 'pprev_disparity'] metrics = [disparity for disparity in all_disparities if disparity in data_table.columns] viz_title = f"{(', ').join(map(lambda x:x.upper(), metrics))} " \ f"ACROSS ATTRIBUTES" num_metrics = len(attributes) * len(metrics) if num_metrics > 1: ncols = 3 else: ncols = 1 rows = math.ceil(num_metrics / ncols) if ncols == 1 or (num_metrics % ncols == 0): axes_to_remove = 0 else: axes_to_remove = ncols - (num_metrics % ncols) if not (0 < rows <= num_metrics): raise ValueError ( "Plot must have at least one row. Please update number of columns" " ('ncols'), the list of metrics to be plotted ('metrics'), or " "the list of attributes to plot disparity metrics across.") if not (0 < ncols <= num_metrics): raise ValueError( "Plot must have at least one column, and no more columns than " "plots. Please update number of columns ('ncols'), the list of " "metrics to be plotted ('metrics'), or the list of attributes to " "plot disparity metrics across.") total_plot_width = 25 fig, axs = plt.subplots(nrows=rows, ncols=ncols, figsize=(total_plot_width, 8 * rows), gridspec_kw={'wspace': 0.025, 'hspace': 0.5}, subplot_kw={'aspect': 'equal'}) if highlight_fairness: mapping = None else: aq_palette = sns.diverging_palette(225, 35, sep=10, as_cmap=True) norm = matplotlib.colors.Normalize(vmin=0, vmax=2) mapping = matplotlib.cm.ScalarMappable(norm=norm, cmap=aq_palette) # set a different metric to be plotted in each subplot ax_col = 0 ax_row = 0 for group_metric in metrics: for attr in attributes: if (ax_col >= ncols) and ((ax_col + 1) % ncols) == 1: ax_row += 1 ax_col = 0 if num_metrics == 1: current_subplot = axs elif (num_metrics > 1) and (rows == 1): current_subplot = axs[ax_col] elif (num_metrics > 1) and (ncols == 1): current_subplot = axs[ax_row] ax_row += 1 else: current_subplot = axs[ax_row, ax_col] plot_fcn(data_table, group_metric=group_metric, attribute_name=attr, color_mapping=mapping, ax=current_subplot, fig=fig, title=title, label_dict=label_dict, highlight_fairness=highlight_fairness, min_group_size=min_group_size, significance_alpha=significance_alpha) ax_col += 1 # disable axes not being used if axes_to_remove > 0: for i in np.arange(axes_to_remove): axs[-1, -(i + 1)].axis('off') plt.suptitle(f"{viz_title}", fontsize=25, fontweight="bold") # fig.tight_layout() if rows > 2: fig.subplots_adjust(top=0.95) else: fig.subplots_adjust(top=0.90) if show_figure: plt.show() return fig
[docs] def plot_group_metric_all(self, data_table, metrics=None, fillzeros=True, ncols=3, title=True, label_dict=None, show_figure=True, min_group_size=None): """ Plot multiple metrics at once from a fairness object table. :param data_table: output of group.get_crosstabs, bias.get_disparity, or fairness.get_fairness functions. :param metrics: which metric(s) to plot, or 'all.' If this value is null, will plot: - Predicted Prevalence (pprev), - Predicted Positive Rate (ppr), - False Discovery Rate (fdr), - False Omission Rate (for), - False Positive Rate (fpr), - False Negative Rate (fnr) :param fillzeros: whether to fill null values with zeros. Default is True. :param ncols: number of subplots per row in figure. Default is 3. :param title: whether to display a title on each plot. Default is True. :param label_dict: optional dictionary of label replacements. Default is None. :param show_figure: whether to show figure (plt.show()). Default is True. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in group metric visualization. :return: A Matplotlib figure """ return self._plot_multiple( data_table, plot_fcn=self.plot_group_metric, metrics=metrics, fillzeros=fillzeros, title=title, ncols=ncols, label_dict=label_dict, show_figure=show_figure, min_group_size=min_group_size)
[docs] def plot_disparity_all(self, data_table, attributes=None, metrics=None, fillzeros=True, title=True, label_dict=None, ncols=3, show_figure=True, min_group_size=None, significance_alpha=0.05): """ Plot multiple metrics at once from a fairness object table. :param data_table: output of group.get_crosstabs, bias.get_disparity, or fairness.get_fairness functions. :param attributes: which attribute(s) to plot metrics for. If this value is null, will plot metrics against all attributes. :param metrics: which metric(s) to plot, or 'all.' If this value is null, will plot: - Predicted Prevalence Disparity (pprev_disparity), - Predicted Positive Rate Disparity (ppr_disparity), - False Discovery Rate Disparity (fdr_disparity), - False Omission Rate Disparity (for_disparity), - False Positive Rate Disparity (fpr_disparity), - False Negative Rate Disparity (fnr_disparity) :param fillzeros: whether to fill null values with zeros. Default is True. :param title: whether to display a title on each plot. Default is True. :param label_dict: optional dictionary of label replacements. Default is None. :param show_figure: whether to show figure (plt.show()). Default is True. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in metric visualization. :param significance_alpha: statistical significance level. Used to determine visual representation of significance (number of asterisks on treemap). :return: A Matplotlib figure """ return self._plot_multiple_treemaps( data_table, plot_fcn=self.plot_disparity, attributes=attributes, metrics=metrics, fillzeros=fillzeros, label_dict=label_dict, highlight_fairness=False, show_figure=show_figure, title=title, min_group_size=min_group_size, significance_alpha=significance_alpha)
[docs] def plot_fairness_group_all(self, fairness_table, metrics=None, fillzeros=True, ncols=3, title=True, label_dict=None, show_figure=True, min_group_size=None): """ Plot multiple metrics at once from a fairness object table. :param fairness_table: output of fairness.get_fairness functions. :param metrics: which metric(s) to plot, or 'all.' If this value is null, will plot: - Predicted Prevalence (pprev), - Predicted Positive Rate (ppr), - False Discovery Rate (fdr), - False Omission Rate (for), - False Positive Rate (fpr), - False Negative Rate (fnr) :param fillzeros: whether to fill null values with zeros. Default is True. :param ncols: number of subplots per row in figure. Default is 3. :param title: whether to display a title on each plot. Default is True. :param label_dict: optional dictionary of label replacements. Default is None. :param show_figure: whether to show figure (plt.show()). Default is True. :param min_group_size: minimum proportion of total group size (all data). a population group must meet in order to be included in fairness visualization :return: A Matplotlib figure """ return self._plot_multiple( fairness_table, plot_fcn=self.plot_fairness_group, metrics=metrics, fillzeros=fillzeros, title=title, ncols=ncols, label_dict=label_dict, show_figure=show_figure, min_group_size=min_group_size)
[docs] def plot_fairness_disparity_all(self, fairness_table, attributes=None, metrics=None, fillzeros=True, title=True, label_dict=None, show_figure=True, min_group_size=None, significance_alpha=0.05): """ Plot multiple metrics at once from a fairness object table. :param fairness_table: output of fairness.get_fairness functions. :param attributes: which attribute(s) to plot metrics for. If this value is null, will plot metrics against all attributes. :param metrics: which metric(s) to plot, or 'all.' If this value is null, will plot: - Predicted Prevalence Disparity (pprev_disparity), - Predicted Positive Rate Disparity (ppr_disparity), - False Discovery Rate Disparity (fdr_disparity), - False Omission Rate Disparity (for_disparity), - False Positive Rate Disparity (fpr_disparity), - False Negative Rate Disparity (fnr_disparity) :param fillzeros: whether to fill null values with zeros. Default is True. :param title: whether to display a title on each plot. Default is True. :param label_dict: optional dictionary of label replacements. Default is None. :param show_figure: whether to show figure (plt.show()). Default is True. :param min_group_size: minimum proportion of total group size (all data) a population group must meet in order to be included in fairness visualization :param significance_alpha: statistical significance level. Used to determine visual representation of significance (number of asterisks on treemap) :return: A Matplotlib figure """ return self._plot_multiple_treemaps( fairness_table, plot_fcn=self.plot_disparity, attributes=attributes, metrics=metrics, fillzeros=fillzeros, label_dict=label_dict, title=title, highlight_fairness=True, show_figure=show_figure, min_group_size=min_group_size, significance_alpha=significance_alpha)