[FastPitch/PyT] Fix updated regulate_len
This commit is contained in:
parent
5a8521ee05
commit
8d8c524df6
|
@ -38,8 +38,9 @@ from fastpitch.transformer import FFTransformer
|
|||
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
||||
"""If target=None, then predicted durations are applied"""
|
||||
dtype = enc_out.dtype
|
||||
reps = (durations.float() / pace + 0.5)
|
||||
dec_lens = reps.sum(dim=1).long()
|
||||
reps = durations.float() / pace
|
||||
reps = (reps + 0.5).long()
|
||||
dec_lens = reps.sum(dim=1)
|
||||
|
||||
max_len = dec_lens.max()
|
||||
reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
||||
|
|
Loading…
Reference in a new issue