Package pyvision :: Package vector :: Module id3
[hide private]
[frames] | no frames]

Source Code for Module pyvision.vector.id3

  1  # PyVision License 
  2  # 
  3  # Copyright (c) 2006-2008 David S. Bolme 
  4  # All rights reserved. 
  5  # 
  6  # Redistribution and use in source and binary forms, with or without 
  7  # modification, are permitted provided that the following conditions 
  8  # are met: 
  9  #  
 10  # 1. Redistributions of source code must retain the above copyright 
 11  # notice, this list of conditions and the following disclaimer. 
 12  #  
 13  # 2. Redistributions in binary form must reproduce the above copyright 
 14  # notice, this list of conditions and the following disclaimer in the 
 15  # documentation and/or other materials provided with the distribution. 
 16  #  
 17  # 3. Neither name of copyright holders nor the names of its contributors 
 18  # may be used to endorse or promote products derived from this software 
 19  # without specific prior written permission. 
 20  #  
 21  #  
 22  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 23  # ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
 24  # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
 25  # A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR 
 26  # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 
 27  # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
 28  # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 
 29  # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 
 30  # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 
 31  # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 
 32  # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 33   
 34  # TODO: Needs some work. 
 35  import math 
 36   
 37   
38 -def lg(x):
39 return math.log(x)/math.log(2)
40
41 -def entropy(labels):
42 label_set = set(labels) 43 44 # setup variables needed for statistics 45 sums = {} 46 count = 0.0 47 for each in label_set: 48 sums[each] = 0.0 49 50 for each in labels: 51 sums[each] += 1.0 52 count += 1.0 53 54 ent = 0.0 55 for each in sums.values(): 56 p_i = each/count 57 ent -= p_i * lg (p_i) 58 return ent
59 60 61
62 -def maxValue(labels):
63 label_set = set(labels) 64 65 # setup variables needed for statistics 66 sums = {} 67 count = 0.0 68 for each in label_set: 69 sums[each] = 0.0 70 71 for each in labels: 72 sums[each] += 1.0 73 count += 1.0 74 75 highVal = 0.0 76 highLab = labels[0] 77 for key,value in sums.iteritems(): 78 if value > highVal: 79 highVal = value 80 highLab = key 81 return highLab
82
83 -def getLabels(features):
84 labels = [ each[0] for each in features ] 85 return labels
86 87
88 -def splitFeatures(feature,features):
89 split = {} 90 for label,values in features: 91 key = values[feature] 92 if not split.has_key(key): 93 split[key] = [] 94 split[key].append([label,values]) 95 96 return split
97
98 -class ID3:
99
100 - def __init__(self):
101 102 self.training_data = [] 103 self.testing_data = [] 104 self.labels = set() 105 self.top = None
106
107 - def addTraining(self,label,feature):
108 '''Training Data''' 109 self.training_data.append((label,feature)) 110 self.labels |= self.labels | set([label])
111
112 - def addTesting(self,label,feature):
113 '''Training Data''' 114 self.testing_data.append((label,feature))
115 #self.labels |= self.labels | set([label]) 116
117 - def train(self):
118 '''Train the classifier on the current data''' 119 self.top = Node(self.training_data)
120
121 - def classify(self,feature):
122 '''Classify the feature vector''' 123 return self.top.classify(feature)
124
125 - def test(self, data = None):
126 if data == None: 127 data = self.testing_data 128 #_logger.info("Running test.") 129 correct = 0 130 wrong = 0 131 for label,feature in data: 132 c,_ = self.classify(feature) 133 if c == label: 134 correct += 1 135 else: 136 wrong += 1 137 print "Test: %d/%d"%(correct,correct+wrong) 138 return float(correct)/float(correct+wrong)
139 140 141
142 -class Node:
143 - def __init__(self,features):
144 145 self.cutoff = 2 146 self.min_entropy = 0.2 147 148 self.feature = None 149 self.entropy = None 150 self.label = None # 151 self.children = None 152 153 self.train(features)
154
155 - def train(self,features):
156 labels = getLabels(features) 157 print "Ent:",entropy(labels) 158 print "Max:",maxValue(labels) 159 160 self.label = maxValue(labels) 161 self.entropy = entropy(labels) 162 163 if len(features) < self.cutoff or self.entropy < self.min_entropy: 164 return 165 166 no_feature = len(features[0][1]) 167 168 max_gain = 0.0 169 max_feature = 0 170 max_children = {} 171 for i in range(no_feature): 172 gain = self.entropy 173 s = splitFeatures(i,features) 174 for _,vals in s.iteritems(): 175 scale = float(len(vals))/float(len(features)) 176 e = entropy(getLabels(vals)) 177 #print "Split %3d:"%i,key,len(vals), e 178 gain -= scale*e 179 if max_gain < gain: 180 max_gain = gain 181 max_feature = i 182 max_children = s 183 print "Gain: ",max_gain,max_feature 184 self.feature = max_feature 185 self.gain = max_gain 186 187 self.children = {} 188 for label,features in max_children.iteritems(): 189 self.children[label] = Node(features)
190 191 #for i in range(features):
192 - def classify(self,feature):
193 '''Classify the feature vector''' 194 195 if self.feature: 196 val = feature[self.feature] 197 if self.children.has_key(val): 198 return self.children[val].classify(feature) 199 return self.label,None
200 201 202 203
204 -def toBits(val,bits = 4):
205 result = [] 206 for _ in range(bits): 207 result.append(val&1) 208 val = val >> 1 209 210 result.reverse() 211 212 return result
213