|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from pandas._typing import ( |
| 6 | + Dict, |
| 7 | + IndexLabel, |
| 8 | +) |
| 9 | + |
| 10 | +from pandas.core.dtypes.missing import remove_na_arraylike |
| 11 | + |
| 12 | +from pandas import ( |
| 13 | + DataFrame, |
| 14 | + MultiIndex, |
| 15 | + Series, |
| 16 | + concat, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +def create_iter_data_given_by( |
| 21 | + data: DataFrame, kind: str = "hist" |
| 22 | +) -> Dict[str, DataFrame | Series]: |
| 23 | + """ |
| 24 | + Create data for iteration given `by` is assigned or not, and it is only |
| 25 | + used in both hist and boxplot. |
| 26 | +
|
| 27 | + If `by` is assigned, return a dictionary of DataFrames in which the key of |
| 28 | + dictionary is the values in groups. |
| 29 | + If `by` is not assigned, return input as is, and this preserves current |
| 30 | + status of iter_data. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + data : reformatted grouped data from `_compute_plot_data` method. |
| 35 | + kind : str, plot kind. This function is only used for `hist` and `box` plots. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + iter_data : DataFrame or Dictionary of DataFrames |
| 40 | +
|
| 41 | + Examples |
| 42 | + -------- |
| 43 | + If `by` is assigned: |
| 44 | +
|
| 45 | + >>> import numpy as np |
| 46 | + >>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')] |
| 47 | + >>> mi = MultiIndex.from_tuples(tuples) |
| 48 | + >>> value = [[1, 3, np.nan, np.nan], |
| 49 | + ... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]] |
| 50 | + >>> data = DataFrame(value, columns=mi) |
| 51 | + >>> create_iter_data_given_by(data) |
| 52 | + {'h1': DataFrame({'a': [1, 3, np.nan], 'b': [3, 4, np.nan]}), |
| 53 | + 'h2': DataFrame({'a': [np.nan, np.nan, 5], 'b': [np.nan, np.nan, 6]})} |
| 54 | + """ |
| 55 | + |
| 56 | + # For `hist` plot, before transformation, the values in level 0 are values |
| 57 | + # in groups and subplot titles, and later used for column subselection and |
| 58 | + # iteration; For `box` plot, values in level 1 are column names to show, |
| 59 | + # and are used for iteration and as subplots titles. |
| 60 | + if kind == "hist": |
| 61 | + level = 0 |
| 62 | + else: |
| 63 | + level = 1 |
| 64 | + |
| 65 | + # Select sub-columns based on the value of level of MI, and if `by` is |
| 66 | + # assigned, data must be a MI DataFrame |
| 67 | + assert isinstance(data.columns, MultiIndex) |
| 68 | + return { |
| 69 | + col: data.loc[:, data.columns.get_level_values(level) == col] |
| 70 | + for col in data.columns.levels[level] |
| 71 | + } |
| 72 | + |
| 73 | + |
| 74 | +def reconstruct_data_with_by( |
| 75 | + data: DataFrame, by: IndexLabel, cols: IndexLabel |
| 76 | +) -> DataFrame: |
| 77 | + """ |
| 78 | + Internal function to group data, and reassign multiindex column names onto the |
| 79 | + result in order to let grouped data be used in _compute_plot_data method. |
| 80 | +
|
| 81 | + Parameters |
| 82 | + ---------- |
| 83 | + data : Original DataFrame to plot |
| 84 | + by : grouped `by` parameter selected by users |
| 85 | + cols : columns of data set (excluding columns used in `by`) |
| 86 | +
|
| 87 | + Returns |
| 88 | + ------- |
| 89 | + Output is the reconstructed DataFrame with MultiIndex columns. The first level |
| 90 | + of MI is unique values of groups, and second level of MI is the columns |
| 91 | + selected by users. |
| 92 | +
|
| 93 | + Examples |
| 94 | + -------- |
| 95 | + >>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]} |
| 96 | + >>> df = DataFrame(d) |
| 97 | + >>> reconstruct_data_with_by(df, by='h', cols=['a', 'b']) |
| 98 | + h1 h2 |
| 99 | + a b a b |
| 100 | + 0 1 3 NaN NaN |
| 101 | + 1 3 4 NaN NaN |
| 102 | + 2 NaN NaN 5 6 |
| 103 | + """ |
| 104 | + grouped = data.groupby(by) |
| 105 | + |
| 106 | + data_list = [] |
| 107 | + for key, group in grouped: |
| 108 | + columns = MultiIndex.from_product([[key], cols]) |
| 109 | + sub_group = group[cols] |
| 110 | + sub_group.columns = columns |
| 111 | + data_list.append(sub_group) |
| 112 | + |
| 113 | + data = concat(data_list, axis=1) |
| 114 | + return data |
| 115 | + |
| 116 | + |
| 117 | +def reformat_hist_y_given_by( |
| 118 | + y: Series | np.ndarray, by: IndexLabel | None |
| 119 | +) -> Series | np.ndarray: |
| 120 | + """Internal function to reformat y given `by` is applied or not for hist plot. |
| 121 | +
|
| 122 | + If by is None, input y is 1-d with NaN removed; and if by is not None, groupby |
| 123 | + will take place and input y is multi-dimensional array. |
| 124 | + """ |
| 125 | + if by is not None and len(y.shape) > 1: |
| 126 | + return np.array([remove_na_arraylike(col) for col in y.T]).T |
| 127 | + return remove_na_arraylike(y) |
0 commit comments