LSTM原理及实现-程序员宅基地

技术标签: 机器学习  深度学习  

这篇博客前面原理部分是colah博客的翻译,后面一部分为自己结合实际实际代码的理解。

LSTM网络

long short term memory,即我们所称呼的LSTM,是为了解决长期以来问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。

LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。

不必担心这里的细节。我们会一步一步地剖析 LSTM 解析图。现在,我们先来熟悉一下图中使用的各种元素的图标。

在上面的图例中,每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。粉色的圈代表 pointwise 的操作,诸如向量的和,而黄色的矩阵就是学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。

LSTM核心思想

LSTM的关键在于细胞的状态整个(绿色的图表示的是一个cell),和穿过细胞的那条水平线。

细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

若只有上面的那条水平线是没办法实现添加或者删除信息的。而是通过一种叫做 门(gates) 的结构来实现的。

门 可以实现选择性地让信息通过,主要是通过一个 sigmoid 的神经层 和一个逐点相乘的操作来实现的。

sigmoid 层输出(是一个向量)的每个元素都是一个在 0 和 1 之间的实数,表示让对应信息通过的权重(或者占比)。比如, 0 表示“不让任何信息通过”, 1 表示“让所有信息通过”。

LSTM通过三个这样的本结构来实现信息的保护和控制。这三个门分别输入门、遗忘门和输出门。

逐步理解LSTM

现在我们就开始通过三个门逐步的了解LSTM的原理

遗忘门

在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。该门会读取 h t − 1 h t − 1 h t − 1 ht−1ht−1 h_{t-1} ht1ht1ht1ht=(1zt)ht1+zth^t

多层LSTM

**多层LSTM是将LSTM进行叠加,其优点是能够在高层更抽象的表达特征,并且减少神经元的个数,增加识别准确率并且降低训练时间。**具体信息参考[3]

LSTM实现手写数字

这里我们利用的数据集是tensorflow提供的一个手写数字数据集。该数据集是一个包含n张28*28的数据集。

设置LSTM参数

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib import rnn

import numpy as np
import input_data

configuration

O * W + b -> 10 labels for each image, O[? 28], W[28 10], B[10]

^ (O: output 28 vec from 28 vec input)

|

±+ ±+ ±-+

|1|->|2|-> … |28| time_step_size = 28

±+ ±+ ±-+

^ ^ … ^

| | |

img1:[28] [28] … [28]

img2:[28] [28] … [28]

img3:[28] [28] … [28]

img128 or img256 (batch_size or test_size 256)

each input size = input_vec_size=lstm_size=28

configuration variables

input_vec_size = lstm_size = 28 # 输入向量的维度
time_step_size = 28 # 循环层长度

batch_size = 128
test_size = 256

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

这里设置将batch_size设置为128,time_step_size表示的是lstm神经元的个数,这里设置为28个(和图片的尺寸有关?),input_vec_size表示一次输入的像素数。

初始化权值参数

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def model(X, W, B, lstm_size):
# X, input shape: (batch_size, time_step_size, input_vec_size)
# XT shape: (time_step_size, batch_size, input_vec_size)
#对这一步操作还不是太理解,为什么需要将第一行和第二行置换
XT = tf.transpose(X, [1, 0, 2]) # permute time_step_size and batch_size,[28, 128, 28]
# XR shape: (time_step_size * batch_size, input_vec_size)
XR = tf.reshape(XT, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)

# Each array shape: (batch_size, input_vec_size)
X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size (28 arrays),shape = [(128, 28),(128, 28)...]
# Make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0
lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)

# Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
# rnn..static_rnn()的输出对应于每一个timestep,如果只关心最后一步的输出,取outputs[-1]即可
outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)  # 时间序列上每个Cell的输出:[... shape=(128, 28)..]
# tanh activation
# Get the last output
return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the state
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

init_weigths函数利用正态分布随机生成参数的初始值,model的四个参数分别为:X为输入的数据,W表示的是28 * 10的权值(标签为0-9),B表示的是偏置,维度和W一样。这里首先将一批128*(28*28)的图片放进神经网络。然后进行相关的操作(注释已经写得很明白了,这里就不再赘述),然后利用WX+B求出预测结果,同时返回lstm的尺寸

训练

py_x, state_size = model(X, W, B, lstm_size)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)

  • 1
  • 2
  • 3
  • 4

然后通过交叉熵计算误差,反复训练得到最优值。

源代码:
https://github.com/geroge-gao/deeplearning/tree/master/LSTM

参考资料

[1].https://www.jianshu.com/p/9dc9f41f0b29
[2].http://blog.csdn.net/Jerr__y/article/details/58598296
[3].Stacked Long Short-Term Memory Networks

        </div>
					<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-258a4616f7.css" rel="stylesheet">
            </div>
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u010947534/article/details/88788180

智能推荐

EFCore 报错 “Input string was not in a correct format”_完工入库历史记录失败:input string was not in a correct forma-程序员宅基地

文章浏览阅读2.6k次。今天遇到一个很奇葩的东西,项目从framework 迁移到 .netCore上后,执行下面的语句报错:“Input string was not in a correct format”context.Database.ExecuteSqlCommand(sqlStr)后来发现只要sql语句中含有 “{” 或 “}” 时就会出现这样的错误,不知道是否只是由于这样的原因造成的,先记录一下。..._完工入库历史记录失败:input string was not in a correct format. 完工入库

MySQL 中你应该使用什么数据类型表示时间?-程序员宅基地

文章浏览阅读1.1k次。当你需要保存日期时间数据时,一个问题来了:你应该使用 MySQL 中的什么类型?使用 MySQL 原生的 DATE 类型还是使用 INT 字段把日期和时间保存为一个纯数字呢?在这篇文章中,我将解释 MySQL 原生的方案,并给出一个最常用数据类型的对比表。我们也将对一些典型的查询做基准测试,然后得出在给定场景下应该使用什么数据类型的结论。如果你想直接看结论,请翻到文章最下方。原生的...

非阻塞I/O多路复用模型_非阻塞多路复用模型-程序员宅基地

文章浏览阅读2.2k次,点赞7次,收藏14次。目录1.传统阻塞I/O模型2.非阻塞I/O多路复用模型3.I/O多路复用模型的应用1.传统阻塞I/O模型传统的I/O阻塞模型中的“阻塞”的含义是指当前线程的阻塞性,下面用一图来说明这一阻塞I/O模型(读操作)。看图中易知:整个I/O请求过程中用户线程是阻塞的,所以cpu的利用率大打折扣。改进版本:上述过程中能否使得用户线程不是阻塞的呢?当然可以,见下图:..._非阻塞多路复用模型

深度揭秘android摄像头的autoFocus-----循环自动聚焦的实现(Android Camera AutoFo_camera autofocus参考-程序员宅基地

文章浏览阅读2k次。深度揭秘android摄像头的autoFocus-----循环自动聚焦的实现(Android Camera AutoFo_camera autofocus参考

namende和secondarynamenode可以正常启动,datanode无法启动解决方式_namenode和secondarynamenode启动成功,但是datanode启动失败-程序员宅基地

文章浏览阅读491次。今天在做实验的过程中,不知道怎么回事,HDFS就不能操作了。在万般无奈的情况下,将namenode格式化了,然而还是启动不起来,由于我的数据一点都不重要,故我尝试了将data目录和logs目录全部删除。问题得以解决。也许这是最差的一种情况,因为会导致所有的数据都丢失。datanode.log报错如下:2019-06-26 16:22:04,081 FATAL org.apache.hadoop..._namenode和secondarynamenode启动成功,但是datanode启动失败

可持续发展的程序员之路 _程序员环可持续发展-程序员宅基地

文章浏览阅读662次。 我觉得每个人的经历都是独特的,由于际遇,兴趣,志向的不同,对于技术方向的选择就会千差万别,谈个人的经历和选择对大部分人没有什么参考意义,甚至会有误导的嫌疑。不过我倒愿意与大家分享自己受挫的经历,或许更有启发。 入行1年不到的时候,我对公司的发展前景产生了悲观的情绪,老的项目一次次被delay,新的项目却遥遥无期,整天有些混日子的感觉。正好有次得到消息华为在金贸招_程序员环可持续发展

随便推点

java数据结构之单链表逆置算法_单链表的逆置java-程序员宅基地

文章浏览阅读4.7k次,点赞2次,收藏13次。单链表逆置算法1设计思想:在链表类中新加成员方法getNode(int i),用来获取指定位置的节点,新建一个空单链表,将原链表的每个节点按照从后往前的顺序依次取出,再把节点的数据依次添加到新的链表中。public Node getNode(int i)//获取指定位置的节点 { Node curr; curr=this.head.getNext......_单链表的逆置java

自考 计算机网络安全第5 章_依据入侵检行为的属性,入侵检测系统可分为-程序员宅基地

文章浏览阅读526次。第5 章 入侵检测技术一、 识记1、入侵检测的原理答:入侵检测是用于检测任何损害或者企图损害系统的保密性、完整性或可用性的一种网络安全技术。 它通过监视受保护系统的状态和活动,采用误用检测或异常检测的方式,发现非授权或恶意的系统及网络 行为,为防范入侵行为提供有效的手段。入侵检测系统,就是执行入侵检测任务的硬件或软件产品。 入侵检测提供了用于发现入侵攻击与合法用户滥用特权的一种方法,其应用前提是:入侵行为和合法 行为是可区分的,也即可以通过提取行为的模式特征来判断该行为的性质。入侵检测系统需要解决_依据入侵检行为的属性,入侵检测系统可分为

TensorFlow CPU版 安装环境 Ubuntu16.04 Python3.5-程序员宅基地

文章浏览阅读842次。tensorflow install quick startonly CPU ubutu16.04 python3.5 tensorflow1.51.sudo apt-get install python3-pip python3-dev2. pip up to v9.0.1sudo -H pip install --upgrade pip3.pip3 install tensorflow4.su...

SQL注入(二)-程序员宅基地

文章浏览阅读108次。接上一篇SQL注入(一)的内容说说防范SQL第5点 5.限制输入长度   如果在Web页面上使用文本框收集用户输入的数据,使用文本框的MaxLength属性来限制用户输入过长的字符也是一个很好的方法,因为用户的输入不够长,也就减少了贴入大量脚本的可能性。程序员可以针对需要收集的数据类型作出一个相应的限制策略。 6.URL重写技术 我们利用URL重写技术过滤一些SQ

Linux之grep命令详解_linux grep-程序员宅基地

文章浏览阅读2.6w次,点赞19次,收藏140次。注: 部分概念介绍来源于网络一、简介grep (global search regular expression(RE) and print out the line,全面搜索正则表达式并把行打印出来)是一种强大的文本搜索工具,它能使用正则表达式搜索文本,并把匹配的行打印出来。二、grep常用用法1、grep [-acinv] [--color=auto] '搜寻字符串' filename.txt选项与参数:-w :被匹配的文本只能是单词,而不能是单词中的某一部分,如文本中有liker,而我_linux grep

android开发6.0权限适配-程序员宅基地

文章浏览阅读542次。写博客只是为了方便记忆,希望自己能够坚持下去。在android开发中,如果项目的targetSdkVersion >= 23,在手机版本为android6.0+的手机上运行时,某些危险的权限需要用户授权,如果用户不同意,而直接运行某些代码,会造成程序的崩溃。 Runtime Permissions This release introduces a new permissions model,