[FastPitch/PyT] Fix updated regulate_len

This commit is contained in:
Adrian Lancucki 2021-05-21 11:35:47 +02:00 committed by GitHub
parent 5a8521ee05
commit 8d8c524df6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -38,8 +38,9 @@ from fastpitch.transformer import FFTransformer
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
"""If target=None, then predicted durations are applied"""
dtype = enc_out.dtype
reps = (durations.float() / pace + 0.5)
dec_lens = reps.sum(dim=1).long()
reps = durations.float() / pace
reps = (reps + 0.5).long()
dec_lens = reps.sum(dim=1)
max_len = dec_lens.max()
reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]