[ML] Data Frame Analytics creation wizard: adds support for extended hyper-parameters (#90843)

* add support for new hyperparameters in the creation wizard

* fix translation error
This commit is contained in:
Melissa Alvarez 2021-02-10 10:52:46 -05:00 committed by GitHub
parent e94a164b7e
commit 2a93ebe43b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 340 additions and 6 deletions

View file

@ -59,11 +59,21 @@ export const extractErrorProperties = (error: ErrorType): MLErrorObject => {
typeof error.body.attributes === 'object' &&
typeof error.body.attributes.body?.error?.reason === 'string'
) {
return {
const errObj: MLErrorObject = {
message: error.body.attributes.body.error.reason,
statusCode: error.body.statusCode,
fullError: error.body.attributes.body,
};
if (
typeof error.body.attributes.body.error.caused_by === 'object' &&
(typeof error.body.attributes.body.error.caused_by?.reason === 'string' ||
typeof error.body.attributes.body.error.caused_by?.caused_by?.reason === 'string')
) {
errObj.causedBy =
error.body.attributes.body.error.caused_by?.caused_by?.reason ||
error.body.attributes.body.error.caused_by?.reason;
}
return errObj;
} else {
return {
message: error.body.message,

View file

@ -11,6 +11,7 @@ import Boom from '@hapi/boom';
export interface EsErrorRootCause {
type: string;
reason: string;
caused_by?: EsErrorRootCause;
}
export interface EsErrorBody {
@ -37,6 +38,7 @@ export interface ErrorMessage {
}
export interface MLErrorObject {
causedBy?: string;
message: string;
statusCode?: number;
fullError?: EsErrorBody;

View file

@ -33,18 +33,24 @@ export { getAnalysisType } from '../../../../common/util/analytics_utils';
export type IndexPattern = string;
export enum ANALYSIS_ADVANCED_FIELDS {
ALPHA = 'alpha',
ETA = 'eta',
ETA_GROWTH_RATE_PER_TREE = 'eta_growth_rate_per_tree',
DOWNSAMPLE_FACTOR = 'downsample_factor',
FEATURE_BAG_FRACTION = 'feature_bag_fraction',
FEATURE_INFLUENCE_THRESHOLD = 'feature_influence_threshold',
GAMMA = 'gamma',
LAMBDA = 'lambda',
MAX_TREES = 'max_trees',
MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = 'max_optimization_rounds_per_hyperparameter',
METHOD = 'method',
N_NEIGHBORS = 'n_neighbors',
NUM_TOP_CLASSES = 'num_top_classes',
NUM_TOP_FEATURE_IMPORTANCE_VALUES = 'num_top_feature_importance_values',
OUTLIER_FRACTION = 'outlier_fraction',
RANDOMIZE_SEED = 'randomize_seed',
SOFT_TREE_DEPTH_LIMIT = 'soft_tree_depth_limit',
SOFT_TREE_DEPTH_TOLERANCE = 'soft_tree_depth_tolerance',
}
export enum OUTLIER_ANALYSIS_METHOD {

View file

@ -138,14 +138,18 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
const { setEstimatedModelMemoryLimit, setFormState } = actions;
const { form, isJobCreated, estimatedModelMemoryLimit } = state;
const {
alpha,
computeFeatureInfluence,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
featureInfluenceThreshold,
gamma,
jobType,
lambda,
maxNumThreads,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
method,
modelMemoryLimit,
@ -157,6 +161,8 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
outlierFraction,
predictionFieldName,
randomizeSeed,
softTreeDepthLimit,
softTreeDepthTolerance,
useEstimatedMml,
} = form;
@ -197,7 +203,7 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
useEffect(() => {
setFetchingAdvancedParamErrors(true);
(async function () {
const { success, errorMessage, expectedMemory } = await fetchExplainData(form);
const { success, errorMessage, errorReason, expectedMemory } = await fetchExplainData(form);
const paramErrors: AdvancedParamErrors = {};
if (success) {
@ -212,6 +218,8 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
Object.values(ANALYSIS_ADVANCED_FIELDS).forEach((param) => {
if (errorMessage.includes(`[${param}]`)) {
paramErrors[param] = errorMessage;
} else if (errorReason?.includes(`[${param}]`)) {
paramErrors[param] = errorReason;
}
});
}
@ -219,12 +227,16 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
setAdvancedParamErrors(paramErrors);
})();
}, [
alpha,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
featureInfluenceThreshold,
gamma,
lambda,
maxNumThreads,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
method,
nNeighbors,
@ -232,6 +244,8 @@ export const AdvancedStepForm: FC<CreateAnalyticsStepProps> = ({
numTopFeatureImportanceValues,
outlierFraction,
randomizeSeed,
softTreeDepthLimit,
softTreeDepthTolerance,
]);
const outlierDetectionAdvancedConfig = (

View file

@ -21,7 +21,20 @@ interface Props extends CreateAnalyticsFormProps {
export const HyperParameters: FC<Props> = ({ actions, state, advancedParamErrors }) => {
const { setFormState } = actions;
const { eta, featureBagFraction, gamma, lambda, maxTrees, randomizeSeed } = state.form;
const {
alpha,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
gamma,
lambda,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
randomizeSeed,
softTreeDepthLimit,
softTreeDepthTolerance,
} = state.form;
return (
<Fragment>
@ -203,6 +216,215 @@ export const HyperParameters: FC<Props> = ({ actions, state, advancedParamErrors
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate('xpack.ml.dataframe.analytics.create.alphaLabel', {
defaultMessage: 'Alpha',
})}
helpText={i18n.translate('xpack.ml.dataframe.analytics.create.alphaText', {
defaultMessage:
'Multiplies a term based on tree depth in the regularized loss. Higher values result in shallower trees and faster training times. Must be greater than or equal to 0. ',
})}
isInvalid={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.ALPHA] !== undefined}
error={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.ALPHA]}
>
<EuiFieldNumber
aria-label={i18n.translate('xpack.ml.dataframe.analytics.create.alphaInputAriaLabel', {
defaultMessage: 'Multiplies a term based on tree depth in the regularized loss',
})}
data-test-subj="mlAnalyticsCreateJobWizardAlphaInput"
onChange={(e) =>
setFormState({ alpha: e.target.value === '' ? undefined : +e.target.value })
}
step={0.001}
min={0}
value={getNumberValue(alpha)}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate('xpack.ml.dataframe.analytics.create.downsampleFactorLabel', {
defaultMessage: 'Downsample factor',
})}
helpText={i18n.translate('xpack.ml.dataframe.analytics.create.downsampleFactorText', {
defaultMessage:
'Controls the fraction of data that is used to compute the derivatives of the loss function for tree training. Must be between 0 and 1.',
})}
isInvalid={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.DOWNSAMPLE_FACTOR] !== undefined}
error={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.DOWNSAMPLE_FACTOR]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.downsampleFactorInputAriaLabel',
{
defaultMessage:
'Controls the fraction of data that is used to compute the derivatives of the loss function for tree training',
}
)}
data-test-subj="mlAnalyticsCreateJobWizardDownsampleFactorInput"
onChange={(e) =>
setFormState({
downsampleFactor: e.target.value === '' ? undefined : +e.target.value,
})
}
step={0.001}
min={0}
max={1}
value={getNumberValue(downsampleFactor)}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate('xpack.ml.dataframe.analytics.create.etaGrowthRatePerTreeLabel', {
defaultMessage: 'Eta growth rate per tree',
})}
helpText={i18n.translate('xpack.ml.dataframe.analytics.create.etaGrowthRatePerTreeText', {
defaultMessage:
'Specifies the rate at which eta increases for each new tree that is added to the forest. Must be between 0.5 and 2.',
})}
isInvalid={
advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.ETA_GROWTH_RATE_PER_TREE] !== undefined
}
error={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.ETA_GROWTH_RATE_PER_TREE]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.etaGrowthRatePerTreeInputAriaLabel',
{
defaultMessage:
'Specifies the rate at which eta increases for each new tree that is added to the forest.',
}
)}
data-test-subj="mlAnalyticsCreateJobWizardEtaGrowthRatePerTreeInput"
onChange={(e) =>
setFormState({
etaGrowthRatePerTree: e.target.value === '' ? undefined : +e.target.value,
})
}
step={0.001}
min={0.5}
max={2}
value={getNumberValue(etaGrowthRatePerTree)}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate(
'xpack.ml.dataframe.analytics.create.maxOptimizationRoundsPerHyperparameterLabel',
{
defaultMessage: 'Max optimization rounds per hyperparameter',
}
)}
helpText={i18n.translate(
'xpack.ml.dataframe.analytics.create.maxOptimizationRoundsPerHyperparameterText',
{
defaultMessage:
'Multiplier responsible for determining the maximum number of hyperparameter optimization steps in the Bayesian optimization procedure.',
}
)}
isInvalid={
advancedParamErrors[
ANALYSIS_ADVANCED_FIELDS.MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER
] !== undefined
}
error={
advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER]
}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.maxOptimizationRoundsPerHyperparameterInputAriaLabel',
{
defaultMessage:
'Multiplier responsible for determining the maximum number of hyperparameter optimization steps in the Bayesian optimization procedure. Must be an integer between 0 and 20.',
}
)}
data-test-subj="mlAnalyticsCreateJobWizardMaxOptimizationRoundsPerHyperparameterInput"
onChange={(e) =>
setFormState({
maxOptimizationRoundsPerHyperparameter:
e.target.value === '' ? undefined : +e.target.value,
})
}
min={0}
max={20}
step={1}
value={getNumberValue(maxOptimizationRoundsPerHyperparameter)}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate('xpack.ml.dataframe.analytics.create.softTreeDepthLimitLabel', {
defaultMessage: 'Soft tree depth limit',
})}
helpText={i18n.translate('xpack.ml.dataframe.analytics.create.softTreeDepthLimitText', {
defaultMessage:
'Tree depth limit that increases regularized loss when exceeded. Must be greater than or equal to 0. ',
})}
isInvalid={
advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.SOFT_TREE_DEPTH_LIMIT] !== undefined
}
error={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.SOFT_TREE_DEPTH_LIMIT]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.softTreeDepthLimitInputAriaLabel',
{
defaultMessage: 'Tree depth limit that increases regularized loss when exceeded',
}
)}
data-test-subj="mlAnalyticsCreateJobWizardSoftTreeDepthLimitInput"
onChange={(e) =>
setFormState({
softTreeDepthLimit: e.target.value === '' ? undefined : +e.target.value,
})
}
step={0.001}
min={0}
value={getNumberValue(softTreeDepthLimit)}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
<EuiFormRow
label={i18n.translate('xpack.ml.dataframe.analytics.create.softTreeDepthToleranceLabel', {
defaultMessage: 'Soft tree depth tolerance',
})}
helpText={i18n.translate(
'xpack.ml.dataframe.analytics.create.softTreeDepthToleranceText',
{
defaultMessage:
'Controls how quickly the regularized loss increases when the tree depth exceeds soft_tree_depth_limit. Must be greater than or equal to 0.01. ',
}
)}
isInvalid={
advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.SOFT_TREE_DEPTH_TOLERANCE] !== undefined
}
error={advancedParamErrors[ANALYSIS_ADVANCED_FIELDS.SOFT_TREE_DEPTH_TOLERANCE]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.softTreeDepthToleranceInputAriaLabel',
{
defaultMessage: 'Tree depth limit that increases regularized loss when exceeded',
}
)}
data-test-subj="mlAnalyticsCreateJobWizardSoftTreeDepthToleranceInput"
onChange={(e) =>
setFormState({
softTreeDepthTolerance: e.target.value === '' ? undefined : +e.target.value,
})
}
step={0.001}
min={0.01}
value={getNumberValue(softTreeDepthTolerance)}
/>
</EuiFormRow>
</EuiFlexItem>
</Fragment>
);
};

View file

@ -6,7 +6,7 @@
*/
import { ml } from '../../../../../services/ml_api_service';
import { extractErrorMessage } from '../../../../../../../common/util/errors';
import { extractErrorProperties } from '../../../../../../../common/util/errors';
import { DfAnalyticsExplainResponse, FieldSelectionItem } from '../../../../common/analytics';
import {
getJobConfigFromFormState,
@ -23,6 +23,7 @@ export interface FetchExplainDataReturnType {
export const fetchExplainData = async (formState: State['form']) => {
const jobConfig = getJobConfigFromFormState(formState);
let errorMessage = '';
let errorReason = '';
let success = true;
let expectedMemory = '';
let fieldSelection: FieldSelectionItem[] = [];
@ -36,8 +37,12 @@ export const fetchExplainData = async (formState: State['form']) => {
expectedMemory = resp.memory_estimation?.expected_memory_without_disk;
fieldSelection = resp.field_selection || [];
} catch (error) {
const errObj = extractErrorProperties(error);
success = false;
errorMessage = extractErrorMessage(error);
errorMessage = errObj.message;
if (errObj.causedBy) {
errorReason = errObj.causedBy;
}
}
return {
@ -45,5 +50,6 @@ export const fetchExplainData = async (formState: State['form']) => {
expectedMemory,
fieldSelection,
errorMessage,
errorReason,
};
};

View file

@ -121,6 +121,30 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
optional: true,
ignore: true,
},
alpha: {
optional: true,
formKey: 'alpha',
},
downsample_factor: {
optional: true,
formKey: 'downsampleFactor',
},
eta_growth_rate_per_tree: {
optional: true,
formKey: 'etaGrowthRatePerTree',
},
max_optimization_rounds_per_hyperparameter: {
optional: true,
formKey: 'maxOptimizationRoundsPerHyperparameter',
},
soft_tree_depth_limit: {
optional: true,
formKey: 'softTreeDepthLimit',
},
soft_tree_depth_tolerance: {
optional: true,
formKey: 'softTreeDepthTolerance',
},
},
}
: {}),
@ -215,6 +239,30 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
optional: true,
ignore: true,
},
alpha: {
optional: true,
formKey: 'alpha',
},
downsample_factor: {
optional: true,
formKey: 'downsampleFactor',
},
eta_growth_rate_per_tree: {
optional: true,
formKey: 'etaGrowthRatePerTree',
},
max_optimization_rounds_per_hyperparameter: {
optional: true,
formKey: 'maxOptimizationRoundsPerHyperparameter',
},
soft_tree_depth_limit: {
optional: true,
formKey: 'softTreeDepthLimit',
},
soft_tree_depth_tolerance: {
optional: true,
formKey: 'softTreeDepthTolerance',
},
},
}
: {}),

View file

@ -47,6 +47,7 @@ export interface State {
advancedEditorRawString: string;
disableSwitchToForm: boolean;
form: {
alpha: undefined | number;
computeFeatureInfluence: string;
createIndexPattern: boolean;
dependentVariable: DependentVariable;
@ -57,7 +58,9 @@ export interface State {
destinationIndexNameValid: boolean;
destinationIndexPatternTitleExists: boolean;
earlyStoppingEnabled: undefined | boolean;
downsampleFactor: undefined | number;
eta: undefined | number;
etaGrowthRatePerTree: undefined | number;
featureBagFraction: undefined | number;
featureInfluenceThreshold: undefined | number;
gamma: undefined | number;
@ -73,6 +76,7 @@ export interface State {
lambda: number | undefined;
loadingFieldOptions: boolean;
maxNumThreads: undefined | number;
maxOptimizationRoundsPerHyperparameter: undefined | number;
maxTrees: undefined | number;
method: undefined | string;
modelMemoryLimit: string | undefined;
@ -88,6 +92,8 @@ export interface State {
requiredFieldsError: string | undefined;
randomizeSeed: undefined | number;
resultsField: undefined | string;
softTreeDepthLimit: undefined | number;
softTreeDepthTolerance: undefined | number;
sourceIndex: EsIndexName;
sourceIndexNameEmpty: boolean;
sourceIndexNameValid: boolean;
@ -117,6 +123,7 @@ export const getInitialState = (): State => ({
advancedEditorRawString: '',
disableSwitchToForm: false,
form: {
alpha: undefined,
computeFeatureInfluence: 'true',
createIndexPattern: true,
dependentVariable: '',
@ -127,7 +134,9 @@ export const getInitialState = (): State => ({
destinationIndexNameValid: false,
destinationIndexPatternTitleExists: false,
earlyStoppingEnabled: undefined,
downsampleFactor: undefined,
eta: undefined,
etaGrowthRatePerTree: undefined,
featureBagFraction: undefined,
featureInfluenceThreshold: undefined,
gamma: undefined,
@ -143,6 +152,7 @@ export const getInitialState = (): State => ({
lambda: undefined,
loadingFieldOptions: false,
maxNumThreads: DEFAULT_MAX_NUM_THREADS,
maxOptimizationRoundsPerHyperparameter: undefined,
maxTrees: undefined,
method: undefined,
modelMemoryLimit: undefined,
@ -158,6 +168,8 @@ export const getInitialState = (): State => ({
requiredFieldsError: undefined,
randomizeSeed: undefined,
resultsField: undefined,
softTreeDepthLimit: undefined,
softTreeDepthTolerance: undefined,
sourceIndex: '',
sourceIndexNameEmpty: true,
sourceIndexNameValid: false,
@ -233,17 +245,31 @@ export const getJobConfigFromFormState = (
analysis = Object.assign(
analysis,
formState.predictionFieldName && { prediction_field_name: formState.predictionFieldName },
formState.alpha && { alpha: formState.alpha },
formState.eta && { eta: formState.eta },
formState.etaGrowthRatePerTree && {
eta_growth_rate_per_tree: formState.etaGrowthRatePerTree,
},
formState.downsampleFactor && { downsample_factor: formState.downsampleFactor },
formState.featureBagFraction && {
feature_bag_fraction: formState.featureBagFraction,
},
formState.gamma && { gamma: formState.gamma },
formState.lambda && { lambda: formState.lambda },
formState.maxOptimizationRoundsPerHyperparameter && {
max_optimization_rounds_per_hyperparameter:
formState.maxOptimizationRoundsPerHyperparameter,
},
formState.maxTrees && { max_trees: formState.maxTrees },
formState.randomizeSeed && { randomize_seed: formState.randomizeSeed },
formState.earlyStoppingEnabled !== undefined && {
early_stopping_enabled: formState.earlyStoppingEnabled,
},
formState.predictionFieldName && { prediction_field_name: formState.predictionFieldName },
formState.randomizeSeed && { randomize_seed: formState.randomizeSeed },
formState.softTreeDepthLimit && { soft_tree_depth_limit: formState.softTreeDepthLimit },
formState.softTreeDepthTolerance && {
soft_tree_depth_tolerance: formState.softTreeDepthTolerance,
}
);