Package pyvision :: Package ml :: Module lda
[hide private]
[frames] | no frames]

Source Code for Module pyvision.ml.lda

  1  # PyVision License 
  2  # 
  3  # Copyright (c) 2006-2008 David S. Bolme 
  4  # All rights reserved. 
  5  # 
  6  # Redistribution and use in source and binary forms, with or without 
  7  # modification, are permitted provided that the following conditions 
  8  # are met: 
  9  #  
 10  # 1. Redistributions of source code must retain the above copyright 
 11  # notice, this list of conditions and the following disclaimer. 
 12  #  
 13  # 2. Redistributions in binary form must reproduce the above copyright 
 14  # notice, this list of conditions and the following disclaimer in the 
 15  # documentation and/or other materials provided with the distribution. 
 16  #  
 17  # 3. Neither name of copyright holders nor the names of its contributors 
 18  # may be used to endorse or promote products derived from this software 
 19  # without specific prior written permission. 
 20  #  
 21  #  
 22  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 23  # ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
 24  # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
 25  # A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR 
 26  # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 
 27  # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
 28  # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 
 29  # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 
 30  # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 
 31  # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 
 32  # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 33  ''' 
 34  Created on Jan 27, 2011 
 35   
 36  @author: bolme 
 37  ''' 
 38  import numpy as np 
 39  #import scipy as sp 
 40  import scipy.linalg as la 
 41   
 42   
43 -def lda(data,labels,reg=0.0):
44 ''' 45 Compute the lda basis vectors. Based on Wikipedia and verified against R 46 47 @param data: the data matrix with features in rows. 48 @type data: np.array 49 @param labels: a corresponding 1D array of labels, one label per row in data 50 @type labels: np.array (int or str) 51 @return: (lda_values,lda_basis,means,priors) 52 @rtype: (np.array,np.array,dict,dict) 53 ''' 54 means = {} 55 priors = {} 56 57 classes = list(set(labels)) 58 classes.sort() 59 60 # number of classes 61 C = len(classes) 62 63 # number of data points 64 N = data.shape[0] 65 66 # number of dimensions 67 D = data.shape[1] 68 69 for key in classes: 70 priors[key] = float((labels == key).sum())/labels.shape[0] 71 means[key] = data[labels==key,:].mean(axis=0) 72 73 # Compute the between class cov 74 t1 = [mean for key,mean in means.iteritems()] 75 t1 = np.array(t1) 76 77 # mean of class means 78 t2 = t1.mean(axis=0) 79 t3 = t2 - t1 80 Sb = np.dot(t3.T,t3)/(C-1) 81 82 # size of cov matrix should be DxD 83 assert Sb.shape == (D,D) 84 85 # Compute the within class cov 86 87 data_w = data.copy() 88 for key in classes: 89 # subtract the class mean (c_mean) from each data point 90 c_mean = means[key].reshape(1,D) 91 data_w[labels == key,:] -= c_mean 92 93 Sw = np.dot(data_w.T,data_w) / (N-C) # within class scatter 94 95 # Check the shape of SW 96 assert Sw.shape == (D,D) 97 98 #Compute vectors using generalized eigenvector solver: Sb v = l Sw v 99 if reg >= 0: 100 Sw = Sw+reg*np.eye(Sw.shape[0]) # regularization for stability 101 val,vec = la.eigh(Sb,Sw) 102 103 # Reorder vectors so the most important comes first 104 order = val.argsort()[::-1] # reverse order 105 val = val[order] 106 vec = vec[:,order] 107 val = val[:C-1] 108 vec = vec[:,:C-1] 109 110 #scale the eigen values 111 val = val/val.sum() 112 113 return val,vec,means,priors
114