Package pyvision :: Package ml :: Module opencv_ml
[hide private]
[frames] | no frames]

Source Code for Module pyvision.ml.opencv_ml

  1  ''' 
  2  This module includes some helper functions for training OpenCV's machine learning algorithms. 
  3  Created on Mar 25, 2013 
  4   
  5  @author: David S. Bolme 
  6  Oak Ridge National Laboratory 
  7  ''' 
  8  import pyvision as pv 
  9  import cv2 
 10  import numpy as np 
 11  import tempfile 
 12  import shutil 
 13  import os 
 14   
15 -class StatsModelWrapper(object):
16 ''' 17 This class wraps an opencv stats model to support pickling and other 18 pythonic features, etc. 19 ''' 20
21 - def __init__(self,model):
22 ''' 23 Init the wrapper with the model. 24 ''' 25 self.model = model
26
27 - def predict(self,*args,**kwarg):
28 ''' 29 Wrapper for the predict function. 30 ''' 31 return self.model.predict(*args,**kwarg)
32
33 - def predict_prob(self,*args,**kwarg):
34 ''' 35 Wrapper for the predict function. 36 ''' 37 return self.model.predict_prob(*args,**kwarg)
38
39 - def predict_all(self,*args,**kwarg):
40 ''' 41 Wrapper for the predict function. 42 ''' 43 return self.model.predict_all(*args,**kwarg)
44
45 - def __getstate__(self):
46 ''' Save the state for pickling ''' 47 state = {} 48 state['model_class'] = str(self.model.__class__).split("'")[-2] 49 filename = tempfile.mktemp(suffix='.mod', prefix='tmp') 50 self.model.save(filename) 51 data = open(filename,'rb').read() 52 state['model_data'] = data 53 54 for key,value in self.__dict__.iteritems(): 55 if key != 'model': 56 state[key] = value 57 58 return state
59 60
61 - def __setstate__(self,state):
62 ''' Load the state for pickling. ''' 63 self.model = eval(state['model_class']+"()") 64 filename = tempfile.mktemp(suffix='.mod', prefix='tmp') 65 open(filename,'wb').write(state['model_data']) 66 self.model.load(filename) 67 os.remove(filename) 68 for key,value in state.iteritems(): 69 if key not in ('model_data','model_class'): 70 setattr(self,key,value)
71
72 - def save(self,*args,**kwargs):
73 self.model.save(*args,**kwargs)
74 75
76 -def svc_rbf(data,responses):
77 ''' 78 Auto trains an OpenCV SVM. 79 ''' 80 np.float32(data) 81 np.float32(responses) 82 params = dict( kernel_type = cv2.SVM_RBF, svm_type = cv2.SVM_C_SVC ) 83 model = cv2.SVM() 84 model.train_auto(data,responses,None,None,params) 85 return StatsModelWrapper(model)
86 87
88 -def svc_linear(data,responses):
89 ''' 90 Auto trains an OpenCV SVM. 91 ''' 92 np.float32(data) 93 np.float32(responses) 94 params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_C_SVC) 95 model = cv2.SVM() 96 model.train_auto(data,responses,None,None,params) 97 return StatsModelWrapper(model)
98 99
100 -def svr_rbf(data,responses):
101 ''' 102 Auto trains an OpenCV SVM. 103 ''' 104 np.float32(data) 105 np.float32(responses) 106 params = dict( kernel_type = cv2.SVM_RBF, svm_type = cv2.SVM_EPS_SVR , p=1.0) 107 model = cv2.SVM() 108 model.train_auto(data,responses,None,None,params) 109 return StatsModelWrapper(model)
110
111 -def svr_linear(data,responses):
112 ''' 113 Auto trains an OpenCV SVM. 114 ''' 115 np.float32(data) 116 np.float32(responses) 117 params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_EPS_SVR , p=1.0 ) 118 model = cv2.SVM() 119 model.train_auto(data,responses,None,None,params) 120 return StatsModelWrapper(model)
121 122
123 -def random_forest(data,responses,n_trees=100):
124 ''' 125 Auto trains an OpenCV SVM. 126 ''' 127 np.float32(data) 128 np.float32(responses) 129 params = dict(max_num_of_trees_in_the_forest=n_trees,termcrit_type=cv2.TERM_CRITERIA_MAX_ITER) 130 #params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_EPS_SVR , p=1.0 ) 131 model = cv2.RTrees() 132 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params) 133 return StatsModelWrapper(model)
134 135
136 -def boost(data,responses,weak_count=100,max_depth=20,boost_type=cv2.BOOST_DISCRETE):
137 ''' 138 Auto trains an OpenCV SVM. 139 ''' 140 np.float32(data) 141 np.float32(responses) 142 params = dict(boost_type=boost_type,weak_count=weak_count,max_depth=max_depth) 143 model = cv2.Boost() 144 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params) 145 return StatsModelWrapper(model)
146
147 -def gbtrees(data,responses,n_trees=100):
148 ''' 149 Auto trains an OpenCV SVM. 150 ''' 151 np.float32(data) 152 np.float32(responses) 153 params = dict(max_num_of_trees_in_the_forest=n_trees,termcrit_type=cv2.TERM_CRITERIA_MAX_ITER) 154 #params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_EPS_SVR , p=1.0 ) 155 model = cv2.GBTrees() 156 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params) 157 return StatsModelWrapper(model)
158 159 160 if __name__ == '__main__': 161 #print "IRIS_DATA:",pv.IRIS_DATA 162 #print pv.IRIS_LABELS 163 labels = np.float32((pv.IRIS_LABELS=='versicolor') + 2*(pv.IRIS_LABELS=='virginica')) 164 165 model = svc_rbf(pv.IRIS_DATA[0::2,:],labels[0::2]) 166 167 import cPickle as pkl 168 buf = pkl.dumps(model) 169 model = pkl.loads(buf) 170 print "Prediction:",np.float32([model.predict(s) for s in pv.IRIS_DATA[1::2,:]]) 171 print "Prediction:",model.predict_all(pv.IRIS_DATA[1::2,:]) 172 assert 0 173