diff --git a/PyTorch/LanguageModeling/BERT/run_pretraining.py b/PyTorch/LanguageModeling/BERT/run_pretraining.py index a3577886..c036b773 100755 --- a/PyTorch/LanguageModeling/BERT/run_pretraining.py +++ b/PyTorch/LanguageModeling/BERT/run_pretraining.py @@ -26,6 +26,7 @@ import os import time import argparse import random +import pickle import h5py from tqdm import tqdm, trange import os @@ -649,7 +650,9 @@ def main(): 'master params': list(amp.master_params(optimizer)), 'files': [f_id] + files, 'epoch': epoch, - 'data_loader': None if global_step >= args.max_steps else train_dataloader}, output_save_file) + 'data_loader': None if global_step >= args.max_steps else train_dataloader}, + output_save_file, + pickle_protocol=pickle.HIGHEST_PROTOCOL) most_recent_ckpts_paths.append(output_save_file) if len(most_recent_ckpts_paths) > 3: