Skip to content

Commit c9f50da

Browse files
dafrizmarkpollack
authored andcommitted
Add metadata to BedrockAnthropic3ChatModel response
This commit enhances the Bedrock Anthropic model's output: - Add response ID, model name, and usage data to ChatResponseMetadata - Introduce DefaultUsage class for token usage information - Update BedrockAnthropic3ChatModel to include new metadata - Add Jackson annotations for serialization/deserialization - Implement unit tests for DefaultUsage These changes provide structured, serializable metadata in the ChatResponse, improving the model's output with additional information.
1 parent 6d38c85 commit c9f50da

File tree

3 files changed

+246
-4
lines changed

3 files changed

+246
-4
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
import java.util.ArrayList;
1919
import java.util.Base64;
2020
import java.util.List;
21-
import java.util.Map;
2221
import java.util.concurrent.atomic.AtomicReference;
2322
import java.util.stream.Collectors;
2423

24+
import org.springframework.ai.chat.messages.AssistantMessage;
2525
import org.springframework.ai.chat.messages.UserMessage;
26+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
27+
import org.springframework.ai.chat.metadata.DefaultUsage;
28+
import org.springframework.ai.chat.metadata.Usage;
2629
import reactor.core.publisher.Flux;
2730

2831
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
@@ -82,11 +85,17 @@ public ChatResponse call(Prompt prompt) {
8285
AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
8386

8487
List<Generation> generations = response.content().stream().map(content -> {
85-
return new Generation(content.text(), Map.of())
86-
.withGenerationMetadata(ChatGenerationMetadata.from(response.stopReason(), null));
88+
return new Generation(new AssistantMessage(content.text()),
89+
ChatGenerationMetadata.from(response.stopReason(), null));
8790
}).toList();
8891

89-
return new ChatResponse(generations);
92+
ChatResponseMetadata metadata = ChatResponseMetadata.builder()
93+
.withId(response.id())
94+
.withModel(response.model())
95+
.withUsage(extractUsage(response))
96+
.build();
97+
98+
return new ChatResponse(generations, metadata);
9099
}
91100

92101
@Override
@@ -116,6 +125,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
116125
});
117126
}
118127

128+
protected Usage extractUsage(AnthropicChatResponse response) {
129+
return new DefaultUsage(response.usage().inputTokens().longValue(),
130+
response.usage().outputTokens().longValue());
131+
}
132+
119133
/**
120134
* Accessible for testing.
121135
*/
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.chat.metadata;
17+
18+
import com.fasterxml.jackson.annotation.JsonCreator;
19+
import com.fasterxml.jackson.annotation.JsonProperty;
20+
21+
import java.util.Objects;
22+
23+
/**
24+
* Default implementation of the {@link Usage} interface.
25+
*
26+
* @author Mark Pollack
27+
* @since 1.0.0
28+
*/
29+
public class DefaultUsage implements Usage {
30+
31+
private final Long promptTokens;
32+
33+
private final Long generationTokens;
34+
35+
private final Long totalTokens;
36+
37+
/**
38+
* Create a new DefaultUsage with promptTokens and generationTokens.
39+
* @param promptTokens the number of tokens in the prompt, or {@code null} if not
40+
* available
41+
* @param generationTokens the number of tokens in the generation, or {@code null} if
42+
* not available
43+
*/
44+
public DefaultUsage(Long promptTokens, Long generationTokens) {
45+
this(promptTokens, generationTokens, null);
46+
}
47+
48+
/**
49+
* Create a new DefaultUsage with promptTokens, generationTokens, and totalTokens.
50+
* @param promptTokens the number of tokens in the prompt, or {@code null} if not
51+
* available
52+
* @param generationTokens the number of tokens in the generation, or {@code null} if
53+
* not available
54+
* @param totalTokens the total number of tokens, or {@code null} to calculate from
55+
* promptTokens and generationTokens
56+
*/
57+
@JsonCreator
58+
public DefaultUsage(@JsonProperty("promptTokens") Long promptTokens,
59+
@JsonProperty("generationTokens") Long generationTokens, @JsonProperty("totalTokens") Long totalTokens) {
60+
this.promptTokens = promptTokens != null ? promptTokens : 0L;
61+
this.generationTokens = generationTokens != null ? generationTokens : 0L;
62+
this.totalTokens = totalTokens != null ? totalTokens
63+
: calculateTotalTokens(this.promptTokens, this.generationTokens);
64+
}
65+
66+
@Override
67+
@JsonProperty("promptTokens")
68+
public Long getPromptTokens() {
69+
return promptTokens;
70+
}
71+
72+
@Override
73+
@JsonProperty("generationTokens")
74+
public Long getGenerationTokens() {
75+
return generationTokens;
76+
}
77+
78+
@Override
79+
@JsonProperty("totalTokens")
80+
public Long getTotalTokens() {
81+
return totalTokens;
82+
}
83+
84+
private Long calculateTotalTokens(Long promptTokens, Long generationTokens) {
85+
return promptTokens + generationTokens;
86+
}
87+
88+
@Override
89+
public boolean equals(Object o) {
90+
if (this == o)
91+
return true;
92+
if (o == null || getClass() != o.getClass())
93+
return false;
94+
DefaultUsage that = (DefaultUsage) o;
95+
return Objects.equals(promptTokens, that.promptTokens)
96+
&& Objects.equals(generationTokens, that.generationTokens)
97+
&& Objects.equals(totalTokens, that.totalTokens);
98+
}
99+
100+
@Override
101+
public int hashCode() {
102+
return Objects.hash(promptTokens, generationTokens, totalTokens);
103+
}
104+
105+
@Override
106+
public String toString() {
107+
return "DefaultUsage{" + "promptTokens=" + promptTokens + ", generationTokens=" + generationTokens
108+
+ ", totalTokens=" + totalTokens + '}';
109+
}
110+
111+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.chat.metadata;
17+
18+
import com.fasterxml.jackson.databind.ObjectMapper;
19+
import org.junit.jupiter.api.Test;
20+
import static org.junit.jupiter.api.Assertions.*;
21+
22+
public class DefaultUsageTests {
23+
24+
private final ObjectMapper objectMapper = new ObjectMapper();
25+
26+
@Test
27+
void testSerializationWithAllFields() throws Exception {
28+
DefaultUsage usage = new DefaultUsage(100L, 50L, 150L);
29+
String json = objectMapper.writeValueAsString(usage);
30+
assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json);
31+
}
32+
33+
@Test
34+
void testDeserializationWithAllFields() throws Exception {
35+
String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}";
36+
DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class);
37+
assertEquals(100L, usage.getPromptTokens());
38+
assertEquals(50L, usage.getGenerationTokens());
39+
assertEquals(150L, usage.getTotalTokens());
40+
}
41+
42+
@Test
43+
void testSerializationWithNullFields() throws Exception {
44+
DefaultUsage usage = new DefaultUsage(null, null, null);
45+
String json = objectMapper.writeValueAsString(usage);
46+
assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json);
47+
}
48+
49+
@Test
50+
void testDeserializationWithMissingFields() throws Exception {
51+
String json = "{\"promptTokens\":100}";
52+
DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class);
53+
assertEquals(100L, usage.getPromptTokens());
54+
assertEquals(0L, usage.getGenerationTokens());
55+
assertEquals(100L, usage.getTotalTokens());
56+
}
57+
58+
@Test
59+
void testDeserializationWithNullFields() throws Exception {
60+
String json = "{\"promptTokens\":null,\"generationTokens\":null,\"totalTokens\":null}";
61+
DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class);
62+
assertEquals(0L, usage.getPromptTokens());
63+
assertEquals(0L, usage.getGenerationTokens());
64+
assertEquals(0L, usage.getTotalTokens());
65+
}
66+
67+
@Test
68+
void testRoundTripSerialization() throws Exception {
69+
DefaultUsage original = new DefaultUsage(100L, 50L, 150L);
70+
String json = objectMapper.writeValueAsString(original);
71+
DefaultUsage deserialized = objectMapper.readValue(json, DefaultUsage.class);
72+
assertEquals(original.getPromptTokens(), deserialized.getPromptTokens());
73+
assertEquals(original.getGenerationTokens(), deserialized.getGenerationTokens());
74+
assertEquals(original.getTotalTokens(), deserialized.getTotalTokens());
75+
}
76+
77+
@Test
78+
void testTwoArgumentConstructorAndSerialization() throws Exception {
79+
DefaultUsage usage = new DefaultUsage(100L, 50L);
80+
81+
// Test that the fields are set correctly
82+
assertEquals(100L, usage.getPromptTokens());
83+
assertEquals(50L, usage.getGenerationTokens());
84+
assertEquals(150L, usage.getTotalTokens()); // 100 + 50 = 150
85+
86+
// Test serialization
87+
String json = objectMapper.writeValueAsString(usage);
88+
assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json);
89+
90+
// Test deserialization
91+
DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class);
92+
assertEquals(100L, deserializedUsage.getPromptTokens());
93+
assertEquals(50L, deserializedUsage.getGenerationTokens());
94+
assertEquals(150L, deserializedUsage.getTotalTokens());
95+
}
96+
97+
@Test
98+
void testTwoArgumentConstructorWithNullValues() throws Exception {
99+
DefaultUsage usage = new DefaultUsage(null, null);
100+
101+
// Test that null values are converted to 0
102+
assertEquals(0L, usage.getPromptTokens());
103+
assertEquals(0L, usage.getGenerationTokens());
104+
assertEquals(0L, usage.getTotalTokens());
105+
106+
// Test serialization
107+
String json = objectMapper.writeValueAsString(usage);
108+
assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json);
109+
110+
// Test deserialization
111+
DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class);
112+
assertEquals(0L, deserializedUsage.getPromptTokens());
113+
assertEquals(0L, deserializedUsage.getGenerationTokens());
114+
assertEquals(0L, deserializedUsage.getTotalTokens());
115+
}
116+
117+
}

0 commit comments

Comments
 (0)