当前位置 博文首页 > pytorch实现线性回归

    pytorch实现线性回归

    作者:逝去〃年华 时间:2021-04-29 17:58

    pytorch实现线性回归代码练习实例,供大家参考,具体内容如下

    欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握

    # 随机初始化一个二维数据集,使用朋友torch训练一个回归模型
    import numpy as np
    import random
    import matplotlib.pyplot as plt
    
    x = np.arange(20)
    y = np.array([5*x[i] + random.randint(1,20) for i in range(len(x))])    # random.randint(参数1,参数2)函数返回参数1和参数2之间的任意整数
    print('-'*50)
    # 打印数据集
    print(x)
    print(y)
    
    import torch
    x_train = torch.from_numpy(x).float()
    y_train = torch.from_numpy(y).float()
    
    # model
    class LinearRegression(torch.nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            # 输入与输出都是一维的
            self.linear = torch.nn.Linear(1,1)
        def forward(self,x):
            return self.linear(x)
    
    # 新建模型,误差函数,优化器
    model = LinearRegression()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(),0.001)
    # 开始训练
    num_epoch = 20
    for i in range(num_epoch):
        input_data = x_train.unsqueeze(1)
        target = y_train.unsqueeze(1)           # unsqueeze(1)在第二维增加一个维度
        out = model(input_data)
        loss = criterion(out,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Eopch:[{}/{},loss:[{:.4f}]".format(i+1,num_epoch,loss.item()))
        if ((i+1)%2 == 0):
            predict = model(input_data)
            plt.plot(x_train.data.numpy(),predict.squeeze(1).data.numpy(),"r")
            loss = criterion(predict,target)
            plt.title("Loss:{:.4f}".format(loss.item()))
            plt.xlabel("X")
            plt.ylabel("Y")
            plt.scatter(x_train,y_train)
            plt.show()

    实验结果:

    js