diff --git a/modules/base/pymod/table.py b/modules/base/pymod/table.py index 43a94db670535e5dfed9b2f964229c7c34133875..9338a1f5d776ee9dcc2ca26d4e4bb419cf6024f2 100644 --- a/modules/base/pymod/table.py +++ b/modules/base/pymod/table.py @@ -894,6 +894,9 @@ class Table: there are not enough data points to calculate a correlation coefficient, None is returned. """ + if IsStringLike(col1) and IsStringLike(col2): + col1 = self.GetColIndex(col1) + col2 = self.GetColIndex(col2) vals1, vals2=([],[]) for v1, v2 in zip(self[col1], self[col2]): if v1!=None and v2!=None: @@ -1203,6 +1206,12 @@ class Table: class_val = row[class_idx] score_val = row[score_idx] if class_val!=None: + if old_score_val==None: + old_score_val = score_val + if score_val!=old_score_val: + x.append(fp) + y.append(tp) + old_score_val = score_val if class_type=='bool': if class_val==True: tp += 1 @@ -1213,10 +1222,8 @@ class Table: tp += 1 else: fp += 1 - if score_val!=old_score_val: - x.append(fp) - y.append(tp) - old_score_val = score_val + x.append(fp) + y.append(tp) x = [float(v)/x[-1] for v in x] y = [float(v)/y[-1] for v in y] return x,y diff --git a/modules/base/tests/test_table.py b/modules/base/tests/test_table.py index f41418193b943e309bb0930c2448e4e84bab4e5d..617f110dd1412858a97ad3b1c9d4d5eab009e857 100644 --- a/modules/base/tests/test_table.py +++ b/modules/base/tests/test_table.py @@ -980,7 +980,7 @@ class TestTable(unittest.TestCase): self.CompareImages(img1, img2) #pl.show() - def testPlotROCSameValue(self): + def testPlotROCSameValues(self): if not HAS_MPL or not HAS_PIL: return tab = Table(['classific', 'score'], 'bf', @@ -1004,10 +1004,25 @@ class TestTable(unittest.TestCase): auc = tab.ComputeROCAUC(score_col='score', score_dir='+', class_col='classific') self.assertAlmostEquals(auc, auc_ref) - def testCalcROCAUCSameValue(self): + def testCalcROC(self): if not HAS_NUMPY: return - auc_ref = 0.66 + tab = Table(['classific', 'score'], 'ff', + classific=[0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1], + score=[0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) + auc = tab.ComputeROCAUC(score_col='score', class_col='classific', class_cutoff=0.5) + self.assertEquals(auc, 1.0) + + def testCalcROCFromFile(self): + tab = Table.Load(os.path.join('testfiles','roc_table.dat')) + auc = tab.ComputeROCAUC(score_col='prediction', class_col='reference', class_cutoff=0.4) + self.assertEquals(auc, 1.0) + + + def testCalcROCAUCSameValues(self): + if not HAS_NUMPY: + return + auc_ref = 0.685 tab = Table(['classific', 'score'], 'bf', classific=[True, True, False, True, True, True, False, False, True, False, True, False, True, False, False, False, True, False, True, False], score=[0.9, 0.8, 0.7, 0.7, 0.7, 0.7, 0.53, 0.52, 0.51, 0.505, 0.4, 0.4, 0.4, 0.4, 0.36, 0.35, 0.34, 0.33, 0.30, 0.1]) diff --git a/modules/base/tests/testfiles/roc-same-val.png b/modules/base/tests/testfiles/roc-same-val.png index 9599e26e92887cd2f97702c9e55c44dcef31557c..b459c553da696100743b9b41dc09e8dcbeaef5a3 100644 Binary files a/modules/base/tests/testfiles/roc-same-val.png and b/modules/base/tests/testfiles/roc-same-val.png differ diff --git a/modules/base/tests/testfiles/roc_table.dat b/modules/base/tests/testfiles/roc_table.dat new file mode 100644 index 0000000000000000000000000000000000000000..2de91a2691f99683a6f86658e6f71b136cc57a6d --- /dev/null +++ b/modules/base/tests/testfiles/roc_table.dat @@ -0,0 +1,152 @@ +rnum[int] reference[float] prediction[float] +1 0.0 0.0 +2 0.0 0.0 +3 0.0 0.0 +4 0.0 0.0 +5 0.2 0.2 +6 0.0 0.0 +7 0.0 0.0 +8 1.0 1.0 +9 0.2 0.2 +10 0.2 0.2 +11 0.2 0.2 +12 1.0 1.0 +13 0.0 0.0 +14 0.0 0.0 +15 0.0 0.0 +16 0.0 0.0 +17 0.0 0.0 +18 0.0 0.0 +19 0.0 0.0 +20 0.0 0.0 +21 0.0 0.0 +22 0.0 0.0 +23 0.0 0.0 +24 0.0 0.0 +25 0.0 0.0 +26 0.0 0.0 +27 0.0 0.0 +28 0.0 0.0 +29 0.0 0.0 +30 0.0 0.0 +31 0.0 0.0 +32 0.2 0.2 +33 0.0 0.0 +34 0.0 0.0 +35 0.0 0.0 +36 0.0 0.0 +37 0.0 0.0 +38 0.0 0.0 +39 0.0 0.0 +40 0.0 0.0 +41 0.0 0.0 +42 0.0 0.0 +43 0.0 0.0 +44 0.0 0.0 +45 0.0 0.0 +46 0.0 0.0 +47 0.0 0.0 +48 0.0 0.0 +49 0.0 0.0 +50 0.0 0.0 +51 0.0 0.0 +52 0.0 0.0 +53 0.0 0.0 +54 0.0 0.0 +55 0.0 0.0 +56 0.0 0.0 +57 0.0 0.0 +58 0.0 0.0 +59 0.0 0.0 +60 0.0 0.0 +61 0.0 0.0 +62 0.0 0.0 +63 0.0 0.0 +64 0.0 0.0 +65 0.0 0.0 +66 0.0 0.0 +67 0.0 0.0 +68 0.0 0.0 +69 0.0 0.0 +70 0.0 0.0 +71 0.2 0.2 +72 0.0 0.0 +73 0.2 0.2 +74 0.0 0.0 +75 0.0 0.0 +76 0.0 0.0 +77 0.0 0.0 +78 0.0 0.0 +79 0.0 0.0 +80 0.0 0.0 +81 0.0 0.0 +82 0.0 0.0 +83 0.0 0.0 +84 0.0 0.0 +85 0.0 0.0 +86 0.0 0.0 +87 0.0 0.0 +88 0.0 0.0 +89 0.0 0.0 +90 0.0 0.0 +91 0.2 0.2 +92 0.0 0.0 +93 0.0 0.0 +94 0.0 0.0 +95 0.0 0.0 +96 0.0 0.0 +97 0.0 0.0 +98 0.0 0.0 +99 0.0 0.0 +100 0.0 0.0 +101 0.0 0.0 +102 0.0 0.0 +103 0.2 0.2 +104 0.2 0.2 +105 0.5 0.5 +106 0.0 0.0 +107 0.0 0.0 +108 0.0 0.0 +109 0.0 0.0 +110 0.0 0.0 +111 0.2 0.2 +112 0.0 0.0 +113 0.5 0.5 +114 0.2 0.2 +115 1.0 1.0 +116 0.5 0.5 +117 0.5 0.5 +118 0.5 0.5 +119 1.0 1.0 +120 0.0 0.0 +121 0.0 0.0 +122 0.0 0.0 +123 0.0 0.0 +124 0.0 0.0 +125 0.0 0.0 +126 0.0 0.0 +127 0.0 0.0 +128 0.0 0.0 +129 0.0 0.0 +130 0.0 0.0 +131 0.0 0.0 +132 0.0 0.0 +133 0.0 0.0 +134 0.0 0.0 +135 0.0 0.0 +136 0.0 0.0 +137 0.0 0.0 +138 0.0 0.0 +139 0.0 0.0 +140 0.0 0.0 +141 0.0 0.0 +142 0.0 0.0 +143 0.0 0.0 +144 0.0 0.0 +145 0.0 0.0 +146 0.0 0.0 +147 0.0 0.0 +148 0.0 0.0 +149 0.0 0.0 +150 0.0 0.0 +151 0.0 0.0