24.9.20学习笔记

news/2024/9/21 16:29:00 标签: 学习, 笔记, lstm

隐藏状态和细胞状态是循环神经网络(RNN)及其变种,如长短期记忆网络(LSTM)中的概念,它们在处理序列数据时扮演着重要的角色。尽管它们都与网络的记忆能力相关,但它们之间存在一些关键的区别:

  1. 隐藏状态(Hidden State)

    • 在标准的RNN中,隐藏状态是网络在每个时间步的输出,它包含了该时间步的信息以及之前所有时间步的信息的累积。
    • 隐藏状态的维度通常与RNN层的维度相同。
    • 隐藏状态在时间步之间传递,用于影响后续时间步的输出。
    • 由于标准的RNN结构简单,它很难捕捉长期依赖关系,因为随着时间的推移,信息可能会逐渐丢失。
  2. 细胞状态(Cell State)

    • 细胞状态是LSTM特有的,它是一种更持久的记忆,可以跨越多个时间步传递信息。
    • 细胞状态在LSTM内部流动,通过一系列的门控机制(遗忘门、输入门、输出门)来控制信息的流动。
    • 细胞状态的维度通常与LSTM层的维度相同,但它可以更有效地保持长期的信息,因为它不会直接受到梯度消失问题的影响。
    • 细胞状态是LSTM能够学习长期依赖关系的关键,它允许网络记住或忘记信息,而不是简单地将所有信息累积在隐藏状态中。

简而言之,隐藏状态是RNN在每个时间步的“工作记忆”,而细胞状态是LSTM的“长期记忆”。在LSTM中,隐藏状态通常用于输出和传递到下一个时间步,而细胞状态则用于在时间步之间保持和传递重要的长期信息。这种区分使得LSTM在处理长序列数据时比标准RNN更有效。


 循环神经网络(Recurrent Neural Network, RNN)是一种用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN具有内部记忆功能,能够捕捉到输入数据中的时间依赖关系。这种特性使得RNN在处理诸如自然语言处理、语音识别、时间序列预测等任务时非常有效。

import torch
import torch.nn as nn
import torch.optim as optim

# 定义 RNN 模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size  # 隐藏层的维度
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)  # RNN 层,输入维度为 input_size,隐藏层维度为 hidden_size,batch_first=True 表示输入的第一个维度是 batch_size
        self.fc = nn.Linear(hidden_size, output_size)  # 全连接层,将隐藏层的输出转换为输出维度

    def forward(self, x, hidden):
        # x: (batch_size, seq_length, input_size) - 输入数据
        # hidden: (num_layers * num_directions, batch_size, hidden_size) - 初始隐藏状态
        out, hidden = self.rnn(x, hidden)  # 前向传播通过 RNN 层
        # out: (batch_size, seq_length, hidden_size) - RNN 层的输出
        # hidden: (num_layers * num_directions, batch_size, hidden_size) - 最终的隐藏状态
        
        # 我们只取最后一个时间步的输出
        out = out[:, -1, :]  # 取出每个样本在最后一个时间步的输出
        out = self.fc(out)  # 通过全连接层将 RNN 的输出转换为最终的输出
        return out, hidden  # 返回最终的输出和隐藏状态

    def init_hidden(self, batch_size):
        # 初始化隐藏状态
        return torch.zeros(1, batch_size, self.hidden_size)  # 初始化一个全零的隐藏状态,形状为 (1, batch_size, hidden_size)

# 参数设置
input_size = 10  # 输入特征的维度
hidden_size = 20  # 隐藏层的维度
output_size = 5  # 输出的维度
batch_size = 32  # 批量大小
seq_length = 50  # 序列长度
num_epochs = 100  # 训练轮数

# 创建模型实例
model = SimpleRNN(input_size, hidden_size, output_size)  # 实例化 RNN 模型

# 随机生成输入数据和标签
inputs = torch.randn(batch_size, seq_length, input_size)  # 随机生成输入数据,形状为 (batch_size, seq_length, input_size)
labels = torch.randn(batch_size, output_size)  # 随机生成标签,形状为 (batch_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam 优化器,学习率为 0.01

# 训练循环
for epoch in range(num_epochs):
    # 初始化隐藏状态
    hidden = model.init_hidden(batch_size)  # 初始化隐藏状态
    
    # 前向传播
    outputs, hidden = model(inputs, hidden)  # 前向传播,得到模型的输出和最终的隐藏状态
    
    # 计算损失
    loss = criterion(outputs, labels)  # 计算模型输出与标签之间的均方误差损失
    
    # 反向传播和优化
    optimizer.zero_grad()  # 清除梯度,防止梯度累积
    loss.backward()  # 反向传播,计算梯度
    optimizer.step()  # 更新模型参数
    
    # 打印损失
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')  # 每 10 个轮次打印一次损失值


http://www.niftyadmin.cn/n/5669061.html

相关文章

Ubuntu 22.04 源码下载、编译

Kernel/BuildYourOwnKernel - Ubuntu Wikihttps://wiki.ubuntu.com/Kernel/BuildYourOwnKernel 一、查询当前系统内核版本 rootubuntu22:~# uname -r 5.15.0-118-generic 二、查询本地软件包数据库中的内核源码信息 rootubuntu22:~# apt search linux-source Sorting... Do…

2010-2020年全国30个省以GDP为核心的区域经济韧性数据(含原始数据+代码+结果)

2010-2020年全国30个省以GDP为核心的区域经济韧性数据(含原始数据代码结果) 1、时间:2010-2022年 2、来源:统计年鉴、各省年鉴、国家统计局 3、指标:地区生产总值 4、范围:30省 5、参考文献: 数字经济及其内部耦…

错题集锦之C语言

直接寻址和立即寻址 算法的又穷性是指算法程序的运行时间是有限的 未经赋值的全局变量值不确定 集成测试是为了发现概要设计的错误 自然连接要求两个关系中进行比较的是相同的属性,并且进行等值连接,在结果中还要把重复的属性列去掉 赋值运算符 赋值…

【STM32 Blue Pill编程实例】-手机通过HC-05串口蓝牙控制LED

手机通过HC-05串口蓝牙控制LED 文章目录 手机通过HC-05串口蓝牙控制LED1、HC-05串口蓝牙模块介绍2、硬件准备和接线3、模块配置4、代码实现5、手机控制在本文中,我们介绍如何使用 STM32CubeIDE 和 HAL 库将 HC-05 蓝牙模块与 STM32 Blue Pill 开发板连接。 我们将使用 Android…

论文阅读 - MDFEND: Multi-domain Fake News Detection

https://arxiv.org/pdf/2201.00987 目录 ABSTRACT INTRODUCTION 2 RELATED WORK 3 WEIBO21: A NEW DATASET FOR MFND 3.1 Data Collection 3.2 Domain Annotation 4 MDFEND: MULTI-DOMAIN FAKE NEWS DETECTION MODEL 4.1 Representation Extraction 4.2 Domain Gate 4.…

数字电子技术-编码器

目录 编码器概述 二进制编码器 二-十进制编码器 优先编码器(Priority Encoder) 8线-3线优先编码器74LS148 优先编码器74LS148功能表 编码器概述 编码:用文字、符号或数字表示特定对象的过程。在数字电路中,采用二进制进行编码。 编码器(Encoder)…

STM32如何修改外部晶振频率和主频

对于STM32F10x系列的单片机,除了STM32F10x_CL单片机,其它的单片机一般外部晶振HSE的时钟频率都默认是8MHz。如果我们使用的外部晶振为12Mhz,那么可以把上图绿色标记改为:12000000 72MHz的主频8MHz的外部晶振HSE*倍频系数9。当然如果像上面把外…

C++ prime plus-2-编程练习

复习题&#xff1a; 1.它们叫作函数。 2.这将导致在最终的编译之前&#xff0c;使用iostream 文件的内容替换该编译指令。 3.它使得程序可以使用 std 名称空间中的定义。 4.cout << "Hello&#xff0c;world\n"; 或cout<<"Hello&#xff0c;wor…