1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 import numpy as np
35 import pyvision as pv
36 import unittest
37
39 positive = []
40 negative = []
41 for i in range(len(names)):
42 for j in range(i+1,len(names)):
43 if class_equal(names[i],names[j]):
44 positive.append(matrix[i][j])
45 else:
46 negative.append(matrix[i][j])
47 return positive,negative
48
50 from os import listdir
51 from os.path import join
52
53 filenames = []
54 for each in listdir(directory):
55 if each[-4:] == ".sfi":
56 filenames.append(each)
57
58 filenames.sort()
59 matrix = []
60 for filename in filenames:
61 f = open(join(directory,filename),'r')
62 row = []
63 count = 0
64 for line in f:
65 fname,sim = line.split()
66 assert fname == filenames[count]
67 sim = -float(sim)
68 row.append(sim)
69 count += 1
70 f.close()
71 assert len(row) == len(filenames)
72 matrix.append(row)
73 assert len(matrix) == len(filenames)
74
75 return filenames,matrix
76
78 - def __init__(self,nscore,nidx,n,far,mscore,midx,m,frr):
79 self.nscore,self.nidx,self.n,self.far,self.mscore,self.midx,self.m,self.frr = nscore,nidx,n,far,mscore,midx,m,frr
80 self.tar = 1.0 - self.frr
81 self.trr = 1.0 - self.far
82
84 return "ROCPoint %f FRR at %f FAR"%(self.frr,self.far)
85
86 ROC_LOG_SAMPLED = 1
87 ROC_MATCH_SAMPLED = 2
88 ROC_PRECISE_SAMPLED = 3
89 ROC_PRECISE_ALL = 4
90
92
93
94
95 - def __init__(self,match,nonmatch,is_distance=True):
96 self.match = np.array(match).copy()
97 self.nonmatch = np.array(nonmatch).copy()
98 self.is_distance = is_distance
99
100 if not is_distance:
101 self.match = -self.match
102 self.nonmatch = -self.nonmatch
103
104
105 self.match.sort()
106 self.nonmatch.sort()
107
108
110 """
111 returns header,rows
112 """
113 header = ["score","frr","far","trr","tar"]
114 rows = []
115
116 if method == ROC_LOG_SAMPLED:
117 for far in 10**np.arange(-6,0.0000001,0.01):
118 point = self.getFAR(far)
119
120 row = [point.nscore,point.frr,point.far,point.trr,point.tar]
121 rows.append(row)
122
123 if method == ROC_MATCH_SAMPLED:
124 for score in self.match:
125 if self.is_distance:
126 point = self.getMatch(score)
127 else:
128 point = self.getMatch(-score)
129 row = [point.nscore,point.frr,point.far,point.trr,point.tar]
130 rows.append(row)
131
132 if method == ROC_PRECISE_SAMPLED:
133 m = len(self.match)
134 n = len(self.nonmatch)
135 both = np.concatenate([self.match,self.nonmatch])
136 matches = np.array(len(self.match)*[1]+len(self.nonmatch)*[0])
137 nonmatches = np.array(len(self.match)*[0]+len(self.nonmatch)*[1])
138 order = both.argsort()
139 scores = both[order]
140 matches = matches[order]
141 nonmatches = nonmatches[order]
142 tar = matches.cumsum()/float(m)
143 far = nonmatches.cumsum()/float(n)
144 keep = np.ones(len(tar),dtype=np.bool)
145 keep[1:-1][(far[:-2] == far[1:-1]) & (far[2:] == far[1:-1])] = False
146 keep[1:-1][(tar[:-2] == tar[1:-1]) & (tar[2:] == tar[1:-1])] = False
147 scores = scores[keep]
148 tar = tar[keep]
149 far = far[keep]
150 rows = np.array([scores,1.0-tar,far,1.0-far,tar]).T
151
152 if method == ROC_PRECISE_ALL:
153 m = len(self.match)
154 n = len(self.nonmatch)
155 both = np.concatenate([self.match,self.nonmatch])
156 matches = np.array(len(self.match)*[1]+len(self.nonmatch)*[0])
157 nonmatches = np.array(len(self.match)*[0]+len(self.nonmatch)*[1])
158 order = both.argsort()
159 scores = both[order]
160 matches = matches[order]
161 nonmatches = nonmatches[order]
162 tar = matches.cumsum()/float(m)
163 far = nonmatches.cumsum()/float(n)
164 rows = np.array([scores,1.0-tar,far,1.0-far,tar]).T
165
166 return header,rows
167
168
170 match = self.match
171 nonmatch = self.nonmatch
172
173
174 m = len(match)
175 n = len(nonmatch)
176
177 nidx = int(round(far*n))
178 far = float(nidx)/n
179 if nidx >= len(nonmatch):
180 nscore = None
181
182
183 else:
184 nscore = nonmatch[nidx]
185
186 if nscore != None:
187 midx = np.searchsorted(match,nscore,side='left')
188 else:
189 midx = m
190
191 frr = 1.0-float(midx)/m
192 if midx >= len(match):
193 mscore = None
194 else:
195 mscore = match[midx]
196
197
198
199
200
201
202
203 if self.is_distance:
204 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
205 else:
206 if nscore != None:
207 nscore = -nscore
208 if mscore != None:
209 mscore = -mscore
210 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
211
213 match = self.match
214 nonmatch = self.nonmatch
215
216 m = len(match)
217 n = len(nonmatch)
218
219 midx = int(round((1.0-frr)*m))
220 frr = 1.0 - float(midx)/m
221 if midx >= len(match):
222 mscore = None
223 else:
224 mscore = match[midx-1]
225
226 nidx = np.searchsorted(nonmatch,mscore)
227 far = float(nidx)/n
228 if nidx-1 < 0:
229 nscore = None
230 else:
231 nscore = nonmatch[nidx-1]
232
233 assert nscore == None or mscore >= nscore
234
235 if self.is_distance:
236 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
237 else:
238 if nscore != None:
239 nscore = -nscore
240 if mscore != None:
241 mscore = -mscore
242 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
243
244
246 _,curve = self.getCurve(method=ROC_PRECISE_SAMPLED)
247
248 for _,frr,far,_,_ in curve:
249 if far > frr:
250 break
251
252 return far
253
254
256 if not self.is_distance:
257 mscore = -mscore
258 match = self.match
259 nonmatch = self.nonmatch
260
261 m = len(match)
262 n = len(nonmatch)
263
264 midx = np.searchsorted(match,mscore)
265
266 frr = 1.0 - float(midx)/m
267
268 nidx = np.searchsorted(nonmatch,mscore)
269 far = float(nidx)/n
270 if nidx-1 < 0:
271 nscore = None
272 else:
273 nscore = nonmatch[nidx-1]
274
275 assert nscore == None or mscore >= nscore
276
277 if self.is_distance:
278 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
279 else:
280 if nscore != None:
281 nscore = -nscore
282 if mscore != None:
283 mscore = -mscore
284 return ROCPoint(nscore,nidx,n,far,mscore,midx,m,frr)
285
286
306
311
312
322
323
325
327 self.match = [-0.3126333774819825, 1.0777130777174635, 1.1045667643589598, 1.022042510130833, -0.58552060836929942,
328 -0.59682041549981257, -1.4873074501595509, -0.49958344415133116, 0.36814022366653204, 0.9292572191289511,
329 0.56740023418734642, -1.3117888037744228, 1.7695340517922449, 0.4098641799520919, 0.43642273019233646,
330 -0.14893755966202349, -1.3490978540595631, 0.18192684849996424, 1.4547096287864199, 1.1698331636208563,
331 0.40439133210485323, -1.2333503530027063, -0.1765228044654879, 0.070450455376130774, -0.85038212096409027,
332 1.6679580794589872, 1.1589669301436729, 1.1756719870079611, -1.0799654160891785, -0.11025751625199756,
333 0.098294009710337069, -0.49832134960232527, -1.4626964355118197, 1.1064208531539006, -0.4251178714268497,
334 1.297279496554774, -1.9318553699779215, -1.2787762925010133, 0.92426958166955997, 0.38501300779378478,
335 -1.7823019361063408, -0.43568112010605503, 0.65785964631537774, -0.63359960475947019, -0.02194247979690072,
336 -0.55595471945130093, -0.8043184500851891, 0.13759846217215868, 0.12524112107182517, 0.48665310853849575,
337 -1.2285460272311253, -1.7721136485547013, 1.4552123210449597, -0.38319646950962838, 0.96456771860484702,
338 0.24739740122504011, -0.38962322566309304, -0.49974207901118639, -1.4515801398271369, 1.0736452978649289,
339 0.55985898085565033, -0.43789279416506094, 0.48021091037667496, 1.8414133735020126, 1.8695789066643793,
340 0.56021842531028732, -0.678323243576336, -0.94407219986362523, -0.33987307773274095, -0.71991668517746144,
341 1.0625139713435376, -1.8026944722350828, 1.8903853852837578, 0.2475468598692494, -0.70834534737086463,
342 -0.62816536381195498, 0.37297277354517611, 0.034474071621219016, 0.47274333081191594, -2.3662542473841786,
343 1.8813720711684221, -0.29916037509951754, -0.57712027528715559, 0.27431335749394231, 0.46414272602323764,
344 -0.61367838919068374, -0.48441048748772131, 0.7807315137448595, 0.5057878952931828, -0.33232362411214894,
345 -0.77896199497583019, 0.81373804337730904, -1.9957402084896527, 1.7976405059518497, 1.2302892852847949,
346 0.67699419193473098, -0.51325483082725243, 0.857942641750577, 1.4295866533235857, -0.76819949833721834]
347
348 self.nonmatch = [1.417052501689489, -0.043563190732366364, 1.6036891630756054, 0.66248145163751671, 0.052384028443254405,
349 0.59629061593353161, 0.82947993373378237, 1.115113519426044, 0.67551158941676637, 1.8422107418890203,
350 0.84941135662024392, 1.1996391852657751, 0.94030154845981673, 2.3269103026602771, -0.030603020790364033,
351 1.258988565904706, 2.9637747860603456, 1.8173999730963109, 0.71892491243068934, 0.81740037138666277,
352 1.7601258039962009, 3.1707523951166898, 0.66982205389142613, 1.6097271105344255, 1.189734646321116,
353 0.22708332837080747, 0.84698202914050347, 1.7635878414797439, 2.3830213681725447, 2.5497367162352944,
354 2.635862209152271, -0.21290078686666103, 1.4048627271264558, 0.72941226308255924, 0.85692961327062467,
355 0.97820944194897774, -0.15500601865255503, 0.58435763771835081, 2.5992330339800831, -0.87305656967588074,
356 0.69311232136547551, 1.1302262899327531, 0.71334154902008384, 0.35695476951005345, -0.5187124559973717,
357 2.024435812626129, 0.26963199371831936, -0.46510024343728285, -0.19970133295471326, 2.0355468834785726,
358 0.82313200923780616, 0.30440704254838935, 0.93632925544825862, 1.9575547911114448, 1.2245628328633855,
359 1.0878755116923233, 2.1602536867629665, 0.04070893565830036, 2.3369676117570961, 1.9724448182299648,
360 1.9850705023975075, 1.015833476781514, 2.4223167168334743, 0.061707944792565028, 0.94626273945251693,
361 1.210865335077099, 1.1145727311637936, 2.8519712553054348, 0.93533306721111675, -0.0060786748305075022,
362 1.9322277720024843, 0.65603343285714444, 1.194849545457592, 0.27772775162736463, 0.078490050192145722,
363 -1.4721630242727111, 1.854285772101625, 1.6112593328478453, 1.8560106579121847, 2.540591694748537,
364 1.7772416902829931, -0.20781473501608816, 1.5221307283377219, 0.1579604472392464, -0.30160614059311297,
365 0.80127729857699337, 1.269704867230514, 2.0490141432761941, 2.0273848755661028, 1.0147875805479856,
366 -0.06676206771791926, 2.1662293957716994, 2.1413537986988493, 0.9046180315857989, -1.0291168800124986,
367 1.0301894509766261, 1.1930459134315883, 0.66868219673238327, 0.43346537494032156, -1.0433576612271738]
368
369
371 roc = pv.ROC(self.match,self.nonmatch,is_distance=True)
372
373 result = roc.getFAR(0.1)
374 self.assertAlmostEqual(result.far,0.1)
375 self.assertAlmostEqual(result.tar,0.45)
376
377 result = roc.getFAR(0.01)
378 self.assertAlmostEqual(result.far,0.01)
379 self.assertAlmostEqual(result.tar,0.15)
380
382 roc = pv.ROC(self.match,self.nonmatch,is_distance=True)
383
384 result = roc.getFRR(0.5)
385 self.assertAlmostEqual(result.far,0.18)
386 self.assertAlmostEqual(result.tar,0.50)
387
388 result = roc.getFRR(0.80)
389 self.assertAlmostEqual(result.far,0.04)
390 self.assertAlmostEqual(result.tar,0.20)
391
393 roc = pv.ROC(self.match,self.nonmatch,is_distance=True)
394
395 eer = roc.getEER()
396 self.assertAlmostEqual(eer,0.29)
397
399 roc = pv.ROC(self.match,self.nonmatch,is_distance=True)
400 auc = roc.getAUC()
401 self.assertAlmostEqual(auc,0.7608)
402
403
405 roc = pv.ROC(self.match,self.nonmatch,is_distance=True)
406 plot = pv.Plot()
407 roc.plot(plot,method=ROC_PRECISE_SAMPLED,color="red",width=5)
408 roc.plot(plot,method=ROC_PRECISE_ALL,color='black')
409 plot.lines([[0,1],[1,0]])
410 eer = roc.getEER()
411 plot.point([eer,1-eer])
412 plot.asImage()
413