最近频繁用到matplotlib绘图,梳理了下官网的tutorial,记录下学习笔记。主要是对下面链接的翻译和个人理解整理。

https://matplotlib.org/3.5.0/tutorials/introductory/usage.html

1. 基础知识

matplotlib绘图中包括两个概念:figure, axes. 其中figure表示一张图,axes对应一个绘图区域,一个figure中可以包括多个axes, 如下代码:

fig = plt.figure()  # an empty figure with no Axes
fig, ax = plt.subplots()  # a figure with a single Axes
fig, axs = plt.subplots(2, 2)  # a figure with a 2x2 grid of Axes

一个figure可以包含的组成成分如下:

注意:matplotlib绘制图片时,接受的数据最好是numpy格式(list和pandas等array-like数据,可能会出现异常情况)

1.1 面向对象接口和pyplot接口

matplotlib提供了两种形式的接口供调用:

  • 面向对象接口(object-oriented interface): 显示的创建figure对象和axes对象,调用figure和axes对象的方法
  • pyplot接口(pyplot interface):依靠pyplot来创建和管理figure和axes对象,调用pyplot函数
  • 面向对象接口风格代码如下:

    x = np.linspace(0, 2, 100)  # Sample data.
    # Note that even in the OO-style, we use `.pyplot.figure` to create the figure.
    fig, ax = plt.subplots()  # Create a figure and an axes.
    ax.plot(x, x, label='linear')  # Plot some data on the axes.
    ax.plot(x, x**2, label='quadratic')  # Plot more data on the axes...
    ax.plot(x, x**3, label='cubic')  # ... and some more.
    ax.set_xlabel('x label')  # Add an x-label to the axes.
    ax.set_ylabel('y label')  # Add a y-label to the axes.
    ax.set_title("Simple Plot")  # Add a title to the axes.
    ax.legend()  # Add a legend.
    

    pyplot接口风格代码:

    x = np.linspace(0, 2, 100)  # Sample data.
    plt.plot(x, x, label='linear')  # Plot some data on the (implicit) axes.
    plt.plot(x, x**2, label='quadratic')  # etc.
    plt.plot(x, x**3, label='cubic')
    plt.xlabel('x label')
    plt.ylabel('y label')
    plt.title("Simple Plot")
    plt.legend()
    

    2. pyplot接口

    matplotlib.pyplot 是一个使matplotlib像MATLAB一样工作的绘图接口(很多函数的集合),pyplot会自动追踪当前figure和axes, 其调用函数也是作用于当前axes。

    import matplotlib.pyplot as plt
    import numpy as np
    # 1. 定义一个图形窗口
    plt.figure()
    # 2. 绘制图形
    plt.plot([1, 3, 5, 4])   # [1, 3, 5, 4]会被当作y,x会被自动设置成[0, 1, 2,3],(x从0开始递增)
    plt.ylabel('some label')
    plt.plot([1, 2, 3, 4], [1, 3, 5, 4])   #x=[1, 2, 3, 4], y=[1, 3, 5, 4]
    # 3. 显示绘制图形
    plt.show()
    

    2.1 控制plot曲线的格式(style)

    对于每一组x,y, plot函数接受一个字符串参数fmt,设置绘制曲线的格式,其中fmt格式如下:

    fmt = '[marker][line][color]'     
    或者fmt = '[color][marker][line]'
    fmt默认设置:'b-
    

    marker:

    character description

    所有marker, line, color参考如下链接:https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.plot.html#matplotlib.pyplot.plot

    常用fmt示例:

    'b'    # blue markers with default shape
    'or'   # red circles
    '-g'   # green solid line
    '--'   # dashed line with default color
    '^k:'  # black triangle_up markers connected by a dotted line
    
    import matplotlib.pyplot as plt
    plt.figure()
    # 绘制红色的圆圈
    plt.plot([1, 2, 3, 4], [1, 4, 9, 16], 'ro')
    plt.show()
    
    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure()
    t = np.arange(0., 5., 0.2)
    # 'r--'红色虚线; 'bs':蓝色的方框('s':square marker); 'g^':绿色上三角形('^': 上三角形marker)
    plt.plot(t, t, 'r--', t, t**2, 'bs', t, t**3, 'g^')
    plt.show()
    

    2.3 plot接受字典和字符串形式输入

    plt.scatter()函数参数data,可以接受一个字典输入,根据字典的key来索引字典中的数据,如下代码:

    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure()
    data = {'a': np.arange(50),
            'c': np.random.randint(0, 50, 50),
            'd': np.random.randn(50),
    data['b'] = data['a']+10*np.random.randn(50)
    data['d'] = np.abs(data['d'])*100
    # 采用data['a'],data['b']表示每个点的x,y
    # c='c', 表示采用data['c']的值设置每个点的颜色
    # s='d', 表示采用data['d']的值设置每个点的大小
    plt.scatter('a', 'b', c='c', s='d', data=data)
    plt.xlabel('entry a')
    plt.ylabel('entry b')
    plt.show()
    

    plt.plot()的横坐标可以接收字符串形式的输入,如下:

    import matplotlib.pyplot as plt
    import numpy as np
    names = ['group_a', 'group_b', 'group_c']
    values = [1, 10, 100]
    plt.figure(figsize=(9, 3))
    # 131: 表示有1行3列,共3个子图,在3个子图的第一个子图中绘制
    plt.subplot(131)
    plt.bar(names, values)
    plt.subplot(132)
    plt.scatter(names, values)
    plt.subplot(133)
    plt.plot(names, values)
    plt.show()
    

    2.4 设置line的属性

    直线有很多属性可以设置,线的宽度,颜色,风格等,有下面三种设置方式,如下面代码:

    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure()
    x1 = np.arange(10)
    x2 = np.arange(10)+np.random.rand(10)
    y1 = x1**2
    y2 = x2+np.random.randint(3, 10)
    # 方式一:通过关键字参数设置
    # plt.plot(x1, y1, linewidth=5.0)
    # 方式二:通过返回的line2D对象设置  (返回包含line2D对象的列表,有两条直线,所以列表里有两个对象)
    # line1, line2 = plt.plot(x1, y1, '-', x2, y2)
    # line1.set_antialiased(False)  # turn off antialiasing
    lines = plt.plot(x1, y1, x2, y2)
    # 方式三:通过plt.setp设置
    plt.setp(lines, color='r', linewidth=5.0)
    plt.setp(lines)  # plt.setp(lines):查看所有可以设置的属性名称
    plt.show()
    

    2.5 多个figure和多个axes

    pyplot中有figure和axes的概念,可以有多个figure,每个figure可以有多个axes,pyplot.gca()返回当前的axes对象( matplotlib.axes.Axes ), pyplot.gcf()返回当前的figure对象( matplotlib.figure.Figure ),pyplot总是绘制在当前axes。

    下面代码中利用plot.subplot()创建多个axes,其参数含义如下:

    plot.subplot(numrows, numcols, plot_number):

  • subplot(211): 等同于subplot(2, 1, 1), 表示创建2行1列,共2个axes, 在其中的第一个axes进行绘制
  • import matplotlib.pyplot as plt
    import numpy as np
    def func(t):
        return np.exp(-t)*np.cos(2*np.pi*t)
    t1 = np.arange(0.0, 5.0, 1.0)
    t2 = np.arange(0.0, 5.0, 0.02)
    plt.figure()
    plt.subplot(211)   # 相当于plt.subplot(2, 1, 1), 第一个axes
    plt.plot(t1, func(t1), 'bo', t2, func(t2), 'k')
    plt.subplot(212)  # 相当于plt.subplot(2, 1, 2),第二个axes
    plt.plot(t2, np.cos(2*np.pi*t2), 'r--')
    plt.show()
    

    下面示例代码中,绘制了两个figure, 每个figure中有自己的多个axes:

    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure(1)                # the first figure
    plt.subplot(211)             # the first subplot in the first figure
    plt.plot([1, 2, 3])
    plt.subplot(212)             # the second subplot in the first figure
    plt.plot([4, 5, 6])
    plt.figure(2)                # a second figure
    plt.plot([4, 5, 6])          # creates a subplot() by default
    plt.figure(1)                # figure 1 current; subplot(212) still current
    plt.subplot(211)             # make subplot(211) in figure1 current
    plt.title('Easy as 1, 2, 3') # subplot 211 title
    plt.show()
    
    plt.clf() 清理当前figure中内容
    plt.cla() # 清理当前axes中内容
    plt.close(2) # 关闭figure2
    

    2.6 pyplot中text

    pyplot中可以通过如下函数添加文本:

  • plt.xlabel('Smarts'): x坐标轴名称
  • plt.ylabel('Probability'): y坐标轴名称
  • plt.title('Histogram of IQ'): axes标题
  • plt.text(60, .025, r'$\mu=100,\ \sigma=15$'): 在坐标(60, 0.025)处添加文本 ($$ 表示markdown文本格式)
  • 上面的四个text中都支持markdown格式,来设置数学表达式,如:

    plt.title(r'$\sigma_i=15$')
    
    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure()
    mu, sigma = 100, 15
    x = mu + sigma*np.random.randn(10000)
    n, bins, patches = plt.hist(x, 50, density=1, facecolor='g', alpha=0.75)
    plt.xlabel('Smart')
    plt.ylabel('Probability')
    plt.title('Histogram of IQ')
    plt.text(60,  0.025, r'$\mu=100,\ \sigma=15$')
    plt.axis([40, 160, 0, 0.03])   # [xmin, xmax, ymin, ymax]
    plt.grid(True)  # 设置网格
    plt.show()
    

    和之前介绍的line属性设置一样,也可以通过关键字,或者pyplot.setp()来设置文本的属性,如下:

    t1 = plt.xlabel('my data', fontsize=14, color='red')
    # plt.setp(t1)  # 打印t1可以设置的属性
    t2 = plt.text(60,  0.025, r'$\mu=100,\ \sigma=15$', color='green')
    # plt.setp(t2)   # 打印t2可以设置的属性
    

    plt.annotate(text='local max', xy=(2, 1), xytext=(3, 1.5),):可以设置文本及其指向的位置

  • text: 文本字符串
  • xy: 文本指向的坐标
  • xytext:文本的坐标位置
  • annotate详细文档:https://matplotlib.org/3.5.0/tutorials/text/annotations.html#annotations-tutorial

    plt.figure()
    ax =plt.subplot()
    t = np.arange(0, 5.0, 0.01)
    s = np.cos(2*np.pi*t)
    line, = plt.plot(t, s, lw=2)
    plt.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
                 arrowprops=dict(facecolor='black', shrink=0.05),
    plt.ylim(-2, 2)
    plt.show()
    

    2.7 Logarithmic scale

    pyplot的坐标轴除了支持linear scale,还支持log scale, symmetric log scale, logit scale:

    关于scale文档:https://matplotlib.org/3.5.0/api/scale_api.html#module-matplotlib.scale

    log scale: 即matplotlib.scale.LogScale, log坐标轴,注意其只绘制正数,会忽略负数

    symmentric log scale: 即matplotlib.scale.SymmetricalLogScale, 对称log坐标轴, 支持正负数

    logit scale: 即matplotlib.scale.LogitScale, logit坐标轴,[0,1]范围内,会将log数据映射[0, 1]范围内。

    ( logit = 1/(1+log(-x)) )

    示例代码如下:

    np.random.seed(19680801)
    y = np.random.normal(loc=0.5, scale=0.4, size=1000)  # shape(1000,)
    y = y[(y > 0) & (y < 1)]  # shape(799, ), 值在[0, 1]范围内
    y.sort()
    x = np.arange(len(y))
    plt.figure()
    # 线性坐标轴
    plt.subplot(221)
    plt.plot(x, y)
    plt.yscale('linear')
    plt.title('linear')
    plt.grid(True)
    # y轴为log坐标轴
    plt.subplot(222)
    plt.plot(x, y)
    plt.yscale('log')
    plt.title('log')
    plt.grid(True)
    # y轴为symlog坐标轴
    plt.subplot(223)
    plt.plot(x,  y - y.mean())  # y可以为正数和负数,均值变成0
    plt.yscale('symlog', linthresh=0.01)  # x趋近于0时,log(x)会趋近于负无穷,linthresh=0.01设置(-0.01, 0.01)范围内为线性值
    plt.title('symlog')
    plt.grid(True)
    # y轴为logit坐标轴
    plt.subplot(224)
    plt.plot(x, y)
    plt.yscale('logit')
    plt.title('logit')
    plt.grid(True)
    # 调整subplot之间的格式
    plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.25,
                        wspace=0.35)
    plt.show()
    

    2. 8 综合案例

    下面时工作中我绘制的一段代码截图,数据采用随机数进行了代替,包括了坐标轴的标题,图例显示,坐标轴刻度范围等。

    import numpy as np
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 6))  # 设置fig的大小,width=12, height=6
    # quaternion linear曲线插值
    quater_linear_joint_y = np.random.randn(19)+0.2   # shape(19, )
    quater_linear_sequence_y = np.random.randn(481)+0.2  # shape(481, )
    # quaternion bezier曲线插值
    quater_bezier_joint_y = np.random.randn(19)    # shape(19, )
    quater_bezier_sequence_y = np.random.randn(481)   # shape(481, )
    quater_linear_joint_x = np.arange(quater_linear_joint_y.shape[0]) + 1    # shape(19, )
    quater_bezier_joint_x = np.arange(quater_bezier_joint_y.shape[0]) + 1     # shape(19, )
    quater_linear_sequence_x = np.arange(quater_linear_sequence_y.shape[0]) + 1    # shape(481, )
    quater_bezier_sequence_x = np.arange(quater_bezier_sequence_y.shape[0]) + 1   # shape(481, )
    # 线性插值
    pos3d_linear_joint_y = np.random.randn(19)+0.1   # shape(19, )
    pos3d_linear_sequence_y = np.random.randn(481)+0.1  # shape(481, )
    # bezier曲线插值
    pos3d_bezier_joint_y = np.random.randn(19)    # shape(19, )
    pos3d_bezier_sequence_y = np.random.randn(481)   # shape(481, )
    pos3d_linear_joint_x = np.arange(pos3d_linear_joint_y.shape[0]) + 1   # shape(19, )
    pos3d_bezier_joint_x = np.arange(pos3d_bezier_joint_y.shape[0]) + 1   # shape(19, )
    pos3d_linear_sequence_x = np.arange(pos3d_linear_sequence_y.shape[0]) + 1  # shape(481, )
    pos3d_bezier_sequence_x = np.arange(pos3d_bezier_sequence_y.shape[0]) + 1  # shape(481, )
    # 绘制joint
    plt.subplot(121)
    plt.title(f"difference in joint")  # 设置subplot标题
    plt.xlabel("joint number")    # 设置y轴标题
    plt.ylabel("l2 distance")   # 设置x轴标题
    ymin = min(np.min(quater_linear_joint_y), np.min(quater_bezier_joint_y),
               np.min(pos3d_linear_joint_y), np.min(pos3d_bezier_joint_y))
    ymax = max(np.max(quater_linear_joint_y), np.max(quater_bezier_joint_y),
               np.max(pos3d_linear_joint_y), np.max(pos3d_bezier_joint_y))
    plt.ylim((ymin, ymax))   # 设置y轴坐标轴刻度范围
    plt.xlim((1, 19))       # 设置x轴坐标轴刻度范围
    plt.plot(quater_linear_joint_x, quater_linear_joint_y, 'rs', label='quaternion linear')   #'rs'表示红色方框,label方便legend显示
    plt.plot(quater_bezier_joint_x, quater_bezier_joint_y, 'r^', label='quaternion bezier')   #'r^'表示红色三角形,label方便legend显示
    plt.plot(pos3d_linear_joint_x, pos3d_linear_joint_y, 'bs', label='pos3d linear')   #'bs'表示蓝方框,label方便legend显示
    plt.plot(pos3d_bezier_joint_x, pos3d_bezier_joint_y, 'b^', label='pos3d bezier')   #'b^'表示蓝色三角形,label方便legend显示
    plt.legend()  # 显示图例,根据plot时设置的label区分
    # 绘制sequence
    plt.subplot(122)
    plt.title(f"difference in frames")
    plt.xlabel("frames number")
    plt.ylabel("l2 distance")
    ymin = min(np.min(quater_linear_sequence_y), np.min(quater_bezier_sequence_y),
               np.min(pos3d_linear_sequence_y), np.min(pos3d_bezier_sequence_y))
    ymax = max(np.max(quater_linear_sequence_y), np.max(quater_bezier_sequence_y),
               np.max(pos3d_linear_sequence_y), np.max(pos3d_bezier_sequence_y))
    plt.ylim((ymin, ymax))   # 设置y轴坐标轴刻度范围
    plt.xlim((1, 481))
    plt.plot(quater_linear_sequence_x, quater_linear_sequence_y, 'rs', label='quaternion linear')
    plt.plot(quater_bezier_sequence_x, quater_bezier_sequence_y, 'r^', label='quaternion bezier')
    plt.plot(pos3d_linear_sequence_x, pos3d_linear_sequence_y, 'bs', label='pos3d linear')
    plt.plot(pos3d_bezier_sequence_x, pos3d_bezier_sequence_y, 'b^', label='pos3d bezier')
    plt.legend()
    # plt.savefig(f'./plot.png')  # 保存绘制图片
    plt.show()
    

    3. 图片显示

    matplotlib.image可以用来读取图片为numpy数据(底层依赖Pillow),其读取的图片numpy格式为RGB, 通过plt.imshow能显示图片numpy数据。

    注意:matplolib读取图片后,会将其缩放到[0, 1]范围内,并且转换为float32格式, 对应RGB图片matplotlib支持float32和uint8类型数据, 但对于灰度图,matplotlib只支持float32格式

    (若采用opencv读取,需要将BGR其转换为RGB,并将其转换为float32类型,缩放到[0,1]范围)

    示例代码如下:

    import matplotlib.pyplot as plt
    import matplotlib.image as mping
    img_path = r"./stinkbug.png"
    img = mping.imread(img_path)
    print(img.shape, img.dtype)
    imgplot = plt.imshow(img)
    plt.show()
    

    3.1 伪彩色机制

    对于灰度图图片,matplotlib可以通过伪彩色机制显示,能够分辨出图片中亮暗区域。如下代码所示:

    import matplotlib.pyplot as plt
    import matplotlib.image as mping
    img_path = r"./stinkbug.png"
    img = mping.imread(img_path)
    print(img.shape, img.dtype)
    # example 2
    gray_img = img[:, :, 0]
    plt.imshow(gray_img)   # 默认采用viridis伪彩色展示灰度图
    plt.colorbar()  # 显示颜色条
    plt.show()
    

    默认采用viridis伪彩色展示灰度图,可以设置伪彩色的机制, 如下所示:

    import matplotlib.pyplot as plt
    import matplotlib.image as mping
    img_path = r"./stinkbug.png"
    img = mping.imread(img_path)
    print(img.shape, img.dtype)
    gray_img = img[:, :, 0]
    plt.imshow(gray_img, cmap='hot')   # 采用hot伪彩色展示灰度图
    # plt.imshow(gray_img, cmap='gray')   # 展示原始灰度图
    # imgplot = plt.imshow(gray_img)
    # imgplot.set_cmap('nipy_spectral') # 采用nipy_spectral伪彩色展示灰度图
    plt.colorbar()  # 显示颜色条
    

    3.2 直方图统计和截取

    plt.hist能绘制直方图,统计每个像素出现的次数,plt.imshow函数中的clim参数能设置显示的像素区间, 如下面代码:

    plt.hist(lum_img.ravel(), bins=256, range=(0.0, 1.0), fc='k', ec='k')  # 绘制单通道灰度图直方图,256个区间
    imgplot = plt.imshow(lum_img, clim=(0.0, 0.7))  # 图像像素范围在[0, 1]区间,只显[0, 0.7]区间像素
    

    下面代码中,截取了原图像素(0, 0.7)的区间。(根据统计直方图可知,超过0.7范围内的像素很少, 相当于增加对比度?)

    import matplotlib.pyplot as plt
    import matplotlib.image as mping
    img_path = r"./stinkbug.png"
    img = mping.imread(img_path)
    gray_img = img[:, :, 0]
    fig = plt.figure()
    ax = fig.add_subplot(1, 2, 1)
    imgplot = plt.imshow(gray_img)
    ax.set_title('Before')
    plt.colorbar(ticks=[0.1, 0.3, 0.5, 0.7], orientation='horizontal')
    ax2 = fig.add_subplot(1, 2, 2)
    imgplot2 = plt.imshow(gray_img, clim=(0.0, 0.7))
    ax2.set_title('After')
    plt.colorbar(ticks=[0.1, 0.3, 0.5, 0.7], orientation='horizontal')
    plt.show()
    

    3.3 插值

    plt.imshow显示图片时,若图片尺寸发生变化,可以设置插值的方式, 示例代码如下:

    import matplotlib.pyplot as plt
    from PIL import Image
    img_path = r"./stinkbug.png"
    img = Image.open(img_path)
    print(img.size)  # shape(500, 375)
    img.thumbnail((64, 64), Image.ANTIALIAS)   # 保持长宽比缩放: shape(64, 48)
    print(img.size)
    # plt.imshow(img)   # 默认采用bilinear 插值
    plt.imshow(img, interpolation='nearest')   # 采用nearest插值
    # plt.imshow(img, interpolation='bicubic')  # 采用bicubic插值
    plt.show()