1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 import tempfile
35 import random
36 import csv
37 import pickle
38 import os
39 import unittest
40 import numpy as np
41 try:
42 import svm
43 except:
44 import libsvm.svm as svm
45
46 import pyvision as pv
47 from pyvision.vector.VectorClassifier import VectorClassifier, TYPE_TWOCLASS, TYPE_MULTICLASS, TYPE_REGRESSION
48 from pyvision.analysis.Table import Table
49
50
51 TYPE_C_SVC = svm.C_SVC
52 TYPE_NU_SVC = svm.NU_SVC
53 TYPE_EPSILON_SVR = svm.EPSILON_SVR
54 TYPE_NU_SVR = svm.NU_SVR
55
56 TYPE_SVC=TYPE_C_SVC
57 TYPE_SVR=TYPE_NU_SVR
58
59
60 KERNEL_LINEAR='LINEAR'
61 KERNEL_RBF='RBF'
62
63
64 -class SVM(VectorClassifier):
65 - def __init__(self, svm_type=TYPE_SVC, kernel=KERNEL_RBF, svr_epsilon=0.1, nu = 0.5, random_seed=None, validation_size=0.33,**kwargs):
66 '''
67 Create an svm.
68
69 Make sure you choose "classifacition" or "regression". Other parameters control
70 features of the SVM.
71
72 also passes keyword args to VectorClassifier
73 '''
74
75
76 self.svm = None
77 self.svm_type = svm_type
78 self.kernel = kernel
79 self.epsilon=svr_epsilon
80 self.nu = nu
81 self.random_seed = random_seed
82 self.validation_size = validation_size
83
84 if svm_type in (TYPE_C_SVC,TYPE_NU_SVC):
85 VectorClassifier.__init__(self,TYPE_MULTICLASS,**kwargs)
86 else:
87 VectorClassifier.__init__(self,TYPE_REGRESSION,**kwargs)
88
90 '''This function is neccessary for pickling'''
91
92 state = {}
93 for key,value in self.__dict__.iteritems():
94 if key == 'svm':
95 filename = tempfile.mktemp()
96 self.svm.save(filename)
97 data_buffer = open(filename).read()
98 os.remove(filename)
99 state[key] = data_buffer
100 continue
101
102 state[key] = value
103
104 return state
105
107 '''This function is neccessary for pickling'''
108
109 for key,value in state.iteritems():
110 if key == 'svm':
111 filename = tempfile.mktemp()
112 open(filename,'w').write(value)
113 self.svm = svm.svm_model(filename)
114 os.remove(filename)
115 continue
116
117 self.__dict__[key] = value
118
119
120 - def trainClassifer(self,labels,vectors,ilog=None,verbose=False,callback=None, C_range = 2.0**np.arange(-5,16,1), G_range = 2.0**np.arange(-15,4,1)):
121 '''
122 Do not call this function instead call train.
123 '''
124 self.training_size = len(labels)
125
126 if verbose: print
127 if verbose: print "Training the SVM"
128
129 new_vectors = []
130 for vec in vectors:
131 new_vec = []
132 for value in vec:
133 new_vec.append(float(value))
134 new_vectors.append(new_vec)
135 vectors = new_vectors
136
137
138
139 if self.svm_type in (TYPE_C_SVC,TYPE_NU_SVC) and self.kernel == KERNEL_RBF:
140 if verbose: print "TRAINING SVC RBF"
141 self.train_SVC_RBF(labels,vectors,verbose,C_range,G_range,callback=callback)
142 elif self.svm_type in (TYPE_C_SVC,TYPE_NU_SVC) and self.kernel == KERNEL_LINEAR:
143 if verbose: print "TRAINING SVC Linear"
144 self.train_SVC_Linear(labels,vectors,verbose,C_range,callback=callback)
145 elif self.svm_type in (TYPE_NU_SVR,TYPE_EPSILON_SVR) and self.kernel == KERNEL_RBF:
146 if verbose: print "TRAINING SVC RBF"
147 self.train_SVR_RBF(labels,vectors,verbose,C_range,G_range,callback=callback)
148 elif self.svm_type in (TYPE_NU_SVR,TYPE_EPSILON_SVR) and self.kernel == KERNEL_LINEAR:
149 if verbose: print "TRAINING SVC Linear"
150 self.train_SVR_Linear(labels,vectors,verbose,C_range,callback=callback)
151 else:
152 raise NotImplementedError("Unknown SVM type or kernel")
153
154
155 - def train_SVC_RBF(self,labels,vectors, verbose, C_range, G_range, callback=None):
156 '''Private use only'''
157
158 data = []
159 for i in range(len(labels)):
160 data.append([labels[i],vectors[i]])
161
162
163 rng = random.Random()
164 if self.random_seed != None:
165 rng.seed(self.random_seed)
166 rng.shuffle(data)
167
168
169 if type(self.validation_size) == float and self.validation_size > 0.0 and self.validation_size < 1.0:
170 training_cutoff = int(len(data)*(1.0-self.validation_size))
171 elif type(self.validation_size) == int and self.validation_size < len(labels):
172 training_cutoff = len(labels)-self.validation_size
173 else:
174 raise NotImplementedError("Cannot determine validation set from %s"%self.validation_size)
175
176 if verbose: print "Training Cutoff:",len(labels),training_cutoff
177 training_data = data[:training_cutoff]
178 validation_data = data[training_cutoff:]
179
180 tmp_labels = []
181 tmp_vectors = []
182 for each in training_data:
183 tmp_labels.append(each[0])
184 tmp_vectors.append(each[1])
185
186 prob = svm.svm_problem(tmp_labels,tmp_vectors)
187
188 training_info = []
189 training_svm = []
190 training_table = Table()
191 self.training_table = training_table
192 i=0
193 for C in C_range:
194 for G in G_range:
195
196 param = svm.svm_parameter(svm_type=self.svm_type,kernel_type = svm.RBF, C = C, gamma=G,p=self.epsilon,nu=self.nu)
197
198 test_svm = svm.svm_model(prob, param)
199
200 successes = 0.0
201 total = len(validation_data)
202 for label,vector in validation_data:
203 pred = test_svm.predict(vector)
204 if pred == label:
205 successes += 1
206 rate = successes/total
207
208 if verbose: print C,G,rate
209 training_svm.append(test_svm)
210 training_info.append([C,G,rate])
211
212 training_table.setElement(i,'C',C)
213 training_table.setElement(i,'G',G)
214 training_table.setElement(i,'rate',rate)
215 i+=1
216
217 if callback != None:
218 callback(int(100*float(i)/(len(C_range)*len(G_range))))
219
220 if verbose: print
221 if verbose: print "------------------------------"
222 if verbose: print " Tuning Information:"
223 if verbose: print " C gamma rate"
224 if verbose: print "------------------------------"
225 best = training_info[0]
226 best_svm = training_svm[0]
227 for i in range(len(training_info)):
228 each = training_info[i]
229 if verbose: print " %8.3e %8.3e %0.8f"%(each[0],each[1],each[-1])
230 if best[-1] < each[-1]:
231 best = each
232 best_svm = training_svm[i]
233 if verbose: print "------------------------------"
234 if verbose: print
235 if verbose: print "------------------------------"
236 if verbose: print " Best Tuning:"
237 if verbose: print " C gamma rate"
238 if verbose: print "------------------------------"
239 if verbose: print " %8.3e %8.3e %0.8f"%(best[0],best[1],best[-1])
240 if verbose: print "------------------------------"
241 if verbose: print
242 self.training_info = training_info
243 self.C = best[0]
244 self.gamma = best[1]
245 self.tuning_rate = best[2]
246
247 self.svm = best_svm
248
249
250 - def train_SVR_RBF(self,labels,vectors,verbose, C_range, G_range, callback=None):
251 '''Private use only'''
252
253 data = []
254 for i in range(len(labels)):
255 data.append([labels[i],vectors[i]])
256
257
258 rng = random.Random()
259 if self.random_seed != None:
260 rng.seed(self.random_seed)
261 rng.shuffle(data)
262
263
264 if type(self.validation_size) == float and self.validation_size > 0.0 and self.validation_size < 1.0:
265 training_cutoff = int(len(data)*(1.0-self.validation_size))
266 elif type(self.validation_size) == int and self.validation_size < len(labels):
267 training_cutoff = len(labels)-self.validation_size
268 else:
269 raise NotImplementedError("Cannot determine validation set from %s"%self.validation_size)
270
271 if verbose: print "Training Cutoff:",len(labels),training_cutoff
272 training_data = data[:training_cutoff]
273 validation_data = data[training_cutoff:]
274
275 tmp_labels = []
276 tmp_vectors = []
277 for each in training_data:
278 tmp_labels.append(each[0])
279 tmp_vectors.append(each[1])
280
281 prob = svm.svm_problem(tmp_labels,tmp_vectors)
282
283 training_info = []
284 training_svm = []
285 training_table=Table()
286 self.training_table = training_table
287 i = 0
288 for C in C_range:
289 for G in G_range:
290 if verbose: print "Testing: %10.5f %10.5f"%(C,G),
291
292 param = svm.svm_parameter(svm_type=self.svm_type,kernel_type = svm.RBF, C = C, gamma=G,p=self.epsilon,nu=self.nu)
293
294 test_svm = svm.svm_model(prob, param)
295
296 mse = 0.0
297 total = len(validation_data)
298 for label,vector in validation_data:
299 pred = test_svm.predict(vector)
300 error = label - pred
301 mse += error*error
302 mse = mse/total
303
304 if verbose: print "%15.8e"%mse
305 training_svm.append(test_svm)
306 training_info.append([C,G,mse])
307 training_table.setElement(i,'C',C)
308 training_table.setElement(i,'G',G)
309 training_table.setElement(i,'mse',mse)
310 i+=1
311
312 if callback != None:
313 callback(int(100*float(i)/(len(C_range)*len(G_range))))
314
315 if verbose: print
316 if verbose: print "------------------------------"
317 if verbose: print " Tuning Information:"
318 if verbose: print " C gamma error"
319 if verbose: print "------------------------------"
320 best = training_info[0]
321 best_svm = training_svm[0]
322 for i in range(len(training_info)):
323 each = training_info[i]
324 if verbose: print " %8.3e %8.3e %0.8f"%(each[0],each[1],each[-1])
325 if best[-1] > each[-1]:
326 best = each
327 best_svm = training_svm[i]
328 if verbose: print "------------------------------"
329 if verbose: print
330 if verbose: print "------------------------------"
331 if verbose: print " Best Tuning:"
332 if verbose: print " C gamma error"
333 if verbose: print "------------------------------"
334 if verbose: print " %8.3e %8.3e %0.8f"%(best[0],best[1],best[-1])
335 if verbose: print "------------------------------"
336 if verbose: print
337 self.training_info = training_info
338 self.C = best[0]
339 self.gamma = best[1]
340 self.error = best[2]
341
342 self.svm = best_svm
343
344
346 '''Private use only.'''
347
348 data = []
349 for i in range(len(labels)):
350 data.append([labels[i],vectors[i]])
351
352
353 rng = random.Random()
354 if self.random_seed != None:
355 rng.seed(self.random_seed)
356 rng.shuffle(data)
357
358
359 if type(self.validation_size) == float and self.validation_size > 0.0 and self.validation_size < 1.0:
360 training_cutoff = int(len(data)*(1.0-self.validation_size))
361 elif type(self.validation_size) == int and self.validation_size < len(labels):
362 training_cutoff = len(labels)-self.validation_size
363 else:
364 raise NotImplementedError("Cannot determine validation set from %s"%self.validation_size)
365
366 if verbose: print "Training Cutoff:",len(labels),training_cutoff
367 training_data = data[:training_cutoff]
368 validation_data = data[training_cutoff:]
369
370 tmp_labels = []
371 tmp_vectors = []
372 for each in training_data:
373 tmp_labels.append(each[0])
374 tmp_vectors.append(each[1])
375
376 prob = svm.svm_problem(tmp_labels,tmp_vectors)
377
378 training_info = []
379 training_svm = []
380 training_table = Table()
381 self.training_table = training_table
382 i=0
383 for C in C_range:
384
385 param = svm.svm_parameter(svm_type=self.svm_type,kernel_type = svm.LINEAR, C = C, p=self.epsilon,nu=self.nu)
386
387 test_svm = svm.svm_model(prob, param)
388
389 successes = 0.0
390 total = len(validation_data)
391 for label,vector in validation_data:
392 pred = test_svm.predict(vector)
393 if pred == label:
394 successes += 1.0
395 rate = successes/total
396
397 training_svm.append(test_svm)
398 training_info.append([C,rate])
399 training_table.setElement(i,'C',C)
400 training_table.setElement(i,'rate',rate)
401 i+=1
402
403 if callback != None:
404 callback(int(100*float(i)/len(C_range)))
405
406 if verbose: print
407 if verbose: print "------------------------------"
408 if verbose: print " Tuning Information:"
409 if verbose: print " C error"
410 if verbose: print "------------------------------"
411 best = training_info[0]
412 best_svm = training_svm[0]
413 for i in range(len(training_info)):
414 each = training_info[i]
415 if verbose: print " %8.3e %0.8f"%(each[0],each[1])
416 if best[-1] < each[-1]:
417 best = each
418 best_svm = training_svm[i]
419 if verbose: print "------------------------------"
420 if verbose: print
421 if verbose: print "------------------------------"
422 if verbose: print " Best Tuning:"
423 if verbose: print " C error"
424 if verbose: print "------------------------------"
425 if verbose: print " %8.3e %0.8f"%(best[0],best[1])
426 if verbose: print "------------------------------"
427 if verbose: print
428 self.training_info = training_info
429 self.C = best[0]
430 self.tuned_rate = best[1]
431
432 self.svm = best_svm
433
434
436 '''Private use only'''
437
438 data = []
439 for i in range(len(labels)):
440 data.append([labels[i],vectors[i]])
441
442
443 rng = random.Random()
444 if self.random_seed != None:
445 rng.seed(self.random_seed)
446 rng.shuffle(data)
447
448
449 if type(self.validation_size) == float and self.validation_size > 0.0 and self.validation_size < 1.0:
450 training_cutoff = int(len(data)*(1.0-self.validation_size))
451 elif type(self.validation_size) == int and self.validation_size < len(labels):
452 training_cutoff = len(labels)-self.validation_size
453 else:
454 raise NotImplementedError("Cannot determine validation set from %s"%self.validation_size)
455
456 if verbose: print "Training Cutoff:",len(labels),training_cutoff
457 training_data = data[:training_cutoff]
458 validation_data = data[training_cutoff:]
459
460 tmp_labels = []
461 tmp_vectors = []
462 for each in training_data:
463 tmp_labels.append(each[0])
464 tmp_vectors.append(each[1])
465
466 prob = svm.svm_problem(tmp_labels,tmp_vectors)
467
468 training_info = []
469 training_svm = []
470 training_table = Table()
471 self.training_table = training_table
472 i=0
473 for C in C_range:
474
475 param = svm.svm_parameter(svm_type=self.svm_type,kernel_type = svm.LINEAR, C = C, p=self.epsilon,nu=self.nu)
476
477 test_svm = svm.svm_model(prob, param)
478
479 mse = 0.0
480 total = len(validation_data)
481 for label,vector in validation_data:
482 pred = test_svm.predict(vector)
483 error = label - pred
484 mse += error*error
485 mse = mse/total
486
487 training_svm.append(test_svm)
488 training_info.append([C,mse])
489 training_table.setElement(i,'C',C)
490 training_table.setElement(i,'mse',mse)
491 i+=1
492
493 if callback != None:
494 callback(int(100*float(i)/len(C_range)))
495
496 if verbose: print
497 if verbose: print "------------------------------"
498 if verbose: print " Tuning Information:"
499 if verbose: print " C error"
500 if verbose: print "------------------------------"
501 best = training_info[0]
502 best_svm = training_svm[0]
503 for i in range(len(training_info)):
504 each = training_info[i]
505 if verbose: print " %8.3e %0.8f"%(each[0],each[1])
506 if best[-1] > each[-1]:
507 best = each
508 best_svm = training_svm[i]
509 if verbose: print "------------------------------"
510 if verbose: print
511 if verbose: print "------------------------------"
512 if verbose: print " Best Tuning:"
513 if verbose: print " C error"
514 if verbose: print "------------------------------"
515 if verbose: print " %8.3e %0.8f"%(best[0],best[1])
516 if verbose: print "------------------------------"
517 if verbose: print
518 self.training_info = training_info
519 self.C = best[0]
520 self.error = best[1]
521
522 self.svm = best_svm
523
524
526 '''
527 Please call predict instead.
528 '''
529 assert self.svm != None
530 new_vec = []
531 for value in data:
532 new_vec.append(float(value))
533 return self.svm.predict(new_vec)
534
536 assert self.svm != None
537 data = self.normalizeVector(data)
538 new_vec = []
539 for value in data:
540 new_vec.append(float(value))
541 prd, prb = self.svm.predict_probability(new_vec)
542
543
544
545
546 return prd,prb
547
549 assert self.svm != None
550 data = self.normalizeVector(data)
551 new_vec = []
552 for value in data:
553 new_vec.append(float(value))
554 d = self.svm.predict_values(new_vec)
555
556 result = {}
557 for key,value in d.iteritems():
558 new_key = tuple()
559 for each in key:
560 if self.type == TYPE_TWOCLASS or self.type == TYPE_MULTICLASS:
561 each = self.invertClass(each)
562
563
564 new_key = new_key + (each,)
565 result[new_key] = value
566
567
568
569
570
571 return result
572
573
575 ''' Unit tests for SVM '''
576
579
580
582
583 xor = SVM(random_seed=0)
584 for _ in xrange(20):
585 xor.addTraining(0,[0,0])
586 xor.addTraining(0,[1,1])
587 xor.addTraining(1,[0,1])
588 xor.addTraining(1,[1,0])
589
590 xor.train()
591
592 self.assertEqual(xor.predict([0,0]),0)
593 self.assertEqual(xor.predict([1,1]),0)
594 self.assertEqual(xor.predict([1,0]),1)
595 self.assertEqual(xor.predict([0,1]),1)
596
598
599 xor = SVM(random_seed=0)
600 for _ in xrange(20):
601 xor.addTraining(0,[0,0])
602 xor.addTraining(0,[1,1])
603 xor.addTraining(1,[0,1])
604 xor.addTraining(1,[1,0])
605
606 xor.train()
607
608 tmp = pickle.dumps(xor)
609 xor = pickle.loads(tmp)
610
611 self.assertEqual(xor.predict([0,0]),0)
612 self.assertEqual(xor.predict([1,1]),0)
613 self.assertEqual(xor.predict([1,0]),1)
614 self.assertEqual(xor.predict([0,1]),1)
615
617
618 xor = SVM(kernel=KERNEL_LINEAR,random_seed=1)
619 for _ in xrange(20):
620 xor.addTraining(0,[0,0])
621 xor.addTraining(0,[1,1])
622 xor.addTraining(1,[0,1])
623 xor.addTraining(1,[1,0])
624
625 xor.train()
626
627
628
629 self.assertEqual(xor.predict([0,0]),0)
630 self.assertEqual(xor.predict([1,1]),0)
631 self.assertEqual(xor.predict([1,0]),1)
632 self.assertEqual(xor.predict([0,1]),0)
633
635 rega = SVM(svm_type=TYPE_EPSILON_SVR,kernel=KERNEL_RBF,random_seed=0)
636 filename = os.path.join(pv.__path__[0],'data','synthetic','regression.dat')
637 reg_file = open(filename,'r')
638 labels = []
639 vectors = []
640 for line in reg_file:
641 datapoint = line.split()
642 labels.append(float(datapoint[0]))
643 vectors.append([float(datapoint[3]),float(datapoint[4]),float(datapoint[5])])
644
645 for i in range(50):
646 rega.addTraining(labels[i],vectors[i])
647 rega.train()
648
649 mse = 0.0
650 for i in range(50,len(labels)):
651 p = rega.predict(vectors[i])
652 e = p - labels[i]
653 mse += e*e
654 mse = mse/(len(labels)-50)
655
656
657 self.assertAlmostEqual(mse,0.47066712325873877,places=4)
658
660
661 rega = SVM(svm_type=TYPE_EPSILON_SVR,kernel=KERNEL_LINEAR,random_seed=0)
662 filename = os.path.join(pv.__path__[0],'data','synthetic','regression.dat')
663 reg_file = open(filename,'r')
664 labels = []
665 vectors = []
666 for line in reg_file:
667 datapoint = line.split()
668 labels.append(float(datapoint[0]))
669 vectors.append([float(datapoint[3]),float(datapoint[4]),float(datapoint[5])])
670
671 for i in range(50):
672 rega.addTraining(labels[i],vectors[i])
673 rega.train()
674
675 mse = 0.0
676 for i in range(50,len(labels)):
677 p = rega.predict(vectors[i])
678 e = p - labels[i]
679
680 mse += e*e
681 mse = mse/(len(labels)-50)
682 self.assertAlmostEqual(mse,0.52674701087510767,places=4)
683
684
686
687
688 gender = SVM(svm_type=TYPE_SVC,random_seed=0)
689 filename = os.path.join(pv.__path__[0],'data','csuScrapShots','gender.txt')
690 f = open(filename,'r')
691 labels = []
692 vectors = []
693 for line in f:
694 im_name, class_name = line.split()
695 im_name = os.path.join(pv.__path__[0],'data','csuScrapShots',im_name)
696 im = pv.Image(im_name)
697 im = pv.Image(im.asPIL().resize((200,200)))
698 labels.append(class_name)
699 vectors.append(im)
700
701 for i in range(100):
702 gender.addTraining(labels[i],vectors[i])
703
704 gender.train()
705
706 sucesses = 0.0
707 total = 0.0
708 for i in range(100,len(labels)):
709 guess = gender.predict(vectors[i])
710
711 if guess == labels[i]:
712 sucesses += 1
713 total += 1
714
715 self.assertAlmostEqual(sucesses/total,0.86301369863013699,places=4)
716
717
719 filename = os.path.join(pv.__path__[0],'data','ml','breast-cancer-wisconsin.data')
720 reader = csv.reader(open(filename, "rb"))
721 breast_cancer_labels = []
722 breast_cancer_data = []
723 for row in reader:
724 data = []
725 for item in row[1:-2]:
726 if item == '?':
727 data.append(0)
728 else:
729 data.append(int(item))
730 breast_cancer_labels.append(int(row[-1]))
731 breast_cancer_data.append(data)
732
733 cancer = SVM(svm_type=TYPE_SVC,random_seed=0)
734 for i in range(300):
735 cancer.addTraining(breast_cancer_labels[i],breast_cancer_data[i])
736 cancer.train()
737 success = 0.0
738 total = 0.0
739 for i in range(300,len(breast_cancer_labels)):
740 label = cancer.predict(breast_cancer_data[i])
741 if breast_cancer_labels[i] == label:
742 success += 1
743
744
745 total += 1
746
747
748
749
750 self.assertAlmostEqual(success/total, 0.97744360902255634,places=4)
751
752
754 filename = os.path.join(pv.__path__[0],'data','ml','breast-cancer-wisconsin.data')
755 reader = csv.reader(open(filename, "rb"))
756 breast_cancer_labels = []
757 breast_cancer_data = []
758 for row in reader:
759 data = []
760 for item in row[1:-2]:
761 if item == '?':
762 data.append(0)
763 else:
764 data.append(int(item))
765 breast_cancer_labels.append(int(row[-1]))
766 breast_cancer_data.append(data)
767
768 cancer = SVM(svm_type=TYPE_SVC,kernel=KERNEL_LINEAR,random_seed=0)
769 for i in range(300):
770 cancer.addTraining(breast_cancer_labels[i],breast_cancer_data[i])
771 cancer.train()
772 success = 0.0
773 total = 0.0
774 for i in range(300,len(breast_cancer_labels)):
775 label = cancer.predict(breast_cancer_data[i])
776 if breast_cancer_labels[i] == label:
777 success += 1
778
779
780 total += 1
781
782
783
784 self.assertAlmostEqual(success/total, 0.97994987468671679,places=4)
785
786
788 filename = os.path.join(pv.__path__[0],'data','ml','breast-cancer-wisconsin.data')
789 reader = csv.reader(open(filename, "rb"))
790 breast_cancer_labels = []
791 breast_cancer_data = []
792 for row in reader:
793 data = []
794 for item in row[1:-2]:
795 if item == '?':
796 data.append(0)
797 else:
798 data.append(int(item))
799 breast_cancer_labels.append(int(row[-1]))
800 breast_cancer_data.append(data)
801
802 cancer = SVM(svm_type=TYPE_NU_SVC,random_seed=0)
803 for i in range(300):
804 cancer.addTraining(breast_cancer_labels[i],breast_cancer_data[i])
805 cancer.train()
806 success = 0.0
807 total = 0.0
808
809
810 for i in range(300,len(breast_cancer_labels)):
811 label = cancer.predict(breast_cancer_data[i])
812 if breast_cancer_labels[i] == label:
813 success += 1
814
815
816 total += 1
817
818
819
820 self.assertAlmostEqual(success/total, 0.97994987468671679,places=4)
821