Skip to content

Commit 1bb518b

Browse files
feat: adds ValueConverter utility and demo samples (#108)
* feat: adds value converter utility class and demo samples * feat: samples updated for EJCL * fix: removed local file references * feat: adds ValueConverter tests Co-authored-by: yoshi-code-bot <[email protected]>
1 parent d6602ce commit 1bb518b

File tree

4 files changed

+68
-30
lines changed

4 files changed

+68
-30
lines changed

aiplatform/snippets/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
<dependency>
2828
<groupId>com.google.cloud</groupId>
2929
<artifactId>google-cloud-aiplatform</artifactId>
30-
<version>0.1.0</version>
30+
<version>0.1.1-SNAPSHOT</version>
3131
</dependency>
3232
<!-- [END aiplatform_install_with_bom] -->
3333
<dependency>

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_create_training_pipeline_image_classification_sample]
20-
20+
import com.google.cloud.aiplatform.util.ValueConverter;
2121
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2222
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2323
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -38,8 +38,8 @@
3838
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3939
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
4040
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
41-
import com.google.protobuf.Value;
42-
import com.google.protobuf.util.JsonFormat;
41+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
4343
import com.google.rpc.Status;
4444
import java.io.IOException;
4545

@@ -74,11 +74,13 @@ static void createTrainingPipelineImageClassificationSample(
7474
+ "automl_image_classification_1.0.0.yaml";
7575
LocationName locationName = LocationName.of(project, location);
7676

77-
String jsonString =
78-
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
79-
+ " \"disableEarlyStopping\": false}";
80-
Value.Builder trainingTaskInputs = Value.newBuilder();
81-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
77+
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
78+
AutoMlImageClassificationInputs.newBuilder()
79+
.setModelType(ModelType.CLOUD)
80+
.setMultiLabel(false)
81+
.setBudgetMilliNodeHours(8000)
82+
.setDisableEarlyStopping(false)
83+
.build();
8284

8385
InputDataConfig trainingInputDataConfig =
8486
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
@@ -87,7 +89,7 @@ static void createTrainingPipelineImageClassificationSample(
8789
TrainingPipeline.newBuilder()
8890
.setDisplayName(trainingPipelineDisplayName)
8991
.setTrainingTaskDefinition(trainingTaskDefinition)
90-
.setTrainingTaskInputs(trainingTaskInputs)
92+
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
9193
.setInputDataConfig(trainingInputDataConfig)
9294
.setModelToUpload(model)
9395
.build();

aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
// [START aiplatform_predict_image_classification_sample]
2020

2121
import com.google.api.client.util.Base64;
22+
import com.google.cloud.aiplatform.util.ValueConverter;
2223
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2324
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
2425
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
2526
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
27+
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance;
28+
import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams;
29+
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
2630
import com.google.protobuf.Value;
27-
import com.google.protobuf.util.JsonFormat;
2831
import java.io.IOException;
2932
import java.nio.charset.StandardCharsets;
3033
import java.nio.file.Files;
@@ -60,23 +63,42 @@ static void predictImageClassification(String project, String fileName, String e
6063
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
6164
String content = new String(contents, StandardCharsets.UTF_8);
6265

63-
Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
64-
65-
String contentDict = "{\"content\": \"" + content + "\"}";
66-
Value.Builder instance = Value.newBuilder();
67-
JsonFormat.parser().merge(contentDict, instance);
66+
ImageClassificationPredictionInstance predictionInstance =
67+
ImageClassificationPredictionInstance.newBuilder()
68+
.setContent(content)
69+
.build();
6870

6971
List<Value> instances = new ArrayList<>();
70-
instances.add(instance.build());
72+
instances.add(ValueConverter.toValue(predictionInstance));
73+
74+
ImageClassificationPredictionParams predictionParams =
75+
ImageClassificationPredictionParams.newBuilder()
76+
.setConfidenceThreshold((float) 0.5)
77+
.setMaxPredictions(5)
78+
.build();
7179

7280
PredictResponse predictResponse =
73-
predictionServiceClient.predict(endpointName, instances, parameter);
81+
predictionServiceClient.predict(endpointName, instances,
82+
ValueConverter.toValue(predictionParams));
7483
System.out.println("Predict Image Classification Response");
7584
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
7685

7786
System.out.println("Predictions");
7887
for (Value prediction : predictResponse.getPredictionsList()) {
79-
System.out.format("\tPrediction: %s\n", prediction);
88+
89+
ClassificationPredictionResult.Builder resultBuilder =
90+
ClassificationPredictionResult.newBuilder();
91+
// Display names and confidences values correspond to
92+
// IDs in the ID list.
93+
ClassificationPredictionResult result =
94+
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
95+
int counter = 0;
96+
for (Long id : result.getIdsList()) {
97+
System.out.printf("Label ID: %d\n", id);
98+
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
99+
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
100+
counter++;
101+
}
80102
}
81103
}
82104
}

aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_predict_text_classification_sample]
20-
20+
import com.google.cloud.aiplatform.util.ValueConverter;
2121
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2222
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
2323
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
2424
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
25+
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextClassificationPredictionInstance;
26+
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
2527
import com.google.protobuf.Value;
26-
import com.google.protobuf.util.JsonFormat;
2728
import java.io.IOException;
2829
import java.util.ArrayList;
2930
import java.util.List;
@@ -52,25 +53,38 @@ static void predictTextClassificationSingleLabel(
5253
try (PredictionServiceClient predictionServiceClient =
5354
PredictionServiceClient.create(predictionServiceSettings)) {
5455
String location = "us-central1";
55-
String jsonString = "{\"content\": \"" + content + "\"}";
56-
5756
EndpointName endpointName = EndpointName.of(project, location, endpointId);
5857

59-
Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
60-
Value.Builder instance = Value.newBuilder();
61-
JsonFormat.parser().merge(jsonString, instance);
58+
TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance
59+
.newBuilder()
60+
.setContent(content)
61+
.build();
6262

6363
List<Value> instances = new ArrayList<>();
64-
instances.add(instance.build());
64+
instances.add(ValueConverter.toValue(predictionInstance));
6565

6666
PredictResponse predictResponse =
67-
predictionServiceClient.predict(endpointName, instances, parameter);
67+
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
6868
System.out.println("Predict Text Classification Response");
6969
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
7070

71-
System.out.println("Predictions");
71+
System.out.println("Predictions:\n\n");
7272
for (Value prediction : predictResponse.getPredictionsList()) {
73-
System.out.format("\tPrediction: %s\n", prediction);
73+
74+
ClassificationPredictionResult.Builder resultBuilder =
75+
ClassificationPredictionResult.newBuilder();
76+
77+
// Display names and confidences values correspond to
78+
// IDs in the ID list.
79+
ClassificationPredictionResult result =
80+
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
81+
int counter = 0;
82+
for (Long id : result.getIdsList()) {
83+
System.out.printf("Label ID: %d\n", id);
84+
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
85+
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
86+
counter++;
87+
}
7488
}
7589
}
7690
}

0 commit comments

Comments
 (0)