DeepLearningExamples/Tools/PyTorch/TimeSeriesPredictionPlatform/models/stat_models.py
2021-11-08 14:08:58 -08:00

28 lines
608 B
Python
Executable file

# SPDX-License-Identifier: Apache-2.0
from abc import ABC
import pmdarima as pm
class StatModel(ABC):
def __init__(self, config):
self.horizon = config.dataset.example_length - config.dataset.encoder_length
self.config = config
def fit(self, endog, exog):
return
def predict(self, exog):
return
class AutoARIMA(StatModel):
def __init__(self, config):
super().__init__(config)
def fit(self, endog, exog):
self.model = pm.auto_arima(endog, X=exog)
def predict(self, exog):
return self.model.predict(self.horizon, X=exog)