Package pyvision :: Package vector :: Module RANSAC
[hide private]
[frames] | no frames]

Source Code for Module pyvision.vector.RANSAC

  1  #from numpy import np.dot, arange, abs 
  2  #from numpy.linalg import lstsq 
  3  import numpy as np 
  4  import random 
  5   
  6   
7 -def computeErrorAndCount(A,b,x,group,tol):
8 error = np.abs(b - np.dot(A,x)) 9 tot_error = 0.0 10 count = 0 11 inliers = np.zeros(b.shape,dtype=np.bool) 12 for j in range(0,len(error),group): 13 e = 0 14 for l in range(group): 15 e += error[j+l]**2 16 e = np.sqrt(e) 17 if e < tol: 18 count += 1 19 tot_error += e 20 for l in range(group): 21 inliers[j+l] = True 22 23 if count == 0: 24 return tot_error,count,inliers 25 26 return tot_error/count,count,inliers
27 28 ## 29 # Uses the RANSAC algorithm to solve Ax=b 30 # 31 # M.A. Fischler and R.C. Bolles. Random sample consensus: A paradigm for 32 # model fitting with applications to image analysis and automated cartography. 33 # Communication of Association for Computing Machinery, 24(6): 381--395, 1981.
34 -def RANSAC(A,b,count=None,tol=1.0,niter=None,group=1,verbose=False,full_output=False):
35 #n = len(y.flatten()) 36 n,k = A.shape 37 group = int(group) 38 assert group > 0 39 assert n % group == 0 40 assert n >= k 41 tmp = np.arange(n/group) 42 43 if niter == None: 44 niter = n/group 45 46 bestx = np.linalg.lstsq(A,b)[0] 47 besterror,bestcount,bestinliers = computeErrorAndCount(A,b,bestx,group,tol) 48 if verbose: print "New Best (LS):",bestcount,besterror,float(bestcount*group)/n 49 50 if bestcount == n/group: 51 if full_output: 52 return bestx,bestcount,besterror,bestinliers 53 54 return bestx 55 56 #bestcount = 0 57 #besterror = 0.0 58 for _ in xrange(niter): 59 sample = random.sample(tmp,k/group) 60 61 new_sample = [] 62 for j in sample: 63 for l in range(group): 64 new_sample.append(group*j+l) 65 66 sample = new_sample 67 68 ty = b[sample,:] 69 tX = A[sample,:] 70 71 #print tX.shape,ty.shape 72 73 try: 74 x = np.linalg.lstsq(tX,ty)[0] 75 except: 76 continue 77 78 error,count,inliers = computeErrorAndCount(A,b,x,group,tol) 79 80 if bestcount < count or (bestcount == count and error < besterror): 81 bestcount = count 82 besterror = error 83 bestx = x 84 bestinliers = inliers 85 if verbose: print " New Best:",bestcount,besterror,float(bestcount*group)/n 86 87 #print x, count, bestcount 88 89 x = bestx 90 91 #refine the estimate 92 #error,count,inliers = computeErrorAndCount(A,b,bestx,group,tol) 93 inliers = bestinliers 94 for _ in xrange(10): 95 ty = b[inliers.flatten(),:] 96 tX = A[inliers.flatten(),:] 97 try: 98 x = np.linalg.lstsq(tX,ty)[0] 99 except: 100 continue 101 error,count,inliers = computeErrorAndCount(A,b,x,group,tol) 102 #if verbose: print " ",error,count,x 103 if bestcount < count or (bestcount == count and error < besterror): 104 bestcount = count 105 besterror = error 106 bestx = x 107 bestinliers = inliers 108 if verbose: print "Improved Best:",bestcount,besterror,float(bestcount*group)/n 109 110 #new_inliers = nonzero(abs(b - np.dot(A,x)) < tol)[0] 111 #if list(new_inliers) == list(inliers): 112 # break 113 #inliers = new_inliers 114 115 if full_output: 116 return bestx,bestcount,besterror,bestinliers 117 118 return bestx
119 120
121 -def _quantile(errors,quantile):
122 123 errors = errors.copy() 124 i = int(quantile*errors.shape[0]) 125 errors.sort() 126 return errors[i]
127 128 129 ## 130 # Uses the LMeDs algorithm to solve Ax=b 131 # 132 # M.A. Fischler and R.C. Bolles. Random sample consensus: A paradigm for 133 # model fitting with applications to image analysis and automated cartography. 134 # Communication of Association for Computing Machinery, 24(6): 381--395, 1981.
135 -def LMeDs(A,b,quantile=0.75,N = None,verbose=True):
136 #n = len(y.flatten()) 137 n,k = A.shape 138 tmp = np.arange(n) 139 140 best_sample = tmp 141 x = bestx = np.linalg.lstsq(A,b)[0] 142 best_error = _quantile(np.abs(b - np.dot(A,x)),quantile) 143 #print "LMeDs Error:",best_error 144 if N == None: 145 N = n 146 147 for i in range(N): 148 sample = random.sample(tmp,k) 149 150 ty = b[sample,:] 151 tX = A[sample,:] 152 153 try: 154 x = np.linalg.lstsq(tX,ty)[0] 155 except: 156 continue 157 158 med_error = _quantile(np.abs(b - np.dot(A,x)),quantile) 159 160 if med_error < best_error: 161 #print " Error:",best_error 162 best_sample = sample 163 best_error = med_error 164 bestx = x 165 166 #print x, count, bestcount 167 168 x = bestx 169 170 #refine the estimate using local search 171 #print " Local Search" 172 sample = np.zeros([n],dtype=np.bool) 173 sample[best_sample] = True 174 best_sample = sample 175 random.shuffle(tmp) 176 177 keep_going = True 178 while keep_going: 179 #print " Iter" 180 keep_going = False 181 for i in tmp: 182 sample = best_sample.copy() 183 sample[i] = not sample[i] 184 185 ty = b[sample,:] 186 tX = A[sample,:] 187 188 try: 189 x = np.linalg.lstsq(tX,ty)[0] 190 except: 191 continue 192 193 med_error = _quantile(np.abs(b - np.dot(A,x)),quantile) 194 195 if med_error < best_error or (med_error == best_error and best_sample.sum() < sample.sum()): 196 #print " Error:",best_error 197 keep_going = True 198 best_sample = sample 199 best_error = med_error 200 bestx = x 201 202 #inliers = nonzero(np.abs(b - np.dot(A,x)) < tol)[0] 203 #for i in range(10): 204 # ty = b[inliers,:] 205 # tX = A[inliers,:] 206 # x = np.linalg.lstsq(tX,ty)[0] 207 # new_inliers = nonzero(np.abs(b - np.dot(A,x)) < tol)[0] 208 # if list(new_inliers) == list(inliers): 209 # break 210 # inliers = new_inliers 211 212 return x
213 214 215 if __name__ == '__main__': 216 A = [] 217 b = [] 218 #print dir(random) 219 220 for x in range(40): 221 b.append( 10*x + 5 + random.normalvariate(0.0,2.0)) 222 A.append([x,1]) 223 224 A = np.array(A) 225 b = np.array(b) 226 b[0] = -20 227 228 print np.linalg.lstsq(A,b)[0] 229 print RANSAC(A,b,tol=6.0) 230 231 232 A = [] 233 b = [] 234 #print dir(random) 235 236 for y in range(-10,10): 237 for x in range(-10,10): 238 b.append( 15*y + 10*x + 5 + random.normalvariate(0.0,2.0)) 239 A.append([x,y,1]) 240 241 A = np.array(A) 242 b = np.array(b) 243 b[0] = -200000. 244 245 print np.linalg.lstsq(A,b)[0] 246 print RANSAC(A,b,group=2,tol=6,full_output = True,verbose=True) 247