机器翻译之创建Seq2Seq的编码器、解码器

news/2024/9/21 16:33:54 标签: python, 深度学习, pytorch, lstm, 人工智能, rnn, 算法

1.创建编码器、解码器的基类

1.1创建编码器的基类

python">from torch import nn


#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Module
    def __init__(self, **kwargs):   #**kwargs:不定常的关键字参数
        super().__init__(**kwargs)
        
    def forward(self, X, *args):  #*args:不定常的位置参数
        #若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错
        raise  NotImplementedError          

1.2创建解码器的基类

python">#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    #初始化state
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError
    
    #前向传播,解码器比编码器多传入一个state
    def forward(self, X, state):
        raise NotImplementedError

 1.3合并编码器和解码器的基类

python">class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, enc_X, dec_X, *args):
        """
        enc_X:编码器需传入的数据
        dec_X:解码器需传入的数据
        """
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

python">import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

python">class Seq2SeqEncoder(Encoder):  #继承父类Encoder
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        """
        vocab_size:词汇表大小
        embed_size:嵌入层大小
        num_hiddens:隐藏层的神经元数量
        num_layers:隐藏层的层数
        dropout=0 : 默认所有的神经元参与计算
        """
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        
    def forward(self, X, *args):
        #在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X) 
        #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  
        #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)
        
        #此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入state
        outputs, state = self.rnn(X)
        #outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
python">#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)

print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

python">class Seq2SeqDecoder(Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        #初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #初始化神经网络层
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        #初始化输出层
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    #定义函数:获取状态state
    def init_state(self, enc_outputs, *args):
        #编码器输出的结果有两个,第二个为state
        return enc_outputs[1]
    
    #前向传播
    def forward(self, X, state):
        #X的原始shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)
       
        # 把X和state拼接到一起. 方便计算. 
        # X现在的形状(num_steps, batch_size, embed_size) , 
        # state的形状(batch_size, num_hiddens)
        # 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)
        context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变
        X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并
        #此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)
        #神经网络层
        outputs, state = self.rnn(X_and_context, state)
        #输出层
        outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来
        #outputs的shape=(batch_size, num_steps, vocab_size)
        #state的shape=(num_layers, batch_size, num_hiddens)
        return outputs, state
python">#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 


http://www.niftyadmin.cn/n/5669070.html

相关文章

uniapp使用uview2上传图片功能

官网地址Upload 上传 | uView 2.0 - 全面兼容 nvue 的 uni-app 生态框架 - uni-app UI 框架 前提&#xff0c;需要下载vuew2插件 <view class"upload"><view class"u-demo-block__content"><view class"u-page__upload-item"&…

锐尔15注册机 锐尔文档扫描影像处理系统15功能介绍

锐尔文档扫描影像处理系统是一款全中文操作界面的文件、档案扫描及影像优化处理软件&#xff0c;是目前国内档案数字化行业里专业且优秀的影像优化处理软件。 无论是从纸质文件制作高质量的影像文件&#xff0c;或是检查已经制作好的影像文件&#xff0c;锐尔文档扫描影像处理…

2024 “华为杯” 中国研究生数学建模竞赛(D题)深度剖析|大数据驱动的地理综合问题|数学建模完整代码+建模过程全解全析

当大家面临着复杂的数学建模问题时&#xff0c;你是否曾经感到茫然无措&#xff1f;作为2022年美国大学生数学建模比赛的O奖得主&#xff0c;我为大家提供了一套优秀的解题思路&#xff0c;让你轻松应对各种难题&#xff01; CS团队倾注了大量时间和心血&#xff0c;深入挖掘解…

英语<数词>

1.基数 one two three 整数 1 2 3 小数 1.1 2.2 3.2 分数 分子用基数&#xff0c;分母用序数 例子 1/3 one third 分子>1 2/3 two thirds 百分数 2.序数 first second

LeetCode118:杨辉三角

题目链接&#xff1a;118. 杨辉三角 - 力扣&#xff08;LeetCode&#xff09; 代码如下 class Solution {public:vector<vector<int>> generate(int numRows) {vector<vector<int>> dp(numRows);vector<int> temp(numRows);for (int i 0; i &…

如何设置 Django 错误邮件通知 ?

Django 是一个强大的 web 框架&#xff0c;非常适合那些想要完美快速完成任务的人。它有许多内置的工具和特性&#xff0c;一个有用的特性是 Django 可以在出现错误时发送电子邮件提醒。这对开发人员和管理员非常有用&#xff0c;因为如果出现问题&#xff0c;他们会立即得到通…

【C#生态园】虚拟现实与增强现实:C#开发库全面评估

C#编程与虚拟现实&#xff1a;六大库全面解析 前言 随着虚拟现实&#xff08;VR&#xff09;和增强现实&#xff08;AR&#xff09;技术的不断发展&#xff0c;C#编程语言在这一领域的应用也愈发广泛。本文将探讨几种用于C#开发的虚拟现实和增强现实库&#xff0c;以及它们的…

C#-日志系统

文章速览 Log 全局变量创建实例 坚持记录实属不易&#xff0c;希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区&#xff01; 谢谢~ Log 全局变量 /// <summary>/// 日志系统/// </summary>public ILogger Log { get; private set; }创建实…