当前位置 主页 > 网站技术 > 代码类 >

    pytorch三层全连接层实现手写字母识别方式

    栏目:代码类 时间:2020-01-14 18:05

    先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

    首先根据已有的模板定义网络结构SimpleNet,命名为net.py

    import torch
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    from torch import nn,optim
    from torch.utils.data import DataLoader
    from torchvision import datasets,transforms
    #定义三层全连接神经网络
    class simpleNet(nn.Module):
     def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
      super(simpleNet,self).__init__()
      self.layer1=nn.Linear(in_dim,n_hidden_1)
      self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
      self.layer3=nn.Linear(n_hidden_2,out_dim)
     def forward(self,x):
      x=self.layer1(x)
      x=self.layer2(x)
      x=self.layer3(x)
      return x
     
     
    #添加激活函数
    class Activation_Net(nn.Module):
     def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
      super(NeutalNetwork,self).__init__()
      self.layer1=nn.Sequential(#Sequential组合结构
      nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
      self.layer2=nn.Sequential(
      nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
      self.layer3=nn.Sequential(
      nn.Linear(n_hidden_2,out_dim))
     def forward(self,x):
      x=self.layer1(x)
      x=self.layer2(x)
      x=self.layer3(x)
      return x
    #添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
    class Batch_Net(nn.Module):
     def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
      super(Batch_net,self).__init__()
      self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
      self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
      self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
     def forword(self,x):
      x=self.layer1(x)
      x=self.layer2(x)
      x=self.layer3(x)
      return x
      
      
    

    训练网络,

    import torch
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    from torch import nn,optim
    from torch.utils.data import DataLoader
    from torchvision import datasets,transforms
    #定义一些超参数
    import net
    batch_size=64
    learning_rate=1e-2
    num_epoches=20
    #预处理
    data_tf=transforms.Compose(
    [transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差
    
    #读取数据集
    train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
    test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
    #使用内置的函数导入数据集
    train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
    test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
    
    #导入网络,定义损失函数和优化方法
    model=net.simpleNet(28*28,300,100,10)
    if torch.cuda.is_available():#是否使用cuda加速
     model=model.cuda()
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(model.parameters(),lr=learning_rate)
    import net
    n_epochs=5
    for epoch in range(n_epochs):
     running_loss=0.0
     running_correct=0
     print("epoch {}/{}".format(epoch,n_epochs))
     print("-"*10)
     for data in train_loader:
      img,label=data
      img=img.view(img.size(0),-1)
      if torch.cuda.is_available():
       img=img.cuda()
       label=label.cuda()
      else:
       img=Variable(img)
       label=Variable(label)
      out=model(img)#得到前向传播的结果
      loss=criterion(out,label)#得到损失函数
      print_loss=loss.data.item()
      optimizer.zero_grad()#归0梯度
      loss.backward()#反向传播
      optimizer.step()#优化
      running_loss+=loss.item()
      epoch+=1
      if epoch%50==0:
       print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))
     
    
    
    
    
    
CallmeJust:从一个面试官的角度谈软件工程师的面试 小创:React 入门-redux 和 react-redux Mr-Tsing:Redis-第五章节-8种数据类型 leesf:Lakehouse: 统一数据仓库和高级分析的新一代开放平台 嵌入式与Linux那些事:程序员如何写一份合格的简历?(附简历模 zzssdd2:QT串口助手(三):数据接收 NeilZhang:对“微信十年产品思考”的思考 r1chard:一种获取context中keys和values的高效方法 | golang 熊泽-学习中的苦与乐:温习数据算法―贪吃蛇 小林coding:图解 ECDHE 密钥交换算法 c语言程序从哪里开始执行 c++清屏函数是什么 c++中不能重载的运算符有哪些 企业需谨防域名被抢注 互联网时代创业 价值共创时代 linux远程拷贝文件命令rcp,远程文件复制 linux远程拷贝文件夹命令, scp远程拷贝文件及文件夹写法 linux远程拷贝文件断点续传(linux限速和断点续传) linux远程拷贝文件到本地命令(复制远程主机上的文件到本地) linux远程拷贝文件到本地命令及用法 利用Python函数实现一个万历表完整示例 Python字符串对齐、删除字符串不需要的内容以及格式化打印字符 Java中ArrayList集合的常用方法大全 一个简单的Spring容器初始化流程详解 js简单粗暴的发布订阅示例代码 godoc命令不存在的解决方法 python获取当前时间 教你用php读写csv格式的文件 网站优化中怎么提高网站的转化率? 互联网的归宿和轻重结合的新市场