郭震 AI公众号:郭震AI

3 LSTM原理解析

发布日期:

最近更新:

分类: 30个神经网络

预计阅读: 4 分钟

阅读次数: 0

系列进度

AI 30 个神经网络 · 第 3 / 62

预计阅读4 分钟
结构重点6 个
图文要点6 张
正文规模1.8k 字
LSTM原理解析结构图查看大图
LSTM原理解析结构图

LSTM 的重点不在名字,而在门控如何筛掉旧信息、写入新信息,再把当前状态交给下一步。读这类文章时,把每个时间步画出来,比只看公式更容易理解。这篇重点看结构。先把数据流、关键模块和输出层画清楚,再回头看公式或代码。

LSTM原理解析实操核对图查看大图
LSTM原理解析实操核对图

我会检查输入维度、序列长度、hidden size 和最后取哪一个时间步。四项说清楚,LSTM 代码才不容易跑偏。

在上一篇中,我们讨论了LSTM的应用场景,包括自然语言处理、序列预测和时间序列分析等。接下来,我们将深入解析LSTM的原理,为实际的代码实现做准备。

LSTM简介

LSTM(长短期记忆网络)是一种特殊的递归神经网络(RNN),它在处理和预测序列数据时克服了传统RNN的梯度消失和爆炸问题。LSTM通过引入一个新的结构单元,即“细胞状态”,能够有效地记住长期依赖信息。

LSTM的结构

LSTM的核心是一个特殊的单元,包括三个主要的门控机制:输入门、遗忘门和输出门。以下是这些门的描述:

  1. 遗忘门(Forget Gate):决定从细胞状态中丢弃多少信息。其计算公式为:

    ft=σ(Wf[ht1,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)

    其中,ht1h_{t-1}是上一时刻的隐藏状态,xtx_t是当前时刻的输入,WfW_fbfb_f分别是权重和偏置,σ\sigma是sigmoid激活函数。

  2. 输入门(Input Gate):决定多少新信息被存储在细胞状态中。其计算公式为:

    it=σ(Wi[ht1,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)

    生成新候选值(通过tanh激活):

    C~t=tanh(WC[ht1,xt]+bC)\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)
  3. 输出门(Output Gate):决定从细胞状态中输出多少信息作为当前时刻的隐藏状态。计算公式为:

ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)

然后,当前的细胞状态CtC_t和输出hth_t的计算方式为:

Ct=ftCt1+itC~tC_t = f_t * C_{t-1} + i_t * \tilde{C}_t ht=ottanh(Ct)h_t = o_t * \tanh(C_t)

以上四个公式描述了LSTM的基本运作机制。细胞状态CtC_t被更新并决定了网络能够在多大程度上遗忘或记住信息。

LSTM的工作原理

在实际操作中,LSTM通过不断地接收输入并更新内部状态,从而在长序列中保持信息。具体地说,在时间步tt,LSTM根据之前的隐藏状态ht1h_{t-1}和当前输入xtx_t,计算出新的输出hth_t和更新后的细胞状态CtC_t

在自然语言处理的情境下,LSTM特别适合处理长文本,因为它能够有效捕捉到上下文的依赖性。例如,在句子生成任务中,LSTM会根据上下文信息生成连贯的文本。

案例:时间序列预测

为了更直观地理解LSTM的工作原理,我们考虑一个时间序列预测的案例,比如股价预测。假设我们要预测未来几天的股价,可以通过历史股价数据作为输入。

LSTM序列记忆判断卡查看大图
LSTM序列记忆判断卡

理解 LSTM 时,可以把输入、遗忘、输出和隐藏状态连成一条时间线。它的价值不在名称,而在于让序列里的重要信息保留下来。

在模型实现中,输入数据保持在时间序列的格式,LSTM就能够发现股价变化的趋势并做出有效预测。通过不断训练,LSTM可以捕捉到不同时间步之间的关系,从而提高预测的准确性。

伪代码展示

以下是一个伪代码,展示了如何用LSTM进行时间序列预测。

神经网络应用拆解卡查看大图
神经网络应用拆解卡

阅读《LSTM原理解析》前,可以先用配图确认主线;读完后再检查哪些步骤能直接操作,哪些还需要补资料。

# 假设我们的输入数据已经准备好
input_data = prepare_data(time_series)

# 创建LSTM模型
model = LSTM(units=50, return_sequences=True, input_shape=(timesteps, features))

# 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')

# 训练模型
model.fit(input_data, target_data, epochs=50, batch_size=32)

# 进行预测
predicted_prices = model.predict(new_data)

在上述伪代码中,我们首先准备好时间序列数据,然后构建LSTM模型。通过指定单元数量和输入形状,我们搭建一个适合的LSTM网络。在训练模型后,我们可以使用新的数据进行股价预测。

LSTM原理解析应用复盘卡查看大图
LSTM原理解析应用复盘卡

如果《LSTM原理解析》还没完全消化,可以从这张卡片的四个动作重新走一遍。

LSTM原理解析应用检查卡查看大图
LSTM原理解析应用检查卡

回看《LSTM原理解析》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。

总结

LSTM凭借其独特的结构和门控机制,成功解决了长序列数据中的长期依赖问题。通过理解LSTM的原理和内部结构,我们能够在各种时间序列任务和自然语言处理任务中有效应用LSTM。下一篇中,我们将继续深入,探索LSTM的代码实现。

相关教程

相关入口

AI 教程总索引

分享文章

转发到常用平台

微信/朋友圈可先复制链接

相关教程

AI 教程总索引

相关内容

相关 AI 教程

返回栏目

Reader Messages

读者留言

有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

最多 800 字

为了防刷,每条留言会做长度、链接数量和提交频率限制。

0/800

留言列表

0
正在加载留言...