[ML] Support trained model aliases (#99174)

* [ML] use toString to make sure boolean values are rendered

* [ML] extract metadata

* [ML] retrieve pipelines associated with model alias

* [ML] fix ts issues

* [ML] functional tests

* [ML] remove unused models definitions
This commit is contained in:
Dima Arnautov 2021-05-05 13:34:07 +02:00 committed by GitHub
parent 8ce3ea0238
commit 5ce926b9f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 251 additions and 18 deletions

View file

@ -46,3 +46,7 @@ export interface ListingPageUrlState {
export type AppPageState<T> = {
[key in MlPages]?: Partial<T>;
};
type Without<T, U> = { [P in Exclude<keyof T, keyof U>]?: never };
export type XOR<T, U> = T | U extends object ? (Without<T, U> & U) | (Without<U, T> & T) : T | U;

View file

@ -7,6 +7,7 @@
import { DataFrameAnalyticsConfig } from './data_frame_analytics';
import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';
import { XOR } from './common';
export interface IngestStats {
count: number;
@ -45,22 +46,54 @@ export interface TrainedModelStat {
};
}
type TreeNode = object;
export type PutTrainedModelConfig = {
description?: string;
metadata?: {
analytics_config: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
model_aliases?: string[];
} & Record<string, unknown>;
tags?: string[];
inference_config?: Record<string, unknown>;
input: { field_names: string[] };
} & XOR<
{ compressed_definition: string },
{
definition: {
preprocessors: object[];
trained_model: {
tree: {
classification_labels?: string;
feature_names: string;
target_type: string;
tree_structure: TreeNode[];
};
tree_node: TreeNode;
ensemble?: object;
};
};
}
>; // compressed_definition and definition are mutually exclusive
export interface TrainedModelConfigResponse {
description: string;
description?: string;
created_by: string;
create_time: string;
default_field_map: Record<string, string>;
estimated_heap_memory_usage_bytes: number;
estimated_operations: number;
license_level: string;
metadata?:
| {
analytics_config: DataFrameAnalyticsConfig;
input: any;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
}
| Record<string, any>;
metadata?: {
analytics_config: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
model_aliases?: string[];
} & Record<string, unknown>;
model_id: string;
tags: string[];
version: string;

View file

@ -31,6 +31,7 @@ export const AnalyticsNavigationBar: FC<{
defaultMessage: 'Jobs',
}),
path: '/data_frame_analytics',
testSubj: 'mlAnalyticsJobsTab',
},
{
id: 'models',
@ -38,6 +39,7 @@ export const AnalyticsNavigationBar: FC<{
defaultMessage: 'Models',
}),
path: '/data_frame_analytics/models',
testSubj: 'mlTrainedModelsTab',
},
];
if (jobId !== undefined || modelId !== undefined) {
@ -47,6 +49,7 @@ export const AnalyticsNavigationBar: FC<{
defaultMessage: 'Map',
}),
path: '/data_frame_analytics/map',
testSubj: '',
});
}
return navTabs;
@ -67,6 +70,7 @@ export const AnalyticsNavigationBar: FC<{
key={`tab-${tab.id}`}
isSelected={tab.id === selectedTabId}
onClick={onTabClick.bind(null, tab)}
data-test-subj={tab.testSubj}
>
{tab.name}
</EuiTab>

View file

@ -29,6 +29,7 @@ import { ModelItemFull } from './models_list';
import { useMlKibana } from '../../../../../contexts/kibana';
import { timeFormatter } from '../../../../../../../common/util/date_utils';
import { isDefined } from '../../../../../../../common/types/guards';
import { isPopulatedObject } from '../../../../../../../common';
interface ExpandedRowProps {
item: ModelItemFull;
@ -70,6 +71,8 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
description,
} = item;
const { analytics_config: analyticsConfig, ...restMetaData } = metadata ?? {};
const details = {
description,
tags,
@ -148,6 +151,26 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
/>
</EuiPanel>
</EuiFlexItem>
{isPopulatedObject(restMetaData) ? (
<EuiFlexItem>
<EuiPanel>
<EuiTitle size={'xs'}>
<h5>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.metadataTitle"
defaultMessage="Metadata"
/>
</h5>
</EuiTitle>
<EuiSpacer size={'m'} />
<EuiDescriptionList
compressed={true}
type="column"
listItems={formatToListItems(restMetaData)}
/>
</EuiPanel>
</EuiFlexItem>
) : null}
</EuiFlexGrid>
</>
),
@ -186,7 +209,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
/>
</EuiPanel>
</EuiFlexItem>
{metadata?.analytics_config && (
{analyticsConfig && (
<EuiFlexItem>
<EuiPanel>
<EuiTitle size={'xs'}>
@ -201,7 +224,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<EuiDescriptionList
compressed={true}
type="column"
listItems={formatToListItems(metadata.analytics_config)}
listItems={formatToListItems(analyticsConfig)}
/>
</EuiPanel>
</EuiFlexItem>

View file

@ -292,7 +292,7 @@ export const ModelsList: FC = () => {
}),
icon: 'visTable',
type: 'icon',
available: (item) => item.metadata?.analytics_config?.id,
available: (item) => !!item.metadata?.analytics_config?.id,
onClick: async (item) => {
if (item.metadata?.analytics_config === undefined) return;
@ -327,7 +327,7 @@ export const ModelsList: FC = () => {
icon: 'graphApp',
type: 'icon',
isPrimary: true,
available: (item) => item.metadata?.analytics_config?.id,
available: (item) => !!item.metadata?.analytics_config?.id,
onClick: async (item) => {
const path = await mlUrlGenerator.createUrl({
page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP,

View file

@ -11,8 +11,8 @@ import { PipelineDefinition } from '../../../common/types/trained_models';
export function modelsProvider(client: IScopedClusterClient) {
return {
/**
* Retrieves the map of model ids and associated pipelines.
* @param modelIds
* Retrieves the map of model ids and aliases with associated pipelines.
* @param modelIds - Array of models ids and model aliases.
*/
async getModelsPipelines(modelIds: string[]) {
const modelIdsMap = new Map<string, Record<string, PipelineDefinition> | null>(

View file

@ -13,6 +13,7 @@ import {
optionalModelIdSchema,
} from './schemas/inference_schema';
import { modelsProvider } from '../models/data_frame_analytics';
import { TrainedModelConfigResponse } from '../../common/types/trained_models';
export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) {
/**
@ -42,14 +43,32 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
...query,
...(modelId ? { model_id: modelId } : {}),
});
const result = body.trained_model_configs;
const result = body.trained_model_configs as TrainedModelConfigResponse[];
try {
if (withPipelines) {
const modelIdsAndAliases: string[] = Array.from(
new Set(
result
.map(({ model_id: id, metadata }) => {
return [id, ...(metadata?.model_aliases ?? [])];
})
.flat()
)
);
const pipelinesResponse = await modelsProvider(client).getModelsPipelines(
result.map(({ model_id: id }: { model_id: string }) => id)
modelIdsAndAliases
);
for (const model of result) {
model.pipelines = pipelinesResponse.get(model.model_id)!;
model.pipelines = {
...(pipelinesResponse.get(model.model_id) ?? {}),
...(model.metadata?.model_aliases ?? []).reduce((acc, alias) => {
return {
...acc,
...(pipelinesResponse.get(alias) ?? {}),
};
}, {}),
};
}
}
} catch (e) {

View file

@ -16,5 +16,6 @@ export default function ({ loadTestFile }: FtrProviderContext) {
loadTestFile(require.resolve('./classification_creation'));
loadTestFile(require.resolve('./cloning'));
loadTestFile(require.resolve('./feature_importance'));
loadTestFile(require.resolve('./trained_models'));
});
}

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { FtrProviderContext } from '../../../ftr_provider_context';
export default function ({ getService }: FtrProviderContext) {
const ml = getService('ml');
describe('trained models', function () {
before(async () => {
await ml.trainedModels.createdTestTrainedModels('classification', 15);
await ml.trainedModels.createdTestTrainedModels('regression', 15);
await ml.securityUI.loginAsMlPowerUser();
await ml.navigation.navigateToTrainedModels();
});
after(async () => {
await ml.api.cleanMlIndices();
});
it('renders trained models list', async () => {
await ml.trainedModels.assertRowsNumberPerPage(10);
// +1 because of the built-in model
await ml.trainedModels.assertStats(31);
});
});
}

View file

@ -23,6 +23,7 @@ import {
ML_ANNOTATIONS_INDEX_ALIAS_WRITE,
} from '../../../../plugins/ml/common/constants/index_patterns';
import { COMMON_REQUEST_HEADERS } from '../../../functional/services/ml/common_api';
import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models';
export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
const es = getService('es');
@ -935,5 +936,17 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
}
}
},
async createTrainedModel(modelId: string, body: PutTrainedModelConfig) {
log.debug(`Creating trained model with id "${modelId}"`);
const model = await esSupertest
.put(`/_ml/trained_models/${modelId}`)
.send(body)
.expect(200)
.then((res: any) => res.body);
log.debug('> Trained model crated');
return model;
},
};
}

View file

@ -245,5 +245,28 @@ export function MachineLearningCommonUIProvider({ getService }: FtrProviderConte
);
});
},
async assertRowsNumberPerPage(testSubj: string, rowsNumber: 10 | 25 | 100) {
const textContent = await testSubjects.getVisibleText(
`${testSubj} > tablePaginationPopoverButton`
);
expect(textContent).to.be(`Rows per page: ${rowsNumber}`);
},
async ensurePagePopupOpen(testSubj: string) {
await retry.tryForTime(5000, async () => {
const isOpen = await testSubjects.exists('tablePagination-10-rows');
if (!isOpen) {
await testSubjects.click(`${testSubj} > tablePaginationPopoverButton`);
await testSubjects.existOrFail('tablePagination-10-rows');
}
});
},
async setRowsNumberPerPage(testSubj: string, rowsNumber: 10 | 25 | 100) {
await this.ensurePagePopupOpen(testSubj);
await testSubjects.click(`tablePagination-${rowsNumber}-rows`);
await this.assertRowsNumberPerPage(testSubj, rowsNumber);
},
};
}

View file

@ -48,6 +48,7 @@ import { MachineLearningAlertingProvider } from './alerting';
import { SwimLaneProvider } from './swim_lane';
import { MachineLearningDashboardJobSelectionTableProvider } from './dashboard_job_selection_table';
import { MachineLearningDashboardEmbeddablesProvider } from './dashboard_embeddables';
import { TrainedModelsProvider } from './trained_models';
export function MachineLearningProvider(context: FtrProviderContext) {
const commonAPI = MachineLearningCommonAPIProvider(context);
@ -108,6 +109,7 @@ export function MachineLearningProvider(context: FtrProviderContext) {
const testResources = MachineLearningTestResourcesProvider(context);
const alerting = MachineLearningAlertingProvider(context, commonUI);
const swimLane = SwimLaneProvider(context);
const trainedModels = TrainedModelsProvider(context, api, commonUI);
return {
anomaliesTable,
@ -151,5 +153,6 @@ export function MachineLearningProvider(context: FtrProviderContext) {
swimLane,
testExecution,
testResources,
trainedModels,
};
}

View file

@ -115,6 +115,13 @@ export function MachineLearningNavigationProvider({
await this.navigateToArea('~mlMainTab & ~dataFrameAnalytics', 'mlPageDataFrameAnalytics');
},
async navigateToTrainedModels() {
await this.navigateToMl();
await this.navigateToDataFrameAnalytics();
await testSubjects.click('mlTrainedModelsTab');
await testSubjects.existOrFail('mlModelsTableContainer');
},
async navigateToDataVisualizer() {
await this.navigateToArea('~mlMainTab & ~dataVisualizer', 'mlPageDataVisualizerSelector');
},

View file

@ -0,0 +1 @@
H4sICOE6Ol8AA2NsZi5qc29uAD2MQQqAIBBF955CXHeCrhIhg44xYBo6LUK8e1rW4i/ef59fhJSKE1BAq/do0atZllY+NeJPjR0Cnwl1gB1zE8sQXcWoBq3Tt2dIG7Lm6+g3ynjImRwZYIrhnVfRU28zTg0thgAAAA==

View file

@ -0,0 +1 @@
H4sICOc8Ol8AA3JnLmpzb24APYxBCoAgFET3nkJcd4KuEiGCkwip8f0tQrx7WtZiFm/eMEVIqZiMj7A6JItdzbK08qmBnxpvMHwSdDQBuYlliK5SUoPW6duzIQfWfB39RhEcIWef4jutoqfeWtCVIIIAAAA=

View file

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import fs from 'fs';
import path from 'path';
import expect from '@kbn/expect';
import { FtrProviderContext } from '../../ftr_provider_context';
import { MlApi } from './api';
import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models';
import { MlCommonUI } from './common_ui';
type ModelType = 'regression' | 'classification';
export function TrainedModelsProvider(
{ getService }: FtrProviderContext,
mlApi: MlApi,
mlCommonUI: MlCommonUI
) {
const testSubjects = getService('testSubjects');
return {
async createdTestTrainedModels(modelType: ModelType, count: number = 10) {
const compressedDefinition = this.getCompressedModelDefinition(modelType);
const models = new Array(count).fill(null).map((v, i) => {
return {
model_id: `dfa_${modelType}_model_n_${i}`,
body: {
compressed_definition: compressedDefinition,
inference_config: {
[modelType]: {},
},
input: {
field_names: ['common_field'],
},
} as PutTrainedModelConfig,
};
});
for (const model of models) {
await mlApi.createTrainedModel(model.model_id, model.body);
}
},
getCompressedModelDefinition(modelType: ModelType) {
return fs.readFileSync(
path.resolve(
__dirname,
'resources',
'trained_model_definitions',
`minimum_valid_config_${modelType}.json.gz.b64`
),
'utf-8'
);
},
async assertStats(expectedTotalCount: number) {
const actualStats = await testSubjects.getVisibleText('mlInferenceModelsStatsBar');
expect(actualStats).to.eql(`Total trained models: ${expectedTotalCount}`);
},
async assertRowsNumberPerPage(rowsNumber: 10 | 25 | 100) {
await mlCommonUI.assertRowsNumberPerPage('mlModelsTableContainer', rowsNumber);
},
};
}