当前位置 博文首页 > shelgi的博客:matplotlib+seaborn几种可视化图的模板

    shelgi的博客:matplotlib+seaborn几种可视化图的模板

    作者:[db:作者] 时间:2021-07-28 08:45

    无论是数据分析还是写论文报告,我们都离不开数据可视化这一步,但是平时简单的图形又显得单调乏味,所以准备写一篇博客把几种常用的可视化图放上,以后使用的时候套模板借鉴就好

    2.代码+可视化图形

    首先是一些最基本需要导入的库,另外的会在不同的代码中另外添加

    import pandas as pd
    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    import seaborn as sns
    from matplotlib import patches
    from scipy.spatial import ConvexHull
    plt.rcParams['font.sans-serif']='Simhei'
    plt.rcParams['axes.unicode_minus']=False
    

    下面是一些对图形的参考设置

    large = 22
    med = 16
    small = 12
    params = {'axes.titlesize': large,
              'legend.fontsize': med,
              'figure.figsize': (16, 10),
              'axes.labelsize': med,
              'axes.titlesize': med,
              'xtick.labelsize': med,
              'ytick.labelsize': med,
              'figure.titlesize': large}
    plt.rcParams.update(params)
    plt.style.use('seaborn-whitegrid')
    sns.set_style("white")
    
    • 1.散点图
    midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
    categories = np.unique(midwest['category'])
    colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
    
    plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')
    
    for i, category in enumerate(categories):
        plt.scatter('area', 'poptotal',
                    data=midwest.loc[midwest.category==category, :],
                    s=20, c=colors[i], label=str(category))
    
    
    plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
                  xlabel='Area', ylabel='Population')
    plt.xticks(fontsize=12); plt.yticks(fontsize=12)
    plt.title("Scatterplot of Midwest Area vs Population", fontsize=22)
    plt.legend(fontsize=12)
    plt.show()
    
    

    在这里插入图片描述

    • 2.带边界的气泡图
    #2.带边界的气泡图
    midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
    categories = np.unique(midwest['category'])
    colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
    fig = plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')
    for i, category in enumerate(categories):
        plt.scatter('area', 'poptotal', data=midwest.loc[midwest.category==category, :], s='dot_size', c=colors[i], label=str(category), edgecolors='black', linewidths=.5)
    
    def encircle(x,y, ax=None, **kw):
        if not ax: ax=plt.gca()
        p = np.c_[x,y]
        hull = ConvexHull(p)
        poly = plt.Polygon(p[hull.vertices,:], **kw)
        ax.add_patch(poly)
    
    midwest_encircle_data = midwest.loc[midwest.state=='IN', :]
    
    encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="k", fc="gold", alpha=0.1)
    encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="firebrick", fc="none", linewidth=1.5)
    
    plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
                  xlabel='Area', ylabel='Population')
    
    plt.xticks(fontsize=12); plt.yticks(fontsize=12)
    plt.title("Bubble Plot with Encircling", fontsize=22)
    plt.legend(fontsize=12)
    plt.show()
    

    在这里插入图片描述

    • 3.带拟合线的散点图
    #3.带拟合线的散点图
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    df_select = df.loc[df.cyl.isin([4,8]), :]
    sns.set_style("white")
    gridobj = sns.lmplot(x="displ", y="hwy", hue="cyl", data=df_select,
                         height=7, aspect=1.6, robust=True, palette='tab10',
                         scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))
    
    gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
    plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize=20)
    plt.show()
    
    sns.set_style("white")
    gridobj = sns.lmplot(x="displ", y="hwy",
                         data=df_select,
                         height=7,
                         robust=True,
                         palette='Set1',
                         col="cyl",
                         scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))
    
    gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
    plt.show()
    

    在这里插入图片描述
    在这里插入图片描述

    • 4.抖动图和计数图
    #4.抖动图和计数图
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)
    sns.stripplot(df.cty, df.hwy, jitter=0.25, size=8, ax=ax, linewidth=.5)
    plt.title('Use jittered plots to avoid overlapping of points', fontsize=22)
    plt.show()
    
    df_counts = df.groupby(['hwy', 'cty']).size().reset_index(name='counts')
    fig, ax = plt.subplots(