Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Machine Learning by (19k points)

When using PCA in sklearn, it's easy to get out the components:

from sklearn import decomposition

pca = decomposition.PCA(n_components=n_components)

pca_data = pca.fit(input_data)

pca_components = pca.components_

But I can't for the life of me figure out how to get the components out of LDA, as there are no components_ attribute. Is there a similar attribute in sklearn LDA?

1 Answer

0 votes
by (33.1k points)
edited by

If you are using Principle Component Analysis (PCA), then the pca.components_ are the eigenvectors.

In the case of Linear Discriminant Analysis (LDA), we simply need the lda.scalings_ attribute.

For example:

import numpy as np

import matplotlib.pyplot as plt

from sklearn import datasets

import pandas as pd

from sklearn.preprocessing import StandardScaler

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

 

 

iris = datasets.load_iris()

X = iris.data

y = iris.target

#In general it is a good idea to scale the data

scaler = StandardScaler()

scaler.fit(X)

X=scaler.transform(X)

 

lda = LinearDiscriminantAnalysis()

lda.fit(X,y)

x_new = lda.transform(X)   

The lda.scalings_ are the eigenvectors here:

print(lda.scalings_)

print(lda.transform(np.identity(4)))

Output:

[[-0.67614337  0.0271192 ]

 [-0.66890811  0.93115101]

 [ 3.84228173 -1.63586613]

 [ 2.17067434  2.13428251]]

 

[[-0.67614337  0.0271192 ]

 [-0.66890811  0.93115101]

 [ 3.84228173 -1.63586613]

 [ 2.17067434  2.13428251]]

def myplot(score,coeff,labels=None):

    xs = score[:,0]

    ys = score[:,1]

    n = coeff.shape[0]

 

    plt.scatter(xs ,ys, c = y) #without scaling

    for i in range(n):

        plt.arrow(0, 0, coeff[i,0], coeff[i,1],color = 'r',alpha = 0.5)

        if labels is None:

            plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, "Var"+str(i+1), color = 'g', ha = 'center', va = 'center')

        else:

            plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, labels[i], color = 'g', ha = 'center', va = 'center')

 

plt.xlabel("LD{}".format(1))

plt.ylabel("LD{}".format(2))

plt.grid()

 

#Call the function. 

myplot(x_new[:,0:2], lda.scalings_) 

plt.show()

Results:

image

More details on this particular domain will be achieved by Machine Learning Course.

Browse Categories

...