当前位置 博文首页 > boysoft2002的专栏:南大《探索数据的奥秘》课件示例代码笔记17

    boysoft2002的专栏:南大《探索数据的奥秘》课件示例代码笔记17

    作者:[db:作者] 时间:2021-06-11 21:16

    In [6]: import numpy as np
    import pandas as pd
    from matplotlib import pyplot as plt
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import cross_val_score
    df=pd.read_csv('C:\Python\Scripts\my_data\iris.csv',header=None,
    names=['sepal_length','sepal_width','petal_length','petal_width','
    target'])
    my_data=df[['sepal_length','sepal_width']]
    my_class=[]
    for n in range(150):
    if n<50:
    my_class.append(1)
    elif n<100:
    my_class.append(2)
    else:
    my_class.append(3)
    k_range=range(1,30)
    errors=[]
    for k in k_range:
    knn=KNeighborsClassifier(n_neighbors=k)
    scores=cross_val_score(knn,df[['sepal_length','sepal_width']],my_class,cv=5,
    scoring='accuracy')
    accuracy=np.mean(scores)
    error=1-accuracy
    errors.append(error)
    plt.figure()
    plt.plot(k_range,errors) # 从图看 KNN 中近邻数对 error 的影响
    plt.xlabel('k')
    plt.ylabel('error rates')
    Out[6]: Text(0,0.5,'error rates')

    In [34]: from sklearn.model_selection import GridSearchCV
    knn=KNeighborsClassifier()
    k_range=range(1,30)
    param_grid=dict(n_neighbors=k_range)
    grid=GridSearchCV(knn,param_grid,cv=5,scoring='accuracy')
    grid.fit(my_data[['sepal_length','sepal_width']],my_class)
    Out[34]: GridSearchCV(cv=5, error_score='raise-deprecating',
    estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30,
    metric='minkowski',
    metric_params=None, n_jobs=None,
    n_neighbors=5, p=2,
    weights='uniform'),
    iid='warn', n_jobs=None, param_grid={'n_neighbors': range(1, 30)},
    pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
    scoring='accuracy', verbose=0)
    
    In [42]: pd.DataFrame(grid.cv_results_).head(29)
    Out[42]: mean_fit_time std_fit_time mean_score_time std_score_time \
    0 0.001196 4.007471e-04 0.001199 3.980875e-04
    1 0.001396 4.974370e-04 0.001589 8.019150e-04
    2 0.001600 8.093282e-04 0.000797 3.986157e-04
    3 0.001784 3.930052e-04 0.001216 4.061510e-04
    4 0.001203 4.268555e-04 0.001196 3.978265e-04
    5 0.000999 2.887130e-06 0.000995 3.635908e-05
    6 0.000985 2.348654e-05 0.000598 4.881907e-04
    7 0.000997 2.081404e-05 0.001004 1.305838e-05
    8 0.000997 2.431402e-07 0.000998 2.138815e-05
    9 0.000798 3.992108e-04 0.000798 3.987560e-04
    10 0.000990 1.368896e-05 0.001005 1.342667e-05
    11 0.000798 3.991871e-04 0.000598 4.884969e-04
    12 0.000998 1.723224e-06 0.000798 3.990436e-04
    13 0.000997 1.144409e-06 0.000997 2.126630e-05
    14 0.000997 6.217196e-07 0.000605 4.942987e-04
    15 0.000998 9.608003e-07 0.000990 1.295875e-05
    16 0.000798 3.987558e-04 0.000805 4.024717e-04
    17 0.000990 1.435953e-05 0.000799 3.992572e-04
    18 0.000805 4.028115e-04 0.000990 1.312126e-05
    19 0.000997 2.048736e-06 0.000799 3.994484e-04
    20 0.000998 2.183664e-05 0.000997 8.064048e-07
    21 0.000798 3.988554e-04 0.000798 3.990891e-04
    22 0.000997 5.091228e-07 0.000997 2.112017e-05
    23 0.000798 3.988763e-04 0.001010 2.363564e-05
    24 0.000997 3.021809e-06 0.000997 1.843085e-06
    25 0.000982 4.103033e-05 0.001012 2.276996e-05
    26 0.001005 1.443124e-05 0.000985 2.403921e-05
    27 0.000997 3.173744e-06 0.000997 1.256174e-06
    28 0.000791 3.956460e-04 0.001012 1.680579e-05
    param_n_neighbors params split0_test_score \
    0 1 {'n_neighbors': 1} 0.733333
    1 2 {'n_neighbors': 2} 0.700000
    2 3 {'n_neighbors': 3} 0.666667
    3 4 {'n_neighbors': 4} 0.666667
    4 5 {'n_neighbors': 5} 0.700000
    5 6 {'n_neighbors': 6} 0.733333
    6 7 {'n_neighbors': 7} 0.733333
    7 8 {'n_neighbors': 8} 0.733333
    8 9 {'n_neighbors': 9} 0.733333
    9 10 {'n_neighbors': 10} 0.700000
    10 11 {'n_neighbors': 11} 0.733333
    11 12 {'n_neighbors': 12} 0.766667
    12 13 {'n_neighbors': 13} 0.733333
    13 14 {'n_neighbors': 14} 0.733333
    14 15 {'n_neighbors': 15} 0.733333
    15 16 {'n_neighbors': 16} 0.733333
    16 17 {'n_neighbors': 17} 0.733333
    17 18 {'n_neighbors': 18} 0.733333
    18 19 {'n_neighbors': 19} 0.733333
    19 20 {'n_neighbors': 20} 0.733333
    20 21 {'n_neighbors': 21} 0.733333
    21 22 {'n_neighbors': 22} 0.733333
    22 23 {'n_neighbors': 23} 0.733333
    23 24 {'n_neighbors': 24} 0.700000
    24 25 {'n_neighbors': 25} 0.700000
    25 26 {'n_neighbors': 26} 0.700000
    26 27 {'n_neighbors': 27} 0.733333
    27 28 {'n_neighbors': 28} 0.700000
    28 29 {'n_neighbors': 29} 0.733333
    split1_test_score split2_test_score split3_test_score \
    0 0.733333 0.666667 0.833333
    1 0.733333 0.666667 0.766667
    2 0.800000 0.633333 0.866667
    3 0.800000 0.733333 0.800000
    4 0.766667 0.733333 0.866667
    5 0.866667 0.833333 0.900000
    6 0.833333 0.800000 0.866667
    7 0.800000 0.766667 0.866667
    8 0.766667 0.766667 0.866667
    9 0.766667 0.800000 0.833333
    10 0.766667 0.733333 0.833333
    11 0.800000 0.700000 0.833333
    12 0.766667 0.733333 0.833333
    13 0.733333 0.700000 0.866667
    14 0.800000 0.733333 0.866667
    15 0.833333 0.766667 0.900000
    16 0.800000 0.766667 0.933333
    17 0.800000 0.766667 0.866667
    18 0.766667 0.766667 0.866667
    19 0.833333 0.766667 0.800000
    20 0.800000 0.766667 0.866667
    21 0.833333 0.833333 0.833333
    22 0.833333 0.800000 0.866667
    23 0.800000 0.833333 0.833333
    24 0.833333 0.800000 0.833333
    25 0.833333 0.800000 0.833333
    26 0.833333 0.733333 0.866667
    27 0.833333 0.733333 0.833333
    28 0.800000 0.733333 0.866667
    split4_test_score mean_test_score std_test_score rank_test_score
    0 0.666667 0.726667 0.061101 26
    1 0.633333 0.700000 0.047140 29
    2 0.666667 0.726667 0.090431 26
    3 0.600000 0.720000 0.077746 28
    4 0.766667 0.766667 0.055777 23
    5 0.733333 0.813333 0.068638 2
    6 0.733333 0.793333 0.053333 8
    7 0.700000 0.773333 0.057349 19
    8 0.733333 0.773333 0.048990 19
    9 0.733333 0.766667 0.047140 23
    10 0.800000 0.773333 0.038873 19
    11 0.766667 0.773333 0.044222 19
    12 0.833333 0.780000 0.045216 18
    13 0.733333 0.753333 0.058119 25
    14 0.800000 0.786667 0.049889 14
    15 0.766667 0.800000 0.059628 5
    16 0.866667 0.820000 0.071802 1
    17 0.800000 0.793333 0.044222 8
    18 0.833333 0.793333 0.048990 8
    19 0.800000 0.786667 0.033993 14
    20 0.766667 0.786667 0.045216 14
    21 0.733333 0.793333 0.048990 8
    22 0.833333 0.813333 0.045216 2
    23 0.833333 0.800000 0.051640 5
    24 0.866667 0.806667 0.057349 4
    25 0.800000 0.793333 0.048990 8
    26 0.833333 0.800000 0.055777 5
    27 0.833333 0.786667 0.058119 14
    28 0.833333 0.793333 0.053333 8
    
    In [35]: print(grid.cv_results_)
    {'mean_fit_time': array([0.001196 , 0.00139551, 0.00159965, 0.00178361, 0.00120277,
    0.00099864, 0.0009851 , 0.0009975 , 0.0009974 , 0.00079842,
    0.00098987, 0.00079837, 0.00099764, 0.00099716, 0.0009973 ,
    0.00099835, 0.00079751, 0.00099001, 0.00080519, 0.00099697,
    0.00099769, 0.0007977 , 0.00099735, 0.00079775, 0.00099697,
    0.00098243, 0.00100498, 0.0009973 , 0.00079083]), 'std_fit_time':
    array([4.00747083e-04, 4.97437049e-04, 8.09328245e-04, 3.93005211e-04,
    4.26855489e-04, 2.88712988e-06, 2.34865415e-05, 2.08140375e-05,
    2.43140197e-07, 3.99210774e-04, 1.36889642e-05, 3.99187143e-04,
    1.72322378e-06, 1.14440918e-06, 6.21719590e-07, 9.60800251e-07,
    3.98755797e-04, 1.43595296e-05, 4.02811513e-04, 2.04873572e-06,
    2.18366430e-05, 3.98855440e-04, 5.09122765e-07, 3.98876264e-04,
    3.02180853e-06, 4.10303343e-05, 1.44312385e-05, 3.17374445e-06,
    3.95645988e-04]), 'mean_score_time': array([0.00119882, 0.00158925, 0.00079722,
    0.0012157 , 0.00119586,
    0.00099535, 0.00059791, 0.00100389, 0.00099802, 0.00079751,
    0.0010046 , 0.00059767, 0.00079808, 0.00099688, 0.0006052 ,
    0.00098953, 0.00080452, 0.00079851, 0.00098991, 0.00079889,
    0.00099697, 0.00079818, 0.0009973 , 0.00100989, 0.00099688,
    0.00101218, 0.000985 , 0.00099711, 0.00101242]), 'std_score_time':
    array([3.98087535e-04, 8.01915004e-04, 3.98615655e-04, 4.06150958e-04,
    3.97826513e-04, 3.63590756e-05, 4.88190680e-04, 1.30583752e-05,
    2.13881532e-05, 3.98755968e-04, 1.34266709e-05, 4.88496859e-04,
    3.99043637e-04, 2.12662958e-05, 4.94298713e-04, 1.29587461e-05,
    4.02471736e-04, 3.99257191e-04, 1.31212555e-05, 3.99448380e-04,
    8.06404806e-07, 3.99089127e-04, 2.11201723e-05, 2.36356400e-05,
    1.84308511e-06, 2.27699604e-05, 2.40392146e-05, 1.25617408e-06,
    1.68057872e-05]), 'param_n_neighbors': masked_array(data=[1, 2, 3, 4, 5,
    6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
    mask=[False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False,
    False, False, False, False, False],
    fill_value='?',
    dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 2},
    {'n_neighbors': 3}, {'n_neighbors': 4}, {'n_neighbors': 5},
    {'n_neighbors': 6}, {'n_neighbors': 7}, {'n_neighbors': 8},
    {'n_neighbors': 9}, {'n_neighbors': 10}, {'n_neighbors': 11},
    {'n_neighbors': 12}, {'n_neighbors': 13}, {'n_neighbors': 14},
    {'n_neighbors': 15}, {'n_neighbors': 16}, {'n_neighbors': 17},
    {'n_neighbors': 18}, {'n_neighbors': 19}, {'n_neighbors': 20},
    {'n_neighbors': 21}, {'n_neighbors': 22}, {'n_neighbors': 23},
    {'n_neighbors': 24}, {'n_neighbors': 25}, {'n_neighbors': 26},
    {'n_neighbors': 27}, {'n_neighbors': 28}, {'n_neighbors': 29}],
    'split0_test_score':
    array([0.73333333, 0.7 , 0.66666667, 0.66666667, 0.7 ,
    0.73333333, 0.73333333, 0.73333333, 0.73333333, 0.7 ,
    0.73333333, 0.76666667, 0.73333333, 0.73333333, 0.73333333,
    0.73333333, 0.73333333, 0.73333333, 0.73333333, 0.73333333,
    0.73333333, 0.73333333, 0.73333333, 0.7 , 0.7 ,
    0.7 , 0.73333333, 0.7 , 0.73333333]), 'split1_test_score':
    array([0.73333333, 0.73333333, 0.8 , 0.8 , 0.76666667,
    0.86666667, 0.83333333, 0.8 , 0.76666667, 0.76666667,
    0.76666667, 0.8 , 0.76666667, 0.73333333, 0.8 ,
    0.83333333, 0.8 , 0.8 , 0.76666667, 0.83333333,
    0.8 , 0.83333333, 0.83333333, 0.8 , 0.83333333,
    0.83333333, 0.83333333, 0.83333333, 0.8 ]), 'split2_test_score':
    array([0.66666667, 0.66666667, 0.63333333, 0.73333333, 0.73333333,
    0.83333333, 0.8 , 0.76666667, 0.76666667, 0.8 ,
    0.73333333, 0.7 , 0.73333333, 0.7 , 0.73333333,
    0.76666667, 0.76666667, 0.76666667, 0.76666667, 0.76666667,
    0.76666667, 0.83333333, 0.8 , 0.83333333, 0.8 ,
    0.8 , 0.73333333, 0.73333333, 0.73333333]), 'split3_test_score':
    array([0.83333333, 0.76666667, 0.86666667, 0.8 , 0.86666667,
    0.9 , 0.86666667, 0.86666667, 0.86666667, 0.83333333,
    0.83333333, 0.83333333, 0.83333333, 0.86666667, 0.86666667,
    0.9 , 0.93333333, 0.86666667, 0.86666667, 0.8 ,
    0.86666667, 0.83333333, 0.86666667, 0.83333333, 0.83333333,
    0.83333333, 0.86666667, 0.83333333, 0.86666667]), 'split4_test_score':
    array([0.66666667, 0.63333333, 0.66666667, 0.6 , 0.76666667,
    0.73333333, 0.73333333, 0.7 , 0.73333333, 0.73333333,
    0.8 , 0.76666667, 0.83333333, 0.73333333, 0.8 ,
    0.76666667, 0.86666667, 0.8 , 0.83333333, 0.8 ,
    0.76666667, 0.73333333, 0.83333333, 0.83333333, 0.86666667,
    0.8 , 0.83333333, 0.83333333, 0.83333333]), 'mean_test_score':
    array([0.72666667, 0.7 , 0.72666667, 0.72 , 0.76666667,
    0.81333333, 0.79333333, 0.77333333, 0.77333333, 0.76666667,
    0.77333333, 0.77333333, 0.78 , 0.75333333, 0.78666667,
    0.8 , 0.82 , 0.79333333, 0.79333333, 0.78666667,
    0.78666667, 0.79333333, 0.81333333, 0.8 , 0.80666667,
    0.79333333, 0.8 , 0.78666667, 0.79333333]), 'std_test_score':
    array([0.06110101, 0.04714045, 0.09043107, 0.07774603, 0.05577734,
    0.06863753, 0.05333333, 0.05734884, 0.04898979, 0.04714045,
    0.03887301, 0.04422166, 0.04521553, 0.05811865, 0.04988877,
    0.05962848, 0.0718022 , 0.04422166, 0.04898979, 0.03399346,
    0.04521553, 0.04898979, 0.04521553, 0.05163978, 0.05734884,
    0.04898979, 0.05577734, 0.05811865, 0.05333333]), 'rank_test_score':
    array([26, 29, 26, 28, 23, 2, 8, 19, 19, 23, 19, 19, 18, 25, 14, 5, 1,
    8, 8, 14, 14, 8, 2, 5, 4, 8, 5, 14, 8])}
    
    In [39]: grid_mean_scores= grid.cv_results_['mean_test_score']
    print(grid_mean_scores,'\n')
    plt.figure()
    plt.xlabel('Tuning Parameter: N nearest neighbors')
    plt.ylabel('Classification Accuracy')
    plt.plot(k_range,grid_mean_scores)
    print('最高得分是近邻值取 k =',grid.best_params_['n_neighbors'],'时的得分'
    ,grid.best_score_)
    plt.plot(grid.best_params_['n_neighbors'],grid.best_score_,'ro',
    markersize=12,markeredgewidth=1.5,
    markerfacecolor='None',markeredgecolor='r')
    [0.72666667 0.7 0.72666667 0.72 0.76666667 0.81333333
    0.79333333 0.77333333 0.77333333 0.76666667 0.77333333 0.77333333
    0.78 0.75333333 0.78666667 0.8 0.82 0.79333333
    0.79333333 0.78666667 0.78666667 0.79333333 0.81333333 0.8
    0.80666667 0.79333333 0.8 0.78666667 0.79333333]
    最高得分是近邻值取 k = 17 时的得分 0.82
    Out[39]: [<matplotlib.lines.Line2D at 0x2042e7b5518>]

    In [43]: print(grid.best_params_)
    print(grid.best_score_)
    print(grid.best_estimator_)
    #p=2 欧氏距离, weights='uniform' 或者'distance'(反距离加权)
    {'n_neighbors': 17}
    0.82
    KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
    metric_params=None, n_jobs=None, n_neighbors=17, p=2,
    weights='uniform')
    
    In [47]: knn=KNeighborsClassifier()
    k_range=range(1,30)
    algorithm_opt=['kd_tree','ball_tree']
    p_range=range(1,5)
    weight_range=['uniform','distance']
    param_grid=dict(n_neighbors=k_range,weights=weight_range,
    algorithm=algorithm_opt,p=p_range)
    print(param_grid)
    {'n_neighbors': range(1, 30), 'weights': ['uniform', 'distance'], 'algorithm':
    ['kd_tree', 'ball_tree'], 'p': range(1, 5)}
    
    In [48]: grid=GridSearchCV(knn,param_grid,cv=5,scoring='accuracy')
    grid.fit(my_data[['sepal_length','sepal_width']],my_class)
    print(grid.best_score_)
    print(grid.best_estimator_)
    0.82
    KNeighborsClassifier(algorithm='kd_tree', leaf_size=30, metric='minkowski',
    metric_params=None, n_jobs=None, n_neighbors=17, p=2,
    weights='uniform')

    ?