NeMo/nemo/core/config/optimizers.py
Tomasz Kornuta be289129ad
removed the # ==== showstoppers in all headers (#924)
Signed-off-by: Tomasz Kornuta <tkornuta@nvidia.com>
2020-07-27 12:55:51 -07:00

257 lines
7.8 KiB
Python

# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple
__all__ = [
'OptimizerParams',
'AdamParams',
'NovogradParams',
'SGDParams',
'AdadeltaParams',
'AdamaxParams',
'AdagradParams',
'AdamWParams',
'RMSpropParams',
'RpropParams',
]
@dataclass
class OptimizerParams:
"""
Base Optimizer params with no values. User can chose it to explicitly override via
command line arguments
"""
@dataclass
class SGDParams(OptimizerParams):
"""
Default configuration for Adam optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD
"""
momentum: float = 0
dampening: float = 0
weight_decay: float = 0
nesterov: bool = False
@dataclass
class AdamParams(OptimizerParams):
"""
Default configuration for Adam optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam
"""
# betas: Tuple[float, float] = (0.9, 0.999)
eps: float = 1e-08
weight_decay: float = 0
amsgrad: bool = False
@dataclass
class AdamWParams(OptimizerParams):
"""
Default configuration for AdamW optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW
"""
betas: Tuple[float, float] = (0.9, 0.999)
eps: float = 1e-08
weight_decay: float = 0
amsgrad: bool = False
@dataclass
class AdadeltaParams(OptimizerParams):
"""
Default configuration for Adadelta optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta
"""
rho: float = 0.9
eps: float = 1e-6
weight_decay: float = 0
@dataclass
class AdamaxParams(OptimizerParams):
"""
Default configuration for Adamax optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax
"""
betas: Tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0
@dataclass
class AdagradParams(OptimizerParams):
"""
Default configuration for Adagrad optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad
"""
lr_decay: float = 0
weight_decay: float = 0
initial_accumulator_value: float = 0
eps: float = 1e-10
@dataclass
class RMSpropParams(OptimizerParams):
"""
Default configuration for RMSprop optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop
"""
alpha: float = 0.99
eps: float = 1e-8
weight_decay: float = 0
momentum: float = 0
centered: bool = False
@dataclass
class RpropParams(OptimizerParams):
"""
Default configuration for RpropParams optimizer.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
..note:
For the details on the function/meanings of the arguments, please refer to:
https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop
"""
etas: Tuple[float, float] = (0.5, 1.2)
step_sizes: Tuple[float, float] = (1e-6, 50)
@dataclass
class NovogradParams(OptimizerParams):
"""
Configuration of the Novograd optimizer.
It has been proposed in "Stochastic Gradient Methods with Layer-wise
Adaptive Moments for Training of Deep Networks"
(https://arxiv.org/abs/1905.11286)
Args:
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and Beyond"
"""
betas: Tuple[float, float] = (0.95, 0.98)
eps: float = 1e-8
weight_decay: float = 0
grad_averaging: bool = False
amsgrad: bool = False
luc: bool = False
luc_trust: float = 1e-3
luc_eps: float = 1e-8
def register_optimizer_params(name: str, optimizer_params: OptimizerParams):
"""
Checks if the optimizer param name exists in the registry, and if it doesnt, adds it.
This allows custom optimizer params to be added and called by name during instantiation.
Args:
name: Name of the optimizer. Will be used as key to retrieve the optimizer.
optimizer_params: Optimizer class
"""
if name in AVAILABLE_OPTIMIZER_PARAMS:
raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}")
AVAILABLE_OPTIMIZER_PARAMS[name] = optimizer_params
def get_optimizer_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> OptimizerParams:
"""
Convenience method to obtain a OptimizerParams class and partially instantiate it with optimizer kwargs.
Args:
name: Name of the OptimizerParams in the registry.
kwargs: Optional kwargs of the optimizer used during instantiation.
Returns:
a partially instantiated OptimizerParams
"""
if name is None:
return kwargs
if name not in AVAILABLE_OPTIMIZER_PARAMS:
raise ValueError(
f"Cannot resolve optimizer parameters '{name}'. Available optimizer parameters are : "
f"{AVAILABLE_OPTIMIZER_PARAMS.keys()}"
)
scheduler_params = AVAILABLE_OPTIMIZER_PARAMS[name]
scheduler_params = partial(scheduler_params, **kwargs)
return scheduler_params
AVAILABLE_OPTIMIZER_PARAMS = {
'optim_params': OptimizerParams,
'adam_params': AdamParams,
'novograd_params': NovogradParams,
'sgd_params': SGDParams,
'adadelta_params': AdadeltaParams,
'adamax_params': AdamaxParams,
'adagrad_params': AdagradParams,
'adamw_params': AdamWParams,
'rmsprop_params': RMSpropParams,
'rprop_params': RpropParams,
}