diff --git a/modules/base/pymod/table.py b/modules/base/pymod/table.py index 9338a1f5d776ee9dcc2ca26d4e4bb419cf6024f2..87a425ebcf9e61852e8709ca26ab6bee526b0ce7 100644 --- a/modules/base/pymod/table.py +++ b/modules/base/pymod/table.py @@ -907,6 +907,34 @@ class Table: except: return None + def SpearmanCorrel(self, col1, col2): + """ + Calculate the Spearman correlation coefficient between col1 and col2, only + taking rows into account where both of the values are not equal to None. If + there are not enough data points to calculate a correlation coefficient, + None is returned. + """ + try: + import scipy.stats.mstats + + if IsStringLike(col1) and IsStringLike(col2): + col1 = self.GetColIndex(col1) + col2 = self.GetColIndex(col2) + vals1, vals2=([],[]) + for v1, v2 in zip(self[col1], self[col2]): + if v1!=None and v2!=None: + vals1.append(v1) + vals2.append(v2) + try: + return scipy.stats.mstats.spearmanr(vals1, vals2)[0] + except: + return None + + except ImportError: + LogError("Function needs scipy, but I could not import it.") + raise + + def Save(self, stream): """ Save the table to stream or filename diff --git a/modules/base/tests/test_table.py b/modules/base/tests/test_table.py index 617f110dd1412858a97ad3b1c9d4d5eab009e857..4220de12fd65d5caa7d83b7d8a9973916dee24b4 100644 --- a/modules/base/tests/test_table.py +++ b/modules/base/tests/test_table.py @@ -10,6 +10,7 @@ from ost.table import * import ost HAS_NUMPY=True +HAS_SCIPY=True HAS_MPL=True HAS_PIL=True try: @@ -18,13 +19,19 @@ except ImportError: HAS_NUMPY=False print "Could not find numpy: ignoring some table class unit tests" +try: + import scipy +except ImportError: + HAS_SCIPY=False + print "Could not find scipy: ignoring some table class unit tests" + try: import matplotlib matplotlib.use('Agg') except ImportError: HAS_MPL=False print "Could not find matplotlib: ignoring some table class unit tests" - + try: import Image import ImageChops @@ -1156,7 +1163,25 @@ class TestTable(unittest.TestCase): self.assertEquals(tab.GetUnique('second', ignore_nan=False), [3,None,9,4,5]) self.assertEquals(tab.GetUnique('third'), [2.2, 3.3, 6.3]) self.assertEquals(tab.GetUnique('third', ignore_nan=False), [None, 2.2, 3.3, 6.3]) - + + def testCorrel(self): + tab = self.CreateTestTable() + self.assertEquals(tab.Correl('second','third'), None) + tab.AddRow(['foo',4, 3.3]) + tab.AddRow([None,5, 6.3]) + tab.AddRow([None,8, 2]) + self.assertAlmostEquals(tab.Correl('second','third'), -0.4954982578) + + def testSpearmanCorrel(self): + if not HAS_SCIPY: + return + tab = self.CreateTestTable() + self.assertEquals(tab.SpearmanCorrel('second','third'), None) + tab.AddRow(['foo',4, 3.3]) + tab.AddRow([None,5, 6.3]) + tab.AddRow([None,8, 2]) + self.assertAlmostEquals(tab.SpearmanCorrel('second','third'), -0.316227766) + if __name__ == "__main__": try: unittest.main()