📅  最后修改于: 2022-03-11 14:44:58.233000             🧑  作者: Mango
import torch
import torch.nn as nn
import torch.nn.functional as F
class BidirectionalLSTM(nn.Module):
def __init__(self, n_in, n_hidden, n_out):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(n_in, n_hidden, bidirectional=True)
self.embedding = nn.Linear(n_hidden * 2, n_out)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output