diff --git a/.gitignore b/.gitignore index 5e5a8ec8915e2268d121d721c1218afe4771074e..a97cae9c231396cee48897d3a2b6f3522cd8ba1d 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,5 @@ RelWithDebInfo Debug *.cxx_parameters /deployment/win/create_archive.bat -/install_manifest.txt \ No newline at end of file +/install_manifest.txt +*_out.csv \ No newline at end of file diff --git a/modules/base/pymod/CMakeLists.txt b/modules/base/pymod/CMakeLists.txt index 22700288ae061f64b45fb741e13b5044b2bf8977..3f4f748a509f9e43796cb4bde0edf2c5cc162fc8 100644 --- a/modules/base/pymod/CMakeLists.txt +++ b/modules/base/pymod/CMakeLists.txt @@ -7,4 +7,4 @@ set(OST_BASE_PYMOD_SOURCES pymod(NAME base OUTPUT_DIR ost CPP ${OST_BASE_PYMOD_SOURCES} - PY __init__.py settings.py stutil.py) + PY __init__.py settings.py stutil.py table.py) diff --git a/modules/base/pymod/table.py b/modules/base/pymod/table.py new file mode 100644 index 0000000000000000000000000000000000000000..eb90be77cf7c19a917d2012e91e66b3a4d4f2b8f --- /dev/null +++ b/modules/base/pymod/table.py @@ -0,0 +1,1198 @@ +import csv +import re +from ost import stutil +import itertools +import operator +from ost import LogError, LogWarning, LogInfo, LogVerbose + +def MakeTitle(col_name): + return col_name.replace('_', ' ') + +def IsStringLike(value): + if isinstance(value, TableCol) or isinstance(value, BinaryColExpr): + return False + try: + value+'' + return True + except: + return False + +def IsScalar(value): + if IsStringLike(value): + return True + try: + if isinstance(value, TableCol) or isinstance(value, BinaryColExpr): + return False + iter(value) + return False + except: + return True + +class BinaryColExpr: + def __init__(self, op, lhs, rhs): + self.op=op + self.lhs=lhs + self.rhs=rhs + if IsScalar(lhs): + self.lhs=itertools.cyle([self.lhs]) + if IsScalar(rhs): + self.rhs=itertools.cycle([self.rhs]) + def __iter__(self): + for l, r in zip(self.lhs, self.rhs): + if l!=None and r!=None: + yield self.op(l, r) + else: + yield None + def __add__(self, rhs): + return BinaryColExpr(operator.add, self, rhs) + + def __sub__(self, rhs): + return BinaryColExpr(operator.sub, self, rhs) + + def __mul__(self, rhs): + return BinaryColExpr(operator.mul, self, rhs) + +class TableCol: + def __init__(self, table, col): + self._table=table + if type(col)==str: + self.col_index=self._table.GetColIndex(col) + else: + self.col_index=col + + def __iter__(self): + for row in self._table.rows: + yield row[self.col_index] + + def __len__(self): + return len(self._table.rows) + + def __getitem__(self, index): + return self._table.rows[index][self.col_index] + + def __setitem__(self, index, value): + self._table.rows[index][self.col_index]=value + + def __add__(self, rhs): + return BinaryColExpr(operator.add, self, rhs) + + def __sub__(self, rhs): + return BinaryColExpr(operator.sub, self, rhs) + + def __mul__(self, rhs): + return BinaryColExpr(operator.mul, self, rhs) + def __div__(self, rhs): + return BinaryColExpr(operator.div, self, rhs) + + +class Table: + """ + + The table class provides convenient access to data in tabular form. An empty + table can be easily constructed as follows + + .. code-block:: python + + tab=Table() + + If you want to add columns directly when creating the table, column names + and column types can be specified as follows + + .. code-block:: python + + tab=Table(['nameX','nameY','nameZ'], 'sfb') + + this will create three columns called nameX, nameY and nameZ of type string, + float and bool, respectively. There will be no data in the table and thus, + the table will not contain any rows. + + If you want to add data to the table in addition, use the following: + + .. code-block:: python + + tab=Table(['nameX','nameY','nameZ'], + 'sfb', + nameX=['a','b','c'], + nameY=[0.1, 1.2, 3.414], + nameZ=[True, False, False]) + + if values for one column is left out, they will be filled with NA, but if + values are specified, all values must be specified (i.e. same number of + values per column) + + """ + + SUPPORTED_TYPES=('int', 'float', 'bool', 'string',) + + + def __init__(self, col_names=None, col_types=None, **kwargs): + self.col_names=col_names + self.comment='' + self.name='' + + self.col_types = self._ParseColTypes(col_types) + self.rows=[] + if len(kwargs)>=0: + if not col_names: + self.col_names=[v for v in kwargs.keys()] + if not self.col_types: + self.col_types=['string' for u in range(len(self.col_names))] + if len(kwargs)>0: + self._AddRowsFromDict(kwargs) + + @staticmethod + def _ParseColTypes(types, exp_num=None): + if types==None: + return None + + short2long = {'s' : 'string', 'i': 'int', 'b' : 'bool', 'f' : 'float'} + allowed_short = short2long.keys() + allowed_long = short2long.values() + + type_list = [] + + # string type + if IsScalar(types): + if type(types)==str: + types = types.lower() + + # single value + if types in allowed_long: + type_list.append(types) + elif types in allowed_short: + type_list.append(short2long[types]) + + # comma separated list of long or short types + elif types.find(',')!=-1: + for t in types.split(','): + if t in allowed_long: + type_list.append(t) + elif t in allowed_short: + type_list.append(short2long[t]) + else: + raise ValueError('Unknown type %s in types %s'%(t,types)) + + # string of short types + else: + for t in types: + if t in allowed_short: + type_list.append(short2long[t]) + else: + raise ValueError('Unknown type %s in types %s'%(t,types)) + + # non-string type + else: + raise ValueError('Col type %s must be string or list'%types) + + # list type + else: + for t in types: + # must be string type + if type(t)==str: + t = t.lower() + if t in allowed_long: + type_list.append(t) + elif t in allowed_short: + type_list.append(short2long[t]) + else: + raise ValueError('Unknown type %s in types %s'%(t,types)) + + # non-string type + else: + raise ValueError('Col type %s must be string or list'%types) + + if exp_num: + if len(type_list)!=exp_num: + raise ValueError('Parsed number of col types (%i) differs from ' + \ + 'expected (%i) in types %s'%(len(type_list),exp_num,types)) + + return type_list + + def SetName(self, name): + self.name = name + + def GetName(self): + return self.name + + def _Coerce(self, value, ty): + if value=='NA' or value==None: + return None + if ty=='int': + return int(value) + if ty=='float': + return float(value) + if ty=='string': + return str(value) + if ty=='bool': + if isinstance(value, str) or isinstance(value, unicode): + if value.upper() in ('FALSE', 'NO',): + return False + return True + return bool(value) + raise ValueError('Unknown type %s' % ty) + + def GetColIndex(self, col): + if col not in self.col_names: + raise ValueError('Table has no column named "%s"' % col) + return self.col_names.index(col) + + def HasCol(self, col): + return col in self.col_names + + def __getitem__(self, k): + if type(k)==int: + return TableCol(self, self.col_names[k]) + else: + return TableCol(self, k) + + def __setitem__(self, k, value): + col_index=k + if type(k)!=int: + col_index=self.GetColIndex(k) + if IsScalar(value): + value=itertools.cycle([value]) + for r, v in zip(self.rows, value): + r[col_index]=v + + def ToString(self, float_format='%.3f', int_format='%d', rows=None): + widths=[len(cn) for cn in self.col_names] + sel_rows=self.rows + if rows: + sel_rows=self.rows[rows[0]:rows[1]] + for row in sel_rows: + for i, (ty, col) in enumerate(zip(self.col_types, row)): + if col==None: + widths[i]=max(widths[i], len('NA')) + elif ty=='float': + widths[i]=max(widths[i], len(float_format % col)) + elif ty=='int': + widths[i]=max(widths[i], len(int_format % col)) + else: + widths[i]=max(widths[i], len(str(col))) + s='' + if self.comment: + s+=''.join(['# %s\n' % l for l in self.comment.split('\n')]) + total_width=sum(widths)+2*len(widths) + for width, col_name in zip(widths, self.col_names): + s+=col_name.center(width+2) + s+='\n%s\n' % ('-'*total_width) + for row in sel_rows: + for width, ty, col in zip(widths, self.col_types, row): + cs='' + if col==None: + cs='NA'.center(width+2) + elif ty=='float': + cs=(float_format % col).rjust(width+2) + elif ty=='int': + cs=(int_format % col).rjust(width+2) + else: + cs=' '+str(col).ljust(width+1) + s+=cs + s+='\n' + return s + + def __str__(self): + return self.ToString() + + def _AddRowsFromDict(self, d, merge=False): + # get column indices + idxs = [self.GetColIndex(k) for k in d.keys()] + + # convert scalar values to list + old_len = None + for k,v in d.iteritems(): + if IsScalar(v): + d[k] = [v] + else: + if not old_len: + old_len = len(v) + elif old_len!=len(v): + raise ValueError("Cannot add rows: length of data must be equal " + \ + "for all columns in %s"%str(d)) + + # convert column based dict to row based dict and create row and add data + for i,data in enumerate(zip(*d.values())): + new_row = [None for a in range(len(self.col_names))] + for idx,v in zip(idxs,data): + new_row[idx] = self._Coerce(v, self.col_types[idx]) + + # partially overwrite existing row with new data + if merge: + merge_idx = self.GetColIndex(merge) + added = False + for i,r in enumerate(self.rows): + if r[merge_idx]==new_row[merge_idx]: + for j,e in enumerate(self.rows[i]): + if new_row[j]==None: + new_row[j] = e + self.rows[i] = new_row + added = True + break + + # if not merge or merge did not find appropriate row + if not merge or not added: + self.rows.append(new_row) + + + def AddRow(self, data, merge=None): + """ + Add a row to the table. *row* may either a dictionary in which case the keys + in the dictionary must match the column names. Columns not found in the dict + will be initialized to None. Alternatively, if data is a list-like object, + the row is initialized from the values in data. The number of items in data + must match the number of columns in the table. A :class:`ValuerError` is + raised otherwise. + """ + if type(data)==dict: + self._AddRowsFromDict(data, merge) + else: + if len(data)!=len(self.col_names): + msg='data array must have %d elements, not %d' + raise ValueError(msg % (len(self.col_names), len(self.data))) + new_row = [self._Coerce(v, t) for v, t in zip(data, self.col_types)] + + # fully overwrite existing row with new data + if merge: + merge_idx = self.GetColIndex(merge) + added = False + for i,r in enumerate(self.rows): + if r[merge_idx]==new_row[merge_idx]: + self.rows[i] = new_row + added = True + break + + # if not merge or merge did not find appropriate row + if not merge or not added: + self.rows.append(new_row) + + def RemoveCol(self, col): + """ + Remove column with the given name from the table + """ + idx = self.GetColIndex(col) + del self.col_names[idx] + del self.col_types[idx] + for row in self.rows: + del row[idx] + + def AddCol(self, col_name, col_type, data=None): + """ + Add a column to the right of the table. + + .. code-block:: python + + tab=Table(['x'], 'f', x=range(5)) + tab.AddCol('even', 'bool', itertools.cycle([True, False])) + print tab + + will produce the table + + ==== ==== + x even + ==== ==== + 0 True + 1 False + 2 True + 3 False + 4 True + ==== ==== + + if data is a constant instead of an iterable object, it's value + will be written into each row + """ + col_type = self._ParseColTypes(col_type, exp_num=1)[0] + self.col_names.append(col_name) + self.col_types.append(col_type) + if IsScalar(data): + for row in self.rows: + row.append(data) + else: + for row, d in zip(self.rows, data): + row.append(d) + + + + def Filter(self, *args, **kwargs): + """ + Returns a filtered table only containing rows matching all the predicates + in kwargs and args For example, + + .. code-block:: python + + tab.Filter(town='Basel') + + will return all the rows where the value of the column "town" is equal to + "Basel". Several predicates may be combined, i.e. + + .. code-block:: python + + tab.Filter(town='Basel', male=True) + + will return the rows with "town" equal to "Basel" and "male" equal to true. + args are unary callables returning true if the row should be included in the + result and false if not. + """ + filt_tab=Table(self.col_names, self.col_types) + for row in self.rows: + matches=True + for func in args: + if not func(row): + matches=False + break + for key, val in kwargs.iteritems(): + if row[self.GetColIndex(key)]!=val: + matches=False + break + if matches: + filt_tab.AddRow(row) + return filt_tab + + @staticmethod + def Load(stream): + fieldname_pattern=re.compile(r'(?P<name>[A-Za-z0-9_]+)(\[(?P<type>\w+)\])?') + if not hasattr(stream, 'read'): + stream=open(stream, 'r') + header=False + num_lines=0 + for line in stream: + line=line.strip() + if line.startswith('#'): + continue + if len(line)==0: + continue + num_lines+=1 + if not header: + fieldnames=[] + fieldtypes=[] + for col in line.split(): + match=fieldname_pattern.match(col) + if match: + if match.group('type'): + fieldtypes.append(match.group('type')) + else: + fieldtypes.append('str') + fieldnames.append(match.group('name')) + tab=Table(fieldnames, fieldtypes) + header=True + continue + tab.AddRow(line.split()) + if num_lines==0: + raise IOError("Cannot read table from empty stream") + return tab + + def Sort(self, by, order='+'): + """ + Performs an in-place sort of the table, based on column. + """ + sign=-1 + if order=='-': + sign=1 + key_index=self.GetColIndex(by) + def _key_cmp(lhs, rhs): + return sign*cmp(lhs[key_index], rhs[key_index]) + self.rows=sorted(self.rows, _key_cmp) + + + def Plot(self, x, y=None, z=None, style='.', x_title=None, y_title=None, + z_title=None, x_range=None, y_range=None, z_range=None, + num_z_levels=10, diag_line=False, labels=None, title=None, + clear=True, save=False): + """ + Plot x against y using matplot lib + """ + try: + import matplotlib.pyplot as plt + import matplotlib.mlab as mlab + import numpy as np + idx1 = self.GetColIndex(x) + xs = [] + ys = [] + zs = [] + + if clear: + plt.clf() + + if x_title: + nice_x=x_title + else: + nice_x=MakeTitle(x) + + if y_title: + nice_y=y_title + else: + if y: + nice_y=MakeTitle(y) + else: + nice_y=None + + if z_title: + nice_z = z_title + else: + if z: + nice_z = MakeTitle(z) + else: + nice_z = None + + if y and z: + idx3 = self.GetColIndex(z) + idx2 = self.GetColIndex(y) + for row in self.rows: + if row[idx1]!=None and row[idx2]!=None and row[idx3]!=None: + xs.append(row[idx1]) + ys.append(row[idx2]) + zs.append(row[idx3]) + levels = [] + if z_range: + z_spacing = (z_range[1] - z_range[0]) / num_z_levels + l = z_range[0] + else: + l = self.Min(z) + z_spacing = (self.Max(z) - l) / num_z_levels + + for i in range(0,num_z_levels+1): + levels.append(l) + l += z_spacing + + xi = np.linspace(min(xs)-0.1,max(xs)+0.1,len(xs)*10) + yi = np.linspace(min(ys)-0.1,max(ys)+0.1,len(ys)*10) + zi = mlab.griddata(xs, ys, zs, xi, yi, interp='linear') + + plt.contour(xi,yi,zi,levels,linewidths=0.5,colors='k') + plt.contourf(xi,yi,zi,levels,cmap=plt.cm.jet) + plt.colorbar(ticks=levels) + + elif y: + idx2=self.GetColIndex(y) + for row in self.rows: + if row[idx1]!=None and row[idx2]!=None: + xs.append(row[idx1]) + ys.append(row[idx2]) + plt.plot(xs, ys, style) + + else: + label_vals=[] + + if labels: + label_idx=self.GetColIndex(labels) + for row in self.rows: + if row[idx1]!=None: + xs.append(row[idx1]) + if labels: + label_vals.append(row[label_idx]) + plt.plot(xs, style) + if labels: + plt.xticks(np.arange(len(xs)), label_vals, rotation=45, size='x-small') + + if not title: + if nice_z: + title = '%s of %s vs. %s' % (nice_z, nice_x, nice_y) + elif nice_y: + title = '%s vs. %s' % (nice_x, nice_y) + else: + title = nice_x + + plt.title(title, size='x-large', fontweight='bold') + if x and y: + plt.xlabel(nice_x, size='x-large') + if x_range: + plt.xlim(x_range[0], x_range[1]) + if y_range: + plt.ylim(y_range[0], y_range[1]) + if diag_line: + plt.plot(x_range, y_range, '-') + + plt.ylabel(nice_y, size='x-large') + else: + plt.ylabel(nice_x, size='x-large') + if save: + plt.savefig(save) + return plt + except ImportError: + LogError("Function needs numpy and matplotlib, but I could not import it.") + raise + + def PlotHistogram(self, col, x_range=None, num_bins=10, normed=False, + histtype='stepfilled', align='mid', x_title=None, + y_title=None, title=None, clear=True, save=False): + """ + Create a histogram of the data in col for the range x_range, split into + num_bins bins and plot it using matplot lib + """ + try: + import matplotlib.pyplot as plt + import numpy as np + + if len(self.rows)==0: + return None + + idx = self.GetColIndex(col) + data = [] + for r in self.rows: + if r[idx]!=None: + data.append(r[idx]) + + if clear: + plt.clf() + + n, bins, patches = plt.hist(data, bins=num_bins, range=x_range, + normed=normed, histtype=histtype, align=align) + + if x_title: + nice_x=x_title + else: + nice_x=MakeTitle(col) + plt.xlabel(nice_x, size='x-large') + + if y_title: + nice_y=y_title + else: + nice_y="bin count" + plt.ylabel(nice_y, size='x-large') + + if title: + nice_title=title + else: + nice_title="Histogram of %s"%nice_x + plt.title(nice_title, size='x-large', fontweight='bold') + + if save: + plt.savefig(save) + return plt + except ImportError: + LogError("Function needs numpy and matplotlib, but I could not import it.") + raise + + def _Max(self, col): + if len(self.rows)==0: + return None, None + idx = self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type=='int' or col_type=='float': + max_val = -float('inf') + elif col_type=='bool': + max_val = False + elif col_type=='string': + max_val = chr(0) + max_idx = None + for i in range(0, len(self.rows)): + if self.rows[i][idx]>max_val: + max_val = self.rows[i][idx] + max_idx = i + return max_val, max_idx + + def MaxRow(self, col): + """ + Returns the row containing the cell with the maximal value in col. If + several rows have the highest value, only the first one is returned. + None values are ignored. + """ + val, idx = self._Max(col) + return self.rows[idx] + + def Max(self, col): + """ + Returns the maximum value in col. If several rows have the highest value, + only the first one is returned. None values are ignored. + """ + val, idx = self._Max(col) + return val + + def MaxIdx(self, col): + """ + Returns the row index of the cell with the maximal value in col. If + several rows have the highest value, only the first one is returned. + None values are ignored. + """ + val, idx = self._Max(col) + return idx + + def _Min(self, col): + if len(self.rows)==0: + return None, None + idx=self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type=='int' or col_type=='float': + min_val=float('inf') + elif col_type=='bool': + min_val=True + elif col_type=='string': + min_val=chr(255) + min_idx=None + for i,row in enumerate(self.rows): + if row[idx]!=None and row[idx]<min_val: + min_val=row[idx] + min_idx=i + return min_val, min_idx + + def Min(self, col): + """ + Returns the minimal value in col. If several rows have the lowest value, + only the first one is returned. None values are ignored. + """ + val, idx = self._Min(col) + return val + + def MinRow(self, col): + """ + Returns the row containing the cell with the minimal value in col. If + several rows have the lowest value, only the first one is returned. + None values are ignored. + """ + val, idx = self._Min(col) + return self.rows[idx] + + def MinIdx(self, col): + """ + Returns the row index of the cell with the minimal value in col. If + several rows have the lowest value, only the first one is returned. + None values are ignored. + """ + val, idx = self._Min(col) + return idx + + def Sum(self, col): + """ + Returns the sum of the given column. Cells with None are ignored. Returns + 0.0, if the column doesn't contain any elements. + """ + idx = self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type!='int' and col_type!='float': + raise TypeError("Sum can only be used on numeric column types") + s = 0.0 + for r in self.rows: + if r[idx]!=None: + s += r[idx] + return s + + def Mean(self, col): + """ + Returns the mean of the given column. Cells with None are ignored. Returns + None, if the column doesn't contain any elements. + """ + idx = self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type!='int' and col_type!='float': + raise TypeError("Mean can only be used on numeric column types") + + vals=[] + for v in self[col]: + if v!=None: + vals.append(v) + try: + return stutil.Mean(vals) + except: + return None + + def Median(self, col): + """ + Returns the median of the given column. Cells with None are ignored. Returns + None, if the column doesn't contain any elements. + """ + idx = self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type!='int' and col_type!='float': + raise TypeError("Mean can only be used on numeric column types") + + vals=[] + for v in self[col]: + if v!=None: + vals.append(v) + stutil.Median(vals) + try: + return stutil.Median(vals) + except: + return None + + def StdDev(self, col): + """ + Returns the standard deviation of the given column. Cells with None are + ignored. Returns None, if the column doesn't contain any elements. + """ + idx = self.GetColIndex(col) + col_type = self.col_types[idx] + if col_type!='int' and col_type!='float': + raise TypeError("Mean can only be used on numeric column types") + + vals=[] + for v in self[col]: + if v!=None: + vals.append(v) + try: + return stutil.StdDev(vals) + except: + return None + + def Count(self, col): + """ + Count the number of cells in column that are not equal to None. + """ + count=0 + idx=self.GetColIndex(col) + for r in self.rows: + if r[idx]!=None: + count+=1 + return count + + def Correl(self, col1, col2): + """ + Calculate the Pearson correlation coefficient between col1 and col2, only + taking rows into account where both of the values are not equal to None. If + there are not enough data points to calculate a correlation coefficient, + None is returned. + """ + vals1, vals2=([],[]) + for v1, v2 in zip(self[col1], self[col2]): + if v1!=None and v2!=None: + vals1.append(v1) + vals2.append(v2) + try: + return stutil.Correl(vals1, vals2) + except: + return None + + def Save(self, stream): + """ + Save the table to stream or filename + """ + if hasattr(stream, 'write'): + writer=csv.writer(stream, delimiter=' ') + else: + stream=open(stream, 'w') + writer=csv.writer(stream, delimiter=' ') + if self.comment: + stream.write(''.join(['# %s\n' % l for l in self.comment.split('\n')])) + writer.writerow(['%s[%s]' % t for t in zip(self.col_names, self.col_types)]) + for row in self.rows: + row=list(row) + for i, c in enumerate(row): + if c==None: + row[i]='NA' + writer.writerow(row) + + + def GetNumpyMatrix(self, *args): + ''' + Returns a numpy matrix containing the selected columns from the table as + columns in the matrix. + Only columns of type int or float are supported. NA values in the table + will be converted to None values. + ''' + try: + import numpy as np + + if len(args)==0: + raise RuntimeError("At least one column must be specified.") + + idxs = [] + for arg in args: + idx = self.GetColIndex(arg) + col_type = self.col_types[idx] + if col_type!='int' and col_type!='float': + raise TypeError("Numpy matrix can only be generated from numeric column types") + idxs.append(idx) + m = np.matrix([list(self[i]) for i in idxs]) + return m.T + + except ImportError: + LogError("Function needs numpy, but I could not import it.") + raise + + def GetOptimalPrefactors(self, ref_col, *args, **kwargs): + ''' + This returns the optimal prefactor values (i.e. a, b, c, ...) for the + following equation + + .. math:: + :label: op1 + + a*u + b*v + c*w + ... = z + + where u, v, w and z are vectors. In matrix notation + + .. math:: + :label: op2 + + A*p = z + + where A contains the data from the table (u,v,w,...), p are the prefactors + to optimize (a,b,c,...) and z is the vector containing the result of + equation :eq:`op1`. + + The parameter ref_col equals to z in both equations, and \*args are columns + u, v and w (or A in :eq:`op2`). All columns must be specified by their names. + + **Example:** + + .. code-block:: python + + tab.GetOptimalPrefactors('colC', 'colA', 'colB') + + The function returns a list of containing the prefactors a, b, c, ... in + the correct order (i.e. same as columns were specified in \*args). + + Weighting: + If the kwarg weights="columX" is specified, the equations are weighted by + the values in that column. Each row is multiplied by the weight in that row, + which leads to :eq:`op3`: + + .. math:: + :label: op3 + + weight*a*u + weight*b*v + weight*c*w + ... = weight*z + + Weights must be float or int and can have any value. A value of 0 ignores + this equation, a value of 1 means the same as no weight. If all weights are + the same for each row, the same result will be obtained as with no weights. + + **Example:** + + .. code-block:: python + + tab.GetOptimalPrefactors('colC', 'colA', 'colB', weights='colD') + + ''' + try: + import numpy as np + + if len(args)==0: + raise RuntimeError("At least one column must be specified.") + + b = self.GetNumpyMatrix(ref_col) + a = self.GetNumpyMatrix(*args) + + if len(kwargs)!=0: + if kwargs.has_key('weights'): + w = self.GetNumpyMatrix(kwargs['weights']) + b = np.multiply(b,w) + a = np.multiply(a,w) + + else: + raise RuntimeError("specified unrecognized kwargs, use weights as key") + + k = (a.T*a).I*a.T*b + return list(np.array(k.T).reshape(-1)) + + except ImportError: + LogError("Function needs numpy, but I could not import it.") + raise + + def PlotEnrichment(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0, + style='-', title=None, x_title=None, y_title=None, + clear=True, save=None): + ''' + Plot an enrichment curve using matplotlib + ''' + + try: + import matplotlib.pyplot as plt + + enrx, enry = self.ComputeEnrichment(score_col, class_col, score_dir, + class_dir, class_cutoff) + + if not title: + title = 'Enrichment of %s'%score_col + + if not x_title: + x_title = '% database' + + if not y_title: + y_title = '% positives' + + if clear: + plt.clf() + + plt.plot(enrx, enry, style) + + plt.title(title, size='x-large', fontweight='bold') + plt.ylabel(y_title, size='x-large') + plt.xlabel(x_title, size='x-large') + + if save: + plt.savefig(save) + + return plt + except ImportError: + LogError("Function needs matplotlib, but I could not import it.") + raise + + def ComputeEnrichment(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0): + ''' + Computes the enrichment of one column (e.g. score) over all data points. + + For this it is necessary, that the datapoints are classified into positive + and negative points. This can be done in two ways: + + - by using one 'bool' column which contains True for positives and False + for negatives + - by specifying an additional column, a cutoff value and the columns + direction. This will generate the classification on the fly + + * if class_dir=='-': values in the classification column that are + less than or equal to class_cutoff will be counted + as positives + * if class_dir=='+': values in the classification column that are + larger than or equal to class_cutoff will be counted + as positives + + During the calculation, the table will be sorted according to score_dir, + where a '-' values means smallest values first and therefore, the smaller + the value, the better. + + ''' + + ALLOWED_DIR = ['+','-'] + + score_idx = self.GetColIndex(score_col) + score_type = self.col_types[score_idx] + if score_type!='int' and score_type!='float': + raise TypeError("Score column must be numeric type") + + class_idx = self.GetColIndex(class_col) + class_type = self.col_types[class_idx] + if class_type!='int' and class_type!='float' and class_type!='bool': + raise TypeError("Classifier column must be numeric or bool type") + + if (score_dir not in ALLOWED_DIR) or (class_dir not in ALLOWED_DIR): + raise ValueError("Direction must be one of %s"%str(ALLOWED_DIR)) + + self.Sort(score_col, score_dir) + + x = [0] + y = [0] + enr = 0 + for i,row in enumerate(self.rows): + class_val = row[class_idx] + if class_val!=None: + if class_type=='bool': + if class_val==True: + enr += 1 + else: + if (class_dir=='-' and class_val<=class_cutoff) or (class_dir=='+' and class_val>=class_cutoff): + enr += 1 + x.append(i+1) + y.append(enr) + x = [float(v)/x[-1] for v in x] + y = [float(v)/y[-1] for v in y] + return x,y + + def ComputeEnrichmentAUC(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0): + ''' + Computes the area under the curve of the enrichment using the trapezoidal + rule + ''' + try: + import numpy as np + + enrx, enry = self.ComputeEnrichment(score_col, class_col, score_dir, + class_dir, class_cutoff) + + return np.trapz(enry, enrx) + except ImportError: + LogError("Function needs numpy, but I could not import it.") + raise + + + + +def Merge(table1, table2, by, only_matching=False): + """ + Returns a new table containing the data from both tables. The rows are + combined based on the common values in the column by. For example, the two + tables below + + ==== ==== + x y + ==== ==== + 1 10 + 2 15 + 3 20 + ==== ==== + + ==== ==== + x u + ==== ==== + 1 100 + 3 200 + 4 400 + ==== ==== + + ===== ===== ===== + x y u + ===== ===== ===== + 1 10 100 + 2 15 None + 3 20 200 + 4 None 400 + ===== ===== ===== + + when merged by column x, produce the following output: + """ + def _key(row, indices): + return tuple([row[i] for i in indices]) + def _keep(indices, cn, ct, ni): + ncn, nct, nni=([],[],[]) + for i in range(len(cn)): + if i not in indices: + ncn.append(cn[i]) + nct.append(ct[i]) + nni.append(ni[i]) + return ncn, nct, nni + col_names=list(table2.col_names) + col_types=list(table2.col_types) + new_index=[i for i in range(len(col_names))] + if isinstance(by, str): + common2_indices=[col_names.index(by)] + else: + common2_indices=[col_names.index(b) for b in by] + col_names, col_types, new_index=_keep(common2_indices, col_names, + col_types, new_index) + + for i, name in enumerate(col_names): + try_name=name + counter=1 + while try_name in table1.col_names: + counter+=1 + try_name='%s_%d' % (name, counter) + col_names[i]=try_name + common1={} + if isinstance(by, str): + common1_indices=[table1.col_names.index(by)] + else: + common1_indices=[table1.col_names.index(b) for b in by] + for row in table1.rows: + key=_key(row, common1_indices) + if key in common1: + raise ValueError('duplicate key "%s in first table"' % (str(key))) + common1[key]=row + common2={} + for row in table2.rows: + key=_key(row, common2_indices) + if key in common2: + raise ValueError('duplicate key "%s" in second table' % (str(key))) + common2[key]=row + new_tab=Table(table1.col_names+col_names, table1.col_types+col_types) + for k, v in common1.iteritems(): + row=v+[None for i in range(len(table2.col_names)-len(common2_indices))] + matched=False + if k in common2: + matched=True + row2=common2[k] + for i, index in enumerate(new_index): + row[len(table1.col_names)+i]=row2[index] + if only_matching and not matched: + continue + new_tab.AddRow(row) + if only_matching: + return new_tab + for k, v in common2.iteritems(): + if not k in common1: + v2=[v[i] for i in new_index] + row=[None for i in range(len(table1.col_names))]+v2 + for common1_index, common2_index in zip(common1_indices, common2_indices): + row[common1_index]=v[common2_index] + new_tab.AddRow(row) + return new_tab \ No newline at end of file diff --git a/modules/base/tests/CMakeLists.txt b/modules/base/tests/CMakeLists.txt index d18e16f0765d8c2cd562338655b90f6ed9e97750..578618b041ae61849a619804ac730afa17f3ceab 100644 --- a/modules/base/tests/CMakeLists.txt +++ b/modules/base/tests/CMakeLists.txt @@ -3,6 +3,7 @@ set(OST_BASE_UNIT_TESTS test_string_ref.cc test_pod_vector.cc test_stutil.py + test_table.py tests.cc ) diff --git a/modules/base/tests/test_table.py b/modules/base/tests/test_table.py new file mode 100644 index 0000000000000000000000000000000000000000..7f43726f71d7e9f18aea4ec764674473b6c7ee4b --- /dev/null +++ b/modules/base/tests/test_table.py @@ -0,0 +1,982 @@ +''' +Unit tests for Table class + +Author: Tobias Schmidt +''' + +import os +import unittest +from ost.table import * +import ost + +HAS_NUMPY=True +HAS_MPL=True +try: + import numpy as np +except ImportError: + HAS_NUMPY=False + +try: + import matplotlib + matplotlib.use('Agg') +except ImportError: + HAS_MPL=False + +class TestTable(unittest.TestCase): + + def setUp(self): + ost.PushVerbosityLevel(3) + + def CreateTestTable(self): + ''' + creates a table with some test data + + first second third + ---------------------- + x 3 NA + foo NA 2.200 + NA 9 3.300 + + ''' + tab = Table() + tab.AddCol('first', 'string') + tab.AddCol('second', 'int') + tab.AddCol('third', 'float', 3.141) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 0) + self.CompareColTypes(tab, ['first','second', 'third'], 'sif') + tab.AddRow(['x',3, None], merge=None) + tab.AddRow(['foo',None, 2.2], merge=None) + tab.AddRow([None,9, 3.3], merge=None) + return tab + + def CompareRowCount(self, t, row_count): + ''' + Compare the number of rows + ''' + self.assertEqual(len(t.rows), + row_count, + "row count (%i) different from expected value (%i)" \ + %(len(t.rows), row_count)) + + def CompareColCount(self, t, col_count): + ''' + Compare the number of columns + ''' + self.assertEqual(len(t.col_names), + col_count, + "column count (%i) different from expected value (%i)" \ + %(len(t.col_names), col_count)) + + def CompareColNames(self, t, col_names): + ''' + Compare all column names of the table with a list of reference col names + ''' + self.CompareColCount(t, len(col_names)) + for i, (col_name, ref_name) in enumerate(zip(t.col_names, col_names)): + self.assertEqual(col_name, + ref_name, + "column name (%s) different from expected name (%s) at col %i" \ + %(col_name, ref_name, i)) + + def CompareDataFromDict(self, t, data_dict): + ''' + Compare all values of a table with reference values given in the form of a + dictionary containing a list of values for each column. + ''' + self.CompareColCount(t, len(data_dict)) + for k, v in data_dict.iteritems(): + self.CompareDataForCol(t, k, v) + + def CompareDataForCol(self, t, col_name, ref_data): + ''' + Compare the values of each row of ONE column specified by its name with + the reference values specified as a list of values for this column. + ''' + self.CompareRowCount(t, len(ref_data)) + idx = t.GetColIndex(col_name) + for i, (row, ref) in enumerate(zip(t.rows, ref_data)): + self.assertEqual(row[idx], + ref, + "data (%s) in col (%s), row (%i) different from expected value (%s)" \ + %(row[idx], col_name, i, ref)) + + def CompareColTypes(self, t, col_names, ref_types): + ''' + Compare the types of n columns specified by their names with reference + values specified either as a string consisting of the short type names + (e.g 'sfb') or a list of strings consisting of the long type names + (e.g. ['string','float','bool']) + ''' + if type(ref_types)==str: + trans = {'s' : 'string', 'i': 'int', 'b' : 'bool', 'f' : 'float'} + ref_types = [trans[rt] for rt in ref_types] + if type(col_names)==str: + col_names = [col_names] + self.assertEqual(len(col_names), + len(ref_types), + "number of col names (%i) different from number of reference col types (%i)" \ + %(len(col_names), len(ref_types))) + idxs = [t.GetColIndex(x) for x in col_names] + for idx, ref_type in zip(idxs, ref_types): + self.assertEqual(t.col_types[idx], + ref_type, + "column type (%s) at column %i, different from reference col type (%s)" \ + %(t.col_types[idx], idx, ref_type)) + + def testTableInitEmpty(self): + ''' + empty table + ''' + tab = Table() + self.CompareColCount(tab, 0) + self.CompareRowCount(tab, 0) + + def testTableInitSingleColEmpty(self): + ''' + empty table with one float column: + + x + --- + + ''' + tab = Table(['x'], 'f') + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 0) + self.CompareColNames(tab, ['x']) + self.CompareColTypes(tab, 'x', 'f') + + def testTableInitMultiColEmpty(self): + ''' + empty table with multiple column with different types: + + x y z a + ------------ + + ''' + tab = Table(['x','y','z','a'], 'sfbi') + self.CompareColCount(tab, 4) + self.CompareRowCount(tab, 0) + self.CompareColNames(tab, ['x','y','z','a']) + self.CompareColTypes(tab, ['x','y','z','a'], 'sfbi') + self.CompareColTypes(tab, ['x','y','z','a'], ['string','float','bool','int']) + + def testTableInitSingleColSingleValueNonEmpty(self): + ''' + table with one column and one row: + + x + ------- + 5.000 + + ''' + tab = Table(['x'], 'f', x=5) + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 1) + self.CompareColNames(tab, ['x']) + self.CompareColTypes(tab, 'x', 'f') + + def testTableInitMultiColSingleValueNonEmpty(self): + ''' + table with three columns and one row: + + x a z + --------------------- + 5.000 False 1.425 + + ''' + tab = Table(['x','a','z'], 'fbf', x=5, z=1.425, a=False) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 1) + self.CompareColNames(tab, ['x','a','z']) + self.CompareColTypes(tab, ['z','x','a'], 'ffb') + self.CompareDataFromDict(tab, {'x': [5], 'z': [1.425], 'a': [False]}) + + def testTableInitMultiColSingleValueAndNoneNonEmpty(self): + ''' + table with three columns and one row with two None values: + + x a1 zzz + ---------------- + 5.000 NA NA + ''' + tab = Table(['x','a1','zzz'], 'fbf', x=5) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 1) + self.CompareColNames(tab, ['x','a1','zzz']) + self.CompareColTypes(tab, ['zzz','x','a1'], 'ffb') + self.CompareDataFromDict(tab, {'x': [5], 'zzz': [None], 'a1': [None]}) + + def testTableInitSingleColMultiValueNonEmpty(self): + ''' + table with one column and five row: + + x + ------- + 0.000 + 1.000 + 2.000 + 3.000 + 4.000 + + ''' + tab = Table(['x'], 'f', x=range(5)) + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 5) + self.CompareColNames(tab, ['x']) + self.CompareColTypes(tab, 'x', 'f') + + def testTableInitMultiColMultiValueNonEmpty(self): + ''' + table with two column and four rows: + + foo bar + --------------- + i 10 + love 11 + unit 12 + tests 13 + + ''' + + tab = Table(['foo', 'bar'], 'si', bar=range(10,14), foo=['i','love','unit','tests']) + self.CompareColCount(tab, 2) + self.CompareRowCount(tab, 4) + self.CompareColNames(tab, ['foo','bar']) + self.CompareColTypes(tab, ['foo', 'bar'], 'si') + self.CompareDataFromDict(tab, {'bar': [10,11,12,13], 'foo': ['i','love','unit','tests']}) + + def testTableInitMultiColMissingMultiValue(self): + ''' + test if error is raised when creating rows with missing data + ''' + + self.assertRaises(ValueError, Table, ['foo', 'bar'], 'si', + bar=range(10,14), foo=['i','love','tests']) + + + def testTableInitMultiColMultiValueAndNoneNonEmpty(self): + ''' + table with two column and four rows with None values: + + foo bar + --------------- + i NA + love NA + unit NA + tests NA + + ''' + tab = Table(['foo', 'bar'], 'si', foo=['i','love','unit','tests']) + self.CompareColCount(tab, 2) + self.CompareRowCount(tab, 4) + self.CompareColNames(tab, ['foo','bar']) + self.CompareColTypes(tab, ['foo', 'bar'], 'si') + self.CompareDataFromDict(tab, {'bar': [None,None,None,None], 'foo': ['i','love','unit','tests']}) + + def testTableAddSingleCol(self): + ''' + init empty table, add one empty column: + + first + ------- + + ''' + tab = Table() + self.CompareColCount(tab, 0) + self.CompareRowCount(tab, 0) + tab.AddCol('first', 'string', 'AB C') + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 0) + self.CompareColNames(tab, ['first']) + self.CompareColTypes(tab, 'first', 's') + + def testTableAddSingleRow(self): + ''' + init table with one col, add one row: + + first + ------- + 2 + ''' + tab = Table(['first'],'i') + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 0) + tab.AddRow([2], merge=None) + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 1) + self.CompareColNames(tab, ['first']) + self.CompareColTypes(tab, 'first', 'i') + self.CompareDataFromDict(tab, {'first': [2]}) + + def testTableAddSingleColSingleRow(self): + ''' + init empty table, add one col, add one row: + + first + ------- + 2 + ''' + tab = Table() + tab.AddCol('first', 'int') + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 0) + tab.AddRow([2], merge=None) + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 1) + self.CompareColNames(tab, ['first']) + self.CompareColTypes(tab, 'first', 'i') + self.CompareDataFromDict(tab, {'first': [2]}) + + def testTableAddSingleColWithRow(self): + ''' + init table with two cols, add row with data, add third column: + + first second third + ---------------------- + x 3 3.141 + + ''' + tab = Table(['first','second'],'si') + self.CompareColCount(tab, 2) + self.CompareRowCount(tab, 0) + self.CompareColTypes(tab, ['first','second'], 'si') + tab.AddRow(['x',3], merge=None) + self.CompareColCount(tab, 2) + self.CompareRowCount(tab, 1) + tab.AddCol('third', 'float', 3.141) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 1) + self.CompareColTypes(tab, ['first','third','second'], 'sfi') + self.CompareDataFromDict(tab, {'second': [3], 'first': ['x'], 'third': [3.141]}) + + def testTableAddMultiColMultiRow(self): + ''' + init empty table add three cols, add three rows with data: + + first second third + ---------------------- + x 3 1.000 + foo 6 2.200 + bar 9 3.300 + + ''' + tab = Table() + tab.AddCol('first', 'string') + tab.AddCol('second', 'int') + tab.AddCol('third', 'float', 3.141) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 0) + self.CompareColTypes(tab, ['first','second', 'third'], 'sif') + tab.AddRow(['x',3, 1.0], merge=None) + tab.AddRow(['foo',6, 2.2], merge=None) + tab.AddRow(['bar',9, 3.3], merge=None) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 3) + self.CompareDataFromDict(tab, {'second': [3,6,9], 'first': ['x','foo','bar'], 'third': [1,2.2,3.3]}) + + def testTableAddMultiColMultiRowFromDict(self): + ''' + init empty table add three cols, add three rows with data: + + first second third + ---------------------- + x 3 1.000 + foo 6 2.200 + bar 9 3.300 + + ''' + tab = Table() + tab.AddCol('first', 'string') + tab.AddCol('second', 'int') + tab.AddCol('aaa', 'float', 3.141) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 0) + self.CompareColTypes(tab, ['first','second', 'aaa'], 'sif') + tab.AddRow({'first':'x','second':3, 'aaa':1.0}, merge=None) + tab.AddRow({'aaa':2.2, 'second':6, 'first':'foo'}, merge=None) + tab.AddRow({'second':9, 'aaa':3.3, 'first':'bar'}, merge=None) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 3) + self.CompareDataFromDict(tab, {'second': [3,6,9], 'first': ['x','foo','bar'], 'aaa': [1,2.2,3.3]}) + + def testTableAddMultiRowMultiCol(self): + ''' + init empty table add one col, add three rows with data, + add one col without data, add one col with data: + + first second third + ---------------------- + x NA 3.141 + foo NA 3.141 + bar NA 3.141 + + ''' + tab = Table() + tab.AddCol('first', 'string') + self.CompareColCount(tab, 1) + self.CompareRowCount(tab, 0) + self.CompareColTypes(tab, ['first'], 's') + tab.AddRow(['x'], merge=None) + tab.AddRow(['foo'], merge=None) + tab.AddRow(['bar'], merge=None) + tab.AddCol('second', 'int') + tab.AddCol('third', 'float', 3.141) + self.CompareColCount(tab, 3) + self.CompareRowCount(tab, 3) + self.CompareDataFromDict(tab, {'second': [None,None,None], + 'first': ['x','foo','bar'], + 'third': [3.141, 3.141, 3.141]}) + + def testAddRowFromDictWithMerge(self): + ''' + add rows from dictionary with merge (i.e. overwrite third row with additional data) + + x foo bar + ------------------ + row1 True 1 + row2 NA 2 + row3 False 3 + + ''' + tab = Table() + tab.AddCol('x', 'string') + tab.AddCol('foo', 'bool') + tab.AddCol('bar', 'int') + tab.AddRow(['row1',True, 1]) + tab.AddRow(['row2',None, 2]) + tab.AddRow(['row3',False, None]) + self.CompareDataFromDict(tab, {'x': ['row1', 'row2', 'row3'], + 'foo': [True, None, False], + 'bar': [1, 2, None]}) + tab.AddRow({'x':'row3', 'bar':3}, merge='x') + self.CompareDataFromDict(tab, {'x': ['row1', 'row2', 'row3'], + 'foo': [True, None, False], + 'bar': [1, 2, 3]}) + + def testAddRowFromListWithMerge(self): + ''' + add rows from list with merge (i.e. overwrite third row with additional data) + + x foo bar + ------------------ + row1 True 1 + row2 NA 2 + row3 True 3 + + ''' + + tab = Table() + tab.AddCol('x', 'string') + tab.AddCol('foo', 'bool') + tab.AddCol('bar', 'int') + tab.AddRow(['row1',True, 1]) + tab.AddRow(['row2',None, 2]) + tab.AddRow(['row3',False, None]) + self.CompareDataFromDict(tab, {'x': ['row1', 'row2', 'row3'], + 'foo': [True, None, False], + 'bar': [1, 2, None]}) + tab.AddRow(['row3', True, 3], merge='x') + self.CompareDataFromDict(tab, {'x': ['row1', 'row2', 'row3'], + 'foo': [True, None, True], + 'bar': [1, 2, 3]}) + + + def testRaiseErrorOnWrongColumnTypes(self): + # wrong columns types in init + self.assertRaises(ValueError, Table, ['bla','bli'], 'ab') + + # wrong column types in Coerce + tab = Table() + self.assertRaises(ValueError, tab._Coerce, 'bla', 'a') + + # wrong column types in AddCol + self.assertRaises(ValueError, tab.AddCol, 'bla', 'a') + + def testParseColumnTypes(self): + types = Table._ParseColTypes(['i','f','s','b']) + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes(['int','float','string','bool']) + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes(['i','float','s','bool']) + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes(['i','fLOAT','S','bool']) + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes('ifsb') + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes('int,float,string,bool') + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes('int,f,s,bool') + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes('INT,F,s,bOOL') + self.assertEquals(types, ['int','float','string','bool']) + + types = Table._ParseColTypes('boOl') + self.assertEquals(types, ['bool']) + + types = Table._ParseColTypes('S') + self.assertEquals(types, ['string']) + + types = Table._ParseColTypes(['i']) + self.assertEquals(types, ['int']) + + types = Table._ParseColTypes(['FLOAT']) + self.assertEquals(types, ['float']) + + self.assertRaises(ValueError, Table._ParseColTypes, 'bfstring') + self.assertRaises(ValueError, Table._ParseColTypes, ['b,f,string']) + self.assertRaises(ValueError, Table._ParseColTypes, 'bi2') + self.assertRaises(ValueError, Table._ParseColTypes, ['b',2,'string']) + self.assertRaises(ValueError, Table._ParseColTypes, [['b'],['f','string']]) + self.assertRaises(ValueError, Table._ParseColTypes, 'a') + self.assertRaises(ValueError, Table._ParseColTypes, 'foo') + self.assertRaises(ValueError, Table._ParseColTypes, ['a']) + self.assertRaises(ValueError, Table._ParseColTypes, ['foo']) + + def testShortLongColumnTypes(self): + tab = Table(['x','y','z','a'],['i','f','s','b']) + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],['int','float','string','bool']) + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],['i','float','s','bool']) + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],['i','fLOAT','S','bool']) + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],'ifsb') + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],'int,float,string,bool') + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],'int,f,s,bool') + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x','y','z','a'],'INT,F,s,bOOL') + self.CompareColTypes(tab, ['x','y','z','a'], 'ifsb') + + tab = Table(['x'], 'boOl') + self.CompareColTypes(tab, ['x'], 'b') + tab = Table(['x'], 'B') + self.CompareColTypes(tab, ['x'], 'b') + tab = Table(['x'], ['b']) + self.CompareColTypes(tab, ['x'], 'b') + tab = Table(['x'], ['Bool']) + self.CompareColTypes(tab, ['x'], 'b') + + self.assertRaises(ValueError, Table, ['x','y','z'], 'bfstring') + self.assertRaises(ValueError, Table, ['x','y','z'], ['b,f,string']) + self.assertRaises(ValueError, Table, ['x','y','z'], 'bi2') + self.assertRaises(ValueError, Table, ['x','y','z'], ['b',2,'string']) + self.assertRaises(ValueError, Table, ['x','y','z'], [['b'],['f','string']]) + self.assertRaises(ValueError, Table, ['x'], 'a') + self.assertRaises(ValueError, Table, ['x'], 'foo') + self.assertRaises(ValueError, Table, ['x'], ['a']) + self.assertRaises(ValueError, Table, ['x'], ['foo']) + + def testCoerce(self): + tab = Table() + + # None values + self.assertEquals(tab._Coerce('NA', 'x'), None) + self.assertEquals(tab._Coerce(None, 'x'), None) + + # int type + self.assertTrue(isinstance(tab._Coerce(2 ,'int'), int)) + self.assertEquals(tab._Coerce(2 ,'int'), 2) + self.assertTrue(isinstance(tab._Coerce(2.2 ,'int'), int)) + self.assertEquals(tab._Coerce(2.2 ,'int'), 2) + self.assertEquals(tab._Coerce(True ,'int'), 1) + self.assertEquals(tab._Coerce(False ,'int'), 0) + self.assertRaises(ValueError, tab._Coerce, "foo" , 'int') + + # float type + self.assertTrue(isinstance(tab._Coerce(2 ,'float'), float)) + self.assertEquals(tab._Coerce(2 ,'float'), 2.000) + self.assertTrue(isinstance(tab._Coerce(3.141 ,'float'), float)) + self.assertEquals(tab._Coerce(3.141 ,'float'), 3.141) + self.assertRaises(ValueError, tab._Coerce, "foo" , 'float') + + # string type + self.assertTrue(isinstance(tab._Coerce('foo' ,'string'), str)) + self.assertTrue(isinstance(tab._Coerce('this is a longer string' ,'string'), str)) + self.assertTrue(isinstance(tab._Coerce(2.2 ,'string'), str)) + self.assertTrue(isinstance(tab._Coerce(2 ,'string'), str)) + self.assertTrue(isinstance(tab._Coerce(True ,'string'), str)) + self.assertTrue(isinstance(tab._Coerce(False ,'string'), str)) + + # bool type + self.assertEquals(tab._Coerce(True ,'bool'), True) + self.assertEquals(tab._Coerce(False ,'bool'), False) + self.assertEquals(tab._Coerce('falSE' ,'bool'), False) + self.assertEquals(tab._Coerce('no' ,'bool'), False) + self.assertEquals(tab._Coerce('not false and not no','bool'), True) + self.assertEquals(tab._Coerce(0, 'bool'), False) + self.assertEquals(tab._Coerce(1, 'bool'), True) + + # unknown type + self.assertRaises(ValueError, tab._Coerce, 'bla', 'abc') + + def testRemoveCol(self): + tab = self.CreateTestTable() + self.CompareDataFromDict(tab, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + tab.RemoveCol("second") + self.CompareDataFromDict(tab, {'first': ['x','foo',None], 'third': [None,2.2,3.3]}) + + # raise error when column is unknown + tab = self.CreateTestTable() + self.assertRaises(ValueError, tab.RemoveCol, "unknown col") + + def testSortTable(self): + tab = self.CreateTestTable() + self.CompareDataFromDict(tab, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + tab.Sort('first', '-') + self.CompareDataFromDict(tab, {'first': [None,'foo','x'], 'second': [9,None,3], 'third': [3.3,2.2,None]}) + tab.Sort('first', '+') + self.CompareDataFromDict(tab, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + tab.Sort('third', '+') + self.CompareDataFromDict(tab, {'first': [None,'foo','x'], 'second': [9,None,3], 'third': [3.3,2.2,None]}) + + def testSaveLoadTable(self): + tab = self.CreateTestTable() + self.CompareDataFromDict(tab, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + + # write to disc + tab.Save("saveloadtable_filename_out.csv") + out_stream = open("saveloadtable_stream_out.csv", 'w') + tab.Save(out_stream) + out_stream.close() + + # read from disc + in_stream = open("saveloadtable_stream_out.csv", 'r') + tab_loaded_stream = Table.Load(in_stream) + in_stream.close() + tab_loaded_fname = Table.Load('saveloadtable_filename_out.csv') + + # check content + self.CompareDataFromDict(tab_loaded_stream, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + self.CompareDataFromDict(tab_loaded_fname, {'first': ['x','foo',None], 'second': [3,None,9], 'third': [None,2.2,3.3]}) + + # check Errors for empty/non existing files + self.assertRaises(IOError, Table.Load, 'nonexisting.file') + self.assertRaises(IOError, Table.Load, os.path.join('testfiles','emptytable.csv')) + in_stream = open(os.path.join('testfiles','emptytable.csv'), 'r') + self.assertRaises(IOError, Table.Load, in_stream) + + def testMergeTable(self): + ''' + Merge the following two tables: + + x y x u + ------- ------- + 1 | 10 1 | 100 + 2 | 15 3 | 200 + 3 | 20 4 | 400 + + to get (only_matching=False): + + x y u + --------------- + 1 | 10 | 100 + 2 | 15 | NA + 3 | 20 | 200 + 4 | NA | 400 + + or (only_matching=True): + + x y u + --------------- + 1 | 10 | 100 + 3 | 20 | 200 + + ''' + tab1 = Table(['x','y'],['int','int']) + tab1.AddRow([1,10]) + tab1.AddRow([2,15]) + tab1.AddRow([3,20]) + + tab2 = Table(['x','u'],['int','int']) + tab2.AddRow([1,100]) + tab2.AddRow([3,200]) + tab2.AddRow([4,400]) + + tab_merged = Merge(tab1, tab2, 'x', only_matching=False) + tab_merged.Sort('x', order='-') + self.CompareDataFromDict(tab_merged, {'x': [1,2,3,4], 'y': [10,15,20,None], 'u': [100,None,200,400]}) + + tab_merged = Merge(tab1, tab2, 'x', only_matching=True) + tab_merged.Sort('x', order='-') + self.CompareDataFromDict(tab_merged, {'x': [1,3], 'y': [10,20], 'u': [100,200]}) + + def testFilterTable(self): + tab = self.CreateTestTable() + tab.AddRow(['foo',1,5.15]) + tab.AddRow(['foo',0,1]) + tab.AddRow(['foo',1,12]) + + # filter on one column + tab_filtered = tab.Filter(first='foo') + self.CompareDataFromDict(tab_filtered, {'first':['foo','foo','foo','foo'], + 'second':[None,1,0,1], + 'third':[2.2,5.15,1.0,12.0]}) + + # filter on multiple columns + tab_filtered = tab.Filter(first='foo',second=1) + self.CompareDataFromDict(tab_filtered, {'first':['foo','foo'], + 'second':[1,1], + 'third':[5.15,12.0]}) + + # raise Error when using non existing column name for filtering + self.assertRaises(ValueError,tab.Filter,first='foo',nonexisting=1) + + def testMinTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[True,True,False]) + + self.assertEquals(tab.Min('first'),'foo') + self.assertEquals(tab.Min('second'),3) + self.assertAlmostEquals(tab.Min('third'),2.2) + self.assertEquals(tab.Min('fourth'),False) + self.assertRaises(ValueError,tab.Min,'fifth') + + self.assertEquals(tab.MinIdx('first'),1) + self.assertEquals(tab.MinIdx('second'),0) + self.assertAlmostEquals(tab.MinIdx('third'),1) + self.assertEquals(tab.MinIdx('fourth'),2) + self.assertRaises(ValueError,tab.MinIdx,'fifth') + + self.assertEquals(tab.MinRow('first'),['foo', None, 2.20, True]) + self.assertEquals(tab.MinRow('second'),['x', 3, None, True]) + self.assertEquals(tab.MinRow('third'),['foo', None, 2.20, True]) + self.assertEquals(tab.MinRow('fourth'),[None, 9, 3.3, False]) + self.assertRaises(ValueError,tab.MinRow,'fifth') + + def testMaxTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,True]) + + self.assertEquals(tab.Max('first'),'x') + self.assertEquals(tab.Max('second'),9) + self.assertAlmostEquals(tab.Max('third'),3.3) + self.assertEquals(tab.Max('fourth'),True) + self.assertRaises(ValueError,tab.Max,'fifth') + + self.assertEquals(tab.MaxIdx('first'),0) + self.assertEquals(tab.MaxIdx('second'),2) + self.assertAlmostEquals(tab.MaxIdx('third'),2) + self.assertEquals(tab.MaxIdx('fourth'),1) + self.assertRaises(ValueError,tab.MaxIdx,'fifth') + + self.assertEquals(tab.MaxRow('first'),['x', 3, None, False]) + self.assertEquals(tab.MaxRow('second'),[None, 9, 3.3, True]) + self.assertEquals(tab.MaxRow('third'),[None, 9, 3.3, True]) + self.assertEquals(tab.MaxRow('fourth'),['foo', None, 2.2, True]) + self.assertRaises(ValueError,tab.MaxRow,'fifth') + + def testSumTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,False]) + + self.assertRaises(TypeError,tab.Sum,'first') + self.assertEquals(tab.Sum('second'),12) + self.assertAlmostEquals(tab.Sum('third'),5.5) + self.assertRaises(TypeError,tab.Sum,'fourth') + self.assertRaises(ValueError,tab.Sum,'fifth') + + def testMedianTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,False]) + + self.assertRaises(TypeError,tab.Median,'first') + self.assertEquals(tab.Median('second'),6.0) + self.assertAlmostEquals(tab.Median('third'),2.75) + self.assertRaises(TypeError,tab.Median,'fourth') + self.assertRaises(ValueError,tab.Median,'fifth') + + def testMeanTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,False]) + + self.assertRaises(TypeError,tab.Mean,'first') + self.assertAlmostEquals(tab.Mean('second'),6.0) + self.assertAlmostEquals(tab.Mean('third'),2.75) + self.assertRaises(TypeError,tab.Mean,'fourth') + self.assertRaises(ValueError,tab.Mean,'fifth') + + def testStdDevTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,False]) + + self.assertRaises(TypeError,tab.StdDev,'first') + self.assertAlmostEquals(tab.StdDev('second'),3.0) + self.assertAlmostEquals(tab.StdDev('third'),0.55) + self.assertRaises(TypeError,tab.StdDev,'fourth') + self.assertRaises(ValueError,tab.StdDev,'fifth') + + def testCountTable(self): + tab = self.CreateTestTable() + tab.AddCol('fourth','bool',[False,True,False]) + + self.assertEquals(tab.Count('first'),2) + self.assertEquals(tab.Count('second'),2) + self.assertEquals(tab.Count('third'),2) + self.assertEquals(tab.Count('fourth'),3) + self.assertRaises(ValueError,tab.Count,'fifth') + + def testCalcEnrichment(self): + enrx_ref = [0.0, 0.041666666666666664, 0.083333333333333329, 0.125, 0.16666666666666666, 0.20833333333333334, 0.25, 0.29166666666666669, 0.33333333333333331, 0.375, 0.41666666666666669, 0.45833333333333331, 0.5, 0.54166666666666663, 0.58333333333333337, 0.625, 0.66666666666666663, 0.70833333333333337, 0.75, 0.79166666666666663, 0.83333333333333337, 0.875, 0.91666666666666663, 0.95833333333333337, 1.0] + enry_ref = [0.0, 0.16666666666666666, 0.33333333333333331, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.66666666666666663, 0.66666666666666663, 0.83333333333333337, 0.83333333333333337, 0.83333333333333337, 0.83333333333333337, 0.83333333333333337, 0.83333333333333337, 0.83333333333333337, 1.0, 1.0, 1.0, 1.0] + + tab = Table(['score', 'rmsd', 'classific'], 'ffb', + score=[2.64,1.11,2.17,0.45,0.15,0.85,1.13,2.90,0.50,1.03,1.46,2.83,1.15,2.04,0.67,1.27,2.22,1.90,0.68,0.36,1.04,2.46,0.91,0.60], + rmsd=[9.58,1.61,7.48,0.29,1.68,3.52,3.34,8.17,4.31,2.85,6.28,8.78,0.41,6.29,4.89,7.30,4.26,3.51,3.38,0.04,2.21,0.24,7.58,8.40], + classific=[False,True,False,True,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,True,False,False]) + + enrx,enry = tab.ComputeEnrichment(score_col='score', score_dir='-', + class_col='rmsd', class_cutoff=2.0, + class_dir='-') + + for x,y,refx,refy in zip(enrx,enry,enrx_ref,enry_ref): + self.assertAlmostEquals(x,refx) + self.assertAlmostEquals(y,refy) + + enrx,enry = tab.ComputeEnrichment(score_col='score', score_dir='-', + class_col='classific') + + for x,y,refx,refy in zip(enrx,enry,enrx_ref,enry_ref): + self.assertAlmostEquals(x,refx) + self.assertAlmostEquals(y,refy) + + tab.AddCol('bad','string','col') + + self.assertRaises(TypeError, tab.ComputeEnrichment, score_col='classific', + score_dir='-', class_col='rmsd', class_cutoff=2.0, + class_dir='-') + + self.assertRaises(TypeError, tab.ComputeEnrichment, score_col='bad', + score_dir='-', class_col='rmsd', class_cutoff=2.0, + class_dir='-') + + self.assertRaises(TypeError, tab.ComputeEnrichment, score_col='score', + score_dir='-', class_col='bad', class_cutoff=2.0, + class_dir='-') + + self.assertRaises(ValueError, tab.ComputeEnrichment, score_col='score', + score_dir='x', class_col='rmsd', class_cutoff=2.0, + class_dir='-') + + self.assertRaises(ValueError, tab.ComputeEnrichment, score_col='score', + score_dir='+', class_col='rmsd', class_cutoff=2.0, + class_dir='y') + + def testPlotEnrichment(self): + if not HAS_MPL: + return + tab = Table(['score', 'rmsd', 'classific'], 'ffb', + score=[2.64,1.11,2.17,0.45,0.15,0.85,1.13,2.90,0.50,1.03,1.46,2.83,1.15,2.04,0.67,1.27,2.22,1.90,0.68,0.36,1.04,2.46,0.91,0.60], + rmsd=[9.58,1.61,7.48,0.29,1.68,3.52,3.34,8.17,4.31,2.85,6.28,8.78,0.41,6.29,4.89,7.30,4.26,3.51,3.38,0.04,2.21,0.24,7.58,8.40], + classific=[False,True,False,True,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,True,False,False]) + + pl = tab.PlotEnrichment(score_col='score', score_dir='-', + class_col='rmsd', class_cutoff=2.0, + class_dir='-') + #pl.show() + + def testCalcEnrichmentAUC(self): + if not HAS_NUMPY: + return + auc_ref = 0.65277777777777779 + tab = Table(['score', 'rmsd', 'classific'], 'ffb', + score=[2.64,1.11,2.17,0.45,0.15,0.85,1.13,2.90,0.50,1.03,1.46,2.83,1.15,2.04,0.67,1.27,2.22,1.90,0.68,0.36,1.04,2.46,0.91,0.60], + rmsd=[9.58,1.61,7.48,0.29,1.68,3.52,3.34,8.17,4.31,2.85,6.28,8.78,0.41,6.29,4.89,7.30,4.26,3.51,3.38,0.04,2.21,0.24,7.58,8.40], + classific=[False,True,False,True,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,True,False,False]) + + auc = tab.ComputeEnrichmentAUC(score_col='score', score_dir='-', + class_col='rmsd', class_cutoff=2.0, + class_dir='-') + + self.assertAlmostEquals(auc, auc_ref) + + def testTableAsNumpyMatrix(self): + + ''' + checks numpy matrix + + first second third fourth + ------------------------------- + x 3 NA True + foo NA 2.200 False + NA 9 3.300 False + ''' + + tab = self.CreateTestTable() + tab.AddCol('fourth','b',[True, False, False]) + m = tab.GetNumpyMatrix('second') + mc = np.matrix([[3],[None],[9]]) + self.assertTrue(np.all(m==mc)) + mc = np.matrix([[3],[None],[10]]) + self.assertFalse(np.all(m==mc)) + m = tab.GetNumpyMatrix('third') + mc = np.matrix([[None],[2.200],[3.300]]) + self.assertTrue(np.all(m==mc)) + m = tab.GetNumpyMatrix('second','third') + mc = np.matrix([[3, None],[None, 2.200],[9, 3.300]]) + self.assertTrue(np.all(m==mc)) + m = tab.GetNumpyMatrix('third','second') + mc = np.matrix([[None, 3],[2.200, None],[3.300, 9]]) + self.assertTrue(np.all(m==mc)) + + self.assertRaises(TypeError, tab.GetNumpyMatrix, 'fourth') + self.assertRaises(TypeError, tab.GetNumpyMatrix, 'first') + self.assertRaises(RuntimeError, tab.GetNumpyMatrix) + + def testOptimalPrefactors(self): + if not HAS_NUMPY: + return + tab = Table(['a','b','c','d','e','f'], + 'ffffff', + a=[1,2,3,4,5,6,7,8,9], + b=[2,3,4,5,6,7,8,9,10], + c=[1,3,2,4,5,6,8,7,9], + d=[0.1,0.1,0.1,0.2,0.3,0.3,0.4,0.5,0.8], + e=[1,1,1,1,1,1,1,1,1], + f=[9,9,9,9,9,9,9,9,9]) + + pref = tab.GetOptimalPrefactors('c','a','b') + self.assertAlmostEquals(pref[0],0.799999999) + self.assertAlmostEquals(pref[1],0.166666666666) + + pref = tab.GetOptimalPrefactors('c','b','a') + self.assertAlmostEquals(pref[0],0.166666666666) + self.assertAlmostEquals(pref[1],0.799999999) + + pref = tab.GetOptimalPrefactors('c','b','a',weights='e') + self.assertAlmostEquals(pref[0],0.166666666666) + self.assertAlmostEquals(pref[1],0.799999999) + + pref = tab.GetOptimalPrefactors('c','b','a',weights='f') + self.assertAlmostEquals(pref[0],0.166666666666) + self.assertAlmostEquals(pref[1],0.799999999) + + pref = tab.GetOptimalPrefactors('c','a','b',weights='d') + self.assertAlmostEquals(pref[0],0.6078825445851) + self.assertAlmostEquals(pref[1],0.3394613806088) + + self.assertRaises(RuntimeError, tab.GetOptimalPrefactors, 'c','a','b',weight='d') + self.assertRaises(RuntimeError, tab.GetOptimalPrefactors, 'c',weights='d') + +if __name__ == "__main__": + try: + unittest.main() + except Exception, e: + print e \ No newline at end of file diff --git a/modules/base/tests/testfiles/emptytable.csv b/modules/base/tests/testfiles/emptytable.csv new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/doc/table.rst b/modules/doc/table.rst new file mode 100644 index 0000000000000000000000000000000000000000..08beaad21d3e558db80ae2e3acee0fb425587b4d --- /dev/null +++ b/modules/doc/table.rst @@ -0,0 +1,33 @@ +:mod:`~ost.table` - Working with tabular data +================================================================================ + +.. module:: ost.table + :synopsis: Working with tabular data + +This module defines the table class that provides convenient functionality to work with tabular data. It features functions to calculate statistical moments, e.g. mean, standard deviations as well as functionality to plot the data using matplotlib. + +Basic Usage +-------------------------------------------------------------------------------- + +.. code-block:: python + + from ost.table import * + # create table with two columns, x and y both of float type + tab=Table(['x', 'y'], 'ff') + for x in range(1000): + tab.AddRow([x, x**2]) + # create a plot + plt=tab.Plot('x', 'y') + # save resulting plot to png file + plt.savefig('x-vs-y.png') + + +The Table class +-------------------------------------------------------------------------------- + + +.. autoclass:: ost.table.Table + :members: + :undoc-members: SUPPORTED_TYPES + +.. autofunction:: ost.table.Merge \ No newline at end of file