# coding: utf-8 import argparse import time import math import torch import torch.nn as nn from torch.autograd import Variable from fp16util import * import data import model parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model') parser.add_argument('--data', type=str, default='./data/wikitext-2', help='location of the data corpus') parser.add_argument('--model', type=str, default='LSTM', help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') parser.add_argument('--emsize', type=int, default=200, help='size of word embeddings') parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer') parser.add_argument('--nlayers', type=int, default=2, help='number of layers') parser.add_argument('--lr', type=float, default=20, help='initial learning rate') parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping') parser.add_argument('--epochs', type=int, default=40, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=20, metavar='N', help='batch size') parser.add_argument('--bptt', type=int, default=35, help='sequence length') parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') parser.add_argument('--tied', action='store_true', help='tie the word embedding and softmax weights') parser.add_argument('--seed', type=int, default=1111, help='random seed') parser.add_argument('--cuda', action='store_true', help='use CUDA') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') parser.add_argument('--save', type=str, default='model.pt', help='path to save the final model') parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') parser.add_argument('--loss_scale', type=float, default=1, help='Loss scaling, positive power of 2 values can improve fp16 convergence.') args = parser.parse_args() # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") else: torch.cuda.manual_seed(args.seed) if args.fp16 and not args.cuda: print("WARNING: --fp16 requires --cuda, ignoring --fp16 option") ############################################################################### # Load data ############################################################################### corpus = data.Corpus(args.data) # Starting from sequential data, batchify arranges the dataset into columns. # For instance, with the alphabet as the sequence and batch size 4, we'd get # ┌ a g m s ┐ # │ b h n t │ # │ c i o u │ # │ d j p v │ # │ e k q w │ # └ f l r x ┘. # These columns are treated as independent by the model, which means that the # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient # batch processing. def batchify(data, bsz): # Work out how cleanly we can divide the dataset into bsz parts. nbatch = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the bsz batches. data = data.view(bsz, -1).t().contiguous() if args.cuda: data = data.cuda() return data eval_batch_size = 10 train_data = batchify(corpus.train, args.batch_size) val_data = batchify(corpus.valid, eval_batch_size) test_data = batchify(corpus.test, eval_batch_size) ############################################################################### # Build the model ############################################################################### ntokens = len(corpus.dictionary) model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied) if args.cuda and args.fp16: model.type(torch.cuda.HalfTensor) param_copy = params_to_32(clone_params(model)) elif args.cuda: model.cuda() criterion = nn.CrossEntropyLoss() ############################################################################### # Training code ############################################################################### def repackage_hidden(h): """Wraps hidden states in new Variables, to detach them from their history.""" if type(h) == Variable: return Variable(h.data) else: return tuple(repackage_hidden(v) for v in h) # get_batch subdivides the source data into chunks of length args.bptt. # If source is equal to the example output of the batchify function, with # a bptt-limit of 2, we'd get the following two Variables for i = 0: # ┌ a g m s ┐ ┌ b h n t ┐ # └ b h n t ┘ └ c i o u ┘ # Note that despite the name of the function, the subdivison of data is not # done along the batch dimension (i.e. dimension 1), since that was handled # by the batchify function. The chunks are along dimension 0, corresponding # to the seq_len dimension in the LSTM. def get_batch(source, i, evaluation=False): seq_len = min(args.bptt, len(source) - 1 - i) data = Variable(source[i:i+seq_len], volatile=evaluation) target = Variable(source[i+1:i+1+seq_len].view(-1)) return data, target def evaluate(data_source): # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(eval_batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, evaluation=True) output, hidden = model(data, hidden) output_flat = output.view(-1, ntokens) #total loss can overflow if accumulated in fp16. total_loss += len(data) * criterion(output_flat, targets).data.float() hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source) def train(): # Turn on training mode which enables dropout. model.train() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): data, targets = get_batch(train_data, i) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) model.zero_grad() output, hidden = model(data, hidden) loss = criterion(output.view(-1, ntokens), targets) loss = loss * args.loss_scale loss.backward() loss = loss / args.loss_scale # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) if args.fp16 and args.cuda: grad = params_to_32(clone_grads(model)) for i, _ in enumerate(param_copy): param_copy[i] = param_copy[i] - grad[i] * (lr/args.loss_scale) copy_in_params(model, params_to_16(param_copy)) else: for p in model.parameters(): p.data.add_(-lr/args.loss_scale, p.grad.data) total_loss += loss.data if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss[0] / args.log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, lr, elapsed * 1000 / args.log_interval, cur_loss, math.exp(min(cur_loss, 20)))) total_loss = 0 start_time = time.time() # Loop over epochs. lr = args.lr best_val_loss = None # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(1, args.epochs+1): epoch_start_time = time.time() train() val_loss = evaluate(val_data) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(min(val_loss, 20)))) print('-' * 89) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: with open(args.save, 'wb') as f: torch.save(model, f) best_val_loss = val_loss else: # Anneal the learning rate if no improvement has been seen in the validation dataset. lr /= 4.0 except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') # Load the best saved model. with open(args.save, 'rb') as f: model = torch.load(f) # Run on test data. test_loss = evaluate(test_data) print('=' * 89) print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) print('=' * 89)