diff --git a/x-pack/plugins/ml/common/types/data_frame_analytics.ts b/x-pack/plugins/ml/common/types/data_frame_analytics.ts index cacc5acb9768..95d82932a121 100644 --- a/x-pack/plugins/ml/common/types/data_frame_analytics.ts +++ b/x-pack/plugins/ml/common/types/data_frame_analytics.ts @@ -34,9 +34,10 @@ interface Regression { } interface Classification { + class_assignment_objective?: string; dependent_variable: string; training_percent?: number; - num_top_classes?: string; + num_top_classes?: number; num_top_feature_importance_values?: number; prediction_field_name?: string; } diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts index 131da93a2328..40e13ea0e686 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts @@ -16,6 +16,7 @@ import { DataFrameAnalyticsId, DataFrameAnalysisConfigType, } from '../../../../../../../common/types/data_frame_analytics'; +import { isClassificationAnalysis } from '../../../../../../../common/util/analytics_utils'; import { ANALYSIS_CONFIG_TYPE } from '../../../../../../../common/constants/data_frame_analytics'; export enum DEFAULT_MODEL_MEMORY_LIMIT { regression = '100mb', @@ -50,6 +51,7 @@ export interface State { alpha: undefined | number; computeFeatureInfluence: string; createIndexPattern: boolean; + classAssignmentObjective: undefined | string; dependentVariable: DependentVariable; description: string; destinationIndex: EsIndexName; @@ -126,6 +128,7 @@ export const getInitialState = (): State => ({ alpha: undefined, computeFeatureInfluence: 'true', createIndexPattern: true, + classAssignmentObjective: undefined, dependentVariable: '', description: '', destinationIndex: '', @@ -278,13 +281,14 @@ export const getJobConfigFromFormState = ( }; } - if ( - formState.jobType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && - jobConfig?.analysis?.classification !== undefined && - formState.numTopClasses !== undefined - ) { - // @ts-ignore - jobConfig.analysis.classification.num_top_classes = formState.numTopClasses; + if (jobConfig?.analysis !== undefined && isClassificationAnalysis(jobConfig?.analysis)) { + if (formState.numTopClasses !== undefined) { + jobConfig.analysis.classification.num_top_classes = formState.numTopClasses; + } + if (formState.classAssignmentObjective !== undefined) { + jobConfig.analysis.classification.class_assignment_objective = + formState.classAssignmentObjective; + } } if (formState.jobType === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION) {