DeepLearningExamples/Tools/PyTorch/TimeSeriesPredictionPlatform/criterion.py
2021-11-08 14:08:58 -08:00

77 lines
1.9 KiB
Python
Executable file

# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
def MAE_wrapper(config):
return nn.L1Loss()
def MSE_wrapper(config):
return nn.MSELoss()
def quantile_wrapper(config):
return QuantileLoss(config)
# assumed for anomaly detection task
def cross_entropy_wrapper(config):
return nn.CrossEntropyLoss()
def huber_wrapper(config):
return nn.SmoothL1Loss()
def GLL_wrapper(config):
return GaussianLogLikelihood()
class QuantileLoss(nn.Module):
def __init__(self, config):
super().__init__()
self.quantiles = list(config.model.quantiles)
def forward(self, predictions, targets):
losses = []
for i, q in enumerate(self.quantiles):
diff = predictions[..., i] - targets[..., 0]
loss = ((1 - q) * F.relu(diff) + q * F.relu(-diff)).mean()
losses.append(loss)
losses = torch.stack(losses)
return losses
class WeightedL1Loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs, targets, weights):
x = F.l1_loss(inputs, targets, reduction="none")
x = x * weights # broadcasted along 0th dimension
x = x.mean()
return x
class GaussianLogLikelihood(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs, targets):
# Inputs with shape [BS, window, 2] (mean + std)
# Targets with shape [BS, window, 1]
zero_index = targets[..., 0:1] != 0
mu = inputs[..., 0:1]
sigma = inputs[..., 1:2]
distribution = torch.distributions.normal.Normal(mu[zero_index], sigma[zero_index])
likelihood = distribution.log_prob(targets[zero_index])
likelihood = -likelihood.view(inputs.shape[0], inputs.shape[1])
loss = likelihood.mean(0).mean()
return loss