(1) reduce the validation loss within a epoch, (2) convert global-batch-based iteartion counts to micro-batch-based (#3055)

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
Sangkug Lym 2021-10-26 11:05:27 -07:00 committed by GitHub
parent 0d0de9b367
commit e3ccd0a90d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -137,6 +137,8 @@ class MegatronGPTModel(NLPModel):
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
output_tensor = self(tokens, position_ids, attention_mask, labels)
loss = self.loss_func(loss_mask, output_tensor)
reduced_loss = average_losses_across_data_parallel_group([loss])
# TODO: add text generation
# take the first k tokens and then generate text - compare with ground truth (won't look similar in general)
"""
@ -149,11 +151,11 @@ class MegatronGPTModel(NLPModel):
context_tokens.append(next_token)
k += 1
"""
return loss
return reduced_loss
def validation_epoch_end(self, outputs):
averaged_loss = average_losses_across_data_parallel_group(outputs)
self.log('val_loss', averaged_loss[0], prog_bar=True)
averaged_loss = torch.stack(outputs).mean()
self.log('val_loss', averaged_loss, prog_bar=True)
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step))
def test_step(self, batch, batch_idx):
@ -198,10 +200,13 @@ class MegatronGPTModel(NLPModel):
def build_train_valid_test_datasets(self):
logging.info('Building GPT datasets.')
global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size
eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
# Compute trianing micro-batch steps: total_global_batch_steps x grad_acumms_per_global_batch
max_train_steps = self.trainer.max_steps * self.trainer.accumulate_grad_batches
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
train_valid_test_num_samples = [
self.trainer.max_steps * global_batch_size,
max_train_steps * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size,
]