1 '''
2 This module includes some helper functions for training OpenCV's machine learning algorithms.
3 Created on Mar 25, 2013
4
5 @author: David S. Bolme
6 Oak Ridge National Laboratory
7 '''
8 import pyvision as pv
9 import cv2
10 import numpy as np
11 import tempfile
12 import shutil
13 import os
14
16 '''
17 This class wraps an opencv stats model to support pickling and other
18 pythonic features, etc.
19 '''
20
22 '''
23 Init the wrapper with the model.
24 '''
25 self.model = model
26
28 '''
29 Wrapper for the predict function.
30 '''
31 return self.model.predict(*args,**kwarg)
32
34 '''
35 Wrapper for the predict function.
36 '''
37 return self.model.predict_prob(*args,**kwarg)
38
40 '''
41 Wrapper for the predict function.
42 '''
43 return self.model.predict_all(*args,**kwarg)
44
46 ''' Save the state for pickling '''
47 state = {}
48 state['model_class'] = str(self.model.__class__).split("'")[-2]
49 filename = tempfile.mktemp(suffix='.mod', prefix='tmp')
50 self.model.save(filename)
51 data = open(filename,'rb').read()
52 state['model_data'] = data
53
54 for key,value in self.__dict__.iteritems():
55 if key != 'model':
56 state[key] = value
57
58 return state
59
60
62 ''' Load the state for pickling. '''
63 self.model = eval(state['model_class']+"()")
64 filename = tempfile.mktemp(suffix='.mod', prefix='tmp')
65 open(filename,'wb').write(state['model_data'])
66 self.model.load(filename)
67 os.remove(filename)
68 for key,value in state.iteritems():
69 if key not in ('model_data','model_class'):
70 setattr(self,key,value)
71
72 - def save(self,*args,**kwargs):
73 self.model.save(*args,**kwargs)
74
75
77 '''
78 Auto trains an OpenCV SVM.
79 '''
80 np.float32(data)
81 np.float32(responses)
82 params = dict( kernel_type = cv2.SVM_RBF, svm_type = cv2.SVM_C_SVC )
83 model = cv2.SVM()
84 model.train_auto(data,responses,None,None,params)
85 return StatsModelWrapper(model)
86
87
89 '''
90 Auto trains an OpenCV SVM.
91 '''
92 np.float32(data)
93 np.float32(responses)
94 params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_C_SVC)
95 model = cv2.SVM()
96 model.train_auto(data,responses,None,None,params)
97 return StatsModelWrapper(model)
98
99
101 '''
102 Auto trains an OpenCV SVM.
103 '''
104 np.float32(data)
105 np.float32(responses)
106 params = dict( kernel_type = cv2.SVM_RBF, svm_type = cv2.SVM_EPS_SVR , p=1.0)
107 model = cv2.SVM()
108 model.train_auto(data,responses,None,None,params)
109 return StatsModelWrapper(model)
110
112 '''
113 Auto trains an OpenCV SVM.
114 '''
115 np.float32(data)
116 np.float32(responses)
117 params = dict( kernel_type = cv2.SVM_LINEAR, svm_type = cv2.SVM_EPS_SVR , p=1.0 )
118 model = cv2.SVM()
119 model.train_auto(data,responses,None,None,params)
120 return StatsModelWrapper(model)
121
122
124 '''
125 Auto trains an OpenCV SVM.
126 '''
127 np.float32(data)
128 np.float32(responses)
129 params = dict(max_num_of_trees_in_the_forest=n_trees,termcrit_type=cv2.TERM_CRITERIA_MAX_ITER)
130
131 model = cv2.RTrees()
132 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params)
133 return StatsModelWrapper(model)
134
135
136 -def boost(data,responses,weak_count=100,max_depth=20,boost_type=cv2.BOOST_DISCRETE):
137 '''
138 Auto trains an OpenCV SVM.
139 '''
140 np.float32(data)
141 np.float32(responses)
142 params = dict(boost_type=boost_type,weak_count=weak_count,max_depth=max_depth)
143 model = cv2.Boost()
144 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params)
145 return StatsModelWrapper(model)
146
147 -def gbtrees(data,responses,n_trees=100):
148 '''
149 Auto trains an OpenCV SVM.
150 '''
151 np.float32(data)
152 np.float32(responses)
153 params = dict(max_num_of_trees_in_the_forest=n_trees,termcrit_type=cv2.TERM_CRITERIA_MAX_ITER)
154
155 model = cv2.GBTrees()
156 model.train(data,cv2.CV_ROW_SAMPLE,responses,params=params)
157 return StatsModelWrapper(model)
158
159
160 if __name__ == '__main__':
161
162
163 labels = np.float32((pv.IRIS_LABELS=='versicolor') + 2*(pv.IRIS_LABELS=='virginica'))
164
165 model = svc_rbf(pv.IRIS_DATA[0::2,:],labels[0::2])
166
167 import cPickle as pkl
168 buf = pkl.dumps(model)
169 model = pkl.loads(buf)
170 print "Prediction:",np.float32([model.predict(s) for s in pv.IRIS_DATA[1::2,:]])
171 print "Prediction:",model.predict_all(pv.IRIS_DATA[1::2,:])
172 assert 0
173