LogoLogo
  • README
  • 前端编程
    • 01 Node JS
    • 02-ES6详解
    • 03-NPM详解
    • 04-Babel详解
    • 05-前端模块化开发
    • 06-WebPack详解
    • 07-Vue详解
    • 08-Git详解
    • 09-微信小程序
  • 人工智能
    • 机器学习
      • 二次分配问题
      • 非负矩阵
      • 概率潜在语义分析
      • 概率图模型
      • 集成学习
      • 降维
      • 距离度量
      • 决策树
      • 逻辑回归
      • 马尔可夫决策过程
      • 马尔可夫链蒙特卡洛法
      • 朴素贝叶斯法
      • 谱聚类
      • 奇异值分解
      • 潜在狄利克雷分配
      • 潜在语义分析
      • 强化学习
      • 社区算法
      • 时间序列模型
      • 特征工程
      • 条件随机场
      • 图论基础
      • 线性分类
      • 线性回归
      • 信息论中的熵
      • 隐马尔科夫模型
      • 支持向量机
      • 主成分分析
      • EM算法
      • Hermite 矩阵的特征值不等式
      • k-means聚类
      • k近邻法
      • PageRank算法
    • 深度学习
      • Pytorch篇
        • 01-线性模型
        • 02-梯度下降法
        • 03-反向传播
        • 04-pytorch入门
        • 05-用pytorch实现线性回归
        • 06-logistic回归
        • 07-处理多维特征的输入
        • 08-加载数据集
        • 09-多分类问题
        • 10-卷积神经网络
        • 11-循环神经网络
    • 图神经网络
      • 图神经网络笔记01
        • 01-图(Graphs)的结构
        • 02-网络的性质和随机图模型
        • 03-网络工具
        • 04-网络中的主题和结构角色
        • 05-网络中的社区结构
      • 图神经网络笔记02
        • 01-深度学习引言
        • 02-神经网络基础
        • 03-卷积神经网络
        • 04-图信号处理与图卷积神经网络
        • 05-GNN的变体与框架-
        • [06-Google PPRGo 两分钟分类千万节点的最快GNN](人工智能/图神经网络/图神经网络笔记02/06-Google%20PPRGo 两分钟分类千万节点的最快GNN.md)
        • 07-序列模型
        • 08-变分自编码器
        • 09-对抗生成网络
  • 日常记录
    • 健身日记
    • 面经记录
    • 自动生成Summary文件
  • 实战项目
    • 谷粒商城
      • 00-项目概述
      • 01-分布式基础-全栈开发篇
      • 02-分布式高级-微服务架构篇
      • 03-高可用集群-架构师提升篇
  • 数据库
    • MySQL笔记
      • 01-MySQL基础篇
      • 02-MySQL架构篇
      • 03-MySQL索引及调优篇
      • 04-MySQL事务篇
      • 05-MySQL日志与备份篇
    • Redis笔记
      • 01-Redis基础篇
      • 02-Redis高级篇
    • 02-Redis篇
  • 算法笔记
    • 01-算法基础篇
    • 02-算法刷题篇
  • 职能扩展
    • 产品运营篇
  • Go编程
    • 01-Go基础
      • 01-Go基础篇
  • Java编程
    • 01-Java基础
      • 01-Java基础篇
      • 02-多线程篇
      • 03-注射与反解篇
      • 04-JUC并发编程篇
      • 05-JUC并发编程与源码分析
      • 06-JVM原理篇
      • 07-Netty原理篇
      • 08-设计模式篇
    • 02 Java Web
      • 01-Mybatis篇
      • 01-Mybatis篇(新版)
      • 02-Spring篇
      • 02-Spring篇(新版)
      • 03-SpringMVC篇
      • 04-MybatisPlus篇
    • 03-Java微服务
      • 01-SpringBoot篇
      • 01-SpringBoot篇(新版)
      • 02-SpringSecurity篇
      • 03-Shiro篇
      • 04-Swagger篇
      • 05-Zookeeper篇
      • 06-Dubbo篇
      • 07-SpringCloud篇
      • 08-SpringAlibaba篇
      • 09-SpringCloud篇(新版)
    • 04-Java中间件
      • 数据库篇
        • 01-分库分表概述
        • 02-MyCat篇
        • 03-MyCat2篇
        • 04-Sharding-jdbc篇
        • 05-ElasticSearch篇
      • 消息中间件篇
        • 01-MQ概述
        • 02-RabbitMQ篇
        • 03-Kafka篇
        • 04-RocketMQ篇
        • 05-Pulsar篇
    • 05-扩展篇
      • Dubbo篇
      • SpringBoot篇
      • SpringCloud篇
    • 06-第三方技术
      • 01-CDN技术篇
      • 02-POI技术篇
      • 03-第三方支付技术篇
      • 04-第三方登录技术篇
      • 05-第三方短信接入篇
      • 06-视频点播技术篇
      • 07-视频直播技术篇
    • 07-云原生
      • 01-Docker篇
      • 02-Kubernetes篇
      • 03-Kubesphere篇
  • Linux运维
    • 01-Linux篇
    • 02-Nginx篇
  • Python编程
    • 01-Python基础
      • 01.配置环境
      • 02.流程控制
      • 03.数值
      • 04.操作符
      • 05.列表
      • 06.元祖
      • 07.集合
      • 08.字典
      • 09.复制
      • 10.字符串
      • 11.函数
      • 12.常见内置函数
      • 13.变量
      • 14.异常和语法错误
      • 15.时间和日期
      • 16.正则表达式
    • 02 Python Web
      • flask篇
        • 01.前言
        • 02.路由
        • 03.模板
        • 04.视图进阶
        • 05.flask-sqlalchemy
        • 06.表单WTForms
        • 07.session与cookie
        • 08.上下文
        • 09.钩子函数
        • 10.flask 信号
        • 11.RESTFUL
        • 13.flask-mail
        • 14.flask+celery
        • 15.部署
        • 16.flask-login
        • 17.flask-cache
        • 18.flask-babel
        • 19.flask-dashed
        • 20.flask-pjax
        • 21.flask上传文件到第三方
        • 22.flask-restless
        • 23.flask-redis
        • 24.flask-flash
        • 25.消息通知
        • 26.分页
    • 03-Python数据分析
      • Matplotlib
      • Numpy
      • Pandas
      • Seaborn
    • 04-Python爬虫
      • 1.准备工作
      • 2.请求模块的使用
      • 3.解析模块的使用
      • 4.数据存储
      • 5.识别验证码
      • 6.爬取APP
      • 7.爬虫框架
      • 8.分布式爬虫
由 GitBook 提供支持
在本页
  • 用途
  • 循环神经网络模型
  • 递归神经网络
  • LSTM
  • 长时依赖问题
  • 模型
  • 核心思想
  • LSTM结构
  • 实现
  • 拓展阅读

这有帮助吗?

在GitHub上编辑
  1. 人工智能
  2. 图神经网络
  3. 图神经网络笔记02

07-序列模型

上一页[06-Google PPRGo 两分钟分类千万节点的最快GNN](人工智能/图神经网络/图神经网络笔记02/06-Google%20PPRGo 两分钟分类千万节点的最快GNN.md)下一页08-变分自编码器

最后更新于3年前

这有帮助吗?

用途

序列模型能够应用在许多领域,例如:

  • 语音识别

  • 音乐发生器

  • 情感分类

  • DNA序列分析

  • 机器翻译

  • 视频动作识别

  • 命名实体识别

比如语音识别 , 输入数据和输出数据都是序列数据 , X 是按时序播放的音频片段 , 输出 Y 是一系列单词 。

image-20210530145515352

比如音乐生成 , 只有输出时序列数据 , 输入数据可以是空集 , 也可以是单一整数 ( 指代音乐风格)。

这些序列模型基本都属于监督式学习,输入 x 和输出 y 不一定都是序列模型。如果都是序列模型的话,模型长度也可以不一致。

循环神经网络模型

递归神经网络

在传统神经网络中,模型不会关注上一时刻的处理会有什么信息可以用于下一时刻,每一次都只会关注当前时刻的处理。举个例子来说,我们想对一部影片中每一刻出现的事件进行分类,如果我们知道电影前面的事件信息,那么对当前时刻事件的分类就会非常容易。实际上,传统神经网络没有记忆功能,所以它对每一刻出现的事件进行分类时不会用到影片已经出现的信息,那么有什么方法可以让神经网络能够记住这些信息呢?答案就是Recurrent Neural Networks(RNNs)递归神经网络。

递归神经网络的结果与传统神经网络有一些不同,它带有一个指向自身的环,用来表示它可以传递当前时刻处理的信息给下一时刻使用,结构如下:

其中,$X_t$为输入,$A$为模型处理部分,$h_t$为输出。为了更容易地说明递归神经网络,把上图展开,得到:

这样的一条链状神经网络代表了一个递归神经网络,可以认为它是对相同神经网络的多重复制,每一时刻的神经网络会传递信息给下一时刻。如何理解它呢?假设有这样一个语言模型,我们要根据句子中已出现的词预测当前词是什么,递归神经网络的工作原理如下:

其中,$W$为各类权重,$x$表示输入,$y$表示输出,$h$表示隐层处理状态。递归神经网络因为具有一定的记忆功能,可以被用来解决很多问题,例如:语音识别、语言模型、机器翻译等。但是它并不能很好地处理长时依赖问题。

LSTM

长时依赖问题

长时依赖是这样的一个问题,当预测点与依赖的相关信息距离比较远的时候,就难以学到该相关信息。例如在句子”我出生在法国,……,我会说法语“中,若要预测末尾”法语“,我们需要用到上下文”法国“。理论上,递归神经网络是可以处理这样的问题的,但是实际上,常规的递归神经网络并不能很好地解决长时依赖,好的是LSTMs可以很好地解决这个问题。

模型

Long Short Term Mermory network(LSTM)是一种特殊的RNNs,可以很好地解决长时依赖问题。那么它与常规神经网络有什么不同?

首先我们来看RNNs具体一点的结构:

所有的递归神经网络都是由重复神经网络模块构成的一条链,可以看到它的处理层非常简单,通常是一个单tanh层,通过当前输入及上一时刻的输出来得到当前输出。与神经网络相比,经过简单地改造,它已经可以利用上一时刻学习到的信息进行当前时刻的学习了。LSTM的结构与上面相似,不同的是它的重复模块会比较复杂一点,它有四层结构:

其中,处理层出现的符号及表示意思如下:

核心思想

理解LSTMs的关键就是下面的矩形方框,被称为memory block(记忆块),主要包含了三个门(forget gate、input gate、output gate)与一个记忆单元(cell)。方框内上方的那条水平线,被称为cell state(单元状态),它就像一个传送带,可以控制信息传递给下一时刻。

这个矩形方框还可以表示为:

这两个图可以对应起来看,下图中心的$c_t$即cell,从下方输入($h_{t−1}$,$x_t$)到输出$h_t$的一条线即为$cell\ $$state$,$f_t$,$i_t$,$o_t$分别为遗忘门、输入门、输出门,用sigmoid层表示。上图中的两个$\tanh$层则分别对应cell的输入与输出。LSTM可以通过门控单元可以对cell添加和删除信息。通过门可以有选择地决定信息是否通过,它有一个sigmoid神经网络层和一个成对乘法操作组成,如下:

LSTM结构

  • 第一种是带遗忘门的Traditional LSTM

  • 带遗忘门的Peephole LSTM

实现

# 导入相应包
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

# 相关配置
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

# 导入数据
train_dataset = datasets.MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform
)
test_dataset = datasets.MNIST(
    root='./data/mnist',
    train=False,
    download=True,
    transform=transform
)

# 打乱顺序
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

class LstmModel(torch.nn.Module):
    def __init__(self,input_size,output_size,hidden_size,n_layer=2):
        super(LstmModel,self).__init__()
        self.n_layer = n_layer
        self.hidden_size = hidden_size
        self.lstm = torch.nn.LSTM(input_size,hidden_size,n_layer,batch_first=True)
        self.fc = torch.nn.Linear(hidden_size,output_size)

    def forward(self,x):
#         in_size = x.size(0)
#         x = x.view(in_size,-1)
        out,(hn,cn) = self.lstm(x)
        x = hn[-1,:,:]
        x = self.fc(x)
        return x

model = LstmModel(28,10,10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

def train(epoch):
    running_loss = 0.0
    for batch_idx,data in enumerate(train_loader,0):
        inputs,target = data
        optimizer.zero_grad()
        outputs = model(torch.squeeze(inputs,1))
        loss = criterion(outputs,target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if batch_idx % 300 == 299:
            print('[{:d},{:5d}] loss:{:.3f}'.format(epoch+1,batch_idx+1,
                                                    running_loss/300))
            running_loss = 0.0
            
            
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs,target = data
            outputs = model(torch.squeeze(inputs,1))
            _,predicted = torch.max(outputs.data,dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
    print("Accuracy on test set:{:.2%}".format(correct/total))

if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

拓展阅读

image-20210530164450997
Ht=σ(XtWxh+Ht−1Whh+bh)Ot=HtWhq+bq\begin{align} &H_t = \sigma(X_tW_{xh} + H_{t-1} W_{hh}+b_h)\\ &O_t = H_tW_{hq} + b_q \end{align}​Ht​=σ(Xt​Wxh​+Ht−1​Whh​+bh​)Ot​=Ht​Whq​+bq​​​

img
img
img
img
img
preview
preview
img
ft=σ(Wfxt+Ufht−1+bf)it=σ(Wixt+Uiht−1+bi)ot=σ(Woxt+Uoht−1+bo)ct=ft∘ct−1+it∘σc(Wcxt+Ucht−1+bc)ht=ot∘σh(ct)\begin{align} &f_t = \sigma(W_fx_t+U_fh_{t-1}+b_f) \\ &i_t = \sigma(W_ix_t+U_ih_{t-1}+b_i) \\ &o_t = \sigma(W_ox_t+U_oh_{t-1}+b_o) \\ &c_t = f_t \circ c_{t-1} + i_t \circ \sigma_c(W_cx_t+U_ch_{t-1}+b_c) \\ &h_t = o_t\circ \sigma_h(c_t) \end{align}​ft​=σ(Wf​xt​+Uf​ht−1​+bf​)it​=σ(Wi​xt​+Ui​ht−1​+bi​)ot​=σ(Wo​xt​+Uo​ht−1​+bo​)ct​=ft​∘ct−1​+it​∘σc​(Wc​xt​+Uc​ht−1​+bc​)ht​=ot​∘σh​(ct​)​​
ft=σ(Wfxt+Ufct−1+bf)it=σ(Wixt+Uict−1+bi)ot=σ(Woxt+Uoct−1+bo)ct=ft∘ct−1+it∘σc(Wcxt+bc)ht=ot∘σh(ct)\begin{align} &f_t = \sigma(W_fx_t+U_fc_{t-1}+b_f) \\ &i_t = \sigma(W_ix_t+U_ic_{t-1}+b_i) \\ &o_t = \sigma(W_ox_t+U_oc_{t-1}+b_o) \\ &c_t = f_t \circ c_{t-1} + i_t \circ \sigma_c(W_cx_t+b_c) \\ &h_t = o_t\circ \sigma_h(c_t) \end{align}​ft​=σ(Wf​xt​+Uf​ct−1​+bf​)it​=σ(Wi​xt​+Ui​ct−1​+bi​)ot​=σ(Wo​xt​+Uo​ct−1​+bo​)ct​=ft​∘ct−1​+it​∘σc​(Wc​xt​+bc​)ht​=ot​∘σh​(ct​)​​

深度学习(五) - 序列模型
深度学习之GRU网络
注意力模型/Encoder与Decoder详解
这里写图片描述