博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
吴裕雄--天生自然 人工智能机器学习实战代码:线性判断分析LINEARDISCRIMINANTANALYSIS...
阅读量:4337 次
发布时间:2019-06-07

本文共 3296 字,大约阅读时间需要 10 分钟。

import numpy as npimport matplotlib.pyplot as pltfrom matplotlib import cmfrom mpl_toolkits.mplot3d import Axes3Dfrom sklearn.model_selection import train_test_splitfrom sklearn import datasets, linear_model,discriminant_analysisdef load_data():    # 使用 scikit-learn 自带的 iris 数据集    iris=datasets.load_iris()    X_train=iris.data    y_train=iris.target    return train_test_split(X_train, y_train,test_size=0.25,random_state=0,stratify=y_train)#线性判断分析LinearDiscriminantAnalysisdef test_LinearDiscriminantAnalysis(*data):    X_train,X_test,y_train,y_test=data    lda = discriminant_analysis.LinearDiscriminantAnalysis()    lda.fit(X_train, y_train)    print('Coefficients:%s, intercept %s'%(lda.coef_,lda.intercept_))    print('Score: %.2f' % lda.score(X_test, y_test))    # 产生用于分类的数据集X_train,X_test,y_train,y_test=load_data()# 调用 test_LinearDiscriminantAnalysistest_LinearDiscriminantAnalysis(X_train,X_test,y_train,y_test)

def plot_LDA(converted_X,y):    '''    绘制经过 LDA 转换后的数据    :param converted_X: 经过 LDA转换后的样本集    :param y: 样本集的标记    '''    fig=plt.figure()    ax=Axes3D(fig)    colors='rgb'    markers='o*s'    for target,color,marker in zip([0,1,2],colors,markers):        pos=(y==target).ravel()        X=converted_X[pos,:]        ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,label="Label %d"%target)    ax.legend(loc="best")    fig.suptitle("Iris After LDA")    plt.show()    def run_plot_LDA():    '''    执行 plot_LDA 。其中数据集来自于 load_data() 函数    '''    X_train,X_test,y_train,y_test=load_data()    X=np.vstack((X_train,X_test))    Y=np.vstack((y_train.reshape(y_train.size,1),y_test.reshape(y_test.size,1)))    lda = discriminant_analysis.LinearDiscriminantAnalysis()    lda.fit(X, Y)    converted_X=np.dot(X,np.transpose(lda.coef_))+lda.intercept_    plot_LDA(converted_X,Y)    # 调用 run_plot_LDArun_plot_LDA()

def test_LinearDiscriminantAnalysis_solver(*data):    '''    测试 LinearDiscriminantAnalysis 的预测性能随 solver 参数的影响    '''    X_train,X_test,y_train,y_test=data    solvers=['svd','lsqr','eigen']    for solver in solvers:        if(solver=='svd'):            lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver)        else:            lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver,shrinkage=None)        lda.fit(X_train, y_train)        print('Score at solver=%s: %.2f' %(solver, lda.score(X_test, y_test)))        # 调用 test_LinearDiscriminantAnalysis_solvertest_LinearDiscriminantAnalysis_solver(X_train,X_test,y_train,y_test)

def test_LinearDiscriminantAnalysis_shrinkage(*data):    '''    测试  LinearDiscriminantAnalysis 的预测性能随 shrinkage 参数的影响    '''    X_train,X_test,y_train,y_test=data    shrinkages=np.linspace(0.0,1.0,num=20)    scores=[]    for shrinkage in shrinkages:        lda = discriminant_analysis.LinearDiscriminantAnalysis(solver='lsqr',shrinkage=shrinkage)        lda.fit(X_train, y_train)        scores.append(lda.score(X_test, y_test))    ## 绘图    fig=plt.figure()    ax=fig.add_subplot(1,1,1)    ax.plot(shrinkages,scores)    ax.set_xlabel(r"shrinkage")    ax.set_ylabel(r"score")    ax.set_ylim(0,1.05)    ax.set_title("LinearDiscriminantAnalysis")    plt.show()# 调用 test_LinearDiscrtest_LinearDiscriminantAnalysis_shrinkage(X_train,X_test,y_train,y_test)

 

转载于:https://www.cnblogs.com/tszr/p/11177949.html

你可能感兴趣的文章
Ecust OJ
查看>>
P3384 【模板】树链剖分
查看>>
Thrift源码分析(二)-- 协议和编解码
查看>>
考勤系统之计算工作小时数
查看>>
4.1 分解条件式
查看>>
Equivalent Strings
查看>>
flume handler
查看>>
收藏其他博客园主写的代码,学习加自用。先表示感谢!!!
查看>>
H5 表单标签
查看>>
su 与 su - 区别
查看>>
C语言编程-9_4 字符统计
查看>>
在webconfig中写好连接后,在程序中如何调用?
查看>>
限制用户不能删除SharePoint列表中的条目(项目)
查看>>
feign调用spring clound eureka 注册中心服务
查看>>
ZT:Linux上安装JDK,最准确
查看>>
LimeJS指南3
查看>>
关于C++ const成员的一些细节
查看>>
《代码大全》学习摘要(五)软件构建中的设计(下)
查看>>
C#检测驱动是否安装的问题
查看>>
web-4. 装饰页面的图像
查看>>