当前位置 博文首页 > shelgi的博客:tensorflow2.0的一些高级函数用法

    shelgi的博客:tensorflow2.0的一些高级函数用法

    作者:[db:作者] 时间:2021-07-28 08:49

    最近在学习tensorflow2.0的时候看到一些特别好用的高级函数,这里来记录一下它们的用法

    1.tf.gather()

    tf.gather(params,indices,validate_indices=None,name=None,axis=0)
    简单的理解一下,首先传入一个需要处理的张量,然后传入对他的选择操作,也就是一个索引张量。

    下面举个例子:

    考虑班级成绩册的例子,共有 4 个班级,每个班级 35 个学生,8 门科目,保存成绩册的张量 shape 为[4,35,8]。

    #创建成绩册
    record=tf.random.uniform([4,35,8],maxval=100)
    record.numpy
    

    在这里插入图片描述
    如果现在需要收集第 1,2 两个班级的成绩册,我们可以通过切片操作

    record1_2=record[0:2]
    record1_2.numpy
    

    在这里插入图片描述
    也可以使用tf.gather()得到一样的结果

    #从第一个维度(班级)选择前两个班级
    record1_2=tf.gather(record,[0,1],axis=0)
    record1_2.numpy
    

    在这里插入图片描述
    但是换个要求,需要抽查所有班级的第 1,4,9,12,13,27 号同学的成绩,这时候用切片就不好得到结果了,用gather还是很容易的

    #从第二个维度(学生)抽取
    score=tf.gather(record,[0,3,8,11,12,26],axis=1)
    score.numpy
    

    在这里插入图片描述

    2.tf.gather_nd()

    通过 tf.gather_nd(),可以通过指定每次采样的坐标来实现采样多个点的目的
    例子:得到班级 1,学生 1 的科目 2;班级 2,学生 2 的科目 3;班级 3,学生 3 的科目 4 的成绩

    score=tf.gather_nd(record,[[0,0,1],[1,1,2],[2,2,3]])
    score.numpy
    

    在这里插入图片描述

    3.tf.scatter_nd()

    通过 tf.scatter_nd(indices, updates, shape)可以高效地刷新张量的部分数据,但是只能在全 0 张量的白板上面刷新,因此可能需要结合其他操作来实现现有张量的数据刷新功能。

    #需要刷新的位置
    indices = tf.constant([[4], [3], [1], [7]])
    # 构造需要写入的数据
    updates = tf.constant([4.4, 3.3, 1.1, 7.7]) 
    # 在长度为 8 的全 0 向量上根据 indices 写入 updates
    tf.scatter_nd(indices, updates, [8])
    

    在这里插入图片描述

    4.tf.meshgrid()

    通过 tf.meshgrid 可以方便地生成二维网格采样点坐标,或者可以理解成为了满足矩阵相乘,把x按行重复y的列次,y按列重复x的行次(广播机制)
    例子:实现 z = s i n ( x 2 + y 2 ) x 2 + y 2 z=\frac{sin(x^2+y^2)}{x^2+y^2} z=x2+y2sin(x2+y2)?

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    plt.rcParams['axes.unicode_minus']=False
    x = tf.linspace(-8.,8,100) # 设置 x 坐标的间隔
    y = tf.linspace(-8.,8,100) # 设置 y 坐标的间隔
    x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回
    print(x.shape,y.shape) # 打印拆分后的所有点的 x,y 坐标张量 shape
    
    z = tf.sqrt(x**2+y**2) 
    z = tf.sin(z)/z # sinc 函数实现
    
    fig = plt.figure()
    ax = Axes3D(fig)
    # 根据网格点绘制 sinc 函数 3D 曲面
    ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50)
    plt.show()
    

    在这里插入图片描述
    或者来个简单的例子更能体现它的变换

    x=tf.constant([1,2,3])
    y=tf.constant([3,4,5])
    x,y = tf.meshgrid(x,y) 
    print(x.numpy,y.numpy)
    

    在这里插入图片描述
    这样meshgrid的作用就一目了然了

    cs
    下一篇:没有了