@@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
2828
2929 private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index" ;
3030
31- private static final String ACTUAL_CLASS_FIELD = "actual_class_field" ;
32- private static final String PREDICTED_CLASS_FIELD = "predicted_class_field" ;
31+ private static final String ANIMAL_NAME_FIELD = "animal_name" ;
32+ private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction" ;
33+ private static final String NO_LEGS_FIELD = "no_legs" ;
34+ private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction" ;
35+ private static final String IS_PREDATOR_FIELD = "predator" ;
36+ private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction" ;
3337
3438 @ Before
3539 public void setup () {
@@ -41,9 +45,9 @@ public void cleanup() {
4145 cleanUp ();
4246 }
4347
44- public void testEvaluate_MulticlassClassification_DefaultMetrics () {
48+ public void testEvaluate_DefaultMetrics () {
4549 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
46- evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , null ));
50+ evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , null ));
4751
4852 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
4953 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -52,10 +56,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
5256 equalTo (MulticlassConfusionMatrix .NAME .getPreferredName ()));
5357 }
5458
55- public void testEvaluate_MulticlassClassification_Accuracy () {
59+ public void testEvaluate_Accuracy_KeywordField () {
5660 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
5761 evaluateDataFrame (
58- ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , Arrays .asList (new Accuracy ())));
62+ ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , Arrays .asList (new Accuracy ())));
5963
6064 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
6165 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -74,11 +78,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
7478 assertThat (accuracyResult .getOverallAccuracy (), equalTo (5.0 / 75 ));
7579 }
7680
77- public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize () {
81+ public void testEvaluate_Accuracy_IntegerField () {
82+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
83+ evaluateDataFrame (
84+ ANIMALS_DATA_INDEX , new Classification (NO_LEGS_FIELD , NO_LEGS_PREDICTION_FIELD , Arrays .asList (new Accuracy ())));
85+
86+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
87+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
88+
89+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
90+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
91+ assertThat (
92+ accuracyResult .getActualClasses (),
93+ equalTo (Arrays .asList (
94+ new Accuracy .ActualClass ("1" , 15 , 1.0 / 15 ),
95+ new Accuracy .ActualClass ("2" , 15 , 2.0 / 15 ),
96+ new Accuracy .ActualClass ("3" , 15 , 3.0 / 15 ),
97+ new Accuracy .ActualClass ("4" , 15 , 4.0 / 15 ),
98+ new Accuracy .ActualClass ("5" , 15 , 5.0 / 15 ))));
99+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (15.0 / 75 ));
100+ }
101+
102+ public void testEvaluate_Accuracy_BooleanField () {
103+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
104+ evaluateDataFrame (
105+ ANIMALS_DATA_INDEX , new Classification (IS_PREDATOR_FIELD , IS_PREDATOR_PREDICTION_FIELD , Arrays .asList (new Accuracy ())));
106+
107+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
108+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
109+
110+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
111+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
112+ assertThat (
113+ accuracyResult .getActualClasses (),
114+ equalTo (Arrays .asList (
115+ new Accuracy .ActualClass ("true" , 45 , 27.0 / 45 ),
116+ new Accuracy .ActualClass ("false" , 30 , 18.0 / 30 ))));
117+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (45.0 / 75 ));
118+ }
119+
120+ public void testEvaluate_ConfusionMatrixMetricWithDefaultSize () {
78121 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
79122 evaluateDataFrame (
80123 ANIMALS_DATA_INDEX ,
81- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , Arrays .asList (new MulticlassConfusionMatrix ())));
124+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , Arrays .asList (new MulticlassConfusionMatrix ())));
82125
83126 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
84127 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -137,11 +180,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
137180 assertThat (confusionMatrixResult .getOtherActualClassCount (), equalTo (0L ));
138181 }
139182
140- public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize () {
183+ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize () {
141184 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
142185 evaluateDataFrame (
143186 ANIMALS_DATA_INDEX ,
144- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , Arrays .asList (new MulticlassConfusionMatrix (3 ))));
187+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , Arrays .asList (new MulticlassConfusionMatrix (3 ))));
145188
146189 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
147190 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -168,20 +211,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP
168211
169212 private static void indexAnimalsData (String indexName ) {
170213 client ().admin ().indices ().prepareCreate (indexName )
171- .addMapping ("_doc" , ACTUAL_CLASS_FIELD , "type=keyword" , PREDICTED_CLASS_FIELD , "type=keyword" )
214+ .addMapping ("_doc" ,
215+ ANIMAL_NAME_FIELD , "type=keyword" ,
216+ ANIMAL_NAME_PREDICTION_FIELD , "type=keyword" ,
217+ NO_LEGS_FIELD , "type=integer" ,
218+ NO_LEGS_PREDICTION_FIELD , "type=integer" ,
219+ IS_PREDATOR_FIELD , "type=boolean" ,
220+ IS_PREDATOR_PREDICTION_FIELD , "type=boolean" )
172221 .get ();
173222
174- List <String > classNames = Arrays .asList ("dog" , "cat" , "mouse" , "ant" , "fox" );
223+ List <String > animalNames = Arrays .asList ("dog" , "cat" , "mouse" , "ant" , "fox" );
175224 BulkRequestBuilder bulkRequestBuilder = client ().prepareBulk ()
176225 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
177- for (int i = 0 ; i < classNames .size (); i ++) {
178- for (int j = 0 ; j < classNames .size (); j ++) {
226+ for (int i = 0 ; i < animalNames .size (); i ++) {
227+ for (int j = 0 ; j < animalNames .size (); j ++) {
179228 for (int k = 0 ; k < j + 1 ; k ++) {
180229 bulkRequestBuilder .add (
181230 new IndexRequest (indexName )
182231 .source (
183- ACTUAL_CLASS_FIELD , classNames .get (i ),
184- PREDICTED_CLASS_FIELD , classNames .get ((i + j ) % classNames .size ())));
232+ ANIMAL_NAME_FIELD , animalNames .get (i ),
233+ ANIMAL_NAME_PREDICTION_FIELD , animalNames .get ((i + j ) % animalNames .size ()),
234+ NO_LEGS_FIELD , i + 1 ,
235+ NO_LEGS_PREDICTION_FIELD , j + 1 ,
236+ IS_PREDATOR_FIELD , i % 2 == 0 ,
237+ IS_PREDATOR_PREDICTION_FIELD , (i + j ) % 2 == 0 ));
185238 }
186239 }
187240 }
0 commit comments