动手学深度学习8-softmax分类pytorch简洁实现-程序员宅基地

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append('..')
import d2lzh_pytorch as d2l
import torchvision
import torchvision.transforms as transforms
定义和初始化模型
#与上一节同样的数据集以及批量大小
batch_size= 256
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor())

if sys.platform.startswith('win'):
    num_worker=0   # 表示不用额外的进程来加速读取数据
    
else:
    num_worker=4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)

softmax的输出层是一个全连接层,所以我们使用一个线性模块就可以,因为前面我们数据返回的每个batch的样本X的形状为(batch_size,1,28,28),我们先用view()将X转化为(batch_size,784)才送入全连接层

num_inputs = 784
num_outputs = 10

class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super(LinearNet,self).__init__()
        self.linear = nn.Linear(num_inputs,num_outputs)
    def forward(self,x):
        y = self.linear(x.view(x.shape[0],-1))
        return y
net = LinearNet(num_inputs,num_outputs)

# 我们将形状转化的这个功能定义成一个FlattenLayer
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer,self).__init__()
    def forward(self,x):
        return x.view(x.shape[0],-1)
from collections import OrderedDict
net = nn.Sequential(
    OrderedDict(
[
    ('flatten',FlattenLayer()),
    ('linear',nn.Linear(num_inputs,num_outputs))  
])
)
# 之前线性回归的是num_output是1
init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
print(net)
Sequential(
  (flatten): FlattenLayer()
  (linear): Linear(in_features=784, out_features=10, bias=True)
)
softamx和交叉熵损失函数
#pytorch提供了一个包括softmax预算和交叉熵损失计算的函数
loss = nn.CrossEntropyLoss()
定义优化算法
optimizer = torch.optim.SGD(net.parameters(),lr=0.1)
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n
训练模型
num_epochs, lr = 5, 0.1
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            if optimizer is None:  
                # 上节的代码optimizer is None,使用的手写的代码SGD
                sgd(params, lr, batch_size)
            else:
                # optimizer 非None,
                optimizer.step()  # “softmax回归的简洁实现”一节将用到


            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))


train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None,optimizer)
epoch 1, loss 0.0031, train acc 0.749, test acc 0.765
epoch 2, loss 0.0022, train acc 0.813, test acc 0.808
epoch 3, loss 0.0021, train acc 0.826, test acc 0.818
epoch 4, loss 0.0020, train acc 0.832, test acc 0.816
epoch 5, loss 0.0019, train acc 0.837, test acc 0.821
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/onemorepoint/article/details/103623441

智能推荐

关于JSESSIONID path问题-程序员宅基地

文章浏览阅读207次。tomcat path_jsessionid path

【视频教程】EasyDSS无法启动显示3s超时问题排查_dss平台断电后开机起不来了-程序员宅基地

文章浏览阅读128次。EasyDSS无法启动显示3s超时_dss平台断电后开机起不来了

解决卷积神经网络测试集正确率低_训练集和测试集的准确率差多少合适-程序员宅基地

文章浏览阅读9.7k次,点赞11次,收藏36次。解决卷积神经网络测试集正确率低问题描述可能的原因最终的原因结论问题描述在运用卷积神经网络进行图像识别的过程中,将数据集分为训练集与验证集,首先进行模型的训练,最终模型在训练集与验证集上的正确率均达到0.999,将训练好的模型保存到本地,测试过程中直接读取本地模型对视频帧进行图像识别,正确率却远达不到0.999,即模型在训练集与验证集上的正确率很高,但在测试集上的正确率却很低。可能的原因测试集与训练集的图片尺寸不一致;测试集与训练集的数组维度不一致;数据未归一化,训练过程中图像为归一化数据,即 _训练集和测试集的准确率差多少合适

基于STM32的RT-Thread-Nano移植-程序员宅基地

文章浏览阅读2.5k次。一、RT-Thread简介RT-Thread诞生于2006年,是一款以开源、中立、社区化发展起来的物联网操作系统。 RT-Thread主要采用 C 语言编写,浅显易懂,且具有方便移植的特性(可快速移植到多种主流 MCU 及模组芯片上)。RT-Thread把面向对象的设计方法应用到实时系统设计中,使得代码风格优雅、架构清晰、系统模块化并且可裁剪性非常好。RT-Thread有完整版和Nano版,对于资源受限的微控制器(MCU)系统,可通过简单易用的工具,裁剪出仅需要 3KB Flash、1...

docker运行Springboot项目以及jar更新_docker更换jar包-程序员宅基地

文章浏览阅读1.3k次。1.挂接目录liunx系统下创建 /data/docker/newhellolearn/package目录mkdir -p /data/docker/newhellolearn/package2.上传jar包将jar包放入hellolearn.sh的同一个文件夹目录即/data/docker/hellolearn/scriptdeploy/package,创建容器时再将该目录(宿主机目录)与容器的/data目录进行映射3.创建hellolearn.sh文件在/data/docker/newhe_docker更换jar包

解决Android Handler的handleMessage()方法内TextView.setText偶尔不执行的问题_handlemessage settext onmeasure 耗時-程序员宅基地

文章浏览阅读2.1k次,点赞2次,收藏5次。前言最近项目中要加一个体温测量的外设模块 利用android的串口通信 可以完美的取到测量的体温数据获取到数据后,在用Handler发送数据到View渲染时 发现一个问题 就是数据能测量到 但是渲染到TextView时有时无!android串口通信CH340转USB android串口通信CH340转USB(可参考)handleMessage在使用Handler通信时,handleMessage()这个方法内渲染TextView数据 偶尔会丢失数据渲染这个问题出发,寻找解决方案:看到这样一个_handlemessage settext onmeasure 耗時

随便推点

leetcode 86 分割链表 经典-程序员宅基地

文章浏览阅读61次。struct ListNode* partition(struct ListNode* head, int x){ struct ListNode* sml_now=(struct ListNode*)malloc(sizeof(struct ListNode)); struct ListNode* big_now=(struct ListNode*)malloc(sizeof(struct ListNode)); struct ListNode* sml_b

LaTeX基础一:安装与基本操作-程序员宅基地

文章浏览阅读108次。一、安装 1.首先下载texlive2015.iso文件。再在解压的镜像文件中运行install-tl-advanced.bat批处理命令。注意要关闭杀毒软件,否则可能会出现错误。 2.可以修改一下安装路径,只要更改一个,其他也随之更改:3.把不要安装的语言包去掉,还有一个陈旧的编辑器,后面会安装一个第三方编辑器:4.此时就可以安装了,经过漫长...

php语句html前后位置,HTML渲染效果与CSS代码前后位置的关系-程序员宅基地

文章浏览阅读64次。CSS 中某些样式的位置会使得 HTML 的渲染产生不同的效果的,特别是位置前后的不同或者载入顺序的不同。本篇文章讨论的是 CSS 在 HTML 上下文的位置问题,并不讨论 CSS 有多少种写法。这里的 CSS 主要是指 CSS 都是放在 标签中的情况,一般的网站也很少直接写 style 属性的。最近碰到的一个小小的不同的网页渲染效果,是我以前没有太注意的。HTML 中 标签是有很多状态的,一..._php渲染html和css

Qt入门系列--弹出菜单_qt 弹出菜单-程序员宅基地

文章浏览阅读3.4k次,点赞4次,收藏16次。弹出菜单与菜单栏处理基本相似,涉及的类为QMenu和QAction,一个弹出菜单可以看作一个菜单栏的菜单项,本篇内容介绍如何在窗口创建弹出菜单,先上效果图:左侧图片为窗口区域单击右键显示效果,与菜单栏的菜单项显示效果一样,右侧图片为单击“测试2”的响应,弹窗中的“测试2”为菜单项的显示数据。通过UI设计器创建通过设计器创建的方法见下图经过上述5步操作后,主窗口中会生成“..._qt 弹出菜单

LeetCode 121.Best Time to Buy and Sell Stock-程序员宅基地

文章浏览阅读97次。1.题目2.题意给定一个数组,它的第 i 个元素是一支给定股票第 i 天的价格。如果你最多只允许完成一笔交易(即买入和卖出一支股票),设计一个算法来计算你所能获取的最大利润。注意你不能在买入股票前卖出股票。3.我的解法int maxProfit(int* prices, int pricesSize) { int i,j,max; int Max=0; for (i=0;i...

sql%notfound与exception_when sql%notfound-程序员宅基地

文章浏览阅读410次。create or replace procedure obj2_06(no number) as job emp1.job%type; unconsistent exception;begin select job into job from emp where empno=no; if sql%notfound then dbms_output.put_line('11'); end if; exception when no_data_when sql%notfound