From 6a97ee87e8d81f772d67c06e703fa2397c32cbf2 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sat, 14 Jul 2018 13:29:53 +0200 Subject: [PATCH] Refactor - Using pysankey from the pypi package Since we added it on pypi, we do not need to copy paste the code and it will be easier to get the last version automatically. We use the latest version, but code will have to change when the new 1.0.0 version is added. (easier way to import it, figure_name to figureName, and removing dependency in the requirements file.) --- requirements.txt | 1 + survey/exporter/tex/question2tex_sankey.py | 2 +- survey/exporter/tex/sankey.py | 225 --------------------- 3 files changed, 2 insertions(+), 226 deletions(-) delete mode 100644 survey/exporter/tex/sankey.py diff --git a/requirements.txt b/requirements.txt index e29c6238..90ebf832 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ django-registration==2.2 # account logic, views and workflows pytz==2017.2 ordereddict==1.1 PyYAML==3.12 +pysankey==0.0.1 matplotlib==2.1.0rc1 seaborn==0.8.1 numpy==1.13.3 diff --git a/survey/exporter/tex/question2tex_sankey.py b/survey/exporter/tex/question2tex_sankey.py index 2d5713b2..747cc032 100755 --- a/survey/exporter/tex/question2tex_sankey.py +++ b/survey/exporter/tex/question2tex_sankey.py @@ -6,7 +6,7 @@ from pandas.core.frame import DataFrame from survey.exporter.tex.question2tex import Question2Tex -from survey.exporter.tex.sankey import sankey +from pySankey.sankey import sankey from survey.models.question import Question LOGGER = logging.getLogger(__name__) diff --git a/survey/exporter/tex/sankey.py b/survey/exporter/tex/sankey.py deleted file mode 100644 index bef34865..00000000 --- a/survey/exporter/tex/sankey.py +++ /dev/null @@ -1,225 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Produces simple Sankey Diagrams with matplotlib. -@author: Anneya Golob & marcomanz & pierre-sassoulas - .-. - .--.( ).--. - <-. .-.-.(.-> )_ .--. - `-`( )-' `) ) - (o o ) `)`-' - ( ) ,) - ( () ) ) - `---"\ , , ,/` - `--' `--' `--' - | | | | - | | | | - ' | ' | - isort:skip_file -""" - -from collections import defaultdict - -import matplotlib - -# matplotlib.use('Agg') must be before the other imports, do not change order -matplotlib.use("Agg") - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns - - -def sankey( - left, - right, - leftWeight=None, - rightWeight=None, - colorDict=None, - leftLabels=None, - rightLabels=None, - aspect=4, - rightColor=False, - fontsize=14, - figure_name="pysankey", -): - """ - Make Sankey Diagram showing flow from left-->right - - Inputs: - left = NumPy array of object labels on the left of the diagram - right = NumPy array of corresponding labels on the right of the diagram - len(right) == len(left) - leftWeight = NumPy array of weights for each strip starting from the - left of the diagram, if not specified 1 is assigned - rightWeight = NumPy array of weights for each strip starting from the - right of the diagram, if not specified the corresponding leftWeight - is assigned - colorDict = Dictionary of colors to use for each label - {'label':'color'} - leftLabels = order of the left labels in the diagram - rightLabels = order of the right labels in the diagram - aspect = vertical extent of the diagram in units of horizontal extent - rightColor = If true, each strip in the diagram will be be colored - according to its left label - Ouput: - None - """ - if leftWeight is None: - leftWeight = [] - if rightWeight is None: - rightWeight = [] - if leftLabels is None: - leftLabels = [] - if rightLabels is None: - rightLabels = [] - # Check weights - if len(leftWeight) == 0: - leftWeight = np.ones(len(left)) - - if len(rightWeight) == 0: - rightWeight = leftWeight - - plt.figure() - plt.rc("text", usetex=False) - plt.rc("font", family="serif") - - # Create Dataframe - df = pd.DataFrame( - { - "left": left, - "right": right, - "leftWeight": leftWeight, - "rightWeight": rightWeight, - }, - index=list(range(len(left))), - ) - - # Identify all labels that appear 'left' or 'right' - allLabels = pd.Series(np.r_[df.left.unique(), df.right.unique()]).unique() - - # Identify left labels - if len(leftLabels) == 0: - leftLabels = pd.Series(df.left.unique()).unique() - - # Identify right labels - if len(rightLabels) == 0: - rightLabels = pd.Series(df.right.unique()).unique() - - # If no colorDict given, make one - if colorDict is None: - colorDict = {} - pal = "hls" - cls = sns.color_palette(pal, len(allLabels)) - for i, l in enumerate(allLabels): - colorDict[l] = cls[i] - - # Determine widths of individual strips - ns_l = defaultdict() - ns_r = defaultdict() - for l in leftLabels: - myD_l = {} - myD_r = {} - for l2 in rightLabels: - myD_l[l2] = df[(df.left == l) & (df.right == l2)].leftWeight.sum() - myD_r[l2] = df[(df.left == l) & (df.right == l2)].rightWeight.sum() - ns_l[l] = myD_l - ns_r[l] = myD_r - - # Determine positions of left label patches and total widths - widths_left = defaultdict() - for i, l in enumerate(leftLabels): - myD = {} - myD["left"] = df[df.left == l].leftWeight.sum() - if i == 0: - myD["bottom"] = 0 - myD["top"] = myD["left"] - else: - myD["bottom"] = widths_left[leftLabels[i - 1]]["top"] + 0.02 * len(df) - myD["top"] = myD["bottom"] + myD["left"] - topEdge = myD["top"] - widths_left[l] = myD - - # Determine positions of right label patches and total widths - widths_right = defaultdict() - for i, l in enumerate(rightLabels): - myD = {} - myD["right"] = df[df.right == l].rightWeight.sum() - if i == 0: - myD["bottom"] = 0 - myD["top"] = myD["right"] - else: - myD["bottom"] = widths_right[rightLabels[i - 1]]["top"] + 0.02 * len(df) - myD["top"] = myD["bottom"] + myD["right"] - topEdge = myD["top"] - widths_right[l] = myD - - # Total vertical extent of diagram - xMax = topEdge / aspect - - # Draw vertical bars on left and right of each label's section & print label - for l in leftLabels: - plt.fill_between( - [-0.02 * xMax, 0], - 2 * [widths_left[l]["bottom"]], - 2 * [widths_left[l]["bottom"] + widths_left[l]["left"]], - color=colorDict[l], - alpha=0.99, - ) - plt.text( - -0.05 * xMax, - widths_left[l]["bottom"] + 0.5 * widths_left[l]["left"], - l, - {"ha": "right", "va": "center"}, - fontsize=fontsize, - ) - for l in rightLabels: - plt.fill_between( - [xMax, 1.02 * xMax], - 2 * [widths_right[l]["bottom"]], - 2 * [widths_right[l]["bottom"] + widths_right[l]["right"]], - color=colorDict[l], - alpha=0.99, - ) - plt.text( - 1.05 * xMax, - widths_right[l]["bottom"] + 0.5 * widths_right[l]["right"], - l, - {"ha": "left", "va": "center"}, - fontsize=fontsize, - ) - - # Plot strips - for l in leftLabels: - for l2 in rightLabels: - lc = l - if rightColor: - lc = l2 - if len(df[(df.left == l) & (df.right == l2)]) > 0: - # Create array of y values for each strip, half at left value, half at right, convolve - ys_d = np.array( - 50 * [widths_left[l]["bottom"]] + 50 * [widths_right[l2]["bottom"]] - ) - ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") - ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") - ys_u = np.array( - 50 * [widths_left[l]["bottom"] + ns_l[l][l2]] - + 50 * [widths_right[l2]["bottom"] + ns_r[l][l2]] - ) - ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") - ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") - - # Update bottom edges at each label so next strip starts at the right place - widths_left[l]["bottom"] += ns_l[l][l2] - widths_right[l2]["bottom"] += ns_r[l][l2] - plt.fill_between( - np.linspace(0, xMax, len(ys_d)), - ys_d, - ys_u, - alpha=0.65, - color=colorDict[lc], - ) - plt.gca().axis("off") - plt.gcf().set_size_inches(6, 6) - plt.savefig("{}.png".format(figure_name), bbox_inches="tight", dpi=150) - plt.close()