Add pytorch examples

This commit is contained in:
Christian Sarofeen 2018-05-02 18:33:56 -07:00
parent 2fdaec9c85
commit 76a3d10a19
17 changed files with 45787 additions and 0 deletions

View file

@ -0,0 +1,25 @@
# Basic Multirpocess Example based on the MNIST example
This example is based on [PyTorch's MNIST Example](https://github.com/pytorch/examples/tree/master/mnist).
This example demonstrates how to modify a network to use a basic but effective distributed data parallel module. This parallel method is designed to easily run multi-gpu runs on a single node. It was created as current parallel methods integraded into pytorch can induce significant overhead due to python GIL lock. This method will reduce the influence of those overheads and potentially provide a benefit in performance, especially for networks with a significant number of fast running operations.
## Getting started
Prior to running please run
```pip install -r requirements.txt```
and start a single process run to allow the dataset to be downloaded (This will not work properly in multi-gpu. You can stop this job as soon as it starts iterating.).
```python main.py```
You can now the code multi-gpu with
```python -m multiproc main.py ...```
adding any normal option you'd like.
## Converting your own model
To understand how to convert your own model to use the distributed module included, please see all sections of main.py within ```#=====START: ADDED FOR DISTRIBUTED======``` and ```#=====END: ADDED FOR DISTRIBUTED======``` flags.
Copy the distributed.py and multiproc.py files from here to your local workspace.
## Requirements
Pytorch master branch built from source. This requirement is to use NCCL as a distributed backend.

View file

@ -0,0 +1,183 @@
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
'''
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
launcher included with this example. It assumes that your run is using multiprocess with 1
GPU/process, that the model is on the correct device, and that torch.set_device has been
used to set the device.
Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
and will be allreduced at the finish of the backward pass.
'''
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = {}
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
if flat_dist_call.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
flat_dist_call.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
class DistributedDataParallel(Module):
def __init__(self, module, message_size=10000000):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL:
for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.record = []
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def create_hooks(self):
#all reduce gradient hook
def allreduce_params():
if(self.needs_reduction):
self.needs_reduction = False
self.needs_refresh = False
else:
return
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
def flush_buckets():
if not self.needs_reduction:
return
self.needs_reduction = False
ready = []
for i in range(len(self.param_state)):
if self.param_state[i] == 1:
param = self.param_list[self.record[i]]
if param.grad is not None:
ready.append(param.grad.data)
if(len(ready)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(ready, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())):
def wrapper(param_i):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.param_state[self.record.index(param_i)] = 1
self.comm_ready_buckets()
if param.requires_grad:
param.register_hook(allreduce_hook)
wrapper(param_i)
def comm_ready_buckets(self):
ready = []
counter = 0
while counter < len(self.param_state) and self.param_state[counter] == 2:
counter += 1
while counter < len(self.param_state) and self.param_state[counter] == 1:
ready.append(counter)
counter += 1
if not ready:
return
grads = []
for ind in ready:
param_ind = self.record[ind]
if self.param_list[param_ind].grad is not None:
grads.append(self.param_list[param_ind].grad.data)
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
bucket_inds.append(ready.pop(0))
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size:
continue
evt = torch.cuda.Event()
evt.record(torch.cuda.current_stream())
evt.wait(stream=self.reduction_stream)
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for ind in bucket_inds:
self.param_state[ind] = 2
def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
)
if self.needs_refresh:
self.record = []
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
return self.module(*inputs, **kwargs)

196
PyTorch/Distributed/main.py Normal file
View file

@ -0,0 +1,196 @@
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
#=====START: ADDED FOR DISTRIBUTED======
'''Add custom module for distributed'''
from distributed import DistributedDataParallel as DDP
'''Import distributed data loader'''
import torch.utils.data
import torch.utils.data.distributed
'''Import torch.distributed'''
import torch.distributed as dist
#=====END: ADDED FOR DISTRIBUTED======
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
#======START: ADDED FOR DISTRIBUTED======
'''
Add some distributed options. For explanation of dist-url and dist-backend please see
http://pytorch.org/tutorials/intermediate/dist_tuto.html
--world-size and --rank are required parameters as they will be used by the multiproc.py launcher
but do not have to be set explicitly.
'''
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
#=====END: ADDED FOR DISTRIBUTED======
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
#======START: ADDED FOR DISTRIBUTED======
'''Add a convenience flag to see if we are running distributed'''
args.distributed = args.world_size > 1
'''Check that we are running with cuda, as distributed is only supported for cuda.'''
if args.distributed:
assert args.cuda, "Distributed mode requires running with CUDA."
if args.distributed:
'''
Set cuda device so everything is done on the right GPU.
THIS MUST BE DONE AS SOON AS POSSIBLE.
'''
torch.cuda.set_device(args.rank % torch.cuda.device_count())
'''Initialize distributed communication'''
dist.init_process_group(args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
#=====END: ADDED FOR DISTRIBUTED======
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
#=====START: ADDED FOR DISTRIBUTED======
'''
Change sampler to distributed if running distributed.
Shuffle data loader only if distributed.
'''
train_dataset = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, sampler=train_sampler,
batch_size=args.batch_size, shuffle=(train_sampler is None), **kwargs
)
#=====END: ADDED FOR DISTRIBUTED======
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = Net()
if args.cuda:
model.cuda()
#=====START: ADDED FOR DISTRIBUTED======
'''
Wrap model in our version of DistributedDataParallel.
This must be done AFTER the model is converted to cuda.
'''
if args.distributed:
model = DDP(model)
#=====END: ADDED FOR DISTRIBUTED======
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
def test():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
for epoch in range(1, args.epochs + 1):
train(epoch)
test()

View file

@ -0,0 +1,28 @@
import torch
import sys
import subprocess
argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()
if '--world-size' in argslist:
argslist[argslist.index('--world-size')+1] = str(world_size)
else:
argslist.append('--world-size')
argslist.append(str(world_size))
workers = []
for i in range(world_size):
if '--rank' in argslist:
argslist[argslist.index('--rank')+1] = str(i)
else:
argslist.append('--rank')
argslist.append(str(i))
stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w")
print(argslist)
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
workers.append(p)
for p in workers:
p.wait()

View file

@ -0,0 +1,2 @@
torch
torchvision

View file

@ -0,0 +1 @@
python -m multiproc main.py

View file

@ -0,0 +1,58 @@
# Word-level language modeling RNN
This example is based on [PyTorch's Word-level language modeling RNN Example](https://github.com/pytorch/examples/tree/master/word_language_model).
This example trains a multi-layer RNN (Elman, GRU, or LSTM) on a language modeling task.
By default, the training script uses the Wikitext-2 dataset, provided.
The trained model can then be used by the generate script to generate new text.
```bash
python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA, reaching perplexity of 117.61
python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA, reaching perplexity of 110.44
python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs, reaching perplexity of 87.17
python generate.py # Generate samples from the trained LSTM model.
```
The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`)
which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.
During training, if a keyboard interrupt (Ctrl-C) is received,
training is stopped and the current model is evaluated against the test dataset.
The `main.py` script accepts the following arguments:
```bash
optional arguments:
-h, --help show this help message and exit
--data DATA location of the data corpus
--model MODEL type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)
--emsize EMSIZE size of word embeddings
--nhid NHID number of hidden units per layer
--nlayers NLAYERS number of layers
--lr LR initial learning rate
--clip CLIP gradient clipping
--epochs EPOCHS upper epoch limit
--batch-size N batch size
--bptt BPTT sequence length
--dropout DROPOUT dropout applied to layers (0 = no dropout)
--decay DECAY learning rate decay per epoch
--tied tie the word embedding and softmax weights
--seed SEED random seed
--cuda use CUDA
--log-interval N report interval
--save SAVE path to save the final model
```
With these arguments, a variety of models can be tested.
As an example, the following arguments produce slower but better models:
```bash
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 # Test perplexity of 80.97
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied # Test perplexity of 75.96
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 # Test perplexity of 77.42
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied # Test perplexity of 72.30
```
Perplexities on PTB are equal or better than
[Recurrent Neural Network Regularization (Zaremba et al. 2014)](https://arxiv.org/pdf/1409.2329.pdf)
and are similar to [Using the Output Embedding to Improve Language Models (Press & Wolf 2016](https://arxiv.org/abs/1608.05859) and [Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling (Inan et al. 2016)](https://arxiv.org/pdf/1611.01462.pdf), though both of these papers have improved perplexities by using a form of recurrent dropout [(variational dropout)](http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks).

View file

@ -0,0 +1,49 @@
import os
import torch
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
return ids

View file

@ -0,0 +1,3 @@
This is raw data from the wikitext-2 dataset.
See https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,36 @@
import torch
def params_to_type(params, totype):
new_params = []
for param in params:
new_params.append(param.type(totype))
return new_params
def params_to_16(params):
return params_to_type(params, torch.cuda.HalfTensor)
def params_to_32(params):
return params_to_type(params, torch.cuda.FloatTensor)
def clone_params(net):
new_params = []
for param in list(net.parameters()):
new_params.append(param.data.clone())
return new_params
def clone_grads(net):
new_params = []
for param in list(net.parameters()):
new_params.append(param.grad.data.clone())
return new_params
def copy_in_params(net, params):
net_params = list(net.parameters())
for i in range(len(params)):
net_params[i].data.copy_(params[i])

View file

@ -0,0 +1,74 @@
###############################################################################
# Language Modeling on Penn Tree Bank
#
# This file generates new sentences sampled from the language model
#
###############################################################################
import argparse
import torch
from torch.autograd import Variable
import data
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model')
# Model parameters.
parser.add_argument('--data', type=str, default='./data/wikitext-2',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
help='model checkpoint to use')
parser.add_argument('--outf', type=str, default='generated.txt',
help='output file for generated text')
parser.add_argument('--words', type=int, default='1000',
help='number of words to generate')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature - higher will increase diversity')
parser.add_argument('--log-interval', type=int, default=100,
help='reporting interval')
args = parser.parse_args()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
else:
torch.cuda.manual_seed(args.seed)
if args.temperature < 1e-3:
parser.error("--temperature has to be greater or equal 1e-3")
with open(args.checkpoint, 'rb') as f:
model = torch.load(f)
model.eval()
if args.cuda:
model.cuda()
else:
model.cpu()
corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
if args.cuda:
input.data = input.data.cuda()
with open(args.outf, 'w') as outf:
for i in range(args.words):
output, hidden = model(input, hidden)
word_weights = output.squeeze().data.div(args.temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]
outf.write(word + ('\n' if i % 20 == 19 else ' '))
if i % args.log_interval == 0:
print('| Generated {}/{} words'.format(i, args.words))

View file

@ -0,0 +1,236 @@
# coding: utf-8
import argparse
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from fp16util import *
import data
import model
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)')
parser.add_argument('--emsize', type=int, default=200,
help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=200,
help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2,
help='number of layers')
parser.add_argument('--lr', type=float, default=20,
help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
help='gradient clipping')
parser.add_argument('--epochs', type=int, default=40,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=20, metavar='N',
help='batch size')
parser.add_argument('--bptt', type=int, default=35,
help='sequence length')
parser.add_argument('--dropout', type=float, default=0.2,
help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--tied', action='store_true',
help='tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
help='path to save the final model')
parser.add_argument('--fp16', action='store_true',
help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).')
parser.add_argument('--loss_scale', type=float, default=1,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
args = parser.parse_args()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
else:
torch.cuda.manual_seed(args.seed)
if args.fp16 and not args.cuda:
print("WARNING: --fp16 requires --cuda, ignoring --fp16 option")
###############################################################################
# Load data
###############################################################################
corpus = data.Corpus(args.data)
# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
if args.cuda:
data = data.cuda()
return data
eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)
###############################################################################
# Build the model
###############################################################################
ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied)
if args.cuda and args.fp16:
model.type(torch.cuda.HalfTensor)
param_copy = params_to_32(clone_params(model))
elif args.cuda:
model.cuda()
criterion = nn.CrossEntropyLoss()
###############################################################################
# Training code
###############################################################################
def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)
# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target
def evaluate(data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
#total loss can overflow if accumulated in fp16.
total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
def train():
# Turn on training mode which enables dropout.
model.train()
total_loss = 0
start_time = time.time()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(args.batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
hidden = repackage_hidden(hidden)
model.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss = loss * args.loss_scale
loss.backward()
loss = loss / args.loss_scale
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
if args.fp16 and args.cuda:
grad = params_to_32(clone_grads(model))
for i, _ in enumerate(param_copy):
param_copy[i] = param_copy[i] - grad[i] * (lr/args.loss_scale)
copy_in_params(model, params_to_16(param_copy))
else:
for p in model.parameters():
p.data.add_(-lr/args.loss_scale, p.grad.data)
total_loss += loss.data
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(min(cur_loss, 20))))
total_loss = 0
start_time = time.time()
# Loop over epochs.
lr = args.lr
best_val_loss = None
# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train()
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(min(val_loss, 20))))
print('-' * 89)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4.0
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)

View file

@ -0,0 +1,59 @@
import torch.nn as nn
from torch.autograd import Variable
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError("""An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

View file

@ -0,0 +1 @@
torch