diff --git a/.gitignore b/.gitignore index 56410cfa551dd48078e87a6a355fa7eadf1823ab..fd4689e6cb74f35784d549a66b504c717d958196 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ pov_*test.inc *-out.pdb *-out.sdf *-out.crd +*-out.png CMakeLists.txt.user OpenStructure.cbp DartConfiguration.tcl diff --git a/modules/base/pymod/table.py b/modules/base/pymod/table.py index 4d6f3bdd45a6cf0535813f2e576698fca75c9cee..12509583a90ae6e483cffe2d817d718dae450729 100644 --- a/modules/base/pymod/table.py +++ b/modules/base/pymod/table.py @@ -1133,6 +1133,136 @@ class Table: LogError("Function needs numpy, but I could not import it.") raise + def ComputeROC(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0): + ''' + Computes the receiver operating characteristics of one column (e.g. score) + over all data points. + + For this it is necessary, that the datapoints are classified into positive + and negative points. This can be done in two ways: + + - by using one 'bool' column (class_col) which contains True for positives + and False for negatives + - by using a non-bool column (class_col), a cutoff value (class_cutoff) + and the classification columns direction (class_dir). This will generate + the classification on the fly + + * if class_dir=='-': values in the classification column that are + less than or equal to class_cutoff will be counted + as positives + * if class_dir=='+': values in the classification column that are + larger than or equal to class_cutoff will be counted + as positives + + During the calculation, the table will be sorted according to score_dir, + where a '-' values means smallest values first and therefore, the smaller + the value, the better. + + ''' + + ALLOWED_DIR = ['+','-'] + + score_idx = self.GetColIndex(score_col) + score_type = self.col_types[score_idx] + if score_type!='int' and score_type!='float': + raise TypeError("Score column must be numeric type") + + class_idx = self.GetColIndex(class_col) + class_type = self.col_types[class_idx] + if class_type!='int' and class_type!='float' and class_type!='bool': + raise TypeError("Classifier column must be numeric or bool type") + + if (score_dir not in ALLOWED_DIR) or (class_dir not in ALLOWED_DIR): + raise ValueError("Direction must be one of %s"%str(ALLOWED_DIR)) + + self.Sort(score_col, score_dir) + + x = [0] + y = [0] + tp = 0 + fp = 0 + old_score_val = None + + for i,row in enumerate(self.rows): + class_val = row[class_idx] + score_val = row[score_idx] + if class_val!=None: + if class_type=='bool': + if class_val==True: + tp += 1 + else: + fp += 1 + else: + if (class_dir=='-' and class_val<=class_cutoff) or (class_dir=='+' and class_val>=class_cutoff): + tp += 1 + else: + fp += 1 + if score_val!=old_score_val: + x.append(fp) + y.append(tp) + old_score_val = score_val + x = [float(v)/x[-1] for v in x] + y = [float(v)/y[-1] for v in y] + return x,y + + def ComputeROCAUC(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0): + ''' + Computes the area under the curve of the receiver operating characteristics + using the trapezoidal rule + ''' + try: + import numpy as np + + rocx, rocy = self.ComputeROC(score_col, class_col, score_dir, + class_dir, class_cutoff) + + return np.trapz(rocy, rocx) + except ImportError: + LogError("Function needs numpy, but I could not import it.") + raise + + def PlotROC(self, score_col, class_col, score_dir='-', + class_dir='-', class_cutoff=2.0, + style='-', title=None, x_title=None, y_title=None, + clear=True, save=None): + ''' + Plot an ROC curve using matplotlib + ''' + + try: + import matplotlib.pyplot as plt + + enrx, enry = self.ComputeROC(score_col, class_col, score_dir, + class_dir, class_cutoff) + + if not title: + title = 'ROC of %s'%score_col + + if not x_title: + x_title = 'false positive rate' + + if not y_title: + y_title = 'true positive rate' + + if clear: + plt.clf() + + plt.plot(enrx, enry, style) + + plt.title(title, size='x-large', fontweight='bold') + plt.ylabel(y_title, size='x-large') + plt.xlabel(x_title, size='x-large') + + if save: + plt.savefig(save) + + return plt + except ImportError: + LogError("Function needs matplotlib, but I could not import it.") + raise + def IsEmpty(self, col_name=None, ignore_nan=True): ''' Checks if a table is empty. @@ -1269,4 +1399,4 @@ def Merge(table1, table2, by, only_matching=False): new_tab.AddRow(row) return new_tab - \ No newline at end of file + diff --git a/modules/base/tests/test_table.py b/modules/base/tests/test_table.py index 2baf48934eb8e58f5a0696b48792c9237059033d..f83b9512c5803fae95c96cb1f2bc84a0b6050f36 100644 --- a/modules/base/tests/test_table.py +++ b/modules/base/tests/test_table.py @@ -11,16 +11,26 @@ import ost HAS_NUMPY=True HAS_MPL=True +HAS_PIL=True try: import numpy as np except ImportError: HAS_NUMPY=False + print "Could not find numpy: ignoring some table class unit tests" try: import matplotlib matplotlib.use('Agg') except ImportError: HAS_MPL=False + print "Could not find matplotlib: ignoring some table class unit tests" + +try: + import Image + import ImageChops +except ImportError: + HAS_PIL=False + print "Could not find python imagine library: ignoring some table class unit tests" class TestTable(unittest.TestCase): @@ -124,6 +134,17 @@ class TestTable(unittest.TestCase): "column type (%s) at column %i, different from reference col type (%s)" \ %(t.col_types[idx], idx, ref_type)) + def CompareImages(self, img1, img2): + ''' + Compares two images based on all pixel values. This function needs the + python imaging library (PIL) package. + ''' + if not HAS_PIL: + return + diff = ImageChops.difference(img1, img2) + self.assertEqual(diff.getbbox(),None) + + def testZip(self): tab=Table(['col1', 'col2', 'col3', 'col4'], 'sssi') tab.AddRow(['a', 'b', 'c', 1]) @@ -914,7 +935,7 @@ class TestTable(unittest.TestCase): class_dir='y') def testPlotEnrichment(self): - if not HAS_MPL: + if not HAS_MPL or not HAS_PIL: return tab = Table(['score', 'rmsd', 'classific'], 'ffb', score=[2.64,1.11,2.17,0.45,0.15,0.85,1.13,2.90,0.50,1.03,1.46,2.83,1.15,2.04,0.67,1.27,2.22,1.90,0.68,0.36,1.04,2.46,0.91,0.60], @@ -923,7 +944,11 @@ class TestTable(unittest.TestCase): pl = tab.PlotEnrichment(score_col='score', score_dir='-', class_col='rmsd', class_cutoff=2.0, - class_dir='-') + class_dir='-', + save=os.path.join("testfiles","enrichment-out.png")) + img1 = Image.open(os.path.join("testfiles","enrichment-out.png")) + img2 = Image.open(os.path.join("testfiles","enrichment.png")) + self.CompareImages(img1, img2) #pl.show() def testCalcEnrichmentAUC(self): @@ -940,7 +965,55 @@ class TestTable(unittest.TestCase): class_dir='-') self.assertAlmostEquals(auc, auc_ref) - + + def testPlotROC(self): + if not HAS_MPL or not HAS_PIL: + return + tab = Table(['classific', 'score'], 'bf', + classific=[True, True, False, True, True, True, False, False, True, False, True, False, True, False, False, False, True, False, True, False], + score=[0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) + pl = tab.PlotROC(score_col='score', score_dir='+', + class_col='classific', + save=os.path.join("testfiles","roc-out.png")) + img1 = Image.open(os.path.join("testfiles","roc-out.png")) + img2 = Image.open(os.path.join("testfiles","roc.png")) + self.CompareImages(img1, img2) + #pl.show() + + def testPlotROCSameValue(self): + if not HAS_MPL or not HAS_PIL: + return + tab = Table(['classific', 'score'], 'bf', + classific=[True, True, False, True, True, True, False, False, True, False, True, False, True, False, False, False, True, False, True, False], + score=[0.9, 0.8, 0.7, 0.7, 0.7, 0.7, 0.53, 0.52, 0.51, 0.505, 0.4, 0.4, 0.4, 0.4, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) + pl = tab.PlotROC(score_col='score', score_dir='+', + class_col='classific', + save=os.path.join("testfiles","roc-same-val-out.png")) + img1 = Image.open(os.path.join("testfiles","roc-same-val-out.png")) + img2 = Image.open(os.path.join("testfiles","roc-same-val.png")) + self.CompareImages(img1, img2) + #pl.show() + + def testCalcROCAUC(self): + if not HAS_NUMPY: + return + auc_ref = 0.68 + tab = Table(['classific', 'score'], 'bf', + classific=[True, True, False, True, True, True, False, False, True, False, True, False, True, False, False, False, True, False, True, False], + score=[0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) + auc = tab.ComputeROCAUC(score_col='score', score_dir='+', class_col='classific') + self.assertAlmostEquals(auc, auc_ref) + + def testCalcROCAUCSameValue(self): + if not HAS_NUMPY: + return + auc_ref = 0.66 + tab = Table(['classific', 'score'], 'bf', + classific=[True, True, False, True, True, True, False, False, True, False, True, False, True, False, False, False, True, False, True, False], + score=[0.9, 0.8, 0.7, 0.7, 0.7, 0.7, 0.53, 0.52, 0.51, 0.505, 0.4, 0.4, 0.4, 0.4, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) + auc = tab.ComputeROCAUC(score_col='score', score_dir='+', class_col='classific') + self.assertAlmostEquals(auc, auc_ref) + def testTableAsNumpyMatrix(self): ''' diff --git a/modules/base/tests/testfiles/enrichment.png b/modules/base/tests/testfiles/enrichment.png new file mode 100644 index 0000000000000000000000000000000000000000..6fa5157f913a556a3d27eccf63479d92f4eb806d Binary files /dev/null and b/modules/base/tests/testfiles/enrichment.png differ diff --git a/modules/base/tests/testfiles/roc-same-val.png b/modules/base/tests/testfiles/roc-same-val.png new file mode 100644 index 0000000000000000000000000000000000000000..9599e26e92887cd2f97702c9e55c44dcef31557c Binary files /dev/null and b/modules/base/tests/testfiles/roc-same-val.png differ diff --git a/modules/base/tests/testfiles/roc.png b/modules/base/tests/testfiles/roc.png new file mode 100644 index 0000000000000000000000000000000000000000..d02985dacd945e6dfc57301d99994b098c0879dd Binary files /dev/null and b/modules/base/tests/testfiles/roc.png differ