lmw0320 2019-08-09 11:55 采纳率: 75%
浏览 2564
已采纳

pytorch训练LSTM模型的代码疑问

原博文链接地址:https://blog.csdn.net/Sebastien23/article/details/80574918
其中有不少代码完全看不太懂,想来这里求教下各位大神~~

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence,self).__init__()
        self.lstm1 = nn.LSTMCell(1,51)
        self.lstm2 = nn.LSTMCell(51,51)
        self.linear = nn.Linear(51,1)
        #上面三行代码是设置网络结构吧?为什么用的是LSTMCell,而不是LSTM??
    def forward(self,inputs,future= 0): 
    #这里的前向传播名称必须是forward,而不能随意更改??因为后面的模型调用过程中,并没有看到该方法的实现
        outputs = []
        h_t = torch.zeros(inputs.size(0),51)
        c_t = torch.zeros(inputs.size(0),51)
        h_t2 = torch.zeros(inputs.size(0),51)
        c_t2 = torch.zeros(inputs.size(0),51)
#下面的代码中,LSTM的原理是要求三个输入:前一层的细胞状态、隐藏层状态和当前层的数据输入。这里却只有2个输入??
        for i,input_t in enumerate(inputs.chunk(inputs.size(1),dim =1)):
            h_t,c_t = self.lstm1(input_t,(h_t,c_t))
            h_t2,c_t2 = self.lstm2(h_t,(h_t2,c_t2))
            output = self.linear(h_t2)
            outputs +=[output]            
        for i in range(future):
            h_t,c_t = self.lstm1(output,(h_t,c_t))
            h_t2,c_t2 = self.lstm2(h_t,(h_t2,c_t2))
            output = self.linear(h_t2)
            outputs +=[output]
#下面将所有的输出在第一维上相拼接,并剪除维度为2的数据??目的是什么?
        outputs = torch.stack(outputs,1).squeeze(2)
        return outputs
  • 写回答

2条回答 默认 最新

  • Bill_zhang5 2019-08-12 15:17
    关注
    def __init__(self):
            super(Sequence,self).__init__()
            self.lstm1 = nn.LSTMCell(1,51)
            self.lstm2 = nn.LSTMCell(51,51)
            self.linear = nn.Linear(51,1)
    

    应该是指单个LSTM cell具有多个hidden layer,为参数设置

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥15 angular项目错误
  • ¥20 需要帮我远程操控一下,运行一下我的那个代码,我觉得我无能为力了
  • ¥20 有偿:在ubuntu上安装arduino以及其常用库文件。
  • ¥15 请问用arcgis处理一些数据和图形,通常里面有一个根据点划泰森多边形的命令,直接划的弊端是只能执行一个完整的边界,但是我们有时候会用到需要在有很多边界内利用点来执行划泰森多边形的命令
  • ¥30 在wave2foam中执行setWaveField时遇到了如下的浮点异常问题,请问该如何解决呢?
  • ¥750 关于一道数论方面的问题,求解答!(关键词-数学方法)
  • ¥200 csgo2的viewmatrix值是否还有别的获取方式
  • ¥15 Stable Diffusion,用Ebsynth utility在视频选帧图重绘,第一步报错,蒙版和帧图没法生成,怎么处理啊
  • ¥15 请把下列每一行代码完整地读懂并注释出来
  • ¥15 寻找公式识别开发,自动识别整页文档、图像公式的软件