From dddac8d30d6f26b43ce6a1340a7643ef00ea60a5 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@stud.unibas.ch>
Date: Mon, 22 Apr 2013 13:27:09 +0200
Subject: [PATCH] reimplementation of PlotBar function

The previous function was hard to understand. Every specified column gets
plotted at one position in the plot. If there are several rows, several
row indices are passed as argument respectively, the values of the
resulting column get plotted in parallel at this position. Note, that
the number off rows is limited to 7.
---
 modules/base/pymod/table.py | 183 ++++++++++++++++++++----------------
 1 file changed, 104 insertions(+), 79 deletions(-)

diff --git a/modules/base/pymod/table.py b/modules/base/pymod/table.py
index 82f65d43d..39a8f40df 100644
--- a/modules/base/pymod/table.py
+++ b/modules/base/pymod/table.py
@@ -1470,51 +1470,68 @@ Statistics for column %(col)s
         max_idx = i
     return max_val, max_idx
 
-  def PlotBar(self, cols, x_labels=None, x_labels_rotation='horizontal', y_title=None, title=None, 
-              colors=None, yerr_cols=None, width=0.8, bottom=0, 
-              legend=True, save=False):
+  def PlotBar(self, cols=None, rows=None, xlabels=None, set_xlabels=True, xlabels_rotation='horizontal', y_title=None, title=None, 
+              colors=None, width=0.8, bottom=0, legend=False, legend_names=None, show=False, save=False):
 
     """
-    Create a barplot of the data in cols. Every element of a column will be represented
-    as a single bar. If there are several columns, each row will be grouped together.
+    Create a barplot of the data in cols. Every column will be represented
+    at one position. If there are several rows, each column will be grouped 
+    together.
 
-    :param cols: Column names with data. If cols is a string, every element of that column
-                 will be represented as a single bar. If cols is a list, every row resulting
-                 of these columns will be grouped together. Every value of the table still
-                 is represented by a single bar.
+    :param cols: List of column names. Every column will be represented as a 
+                 single bar. If cols is None, every column of the table gets 
+                 plotted.
+    :type cols: :class:`list`
 
-    :param x_labels: Label for every row on x-axis.
-    :type x_labels: :class:`list`
-    
-    :param x_labels_rotation: Can either be 'horizontal', 'vertical' or a number that 
-                              describes the rotation in degrees.
+    :param rows: List of row indices. Values from given rows will be plotted 
+                 in parallel at one column position. If set to None, all rows 
+                 of the table will be plotted. Note, that the maximum number 
+                 of rows is 7.
+    :type rows: :class:`list`
+
+    :param xlabels: Label for every col on x-axis. If set to None, the column 
+                    names are used. The xlabel plotting can be supressed by 
+                    the parameter set_xlabel.
+    :type xlabels: :class:`list`
+
+    :param set_xlabels: Controls whether xlabels are plotted or not.
+    :type set_xlabels: :class:`bool`
+
+    :param x_labels_rotation: Can either be 'horizontal', 'vertical' or an 
+                              integer, that describes the rotation in degrees.
 
     :param y_title: Y-axis description
     :type y_title: :class:`str`
 
-    :title: Title
+    :title: Title of the plot. No title appears if set to None
     :type title: :class:`str`
 
-    :param colors: Colors of the different bars in each group. Must be a list of valid
-                   colornames in matplotlib. Length of color and cols must be consistent.
+    :param colors: Colors of the different bars in each group. Must be a list 
+                   of valid colors in matplotlib. Length of color and rows must 
+                   be consistent.
     :type colors: :class:`list`
 
-    :param yerr_cols: Columns containing the y-error information. Can either be a string
-                      if only one column is plotted or a list otherwise. Length of
-                      yerr_cols and cols must be consistent.
-
-    :param width: The available space for the groups on the x-axis is divided by the exact
-                  number of groups. The parameters width is the fraction of what is actually
-                  used. If it would be 1.0 the bars of the different groups would touch each other.
+    :param width: The available space for the groups on the x-axis is divided 
+                  by the exact number of groups. The parameters width is the 
+                  fraction of what is actually used. If it would be 1.0 the 
+                  bars of the different groups would touch each other.
+                  Value must be in [0;1]
     :type width: :class:`float`
 
     :param bottom: Bottom
     :type bottom: :class:`float`
 
-    :param legend: Legend for color explanation, the corresponding column respectively.
+    :param legend: Legend for color explanation, the corresponding row 
+                   respectively. If set to True, legend_names must be provided.
     :type legend: :class:`bool`
 
-    :param save: If set, a png image with name $save in the current working directory will be saved.
+    :param legend_names: List of names, that describe the differently colored 
+                         bars. Length must be consistent with number of rows.
+
+    :param show: If set to True, the plot is directly displayed.
+
+    :param save: If set, a png image with name save in the current working 
+                 directory will be saved.
     :type save: :class:`str`
 
     """
@@ -1522,47 +1539,50 @@ Statistics for column %(col)s
       import numpy as np
       import matplotlib.pyplot as plt
     except:
-      raise ImportError('PlotBar relies on numpy and matplotlib, but I could not import it!')
-    
-    if len(cols)>7:
-      raise ValueError('More than seven bars at one position looks rather meaningless...')
+      raise ImportError('PlotBar relies on numpy and matplotlib, but I could' \
+                        'not import it!')
       
     standard_colors=['b','g','y','c','m','r','k']
     data=[]
-    yerr_data=[]
 
-    if not isinstance(cols, list):
-      cols=[cols]
-      
-    if yerr_cols:
-      if not isinstance(yerr_cols, list):
-        yerr_cols=[yerr_cols]
-      if len(yerr_cols)!=len(cols):
-        raise RuntimeError ('Number of cols and number of error columns must be consistent!')
-      
-    for c in cols:
-      cid=self.GetColIndex(c)
+    if cols==None:
+      cols=self.col_names
+
+    if width<=0 or width>1:
+      raise ValueError('Width must be in [0;1]')
+
+    if rows==None:
+      if len(self.rows)>7:
+        raise ValueError('Table contains too many rows to represent them at one '\
+                         'bar position in parallel. You can Select a Subtable or '\
+                         'specify the parameter rows with a list of row indices '\
+                         '(max 7)')
+      else:
+        rows=range(len(self.rows))
+    else:
+      if not isinstance(rows,list):
+        rows=[rows]
+      if len(rows)>7:
+        raise ValueError('Too many rows to represent (max 7). Please note, that '\
+                         'data from multiple rows from one column gets '\
+                         'represented at one position in parallel.')
+
+    for r_idx in rows:
+      row=self.rows[r_idx] 
       temp=list()
-      for r in self.rows:
-        temp.append(r[cid])
+      for c in cols:
+        try:
+          c_idx=self.GetColIndex(c)
+        except:
+          raise ValueError('Cannot find column with name '+str(c))
+        temp.append(row[c_idx])
       data.append(temp)  
-      
-    if yerr_cols:
-      for c in yerr_cols:
-        cid=self.GetColIndex(c)
-        temp=list()
-        for r in self.rows:
-          temp.append(r[cid])
-        yerr_data.append(temp)
-    else:
-      for i in range(len(cols)):
-        yerr_data.append(None)
 
-    if not colors:
-      colors=standard_colors[:len(cols)]
+    if colors==None:
+      colors=standard_colors[:len(rows)]
 
-    if len(cols)!=len(colors):
-      raise RuntimeError("Number of columns and number of colors must be consistent!")
+    if len(rows)!=len(colors):
+      raise ValueError("Number of rows and number of colors must be consistent!")
 
     ind=np.arange(len(data[0]))
     single_bar_width=float(width)/len(data)
@@ -1570,40 +1590,45 @@ Statistics for column %(col)s
     fig=plt.figure()
     ax=fig.add_subplot(111)
     legend_data=[]
+
     for i in range(len(data)):
-      legend_data.append(ax.bar(ind+i*single_bar_width,data[i],single_bar_width,bottom=bottom,color=colors[i],yerr=yerr_data[i], ecolor='black')[0])
+      legend_data.append(ax.bar(ind+i*single_bar_width+(1-width)/2,data[i],single_bar_width,bottom=bottom,color=colors[i])[0])
       
     if title!=None:
-      nice_title=title
-    else:
-      nice_title="coolest barplot on earth"
-    ax.set_title(nice_title, size='x-large', fontweight='bold')  
+      ax.set_title(title, size='x-large', fontweight='bold')  
     
     if y_title!=None:
       nice_y=y_title
     else:
-      nice_y="score" 
+      nice_y="value" 
     ax.set_ylabel(nice_y)
     
-    if x_labels:
-      if len(data[0])!=len(x_labels):
-        raise ValueError('Number of xlabels is not consistent with number of rows!')
+    if xlabels:
+      if len(data[0])!=len(xlabels):
+        raise ValueError('Number of xlabels is not consistent with number of cols!')
     else:
-      x_labels=list()
-      for i in range(1,len(data[0])+1):
-        x_labels.append('Row '+str(i))
+      xlabels=cols
       
-    ax.set_xticks(ind+width*0.5)
-    ax.set_xticklabels(x_labels, rotation = x_labels_rotation)
+    if set_xlabels:
+      ax.set_xticks(ind+0.5)
+      ax.set_xticklabels(xlabels, rotation = xlabels_rotation)
+    else:
+      ax.set_xticks([])
       
-    if legend:
-      if legend == True:
-        ax.legend(legend_data, cols)   
-      else:
-        ax.legend(legend_data, legend)
-        #pass
+    if legend == True:
+      if legend_names==None:
+        raise ValueError('You must provide legend names! e.g. names for the rows, '\
+                         'that are printed in parallel.')
+      if len(legend_names)!=len(data):
+        raise ValueError('length of legend_names must be consistent with number '\
+                         'of plotted rows!')
+      ax.legend(legend_data, legend_names)   
+
     if save:
       plt.savefig(save)
+
+    if show:
+      plt.show()
     
     return plt
       
-- 
GitLab