当前位置 主页 > 服务器问题 > Linux/apache问题 >

    PyTorch中permute的用法详解

    栏目:Linux/apache问题 时间:2020-01-06 16:09

    permute(dims)

    将tensor的维度换位。

    参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。

    例:

    import torch
    import numpy as np
    a=np.array([[[1,2,3],[4,5,6]]])
    unpermuted=torch.tensor(a)
    print(unpermuted.size()) # ——> torch.Size([1, 2, 3])
    permuted=unpermuted.permute(2,0,1)
    print(permuted.size()) # ——> torch.Size([3, 1, 2])

    再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

    利用这个函数permute(1,3,2)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

    tensor([[[1., 4.],
    [2., 5.],
    [3., 6.]]])

    如果使用view(1,3,2),可以得到

    tensor([[[1., 2.],
    [3., 4.],
    [5., 6.]]])
    

    以上这篇PyTorch中permute的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持IIS7站长之家。