[7.x] [ML] Add decision path charts to exploration results table (#73561) (#77082)

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Quynh Nguyen 2020-09-09 16:06:08 -05:00 committed by GitHub
parent 2b765b18e7
commit ac5a6eae8a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 1083 additions and 125 deletions

View file

@ -0,0 +1,7 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
export const DEFAULT_RESULTS_FIELD = 'ml';

View file

@ -79,3 +79,9 @@ export interface DataFrameAnalyticsConfig {
version: string;
allow_lazy_start?: boolean;
}
export enum ANALYSIS_CONFIG_TYPE {
OUTLIER_DETECTION = 'outlier_detection',
REGRESSION = 'regression',
CLASSIFICATION = 'classification',
}

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
export interface ClassFeatureImportance {
class_name: string | boolean;
importance: number;
}
export interface FeatureImportance {
feature_name: string;
importance?: number;
classes?: ClassFeatureImportance[];
}
export interface TopClass {
class_name: string;
class_probability: number;
class_score: number;
}
export type TopClasses = TopClass[];

View file

@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import {
AnalysisConfig,
ClassificationAnalysis,
OutlierAnalysis,
RegressionAnalysis,
ANALYSIS_CONFIG_TYPE,
} from '../types/data_frame_analytics';
export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
};
export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
};
export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};
export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';
if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}
if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};
export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.prediction_field_name !== undefined
) {
predictionFieldName = analysis.classification.prediction_field_name;
}
return predictionFieldName;
};
export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => {
return `${getDependentVar(analysis)}_prediction`;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis)
}`;
return predictedField;
};

View file

@ -119,13 +119,14 @@ export const getDataGridSchemasFromFieldTypes = (fieldTypes: FieldTypes, results
schema = 'numeric';
}
if (
field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`) ||
field.includes(`${resultsField}.${TOP_CLASSES}`)
) {
if (field.includes(`${resultsField}.${TOP_CLASSES}`)) {
schema = 'json';
}
if (field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`)) {
schema = 'featureImportance';
}
return { id: field, schema, isSortable };
});
};
@ -250,10 +251,6 @@ export const useRenderCellValue = (
return cellValue ? 'true' : 'false';
}
if (typeof cellValue === 'object' && cellValue !== null) {
return JSON.stringify(cellValue);
}
return cellValue;
};
}, [indexPattern?.fields, pagination.pageIndex, pagination.pageSize, tableItems]);

View file

@ -5,8 +5,7 @@
*/
import { isEqual } from 'lodash';
import React, { memo, useEffect, FC } from 'react';
import React, { memo, useEffect, FC, useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import {
@ -24,13 +23,16 @@ import {
} from '@elastic/eui';
import { CoreSetup } from 'src/core/public';
import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_histograms';
import { INDEX_STATUS } from '../../data_frame_analytics/common';
import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';
import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
// TODO Fix row hovering + bar highlighting
// import { hoveredRow$ } from './column_chart';
@ -41,6 +43,9 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);
interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
analysisType?: ANALYSIS_CONFIG_TYPE;
resultsField?: string;
dataTestSubj: string;
toastNotifications: CoreSetup['notifications']['toasts'];
}
@ -60,6 +65,7 @@ type Props = PropsWithHeader | PropsWithoutHeader;
export const DataGrid: FC<Props> = memo(
(props) => {
const {
baseline,
chartsVisible,
chartsButtonVisible,
columnsWithCharts,
@ -80,8 +86,10 @@ export const DataGrid: FC<Props> = memo(
toastNotifications,
toggleChartVisibility,
visibleColumns,
predictionFieldName,
resultsField,
analysisType,
} = props;
// TODO Fix row hovering + bar highlighting
// const getRowProps = (item: any) => {
// return {
@ -90,6 +98,45 @@ export const DataGrid: FC<Props> = memo(
// };
// };
const popOverContent = useMemo(() => {
return analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION ||
analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION
? {
featureImportance: ({ children }: { cellContentsElement: any; children: any }) => {
const rowIndex = children?.props?.visibleRowIndex;
const row = data[rowIndex];
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
}
return (
<DecisionPathPopover
analysisType={analysisType}
predictedValue={predictedValue}
baseline={baseline}
featureImportance={parsedFIArray}
topClasses={topClasses}
predictionFieldName={
predictionFieldName ? predictionFieldName.replace('_prediction', '') : undefined
}
/>
);
},
}
: undefined;
}, [baseline, data]);
useEffect(() => {
if (invalidSortingColumnns.length > 0) {
invalidSortingColumnns.forEach((columnId) => {
@ -225,6 +272,7 @@ export const DataGrid: FC<Props> = memo(
}
: {}),
}}
popoverContents={popOverContent}
pagination={{
...pagination,
pageSizeOptions: [5, 10, 25],

View file

@ -0,0 +1,166 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import {
AnnotationDomainTypes,
Axis,
AxisStyle,
Chart,
LineAnnotation,
LineAnnotationStyle,
LineAnnotationDatum,
LineSeries,
PartialTheme,
Position,
RecursivePartial,
ScaleType,
Settings,
} from '@elastic/charts';
import { EuiIcon } from '@elastic/eui';
import React, { useCallback, useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import { DecisionPathPlotData } from './use_classification_path_data';
const { euiColorFullShade, euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;
const baselineStyle: LineAnnotationStyle = {
line: {
strokeWidth: 1,
stroke: euiColorFullShade,
opacity: 0.75,
},
details: {
fontFamily: 'Arial',
fontSize: 10,
fontStyle: 'bold',
fill: euiColorMediumShade,
padding: 0,
},
};
const axes: RecursivePartial<AxisStyle> = {
axisLine: {
stroke: axisColor,
},
tickLabel: {
fontSize: 10,
fill: axisColor,
},
tickLine: {
stroke: axisColor,
},
gridLine: {
horizontal: {
dash: [1, 2],
},
vertical: {
strokeWidth: 0,
},
},
};
const theme: PartialTheme = {
axes,
};
interface DecisionPathChartProps {
decisionPathData: DecisionPathPlotData;
predictionFieldName?: string;
baseline?: number;
minDomain: number | undefined;
maxDomain: number | undefined;
}
const DECISION_PATH_MARGIN = 125;
const DECISION_PATH_ROW_HEIGHT = 10;
const NUM_PRECISION = 3;
const AnnotationBaselineMarker = <EuiIcon type="dot" size="m" />;
export const DecisionPathChart = ({
decisionPathData,
predictionFieldName,
minDomain,
maxDomain,
baseline,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const baselineData: LineAnnotationDatum[] = useMemo(
() => [
{
dataValue: baseline,
header: baseline ? baseline.toPrecision(NUM_PRECISION) : '',
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
],
[baseline]
);
// guarantee up to num_precision significant digits
// without having it in scientific notation
const tickFormatter = useCallback((d) => Number(d.toPrecision(NUM_PRECISION)).toString(), []);
return (
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baseline && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
dataValues={baselineData}
style={baselineStyle}
marker={AnnotationBaselineMarker}
/>
)}
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{
defaultMessage: "Prediction for '{predictionFieldName}'",
values: { predictionFieldName },
}
)}
showGridLines={false}
position={Position.Top}
showOverlappingTicks
domain={
minDomain && maxDomain
? {
min: minDomain,
max: maxDomain,
}
: undefined
}
/>
<Axis showGridLines={true} id="left" position={Position.Left} />
<LineSeries
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathLine'}
name={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathLineTitle',
{
defaultMessage: 'Prediction',
}
)}
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
xAccessor={0}
yAccessors={[2]}
data={decisionPathData}
/>
</Chart>
);
};

View file

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React, { FC, useMemo, useState } from 'react';
import { i18n } from '@kbn/i18n';
import { EuiHealth, EuiSpacer, EuiSuperSelect, EuiTitle } from '@elastic/eui';
import d3 from 'd3';
import {
isDecisionPathData,
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
interface ClassificationDecisionPathProps {
predictedValue: string | boolean;
predictionFieldName?: string;
featureImportance: FeatureImportance[];
topClasses: TopClasses;
}
export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
featureImportance,
predictedValue,
topClasses,
predictionFieldName,
}) => {
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const { decisionPathData } = useDecisionPathData({
featureImportance,
predictedValue: currentClass,
});
const options = useMemo(() => {
const predictionValueStr = getStringBasedClassName(predictedValue);
return Array.isArray(topClasses)
? topClasses.map((c) => {
const className = getStringBasedClassName(c.class_name);
return {
value: className,
inputDisplay:
className === predictionValueStr ? (
<EuiHealth color="success" style={{ lineHeight: 'inherit' }}>
{className}
</EuiHealth>
) : (
className
),
};
})
: undefined;
}, [topClasses, predictedValue]);
const domain = useMemo(() => {
let maxDomain;
let minDomain;
// if decisionPathData has calculated cumulative path
if (decisionPathData && isDecisionPathData(decisionPathData)) {
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
const buffer = Math.abs(max - min) * 0.1;
maxDomain = max + buffer;
minDomain = min - buffer;
}
return { maxDomain, minDomain };
}, [decisionPathData]);
if (!decisionPathData) return <MissingDecisionPathCallout />;
return (
<>
<EuiSpacer size={'xs'} />
<EuiTitle size={'xxxs'}>
<span>
{i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.classificationDecisionPathClassNameTitle',
{
defaultMessage: 'Class name',
}
)}
</span>
</EuiTitle>
{options !== undefined && (
<EuiSuperSelect
compressed={true}
options={options}
valueOfSelected={currentClass}
onChange={setCurrentClass}
/>
)}
<DecisionPathChart
decisionPathData={decisionPathData}
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
/>
</>
);
};

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React, { FC } from 'react';
import { EuiCodeBlock } from '@elastic/eui';
import { FeatureImportance } from '../../../../../common/types/feature_importance';
interface DecisionPathJSONViewerProps {
featureImportance: FeatureImportance[];
}
export const DecisionPathJSONViewer: FC<DecisionPathJSONViewerProps> = ({ featureImportance }) => {
return <EuiCodeBlock isCopyable={true}>{JSON.stringify(featureImportance)}</EuiCodeBlock>;
};

View file

@ -0,0 +1,134 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React, { FC, useState } from 'react';
import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import { RegressionDecisionPath } from './decision_path_regression';
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
import { ClassificationDecisionPath } from './decision_path_classification';
import { useMlKibana } from '../../../contexts/kibana';
interface DecisionPathPopoverProps {
featureImportance: FeatureImportance[];
analysisType: ANALYSIS_CONFIG_TYPE;
predictionFieldName?: string;
baseline?: number;
predictedValue?: number | string | undefined;
topClasses?: TopClasses;
}
enum DECISION_PATH_TABS {
CHART = 'decision_path_chart',
JSON = 'decision_path_json',
}
export interface ExtendedFeatureImportance extends FeatureImportance {
absImportance?: number;
}
export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
baseline,
featureImportance,
predictedValue,
topClasses,
analysisType,
predictionFieldName,
}) => {
const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART);
const {
services: { docLinks },
} = useMlKibana();
const { ELASTIC_WEBSITE_URL, DOC_LINK_VERSION } = docLinks;
if (featureImportance.length < 2) {
return <DecisionPathJSONViewer featureImportance={featureImportance} />;
}
const tabs = [
{
id: DECISION_PATH_TABS.CHART,
name: (
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathPlotTab"
defaultMessage="Decision plot"
/>
),
},
{
id: DECISION_PATH_TABS.JSON,
name: (
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathJSONTab"
defaultMessage="JSON"
/>
),
},
];
return (
<>
<div style={{ display: 'flex', width: 300 }}>
<EuiTabs size={'s'}>
{tabs.map((tab) => (
<EuiTab
isSelected={tab.id === selectedTabId}
onClick={() => setSelectedTabId(tab.id)}
key={tab.id}
>
{tab.name}
</EuiTab>
))}
</EuiTabs>
</div>
{selectedTabId === DECISION_PATH_TABS.CHART && (
<>
<EuiText size={'xs'} color="subdued" style={{ paddingTop: 5 }}>
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathPlotHelpText"
defaultMessage="SHAP decision plots use {linkedFeatureImportanceValues} to show how models arrive at the predicted value for '{predictionFieldName}'."
values={{
predictionFieldName,
linkedFeatureImportanceValues: (
<EuiLink
href={`${ELASTIC_WEBSITE_URL}guide/en/machine-learning/${DOC_LINK_VERSION}/ml-feature-importance.html`}
target="_blank"
>
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.linkedFeatureImportanceValues"
defaultMessage="feature importance values"
/>
</EuiLink>
),
}}
/>
</EuiText>
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={predictedValue as string}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={predictedValue as number}
predictionFieldName={predictionFieldName}
/>
)}
</>
)}
{selectedTabId === DECISION_PATH_TABS.JSON && (
<DecisionPathJSONViewer featureImportance={featureImportance} />
)}
</>
);
};

View file

@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
interface RegressionDecisionPathProps {
predictionFieldName?: string;
baseline?: number;
predictedValue?: number | undefined;
featureImportance: FeatureImportance[];
topClasses?: TopClasses;
}
export const RegressionDecisionPath: FC<RegressionDecisionPathProps> = ({
baseline,
featureImportance,
predictedValue,
predictionFieldName,
}) => {
const { decisionPathData } = useDecisionPathData({
baseline,
featureImportance,
predictedValue,
});
const domain = useMemo(() => {
let maxDomain;
let minDomain;
// if decisionPathData has calculated cumulative path
if (decisionPathData && isDecisionPathData(decisionPathData)) {
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
maxDomain = max;
minDomain = min;
const buffer = Math.abs(maxDomain - minDomain) * 0.1;
maxDomain =
(typeof baseline === 'number' ? Math.max(maxDomain, baseline) : maxDomain) + buffer;
minDomain =
(typeof baseline === 'number' ? Math.min(minDomain, baseline) : minDomain) - buffer;
}
return { maxDomain, minDomain };
}, [decisionPathData, baseline]);
if (!decisionPathData) return <MissingDecisionPathCallout />;
return (
<>
{baseline === undefined && (
<EuiCallOut
size={'s'}
heading={'p'}
title={
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.missingBaselineCallout"
defaultMessage="Unable to calculate baseline value, which might result in a shifted decision path."
/>
}
color="warning"
iconType="alert"
/>
)}
<DecisionPathChart
decisionPathData={decisionPathData}
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
baseline={baseline}
/>
</>
);
};

View file

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
export const MissingDecisionPathCallout = () => {
return (
<EuiCallOut color={'warning'}>
<FormattedMessage
id="xpack.ml.dataframe.analytics.explorationResults.regressionDecisionPathDataMissingCallout"
defaultMessage="No decision path data available."
/>
</EuiCallOut>
);
};

View file

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import { useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { ExtendedFeatureImportance } from './decision_path_popover';
export type DecisionPathPlotData = Array<[string, number, number]>;
interface UseDecisionPathDataParams {
featureImportance: FeatureImportance[];
baseline?: number;
predictedValue?: string | number | undefined;
topClasses?: TopClasses;
}
interface RegressionDecisionPathProps {
baseline?: number;
predictedValue?: number | undefined;
featureImportance: FeatureImportance[];
topClasses?: TopClasses;
}
const FEATURE_NAME = 'feature_name';
const FEATURE_IMPORTANCE = 'importance';
export const isDecisionPathData = (decisionPathData: any): boolean => {
return (
Array.isArray(decisionPathData) &&
decisionPathData.length > 0 &&
decisionPathData[0].length === 3
);
};
// cast to 'True' | 'False' | value to match Eui display
export const getStringBasedClassName = (v: string | boolean | undefined | number): string => {
if (v === undefined) {
return '';
}
if (typeof v === 'boolean') {
return v ? 'True' : 'False';
}
if (typeof v === 'number') {
return v.toString();
}
return v;
};
export const useDecisionPathData = ({
baseline,
featureImportance,
predictedValue,
}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => {
const decisionPathData = useMemo(() => {
return baseline
? buildRegressionDecisionPathData({
baseline,
featureImportance,
predictedValue: predictedValue as number | undefined,
})
: buildClassificationDecisionPathData({
featureImportance,
currentClass: predictedValue as string | undefined,
});
}, [baseline, featureImportance, predictedValue]);
return { decisionPathData };
};
export const buildDecisionPathData = (featureImportance: ExtendedFeatureImportance[]) => {
const finalResult: DecisionPathPlotData = featureImportance
// sort so absolute importance so it goes from bottom (baseline) to top
.sort(
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
b.absImportance! - a.absImportance!
)
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
// start at the baseline and end at predicted value
// for regression, cumulativeSum should add up to baseline
let cumulativeSum = 0;
for (let i = featureImportance.length - 1; i >= 0; i--) {
cumulativeSum += finalResult[i][1];
finalResult[i][2] = cumulativeSum;
}
return finalResult;
};
export const buildRegressionDecisionPathData = ({
baseline,
featureImportance,
predictedValue,
}: RegressionDecisionPathProps): DecisionPathPlotData | undefined => {
let mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance;
mappedFeatureImportance = mappedFeatureImportance.map((d) => ({
...d,
absImportance: Math.abs(d[FEATURE_IMPORTANCE] as number),
}));
if (baseline && predictedValue !== undefined && Number.isFinite(predictedValue)) {
// get the adjusted importance needed for when # of fields included in c++ analysis != max allowed
// if num fields included = num features allowed exactly, adjustedImportance should be 0
const adjustedImportance =
predictedValue -
mappedFeatureImportance.reduce(
(accumulator, currentValue) => accumulator + currentValue.importance!,
0
) -
baseline;
mappedFeatureImportance.push({
[FEATURE_NAME]: i18n.translate(
'xpack.ml.dataframe.analytics.decisionPathFeatureBaselineTitle',
{
defaultMessage: 'baseline',
}
),
[FEATURE_IMPORTANCE]: baseline,
absImportance: -1,
});
// if the difference is small enough then no need to plot the residual feature importance
if (Math.abs(adjustedImportance) > 1e-5) {
mappedFeatureImportance.push({
[FEATURE_NAME]: i18n.translate(
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
{
defaultMessage: 'other',
}
),
[FEATURE_IMPORTANCE]: adjustedImportance,
absImportance: 0, // arbitrary importance so this will be of higher importance than baseline
});
}
}
const filteredFeatureImportance = mappedFeatureImportance.filter(
(f) => f !== undefined
) as ExtendedFeatureImportance[];
return buildDecisionPathData(filteredFeatureImportance);
};
export const buildClassificationDecisionPathData = ({
featureImportance,
currentClass,
}: {
featureImportance: FeatureImportance[];
currentClass: string | undefined;
}): DecisionPathPlotData | undefined => {
if (currentClass === undefined) return [];
const mappedFeatureImportance: Array<
ExtendedFeatureImportance | undefined
> = featureImportance.map((feature) => {
const classFeatureImportance = Array.isArray(feature.classes)
? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass)
: feature;
if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') {
return {
[FEATURE_NAME]: feature[FEATURE_NAME],
[FEATURE_IMPORTANCE]: classFeatureImportance[FEATURE_IMPORTANCE],
absImportance: Math.abs(classFeatureImportance[FEATURE_IMPORTANCE] as number),
};
}
return undefined;
});
const filteredFeatureImportance = mappedFeatureImportance.filter(
(f) => f !== undefined
) as ExtendedFeatureImportance[];
return buildDecisionPathData(filteredFeatureImportance);
};

View file

@ -74,6 +74,9 @@ export interface UseIndexDataReturnType
| 'tableItems'
| 'toggleChartVisibility'
| 'visibleColumns'
| 'baseline'
| 'predictionFieldName'
| 'resultsField'
> {
renderCellValue: RenderCellValue;
}
@ -105,4 +108,7 @@ export interface UseDataGridReturnType {
tableItems: DataGridItem[];
toggleChartVisibility: () => void;
visibleColumns: ColumnId[];
baseline?: number;
predictionFieldName?: string;
resultsField?: string;
}

View file

@ -15,18 +15,19 @@ import { SavedSearchQuery } from '../../contexts/ml';
import {
AnalysisConfig,
ClassificationAnalysis,
OutlierAnalysis,
RegressionAnalysis,
ANALYSIS_CONFIG_TYPE,
} from '../../../../common/types/data_frame_analytics';
import {
isOutlierAnalysis,
isRegressionAnalysis,
isClassificationAnalysis,
getPredictionFieldName,
getDependentVar,
getPredictedFieldName,
} from '../../../../common/util/analytics_utils';
export type IndexPattern = string;
export enum ANALYSIS_CONFIG_TYPE {
OUTLIER_DETECTION = 'outlier_detection',
REGRESSION = 'regression',
CLASSIFICATION = 'classification',
}
export enum ANALYSIS_ADVANCED_FIELDS {
ETA = 'eta',
FEATURE_BAG_FRACTION = 'feature_bag_fraction',
@ -156,23 +157,6 @@ export const getAnalysisType = (analysis: AnalysisConfig): string => {
return 'unknown';
};
export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';
if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}
if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};
export const getTrainingPercent = (
analysis: AnalysisConfig
):
@ -190,24 +174,6 @@ export const getTrainingPercent = (
return trainingPercent;
};
export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.prediction_field_name !== undefined
) {
predictionFieldName = analysis.classification.prediction_field_name;
}
return predictionFieldName;
};
export const getNumTopClasses = (
analysis: AnalysisConfig
): ClassificationAnalysis['classification']['num_top_classes'] => {
@ -238,35 +204,6 @@ export const getNumTopFeatureImportanceValues = (
return numTopFeatureImportanceValues;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const defaultPredictionField = `${getDependentVar(analysis)}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
return predictedField;
};
export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
};
export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
};
export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};
export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => {
if (arg === undefined) return false;
const keys = Object.keys(arg);
@ -607,3 +544,13 @@ export const loadDocsCount = async ({
};
}
};
export {
isOutlierAnalysis,
isRegressionAnalysis,
isClassificationAnalysis,
getPredictionFieldName,
ANALYSIS_CONFIG_TYPE,
getDependentVar,
getPredictedFieldName,
};

View file

@ -3,8 +3,6 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
export const DEFAULT_RESULTS_FIELD = 'ml';
export const FEATURE_IMPORTANCE = 'feature_importance';
export const FEATURE_INFLUENCE = 'feature_influence';
export const TOP_CLASSES = 'top_classes';

View file

@ -4,17 +4,16 @@
* you may not use this file except in compliance with the Elastic License.
*/
import { getNumTopClasses, getNumTopFeatureImportanceValues } from './analytics';
import { Field } from '../../../../common/types/fields';
import {
getNumTopClasses,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
isClassificationAnalysis,
isOutlierAnalysis,
isRegressionAnalysis,
} from './analytics';
import { Field } from '../../../../common/types/fields';
} from '../../../../common/util/analytics_utils';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';

View file

@ -9,7 +9,6 @@ import React, { FC } from 'react';
import { i18n } from '@kbn/i18n';
import { ExplorationPageWrapper } from '../exploration_page_wrapper';
import { EvaluatePanel } from './evaluate_panel';
interface Props {

View file

@ -51,7 +51,6 @@ export const ExplorationPageWrapper: FC<Props> = ({ jobId, title, EvaluatePanel
/>
);
}
return (
<>
{isLoadingJobConfig === true && jobConfig === undefined && <LoadingPanel />}

View file

@ -28,6 +28,8 @@ import {
INDEX_STATUS,
SEARCH_SIZE,
defaultSearchQuery,
getAnalysisType,
ANALYSIS_CONFIG_TYPE,
} from '../../../../common';
import { getTaskStateBadge } from '../../../analytics_management/components/analytics_list/use_columns';
import { DATA_FRAME_TASK_STATE } from '../../../analytics_management/components/analytics_list/common';
@ -36,6 +38,7 @@ import { ExplorationQueryBar } from '../exploration_query_bar';
import { IndexPatternPrompt } from '../index_pattern_prompt';
import { useExplorationResults } from './use_exploration_results';
import { useMlKibana } from '../../../../../contexts/kibana';
const showingDocs = i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.documentsShownHelpText',
@ -70,18 +73,27 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
setEvaluateSearchQuery,
title,
}) => {
const {
services: {
mlServices: { mlApiServices },
},
} = useMlKibana();
const [searchQuery, setSearchQuery] = useState<SavedSearchQuery>(defaultSearchQuery);
useEffect(() => {
setEvaluateSearchQuery(searchQuery);
}, [JSON.stringify(searchQuery)]);
const analysisType = getAnalysisType(jobConfig.analysis);
const classificationData = useExplorationResults(
indexPattern,
jobConfig,
searchQuery,
getToastNotifications()
getToastNotifications(),
mlApiServices
);
const docFieldsCount = classificationData.columnsWithCharts.length;
const {
columnsWithCharts,
@ -94,7 +106,6 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
if (jobConfig === undefined || classificationData === undefined) {
return null;
}
// if it's a searchBar syntax error leave the table visible so they can try again
if (status === INDEX_STATUS.ERROR && !errorMessage.includes('failed to create query')) {
return (
@ -184,6 +195,7 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
{...classificationData}
dataTestSubj="mlExplorationDataGrid"
toastNotifications={getToastNotifications()}
analysisType={(analysisType as unknown) as ANALYSIS_CONFIG_TYPE}
/>
</EuiFlexItem>
</EuiFlexGroup>

View file

@ -4,12 +4,14 @@
* you may not use this file except in compliance with the Elastic License.
*/
import { useEffect, useMemo } from 'react';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { EuiDataGridColumn } from '@elastic/eui';
import { CoreSetup } from 'src/core/public';
import { i18n } from '@kbn/i18n';
import { MlApiServices } from '../../../../../services/ml_api_service';
import { IndexPattern } from '../../../../../../../../../../src/plugins/data/public';
import { DataLoader } from '../../../../../datavisualizer/index_based/data_loader';
@ -23,21 +25,26 @@ import {
UseIndexDataReturnType,
} from '../../../../../components/data_grid';
import { SavedSearchQuery } from '../../../../../contexts/ml';
import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common';
import {
DEFAULT_RESULTS_FIELD,
FEATURE_IMPORTANCE,
TOP_CLASSES,
} from '../../../../common/constants';
getPredictionFieldName,
getDefaultPredictionFieldName,
} from '../../../../../../../common/util/analytics_utils';
import { FEATURE_IMPORTANCE, TOP_CLASSES } from '../../../../common/constants';
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
import { isRegressionAnalysis } from '../../../../common/analytics';
import { extractErrorMessage } from '../../../../../../../common/util/errors';
export const useExplorationResults = (
indexPattern: IndexPattern | undefined,
jobConfig: DataFrameAnalyticsConfig | undefined,
searchQuery: SavedSearchQuery,
toastNotifications: CoreSetup['notifications']['toasts']
toastNotifications: CoreSetup['notifications']['toasts'],
mlApiServices: MlApiServices
): UseIndexDataReturnType => {
const [baseline, setBaseLine] = useState();
const needsDestIndexFields =
indexPattern !== undefined && indexPattern.title === jobConfig?.source.index[0];
@ -52,7 +59,6 @@ export const useExplorationResults = (
)
);
}
const dataGrid = useDataGrid(
columns,
25,
@ -107,16 +113,60 @@ export const useExplorationResults = (
jobConfig?.dest.index,
JSON.stringify([searchQuery, dataGrid.visibleColumns]),
]);
const predictionFieldName = useMemo(() => {
if (jobConfig) {
return (
getPredictionFieldName(jobConfig.analysis) ??
getDefaultPredictionFieldName(jobConfig.analysis)
);
}
return undefined;
}, [jobConfig]);
const getAnalyticsBaseline = useCallback(async () => {
try {
if (
jobConfig !== undefined &&
jobConfig.analysis !== undefined &&
isRegressionAnalysis(jobConfig.analysis)
) {
const result = await mlApiServices.dataFrameAnalytics.getAnalyticsBaseline(jobConfig.id);
if (result?.baseline) {
setBaseLine(result.baseline);
}
}
} catch (e) {
const error = extractErrorMessage(e);
toastNotifications.addDanger({
title: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.baselineErrorMessageToast',
{
defaultMessage: 'An error occurred getting feature importance baseline',
}
),
text: error,
});
}
}, [mlApiServices, jobConfig]);
useEffect(() => {
getAnalyticsBaseline();
}, [jobConfig]);
const resultsField = jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD;
const renderCellValue = useRenderCellValue(
indexPattern,
dataGrid.pagination,
dataGrid.tableItems,
jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD
resultsField
);
return {
...dataGrid,
renderCellValue,
baseline,
predictionFieldName,
resultsField,
};
};

View file

@ -29,7 +29,8 @@ import { SavedSearchQuery } from '../../../../../contexts/ml';
import { getToastNotifications } from '../../../../../util/dependency_cache';
import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common';
import { DEFAULT_RESULTS_FIELD, FEATURE_INFLUENCE } from '../../../../common/constants';
import { FEATURE_INFLUENCE } from '../../../../common/constants';
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
import { getFeatureCount, getOutlierScoreFieldName } from './common';

View file

@ -12,7 +12,7 @@ import { IIndexPattern } from 'src/plugins/data/common';
import { DeepReadonly } from '../../../../../../../common/types/common';
import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common';
import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants';
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
import { useMlKibana, useNavigateToPath } from '../../../../../contexts/kibana';
import { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from '../../hooks/use_create_analytics_form';
import { State } from '../../hooks/use_create_analytics_form/state';

View file

@ -135,4 +135,10 @@ export const dataFrameAnalytics = {
method: 'GET',
});
},
getAnalyticsBaseline(analyticsId: string) {
return http<any>({
path: `${basePath()}/data_frame/analytics/${analyticsId}/baseline`,
method: 'POST',
});
},
};

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import { IScopedClusterClient } from 'kibana/server';
import {
getDefaultPredictionFieldName,
getPredictionFieldName,
isRegressionAnalysis,
} from '../../../common/util/analytics_utils';
import { DEFAULT_RESULTS_FIELD } from '../../../common/constants/data_frame_analytics';
// Obtains data for the data frame analytics feature importance functionalities
// such as baseline, decision paths, or importance summary.
export function analyticsFeatureImportanceProvider({
asInternalUser,
asCurrentUser,
}: IScopedClusterClient) {
async function getRegressionAnalyticsBaseline(analyticsId: string): Promise<number | undefined> {
const { body } = await asInternalUser.ml.getDataFrameAnalytics({
id: analyticsId,
});
const jobConfig = body.data_frame_analytics[0];
if (!isRegressionAnalysis) return undefined;
const destinationIndex = jobConfig.dest.index;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const mlResultsField = jobConfig.dest?.results_field ?? DEFAULT_RESULTS_FIELD;
const predictedField = `${mlResultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(jobConfig.analysis)
}`;
const isTrainingField = `${mlResultsField}.is_training`;
const params = {
index: destinationIndex,
size: 0,
body: {
query: {
bool: {
filter: [
{
term: {
[isTrainingField]: true,
},
},
],
},
},
aggs: {
featureImportanceBaseline: {
avg: {
field: predictedField,
},
},
},
},
};
let baseline;
const { body: aggregationResult } = await asCurrentUser.search(params);
if (aggregationResult) {
baseline = aggregationResult.aggregations.featureImportanceBaseline.value;
}
return baseline;
}
return {
getRegressionAnalyticsBaseline,
};
}

View file

@ -20,6 +20,7 @@ import {
import { IndexPatternHandler } from '../models/data_frame_analytics/index_patterns';
import { DeleteDataFrameAnalyticsWithIndexStatus } from '../../common/types/data_frame_analytics';
import { getAuthorizationHeader } from '../lib/request_authorization';
import { analyticsFeatureImportanceProvider } from '../models/data_frame_analytics/feature_importance';
function getIndexPatternId(context: RequestHandlerContext, patternName: string) {
const iph = new IndexPatternHandler(context.core.savedObjects.client);
@ -545,4 +546,38 @@ export function dataFrameAnalyticsRoutes({ router, mlLicense }: RouteInitializat
}
})
);
/**
* @apiGroup DataFrameAnalytics
*
* @api {get} /api/ml/data_frame/analytics/baseline Get analytics's feature importance baseline
* @apiName GetDataFrameAnalyticsBaseline
* @apiDescription Returns the baseline for data frame analytics job.
*
* @apiSchema (params) analyticsIdSchema
*/
router.post(
{
path: '/api/ml/data_frame/analytics/{analyticsId}/baseline',
validate: {
params: analyticsIdSchema,
},
options: {
tags: ['access:ml:canGetDataFrameAnalytics'],
},
},
mlLicense.fullLicenseAPIGuard(async ({ client, request, response }) => {
try {
const { analyticsId } = request.params;
const { getRegressionAnalyticsBaseline } = analyticsFeatureImportanceProvider(client);
const baseline = await getRegressionAnalyticsBaseline(analyticsId);
return response.ok({
body: { baseline },
});
} catch (e) {
return response.customError(wrapError(e));
}
})
);
}

View file

@ -10,25 +10,9 @@ import { FtrProviderContext } from '../../ftr_provider_context';
import { MlCommonUI } from './common_ui';
import { MlApi } from './api';
import {
ClassificationAnalysis,
RegressionAnalysis,
} from '../../../../plugins/ml/common/types/data_frame_analytics';
enum ANALYSIS_CONFIG_TYPE {
OUTLIER_DETECTION = 'outlier_detection',
REGRESSION = 'regression',
CLASSIFICATION = 'classification',
}
const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
};
const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};
isRegressionAnalysis,
isClassificationAnalysis,
} from '../../../../plugins/ml/common/util/analytics_utils';
export function MachineLearningDataFrameAnalyticsCreationProvider(
{ getService }: FtrProviderContext,