Package pyvision :: Package analysis :: Module roc
[hide private]
[frames] | no frames]

Source Code for Module pyvision.analysis.roc

  1  # PyVision License 
  2  # 
  3  # Copyright (c) 2006-2008 David S. Bolme 
  4  # All rights reserved. 
  5  # 
  6  # Redistribution and use in source and binary forms, with or without 
  7  # modification, are permitted provided that the following conditions 
  8  # are met: 
  9  #  
 10  # 1. Redistributions of source code must retain the above copyright 
 11  # notice, this list of conditions and the following disclaimer. 
 12  #  
 13  # 2. Redistributions in binary form must reproduce the above copyright 
 14  # notice, this list of conditions and the following disclaimer in the 
 15  # documentation and/or other materials provided with the distribution. 
 16  #  
 17  # 3. Neither name of copyright holders nor the names of its contributors 
 18  # may be used to endorse or promote products derived from this software 
 19  # without specific prior written permission. 
 20  #  
 21  #  
 22  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 23  # ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
 24  # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
 25  # A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR 
 26  # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 
 27  # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
 28  # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 
 29  # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 
 30  # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 
 31  # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 
 32  # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 33   
 34  import numpy as np 
 35  import pyvision as pv 
 36  import unittest 
 37       
38 -def buildPositiveNegativeLists(names,matrix,class_equal):
39 positive = [] 40 negative = [] 41 for i in range(len(names)): 42 for j in range(i+1,len(names)): 43 if class_equal(names[i],names[j]): 44 positive.append(matrix[i][j]) 45 else: 46 negative.append(matrix[i][j]) 47 return positive,negative
48
49 -def readCsuDistanceMatrix(directory):
50 from os import listdir 51 from os.path import join 52 53 filenames = [] 54 for each in listdir(directory): 55 if each[-4:] == ".sfi": 56 filenames.append(each) 57 58 filenames.sort() 59 matrix = [] 60 for filename in filenames: 61 f = open(join(directory,filename),'r') 62 row = [] 63 count = 0 64 for line in f: 65 fname,sim = line.split() 66 assert fname == filenames[count] 67 sim = -float(sim) 68 row.append(sim) 69 count += 1 70 f.close() 71 assert len(row) == len(filenames) 72 matrix.append(row) 73 assert len(matrix) == len(filenames) 74 75 return filenames,matrix
76
77 -class ROCPoint:
78 - def __init__(self,nscore,nidx,n,far,mscore,midx,m,frr):
79 self.nscore,self.nidx,self.n,self.far,self.mscore,self.midx,self.m,self.frr = nscore,nidx,n,far,mscore,midx,m,frr 80 self.tar = 1.0 - self.frr 81 self.trr = 1.0 - self.far
82
83 - def __str__(self):
84 return "ROCPoint %f FRR at %f FAR"%(self.frr,self.far)
85 86 ROC_LOG_SAMPLED = 1 87 ROC_MATCH_SAMPLED = 2 88 ROC_PRECISE_SAMPLED = 3 89 ROC_PRECISE_ALL = 4 90
91 -class ROC:
92 93 # TODO: add options for area under curve (AUC) and equal error rate (EER) 94
95 - def __init__(self,match,nonmatch,is_distance=True):
96 self.match = np.array(match).copy() 97 self.nonmatch = np.array(nonmatch).copy() 98 self.is_distance = is_distance 99 100 if not is_distance: 101 self.match = -self.match 102 self.nonmatch = -self.nonmatch 103 104 105 self.match.sort() 106 self.nonmatch.sort()
107 108
109 - def getCurve(self,method=ROC_LOG_SAMPLED):
110 """ 111 returns header,rows 112 """ 113 header = ["score","frr","far","trr","tar"] 114 rows = [] 115 116 if method == ROC_LOG_SAMPLED: 117 for far in 10**np.arange(-6,0.0000001,0.01): 118 point = self.getFAR(far) 119 120 row = [point.nscore,point.frr,point.far,point.trr,point.tar] 121 rows.append(row) 122 123 if method == ROC_MATCH_SAMPLED: 124 for score in self.match: 125 if self.is_distance: 126 point = self.getMatch(score) 127 else: 128 point = self.getMatch(-score) 129 row = [point.nscore,point.frr,point.far,point.trr,point.tar] 130 rows.append(row) 131 132 if method == ROC_PRECISE_SAMPLED: 133 m = len(self.match) 134 n = len(self.nonmatch) 135 both = np.concatenate([self.match,self.nonmatch]) 136 matches = np.array(len(self.match)*[1]+len(self.nonmatch)*[0]) 137 nonmatches = np.array(len(self.match)*[0]+len(self.nonmatch)*[1]) 138 order = both.argsort() 139 scores = both[order] 140 matches = matches[order] 141 nonmatches = nonmatches[order] 142 tar = matches.cumsum()/float(m) 143 far = nonmatches.cumsum()/float(n) 144 keep = np.ones(len(tar),dtype=np.bool) 145 keep[1:-1][(far[:-2] == far[1:-1]) & (far[2:] == far[1:-1])] = False 146 keep[1:-1][(tar[:-2] == tar[1:-1]) & (tar[2:] == tar[1:-1])] = False 147 scores = scores[keep] 148 tar = tar[keep] 149 far = far[keep] 150 rows = np.array([scores,1.0-tar,far,1.0-far,tar]).T 151 152 if method == ROC_PRECISE_ALL: 153 m = len(self.match) 154 n = len(self.nonmatch) 155 both = np.concatenate([self.match,self.nonmatch]) 156 matches = np.array(len(self.match)*[1]+len(self.nonmatch)*[0]) 157 nonmatches = np.array(len(self.match)*[0]+len(self.nonmatch)*[1]) 158 order = both.argsort() 159 scores = both[order] 160 matches = matches[order] 161 nonmatches = nonmatches[order] 162 tar = matches.cumsum()/float(m) 163 far = nonmatches.cumsum()/float(n) 164 rows = np.array([scores,1.0-tar,far,1.0-far,tar]).T 165 166 return header,rows
167 168
169 - def getFAR(self,far):
170 match = self.match 171 nonmatch = self.nonmatch 172 #orig_far = far 173 174 m = len(match) 175 n = len(nonmatch) 176 177 nidx = int(round(far*n)) 178 far = float(nidx)/n 179 if nidx >= len(nonmatch): 180 nscore = None 181 #elif nidx == 0: 182 # nscore = nonmatch[nidx] 183 else: 184 nscore = nonmatch[nidx] 185 186 if nscore != None: 187 midx = np.searchsorted(match,nscore,side='left') 188 else: 189 midx = m 190 191 frr = 1.0-float(midx)/m 192 if midx >= len(match): 193 mscore = None 194 else: 195 mscore = match[midx] 196 197 #assert mscore == None or mscore <= nscore 198 199 #if nidx == 0: 200 #print "Zero:",orig_far,nscore,nidx,n,far,mscore,midx,m,frr 201 #print nonmatch 202 #print match 203 if self.is_distance: 204 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr) 205 else: 206 if nscore != None: 207 nscore = -nscore 208 if mscore != None: 209 mscore = -mscore 210 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
211
212 - def getFRR(self,frr):
213 match = self.match 214 nonmatch = self.nonmatch 215 216 m = len(match) 217 n = len(nonmatch) 218 219 midx = int(round((1.0-frr)*m)) 220 frr = 1.0 - float(midx)/m 221 if midx >= len(match): 222 mscore = None 223 else: 224 mscore = match[midx-1] 225 226 nidx = np.searchsorted(nonmatch,mscore) 227 far = float(nidx)/n 228 if nidx-1 < 0: 229 nscore = None 230 else: 231 nscore = nonmatch[nidx-1] 232 233 assert nscore == None or mscore >= nscore 234 235 if self.is_distance: 236 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr) 237 else: 238 if nscore != None: 239 nscore = -nscore 240 if mscore != None: 241 mscore = -mscore 242 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
243 244
245 - def getEER(self):
246 _,curve = self.getCurve(method=ROC_PRECISE_SAMPLED) 247 248 for _,frr,far,_,_ in curve: 249 if far > frr: 250 break 251 252 return far
253 254
255 - def getMatch(self,mscore):
256 if not self.is_distance: 257 mscore = -mscore 258 match = self.match 259 nonmatch = self.nonmatch 260 261 m = len(match) 262 n = len(nonmatch) 263 264 midx = np.searchsorted(match,mscore) 265 #midx = int(round((1.0-frr)*m)) 266 frr = 1.0 - float(midx)/m 267 268 nidx = np.searchsorted(nonmatch,mscore) 269 far = float(nidx)/n 270 if nidx-1 < 0: 271 nscore = None 272 else: 273 nscore = nonmatch[nidx-1] 274 275 assert nscore == None or mscore >= nscore 276 277 if self.is_distance: 278 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr) 279 else: 280 if nscore != None: 281 nscore = -nscore 282 if mscore != None: 283 mscore = -mscore 284 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
285 286
287 - def results(self):
288 table = pv.Table(default_value="") 289 290 pt = self.getFAR(0.001) 291 table[0,'FAR'] = 0.001 292 table[0,'TAR'] = pt.tar 293 294 pt = self.getFAR(0.01) 295 table[1,'FAR'] = 0.01 296 table[1,'TAR'] = pt.tar 297 298 pt = self.getFAR(0.1) 299 table[2,'FAR'] = 0.1 300 table[2,'TAR'] = pt.tar 301 302 table[3,'EER'] = self.getEER() 303 table[4,'AUC'] = self.getAUC() 304 305 return table
306
307 - def plot(self,plot,method=ROC_PRECISE_SAMPLED,**kwargs):
308 _,curve = self.getCurve(method=method) 309 points = [[0.0,0.0]] + [ [x,y] for _,_,x,_,y in curve ]+[[1.0,1.0]] 310 plot.lines(points,**kwargs)
311 312
313 - def getAUC(self,**kwargs):
314 _,curve = self.getCurve(method=ROC_PRECISE_SAMPLED) 315 points = [[0.0,0.0]] + [ [x,y] for _,_,x,_,y in curve ]+[[1.0,1.0]] 316 auc = 0.0 317 for i in range(len(points)-1): 318 y = 0.5*(points[i][1] + points[i+1][1]) 319 dx = points[i+1][0] - points[i][0] 320 auc += y*dx 321 return auc
322 323
324 -class ROCTest(unittest.TestCase):
325
326 - def setUp(self):
327 self.match = [-0.3126333774819825, 1.0777130777174635, 1.1045667643589598, 1.022042510130833, -0.58552060836929942, 328 -0.59682041549981257, -1.4873074501595509, -0.49958344415133116, 0.36814022366653204, 0.9292572191289511, 329 0.56740023418734642, -1.3117888037744228, 1.7695340517922449, 0.4098641799520919, 0.43642273019233646, 330 -0.14893755966202349, -1.3490978540595631, 0.18192684849996424, 1.4547096287864199, 1.1698331636208563, 331 0.40439133210485323, -1.2333503530027063, -0.1765228044654879, 0.070450455376130774, -0.85038212096409027, 332 1.6679580794589872, 1.1589669301436729, 1.1756719870079611, -1.0799654160891785, -0.11025751625199756, 333 0.098294009710337069, -0.49832134960232527, -1.4626964355118197, 1.1064208531539006, -0.4251178714268497, 334 1.297279496554774, -1.9318553699779215, -1.2787762925010133, 0.92426958166955997, 0.38501300779378478, 335 -1.7823019361063408, -0.43568112010605503, 0.65785964631537774, -0.63359960475947019, -0.02194247979690072, 336 -0.55595471945130093, -0.8043184500851891, 0.13759846217215868, 0.12524112107182517, 0.48665310853849575, 337 -1.2285460272311253, -1.7721136485547013, 1.4552123210449597, -0.38319646950962838, 0.96456771860484702, 338 0.24739740122504011, -0.38962322566309304, -0.49974207901118639, -1.4515801398271369, 1.0736452978649289, 339 0.55985898085565033, -0.43789279416506094, 0.48021091037667496, 1.8414133735020126, 1.8695789066643793, 340 0.56021842531028732, -0.678323243576336, -0.94407219986362523, -0.33987307773274095, -0.71991668517746144, 341 1.0625139713435376, -1.8026944722350828, 1.8903853852837578, 0.2475468598692494, -0.70834534737086463, 342 -0.62816536381195498, 0.37297277354517611, 0.034474071621219016, 0.47274333081191594, -2.3662542473841786, 343 1.8813720711684221, -0.29916037509951754, -0.57712027528715559, 0.27431335749394231, 0.46414272602323764, 344 -0.61367838919068374, -0.48441048748772131, 0.7807315137448595, 0.5057878952931828, -0.33232362411214894, 345 -0.77896199497583019, 0.81373804337730904, -1.9957402084896527, 1.7976405059518497, 1.2302892852847949, 346 0.67699419193473098, -0.51325483082725243, 0.857942641750577, 1.4295866533235857, -0.76819949833721834] 347 348 self.nonmatch = [1.417052501689489, -0.043563190732366364, 1.6036891630756054, 0.66248145163751671, 0.052384028443254405, 349 0.59629061593353161, 0.82947993373378237, 1.115113519426044, 0.67551158941676637, 1.8422107418890203, 350 0.84941135662024392, 1.1996391852657751, 0.94030154845981673, 2.3269103026602771, -0.030603020790364033, 351 1.258988565904706, 2.9637747860603456, 1.8173999730963109, 0.71892491243068934, 0.81740037138666277, 352 1.7601258039962009, 3.1707523951166898, 0.66982205389142613, 1.6097271105344255, 1.189734646321116, 353 0.22708332837080747, 0.84698202914050347, 1.7635878414797439, 2.3830213681725447, 2.5497367162352944, 354 2.635862209152271, -0.21290078686666103, 1.4048627271264558, 0.72941226308255924, 0.85692961327062467, 355 0.97820944194897774, -0.15500601865255503, 0.58435763771835081, 2.5992330339800831, -0.87305656967588074, 356 0.69311232136547551, 1.1302262899327531, 0.71334154902008384, 0.35695476951005345, -0.5187124559973717, 357 2.024435812626129, 0.26963199371831936, -0.46510024343728285, -0.19970133295471326, 2.0355468834785726, 358 0.82313200923780616, 0.30440704254838935, 0.93632925544825862, 1.9575547911114448, 1.2245628328633855, 359 1.0878755116923233, 2.1602536867629665, 0.04070893565830036, 2.3369676117570961, 1.9724448182299648, 360 1.9850705023975075, 1.015833476781514, 2.4223167168334743, 0.061707944792565028, 0.94626273945251693, 361 1.210865335077099, 1.1145727311637936, 2.8519712553054348, 0.93533306721111675, -0.0060786748305075022, 362 1.9322277720024843, 0.65603343285714444, 1.194849545457592, 0.27772775162736463, 0.078490050192145722, 363 -1.4721630242727111, 1.854285772101625, 1.6112593328478453, 1.8560106579121847, 2.540591694748537, 364 1.7772416902829931, -0.20781473501608816, 1.5221307283377219, 0.1579604472392464, -0.30160614059311297, 365 0.80127729857699337, 1.269704867230514, 2.0490141432761941, 2.0273848755661028, 1.0147875805479856, 366 -0.06676206771791926, 2.1662293957716994, 2.1413537986988493, 0.9046180315857989, -1.0291168800124986, 367 1.0301894509766261, 1.1930459134315883, 0.66868219673238327, 0.43346537494032156, -1.0433576612271738]
368 369
370 - def testFAR(self):
371 roc = pv.ROC(self.match,self.nonmatch,is_distance=True) 372 373 result = roc.getFAR(0.1) 374 self.assertAlmostEqual(result.far,0.1) 375 self.assertAlmostEqual(result.tar,0.45) 376 377 result = roc.getFAR(0.01) 378 self.assertAlmostEqual(result.far,0.01) 379 self.assertAlmostEqual(result.tar,0.15)
380
381 - def testFRR(self):
382 roc = pv.ROC(self.match,self.nonmatch,is_distance=True) 383 384 result = roc.getFRR(0.5) 385 self.assertAlmostEqual(result.far,0.18) 386 self.assertAlmostEqual(result.tar,0.50) 387 388 result = roc.getFRR(0.80) 389 self.assertAlmostEqual(result.far,0.04) 390 self.assertAlmostEqual(result.tar,0.20)
391
392 - def testEER(self):
393 roc = pv.ROC(self.match,self.nonmatch,is_distance=True) 394 395 eer = roc.getEER() 396 self.assertAlmostEqual(eer,0.29)
397
398 - def testAUC(self):
399 roc = pv.ROC(self.match,self.nonmatch,is_distance=True) 400 auc = roc.getAUC() 401 self.assertAlmostEqual(auc,0.7608)
402 403
404 - def testPlot(self):
405 roc = pv.ROC(self.match,self.nonmatch,is_distance=True) 406 plot = pv.Plot() 407 roc.plot(plot,method=ROC_PRECISE_SAMPLED,color="red",width=5) 408 roc.plot(plot,method=ROC_PRECISE_ALL,color='black') 409 plot.lines([[0,1],[1,0]]) 410 eer = roc.getEER() 411 plot.point([eer,1-eer]) 412 plot.asImage()
413