(1) reduce the validation loss within a epoch, (2) convert global-batch-based iteartion counts to micro-batch-based (#3055)
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
parent
0d0de9b367
commit
e3ccd0a90d
|
@ -137,6 +137,8 @@ class MegatronGPTModel(NLPModel):
|
|||
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
|
||||
output_tensor = self(tokens, position_ids, attention_mask, labels)
|
||||
loss = self.loss_func(loss_mask, output_tensor)
|
||||
reduced_loss = average_losses_across_data_parallel_group([loss])
|
||||
|
||||
# TODO: add text generation
|
||||
# take the first k tokens and then generate text - compare with ground truth (won't look similar in general)
|
||||
"""
|
||||
|
@ -149,11 +151,11 @@ class MegatronGPTModel(NLPModel):
|
|||
context_tokens.append(next_token)
|
||||
k += 1
|
||||
"""
|
||||
return loss
|
||||
return reduced_loss
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
averaged_loss = average_losses_across_data_parallel_group(outputs)
|
||||
self.log('val_loss', averaged_loss[0], prog_bar=True)
|
||||
averaged_loss = torch.stack(outputs).mean()
|
||||
self.log('val_loss', averaged_loss, prog_bar=True)
|
||||
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step))
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
|
@ -198,10 +200,13 @@ class MegatronGPTModel(NLPModel):
|
|||
def build_train_valid_test_datasets(self):
|
||||
logging.info('Building GPT datasets.')
|
||||
global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size
|
||||
eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
|
||||
# Compute trianing micro-batch steps: total_global_batch_steps x grad_acumms_per_global_batch
|
||||
max_train_steps = self.trainer.max_steps * self.trainer.accumulate_grad_batches
|
||||
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
|
||||
test_iters = self.trainer.limit_test_batches
|
||||
|
||||
train_valid_test_num_samples = [
|
||||
self.trainer.max_steps * global_batch_size,
|
||||
max_train_steps * global_batch_size,
|
||||
eval_iters * global_batch_size,
|
||||
test_iters * global_batch_size,
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue