fixed device in tacotron2 inference

This commit is contained in:
gkarch 2020-04-21 22:07:11 +02:00
parent 362dfe6b3b
commit 46758dadf8

View file

@ -530,8 +530,8 @@ class Decoder(nn.Module):
self.initialize_decoder_states(memory, mask=None)
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda()
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda()
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).to(memory.device)
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).to(memory.device)
mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)