fixed device in tacotron2 inference
This commit is contained in:
parent
362dfe6b3b
commit
46758dadf8
|
@ -530,8 +530,8 @@ class Decoder(nn.Module):
|
|||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda()
|
||||
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda()
|
||||
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).to(memory.device)
|
||||
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).to(memory.device)
|
||||
mel_outputs, gate_outputs, alignments = [], [], []
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
|
Loading…
Reference in a new issue