|
19 | 19 | // [START aiplatform_predict_image_classification_sample] |
20 | 20 |
|
21 | 21 | import com.google.api.client.util.Base64; |
| 22 | +import com.google.cloud.aiplatform.util.ValueConverter; |
22 | 23 | import com.google.cloud.aiplatform.v1beta1.EndpointName; |
23 | 24 | import com.google.cloud.aiplatform.v1beta1.PredictResponse; |
24 | 25 | import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; |
25 | 26 | import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; |
| 27 | +import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance; |
| 28 | +import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams; |
| 29 | +import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult; |
26 | 30 | import com.google.protobuf.Value; |
27 | | -import com.google.protobuf.util.JsonFormat; |
28 | 31 | import java.io.IOException; |
29 | 32 | import java.nio.charset.StandardCharsets; |
30 | 33 | import java.nio.file.Files; |
@@ -60,23 +63,42 @@ static void predictImageClassification(String project, String fileName, String e |
60 | 63 | byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName))); |
61 | 64 | String content = new String(contents, StandardCharsets.UTF_8); |
62 | 65 |
|
63 | | - Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); |
64 | | - |
65 | | - String contentDict = "{\"content\": \"" + content + "\"}"; |
66 | | - Value.Builder instance = Value.newBuilder(); |
67 | | - JsonFormat.parser().merge(contentDict, instance); |
| 66 | + ImageClassificationPredictionInstance predictionInstance = |
| 67 | + ImageClassificationPredictionInstance.newBuilder() |
| 68 | + .setContent(content) |
| 69 | + .build(); |
68 | 70 |
|
69 | 71 | List<Value> instances = new ArrayList<>(); |
70 | | - instances.add(instance.build()); |
| 72 | + instances.add(ValueConverter.toValue(predictionInstance)); |
| 73 | + |
| 74 | + ImageClassificationPredictionParams predictionParams = |
| 75 | + ImageClassificationPredictionParams.newBuilder() |
| 76 | + .setConfidenceThreshold((float) 0.5) |
| 77 | + .setMaxPredictions(5) |
| 78 | + .build(); |
71 | 79 |
|
72 | 80 | PredictResponse predictResponse = |
73 | | - predictionServiceClient.predict(endpointName, instances, parameter); |
| 81 | + predictionServiceClient.predict(endpointName, instances, |
| 82 | + ValueConverter.toValue(predictionParams)); |
74 | 83 | System.out.println("Predict Image Classification Response"); |
75 | 84 | System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); |
76 | 85 |
|
77 | 86 | System.out.println("Predictions"); |
78 | 87 | for (Value prediction : predictResponse.getPredictionsList()) { |
79 | | - System.out.format("\tPrediction: %s\n", prediction); |
| 88 | + |
| 89 | + ClassificationPredictionResult.Builder resultBuilder = |
| 90 | + ClassificationPredictionResult.newBuilder(); |
| 91 | + // Display names and confidences values correspond to |
| 92 | + // IDs in the ID list. |
| 93 | + ClassificationPredictionResult result = |
| 94 | + (ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); |
| 95 | + int counter = 0; |
| 96 | + for (Long id : result.getIdsList()) { |
| 97 | + System.out.printf("Label ID: %d\n", id); |
| 98 | + System.out.printf("Label: %s\n", result.getDisplayNames(counter)); |
| 99 | + System.out.printf("Confidence: %.4f\n", result.getConfidences(counter)); |
| 100 | + counter++; |
| 101 | + } |
80 | 102 | } |
81 | 103 | } |
82 | 104 | } |
|
0 commit comments