# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Dict, Tuple, Optional, List if os.environ.get("TFT_SCRIPTING", False): from torch.nn import LayerNorm else: from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm class MaybeLayerNorm(nn.Module): def __init__(self, output_size, hidden_size, eps): super().__init__() if output_size and output_size == 1: self.ln = nn.Identity() else: self.ln = LayerNorm(output_size if output_size else hidden_size, eps=eps) def forward(self, x): return self.ln(x) class GLU(nn.Module): def __init__(self, hidden_size, output_size): super().__init__() self.lin = nn.Linear(hidden_size, output_size * 2) def forward(self, x: Tensor) -> Tensor: x = self.lin(x) x = F.glu(x) return x class GRN(nn.Module): def __init__(self, input_size, hidden_size, output_size=None, context_hidden_size=None, dropout=0): super().__init__() self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3) self.lin_a = nn.Linear(input_size, hidden_size) if context_hidden_size is not None: self.lin_c = nn.Linear(context_hidden_size, hidden_size, bias=False) self.lin_i = nn.Linear(hidden_size, hidden_size) self.glu = GLU(hidden_size, output_size if output_size else hidden_size) self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(input_size, output_size) if output_size else None def forward(self, a: Tensor, c: Optional[Tensor] = None): x = self.lin_a(a) if c is not None: x = x + self.lin_c(c).unsqueeze(1) x = F.elu(x) x = self.lin_i(x) x = self.dropout(x) x = self.glu(x) y = a if not self.out_proj else self.out_proj(a) x = x + y x = self.layer_norm(x) return x class TFTEmbedding(nn.Module): def __init__(self, config): super().__init__() self.s_cat_inp_lens = config.static_categorical_inp_lens self.t_cat_k_inp_lens = config.temporal_known_categorical_inp_lens self.t_cat_o_inp_lens = config.temporal_observed_categorical_inp_lens self.s_cont_inp_size = config.static_continuous_inp_size self.t_cont_k_inp_size = config.temporal_known_continuous_inp_size self.t_cont_o_inp_size = config.temporal_observed_continuous_inp_size self.t_tgt_size = config.temporal_target_size self.hidden_size = config.hidden_size # There are 7 types of input: # 1. Static categorical # 2. Static continuous # 3. Temporal known a priori categorical # 4. Temporal known a priori continuous # 5. Temporal observed categorical # 6. Temporal observed continuous # 7. Temporal observed targets (time series obseved so far) self.s_cat_embed = nn.ModuleList([ nn.Embedding(n, self.hidden_size) for n in self.s_cat_inp_lens]) if self.s_cat_inp_lens else None self.t_cat_k_embed = nn.ModuleList([ nn.Embedding(n, self.hidden_size) for n in self.t_cat_k_inp_lens]) if self.t_cat_k_inp_lens else None self.t_cat_o_embed = nn.ModuleList([ nn.Embedding(n, self.hidden_size) for n in self.t_cat_o_inp_lens]) if self.t_cat_o_inp_lens else None self.s_cont_embedding_vectors = nn.Parameter(torch.Tensor(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None self.t_cont_k_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None self.t_cont_o_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None self.t_tgt_embedding_vectors = nn.Parameter(torch.Tensor(self.t_tgt_size, self.hidden_size)) self.s_cont_embedding_bias = nn.Parameter(torch.zeros(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None self.t_cont_k_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None self.t_cont_o_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None self.t_tgt_embedding_bias = nn.Parameter(torch.zeros(self.t_tgt_size, self.hidden_size)) if self.s_cont_embedding_vectors is not None: torch.nn.init.xavier_normal_(self.s_cont_embedding_vectors) if self.t_cont_k_embedding_vectors is not None: torch.nn.init.xavier_normal_(self.t_cont_k_embedding_vectors) if self.t_cont_o_embedding_vectors is not None: torch.nn.init.xavier_normal_(self.t_cont_o_embedding_vectors) torch.nn.init.xavier_normal_(self.t_tgt_embedding_vectors) def _apply_embedding(self, cat: Optional[Tensor], cont: Optional[Tensor], cat_emb: Optional[nn.ModuleList], cont_emb: Tensor, cont_bias: Tensor, ) -> Tuple[Optional[Tensor], Optional[Tensor]]: e_cat = torch.stack([embed(cat[...,i]) for i, embed in enumerate(cat_emb)], dim=-2) if cat is not None else None if cont is not None: #the line below is equivalent to following einsums #e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb) #e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb) e_cont = torch.mul(cont.unsqueeze(-1), cont_emb) e_cont = e_cont + cont_bias else: e_cont = None if e_cat is not None and e_cont is not None: return torch.cat([e_cat, e_cont], dim=-2) elif e_cat is not None: return e_cat elif e_cont is not None: return e_cont else: return None def forward(self, x: Dict[str, Tensor]): # temporal/static categorical/continuous known/observed input s_cat_inp = x.get('s_cat', None) s_cont_inp = x.get('s_cont', None) t_cat_k_inp = x.get('k_cat', None) t_cont_k_inp = x.get('k_cont', None) t_cat_o_inp = x.get('o_cat', None) t_cont_o_inp = x.get('o_cont', None) t_tgt_obs = x['target'] # Has to be present # Static inputs are expected to be equal for all timesteps # For memory efficiency there is no assert statement s_cat_inp = s_cat_inp[:,0,:] if s_cat_inp is not None else None s_cont_inp = s_cont_inp[:,0,:] if s_cont_inp is not None else None s_inp = self._apply_embedding(s_cat_inp, s_cont_inp, self.s_cat_embed, self.s_cont_embedding_vectors, self.s_cont_embedding_bias) t_known_inp = self._apply_embedding(t_cat_k_inp, t_cont_k_inp, self.t_cat_k_embed, self.t_cont_k_embedding_vectors, self.t_cont_k_embedding_bias) t_observed_inp = self._apply_embedding(t_cat_o_inp, t_cont_o_inp, self.t_cat_o_embed, self.t_cont_o_embedding_vectors, self.t_cont_o_embedding_bias) # Temporal observed targets # t_observed_tgt = torch.einsum('btf,fh->btfh', t_tgt_obs, self.t_tgt_embedding_vectors) t_observed_tgt = torch.matmul(t_tgt_obs.unsqueeze(3).unsqueeze(4), self.t_tgt_embedding_vectors.unsqueeze(1)).squeeze(3) t_observed_tgt = t_observed_tgt + self.t_tgt_embedding_bias return s_inp, t_known_inp, t_observed_inp, t_observed_tgt class VariableSelectionNetwork(nn.Module): def __init__(self, config, num_inputs): super().__init__() self.joint_grn = GRN(config.hidden_size*num_inputs, config.hidden_size, output_size=num_inputs, context_hidden_size=config.hidden_size) self.var_grns = nn.ModuleList([GRN(config.hidden_size, config.hidden_size, dropout=config.dropout) for _ in range(num_inputs)]) def forward(self, x: Tensor, context: Optional[Tensor] = None): Xi = x.reshape(*x.shape[:-2], -1) grn_outputs = self.joint_grn(Xi, c=context) sparse_weights = F.softmax(grn_outputs, dim=-1) transformed_embed_list = [m(x[...,i,:]) for i, m in enumerate(self.var_grns)] transformed_embed = torch.stack(transformed_embed_list, dim=-1) #the line below performs batched matrix vector multiplication #for temporal features it's bthf,btf->bth #for static features it's bhf,bf->bh variable_ctx = torch.matmul(transformed_embed, sparse_weights.unsqueeze(-1)).squeeze(-1) return variable_ctx, sparse_weights class StaticCovariateEncoder(nn.Module): def __init__(self, config): super().__init__() self.vsn = VariableSelectionNetwork(config, config.num_static_vars) self.context_grns = nn.ModuleList([GRN(config.hidden_size, config.hidden_size, dropout=config.dropout) for _ in range(4)]) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: variable_ctx, sparse_weights = self.vsn(x) # Context vectors: # variable selection context # enrichment context # state_c context # state_h context cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns) return cs, ce, ch, cc class InterpretableMultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head assert config.hidden_size % config.n_head == 0 self.d_head = config.hidden_size // config.n_head self.qkv_linears = nn.Linear(config.hidden_size, (2 * self.n_head + 1) * self.d_head, bias=False) self.out_proj = nn.Linear(self.d_head, config.hidden_size, bias=False) self.attn_dropout = nn.Dropout(config.attn_dropout) self.out_dropout = nn.Dropout(config.dropout) self.scale = self.d_head**-0.5 self.register_buffer("_mask", torch.triu(torch.full((config.example_length, config.example_length), float('-inf')), 1).unsqueeze(0)) def forward(self, x: Tensor, mask_future_timesteps: bool = True) -> Tuple[Tensor, Tensor]: bs, t, h_size = x.shape qkv = self.qkv_linears(x) q, k, v = qkv.split((self.n_head * self.d_head, self.n_head * self.d_head, self.d_head), dim=-1) q = q.view(bs, t, self.n_head, self.d_head) k = k.view(bs, t, self.n_head, self.d_head) v = v.view(bs, t, self.d_head) # attn_score = torch.einsum('bind,bjnd->bnij', q, k) attn_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1))) attn_score.mul_(self.scale) if mask_future_timesteps: attn_score = attn_score + self._mask attn_prob = F.softmax(attn_score, dim=3) attn_prob = self.attn_dropout(attn_prob) # attn_vec = torch.einsum('bnij,bjd->bnid', attn_prob, v) attn_vec = torch.matmul(attn_prob, v.unsqueeze(1)) m_attn_vec = torch.mean(attn_vec, dim=1) out = self.out_proj(m_attn_vec) out = self.out_dropout(out) return out, attn_vec class TemporalFusionTransformer(nn.Module): """ Implementation of https://arxiv.org/abs/1912.09363 """ def __init__(self, config): super().__init__() if hasattr(config, 'model'): config = config.model self.encoder_length = config.encoder_length #this determines from how distant past we want to use data from self.embedding = TFTEmbedding(config) self.static_encoder = StaticCovariateEncoder(config) self.history_vsn = VariableSelectionNetwork(config, config.num_historic_vars) self.history_encoder = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True) self.future_vsn = VariableSelectionNetwork(config, config.num_future_vars) self.future_encoder = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True) self.input_gate = GLU(config.hidden_size, config.hidden_size) self.input_gate_ln = LayerNorm(config.hidden_size, eps=1e-3) self.enrichment_grn = GRN(config.hidden_size, config.hidden_size, context_hidden_size=config.hidden_size, dropout=config.dropout) self.attention = InterpretableMultiHeadAttention(config) self.attention_gate = GLU(config.hidden_size, config.hidden_size) self.attention_ln = LayerNorm(config.hidden_size, eps=1e-3) self.positionwise_grn = GRN(config.hidden_size, config.hidden_size, dropout=config.dropout) self.decoder_gate = GLU(config.hidden_size, config.hidden_size) self.decoder_ln = LayerNorm(config.hidden_size, eps=1e-3) self.quantile_proj = nn.Linear(config.hidden_size, len(config.quantiles)) def forward(self, x: Dict[str, Tensor]) -> Tensor: s_inp, t_known_inp, t_observed_inp, t_observed_tgt = self.embedding(x) # Static context cs, ce, ch, cc = self.static_encoder(s_inp) ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) #lstm initial states # Temporal input _historical_inputs = [t_known_inp[:,:self.encoder_length,:], t_observed_tgt[:,:self.encoder_length,:]] if t_observed_inp is not None: _historical_inputs.insert(0,t_observed_inp[:,:self.encoder_length,:]) historical_inputs = torch.cat(_historical_inputs, dim=-2) future_inputs = t_known_inp[:, self.encoder_length:] # Encoders historical_features, _ = self.history_vsn(historical_inputs, cs) history, state = self.history_encoder(historical_features, (ch, cc)) future_features, _ = self.future_vsn(future_inputs, cs) future, _ = self.future_encoder(future_features, state) torch.cuda.synchronize() # this call gives perf boost for unknown reasons # skip connection input_embedding = torch.cat([historical_features, future_features], dim=1) temporal_features = torch.cat([history, future], dim=1) temporal_features = self.input_gate(temporal_features) temporal_features = temporal_features + input_embedding temporal_features = self.input_gate_ln(temporal_features) # Static enrichment enriched = self.enrichment_grn(temporal_features, c=ce) # Temporal self attention x, _ = self.attention(enriched, mask_future_timesteps=True) # Don't compute hictorical quantiles x = x[:, self.encoder_length:, :] temporal_features = temporal_features[:, self.encoder_length:, :] enriched = enriched[:, self.encoder_length:, :] x = self.attention_gate(x) x = x + enriched x = self.attention_ln(x) # Position-wise feed-forward x = self.positionwise_grn(x) # Final skip connection x = self.decoder_gate(x) x = x + temporal_features x = self.decoder_ln(x) out = self.quantile_proj(x) return out