Skip to content

Commit 0965a10

Browse files
authored
[7.x] Pass prediction_field_type to C++ analytics process (#49861) (#49981)
1 parent 049d854 commit 0965a10

File tree

21 files changed

+313
-127
lines changed

21 files changed

+313
-127
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
6767
.flatMap(Set::stream)
6868
.collect(Collectors.toSet()));
6969

70+
/**
71+
* Name of the parameter passed down to C++.
72+
* This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction.
73+
*/
74+
private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
75+
7076
/**
7177
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
7278
* This way the user can see if the prediction was made with confidence they need.
@@ -154,17 +160,38 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
154160
}
155161

156162
@Override
157-
public Map<String, Object> getParams() {
163+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
158164
Map<String, Object> params = new HashMap<>();
159165
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
160166
params.putAll(boostedTreeParams.getParams());
161167
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
162168
if (predictionFieldName != null) {
163169
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
164170
}
171+
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
172+
if (predictionFieldType != null) {
173+
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
174+
}
165175
return params;
166176
}
167177

178+
private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
179+
if (dependentVariableTypes == null) {
180+
return null;
181+
}
182+
if (Types.categorical().containsAll(dependentVariableTypes)) {
183+
return "string";
184+
}
185+
if (Types.bool().containsAll(dependentVariableTypes)) {
186+
return "bool";
187+
}
188+
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
189+
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
190+
return "int";
191+
}
192+
return null;
193+
}
194+
168195
@Override
169196
public boolean supportsCategoricalFields() {
170197
return true;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
1616

1717
/**
1818
* @return The analysis parameters as a map
19+
* @param extractedFields map of (name, types) for all the extracted fields
1920
*/
20-
Map<String, Object> getParams();
21+
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
2122

2223
/**
2324
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ public int hashCode() {
192192
}
193193

194194
@Override
195-
public Map<String, Object> getParams() {
195+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
196196
Map<String, Object> params = new HashMap<>();
197197
if (nNeighbors != null) {
198198
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
124124
}
125125

126126
@Override
127-
public Map<String, Object> getParams() {
127+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
128128
Map<String, Object> params = new HashMap<>();
129129
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
130130
params.putAll(boostedTreeParams.getParams());

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88
import org.elasticsearch.ElasticsearchStatusException;
99
import org.elasticsearch.common.io.stream.Writeable;
1010
import org.elasticsearch.common.xcontent.XContentParser;
11+
import org.elasticsearch.index.mapper.BooleanFieldMapper;
12+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
13+
import org.elasticsearch.index.mapper.NumberFieldMapper;
1114
import org.elasticsearch.test.AbstractSerializingTestCase;
15+
import org.hamcrest.Matchers;
1216

1317
import java.io.IOException;
18+
import java.util.Collections;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import java.util.Set;
1422

1523
import static org.hamcrest.Matchers.equalTo;
24+
import static org.hamcrest.Matchers.hasEntry;
1625
import static org.hamcrest.Matchers.is;
1726
import static org.hamcrest.Matchers.not;
1827
import static org.hamcrest.Matchers.nullValue;
@@ -115,6 +124,34 @@ public void testGetTrainingPercent() {
115124
assertThat(classification.getTrainingPercent(), equalTo(100.0));
116125
}
117126

127+
public void testGetParams() {
128+
Map<String, Set<String>> extractedFields = new HashMap<>(3);
129+
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
130+
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
131+
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
132+
assertThat(
133+
new Classification("foo").getParams(extractedFields),
134+
Matchers.<Map<String, Object>>allOf(
135+
hasEntry("dependent_variable", "foo"),
136+
hasEntry("num_top_classes", 2),
137+
hasEntry("prediction_field_name", "foo_prediction"),
138+
hasEntry("prediction_field_type", "bool")));
139+
assertThat(
140+
new Classification("bar").getParams(extractedFields),
141+
Matchers.<Map<String, Object>>allOf(
142+
hasEntry("dependent_variable", "bar"),
143+
hasEntry("num_top_classes", 2),
144+
hasEntry("prediction_field_name", "bar_prediction"),
145+
hasEntry("prediction_field_type", "int")));
146+
assertThat(
147+
new Classification("baz").getParams(extractedFields),
148+
Matchers.<Map<String, Object>>allOf(
149+
hasEntry("dependent_variable", "baz"),
150+
hasEntry("num_top_classes", 2),
151+
hasEntry("prediction_field_name", "baz_prediction"),
152+
hasEntry("prediction_field_type", "string")));
153+
}
154+
118155
public void testFieldCardinalityLimitsIsNonNull() {
119156
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
120157
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ protected Writeable.Reader<OutlierDetection> instanceReader() {
5151

5252
public void testGetParams_GivenDefaults() {
5353
OutlierDetection outlierDetection = new OutlierDetection.Builder().build();
54-
Map<String, Object> params = outlierDetection.getParams();
54+
Map<String, Object> params = outlierDetection.getParams(null);
5555
assertThat(params.size(), equalTo(3));
5656
assertThat(params.containsKey("compute_feature_influence"), is(true));
5757
assertThat(params.get("compute_feature_influence"), is(true));
@@ -71,7 +71,7 @@ public void testGetParams_GivenExplicitValues() {
7171
.setStandardizationEnabled(false)
7272
.build();
7373

74-
Map<String, Object> params = outlierDetection.getParams();
74+
Map<String, Object> params = outlierDetection.getParams(null);
7575

7676
assertThat(params.size(), equalTo(6));
7777
assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42));

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
import java.io.IOException;
1414

15+
import static org.hamcrest.Matchers.allOf;
1516
import static org.hamcrest.Matchers.equalTo;
17+
import static org.hamcrest.Matchers.hasEntry;
1618
import static org.hamcrest.Matchers.is;
1719
import static org.hamcrest.Matchers.not;
1820
import static org.hamcrest.Matchers.nullValue;
@@ -83,6 +85,12 @@ public void testGetTrainingPercent() {
8385
assertThat(regression.getTrainingPercent(), equalTo(100.0));
8486
}
8587

88+
public void testGetParams() {
89+
assertThat(
90+
new Regression("foo").getParams(null),
91+
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction")));
92+
}
93+
8694
public void testFieldCardinalityLimitsIsNonNull() {
8795
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
8896
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
7+
8+
import org.elasticsearch.common.util.set.Sets;
9+
import org.elasticsearch.test.ESTestCase;
10+
11+
import static org.hamcrest.Matchers.empty;
12+
13+
public class TypesTests extends ESTestCase {
14+
15+
public void testTypes() {
16+
assertThat(Sets.intersection(Types.bool(), Types.categorical()), empty());
17+
assertThat(Sets.intersection(Types.categorical(), Types.numerical()), empty());
18+
assertThat(Sets.intersection(Types.numerical(), Types.bool()), empty());
19+
assertThat(Sets.difference(Types.discreteNumerical(), Types.numerical()), empty());
20+
}
21+
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
2828

2929
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
3030

31-
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
32-
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
31+
private static final String ANIMAL_NAME_FIELD = "animal_name";
32+
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
33+
private static final String NO_LEGS_FIELD = "no_legs";
34+
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
35+
private static final String IS_PREDATOR_FIELD = "predator";
36+
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
3337

3438
@Before
3539
public void setup() {
@@ -41,9 +45,9 @@ public void cleanup() {
4145
cleanUp();
4246
}
4347

44-
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
48+
public void testEvaluate_DefaultMetrics() {
4549
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
46-
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
50+
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
4751

4852
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
4953
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -52,10 +56,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
5256
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
5357
}
5458

55-
public void testEvaluate_MulticlassClassification_Accuracy() {
59+
public void testEvaluate_Accuracy_KeywordField() {
5660
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
5761
evaluateDataFrame(
58-
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
62+
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
5963

6064
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
6165
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -74,11 +78,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
7478
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
7579
}
7680

77-
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
81+
public void testEvaluate_Accuracy_IntegerField() {
82+
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
83+
evaluateDataFrame(
84+
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
85+
86+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
87+
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
88+
89+
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
90+
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
91+
assertThat(
92+
accuracyResult.getActualClasses(),
93+
equalTo(Arrays.asList(
94+
new Accuracy.ActualClass("1", 15, 1.0 / 15),
95+
new Accuracy.ActualClass("2", 15, 2.0 / 15),
96+
new Accuracy.ActualClass("3", 15, 3.0 / 15),
97+
new Accuracy.ActualClass("4", 15, 4.0 / 15),
98+
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
99+
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
100+
}
101+
102+
public void testEvaluate_Accuracy_BooleanField() {
103+
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
104+
evaluateDataFrame(
105+
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
106+
107+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
108+
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
109+
110+
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
111+
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
112+
assertThat(
113+
accuracyResult.getActualClasses(),
114+
equalTo(Arrays.asList(
115+
new Accuracy.ActualClass("true", 45, 27.0 / 45),
116+
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
117+
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
118+
}
119+
120+
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
78121
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
79122
evaluateDataFrame(
80123
ANIMALS_DATA_INDEX,
81-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
124+
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
82125

83126
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
84127
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -137,11 +180,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
137180
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
138181
}
139182

140-
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
183+
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
141184
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
142185
evaluateDataFrame(
143186
ANIMALS_DATA_INDEX,
144-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
187+
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
145188

146189
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
147190
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -168,20 +211,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP
168211

169212
private static void indexAnimalsData(String indexName) {
170213
client().admin().indices().prepareCreate(indexName)
171-
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
214+
.addMapping("_doc",
215+
ANIMAL_NAME_FIELD, "type=keyword",
216+
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
217+
NO_LEGS_FIELD, "type=integer",
218+
NO_LEGS_PREDICTION_FIELD, "type=integer",
219+
IS_PREDATOR_FIELD, "type=boolean",
220+
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
172221
.get();
173222

174-
List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
223+
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
175224
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
176225
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
177-
for (int i = 0; i < classNames.size(); i++) {
178-
for (int j = 0; j < classNames.size(); j++) {
226+
for (int i = 0; i < animalNames.size(); i++) {
227+
for (int j = 0; j < animalNames.size(); j++) {
179228
for (int k = 0; k < j + 1; k++) {
180229
bulkRequestBuilder.add(
181230
new IndexRequest(indexName)
182231
.source(
183-
ACTUAL_CLASS_FIELD, classNames.get(i),
184-
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
232+
ANIMAL_NAME_FIELD, animalNames.get(i),
233+
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
234+
NO_LEGS_FIELD, i + 1,
235+
NO_LEGS_PREDICTION_FIELD, j + 1,
236+
IS_PREDATOR_FIELD, i % 2 == 0,
237+
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
185238
}
186239
}
187240
}

0 commit comments

Comments
 (0)