4040import org .springframework .ai .chat .model .ChatResponse ;
4141import org .springframework .ai .chat .model .Generation ;
4242import org .springframework .ai .chat .model .StreamingChatModel ;
43- import org .springframework .ai .chat .prompt .ChatOptions ;
4443import org .springframework .ai .chat .prompt .Prompt ;
4544import org .springframework .ai .chat .prompt .PromptTemplate ;
4645import org .springframework .ai .chat .prompt .SystemPromptTemplate ;
@@ -85,13 +84,22 @@ class ZhiPuAiChatModelIT {
8584 @ Value ("classpath:/prompts/system-message.st" )
8685 private Resource systemResource ;
8786
87+ /**
88+ * Default chat options to use for the tests.
89+ * <p>
90+ * glm-4-flash is a free model, so it is used by default on the tests.
91+ */
92+ private static final ZhiPuAiChatOptions DEFAULT_CHAT_OPTIONS = ZhiPuAiChatOptions .builder ()
93+ .model (ZhiPuAiApi .ChatModel .GLM_4_Flash .getValue ())
94+ .build ();
95+
8896 @ Test
8997 void roleTest () {
9098 UserMessage userMessage = new UserMessage (
9199 "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did." );
92100 SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate (this .systemResource );
93101 Message systemMessage = systemPromptTemplate .createMessage (Map .of ("name" , "Bob" , "voice" , "pirate" ));
94- Prompt prompt = new Prompt (List .of (userMessage , systemMessage ), ChatOptions . builder (). build () );
102+ Prompt prompt = new Prompt (List .of (userMessage , systemMessage ), DEFAULT_CHAT_OPTIONS );
95103 ChatResponse response = this .chatModel .call (prompt );
96104 assertThat (response .getResults ()).hasSize (1 );
97105 assertThat (response .getResults ().get (0 ).getOutput ().getText ()).contains ("Blackbeard" );
@@ -104,7 +112,7 @@ void streamRoleTest() {
104112 "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did." );
105113 SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate (this .systemResource );
106114 Message systemMessage = systemPromptTemplate .createMessage (Map .of ("name" , "Bob" , "voice" , "pirate" ));
107- Prompt prompt = new Prompt (List .of (userMessage , systemMessage ));
115+ Prompt prompt = new Prompt (List .of (userMessage , systemMessage ), DEFAULT_CHAT_OPTIONS );
108116 Flux <ChatResponse > flux = this .streamingChatModel .stream (prompt );
109117
110118 List <ChatResponse > responses = flux .collectList ().block ();
@@ -135,7 +143,7 @@ void listOutputConverter() {
135143 .template (template )
136144 .variables (Map .of ("subject" , "ice cream flavors" , "format" , format ))
137145 .build ();
138- Prompt prompt = new Prompt (promptTemplate .createMessage (), ChatOptions . builder (). build () );
146+ Prompt prompt = new Prompt (promptTemplate .createMessage (), DEFAULT_CHAT_OPTIONS );
139147 Generation generation = this .chatModel .call (prompt ).getResult ();
140148
141149 List <String > list = outputConverter .convert (generation .getOutput ().getText ());
@@ -157,8 +165,9 @@ void mapOutputConverter() {
157165 .variables (Map .of ("subject" , "an array of numbers from 1 to 9 under they key name 'numbers'" , "format" ,
158166 format ))
159167 .build ();
160- Prompt prompt = new Prompt (promptTemplate .createMessage (), ChatOptions .builder ().build ());
161- Generation generation = this .chatModel .call (prompt ).getResult ();
168+ Prompt prompt = new Prompt (promptTemplate .createMessage (), DEFAULT_CHAT_OPTIONS );
169+ ChatResponse chatResponse = this .chatModel .call (prompt );
170+ Generation generation = chatResponse .getResult ();
162171
163172 Map <String , Object > result = outputConverter .convert (generation .getOutput ().getText ());
164173 assertThat (result .get ("numbers" )).isEqualTo (Arrays .asList (1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ));
@@ -179,7 +188,7 @@ void beanOutputConverter() {
179188 .template (template )
180189 .variables (Map .of ("format" , format ))
181190 .build ();
182- Prompt prompt = new Prompt (promptTemplate .createMessage (), ChatOptions . builder (). build () );
191+ Prompt prompt = new Prompt (promptTemplate .createMessage (), DEFAULT_CHAT_OPTIONS );
183192 Generation generation = this .chatModel .call (prompt ).getResult ();
184193
185194 ActorsFilms actorsFilms = outputConverter .convert (generation .getOutput ().getText ());
@@ -198,7 +207,7 @@ void beanOutputConverterRecords() {
198207 .template (template )
199208 .variables (Map .of ("format" , format ))
200209 .build ();
201- Prompt prompt = new Prompt (promptTemplate .createMessage (), ChatOptions . builder (). build () );
210+ Prompt prompt = new Prompt (promptTemplate .createMessage (), DEFAULT_CHAT_OPTIONS );
202211 Generation generation = this .chatModel .call (prompt ).getResult ();
203212
204213 ActorsFilmsRecord actorsFilms = outputConverter .convert (generation .getOutput ().getText ());
@@ -221,7 +230,7 @@ void beanStreamOutputConverterRecords() {
221230 .template (template )
222231 .variables (Map .of ("format" , format ))
223232 .build ();
224- Prompt prompt = new Prompt (promptTemplate .createMessage ());
233+ Prompt prompt = new Prompt (promptTemplate .createMessage (), DEFAULT_CHAT_OPTIONS );
225234
226235 String generationTextFromStream = Objects
227236 .requireNonNull (this .streamingChatModel .stream (prompt ).collectList ().block ())
@@ -253,7 +262,10 @@ void jsonObjectResponseFormatOutputConverterRecords() {
253262 .variables (Map .of ("format" , format ))
254263 .build ();
255264 Prompt prompt = new Prompt (promptTemplate .createMessage (),
256- ZhiPuAiChatOptions .builder ().responseFormat (ChatCompletionRequest .ResponseFormat .jsonObject ()).build ());
265+ ZhiPuAiChatOptions .builder ()
266+ .model (ZhiPuAiApi .ChatModel .GLM_4_Flash .getValue ())
267+ .responseFormat (ChatCompletionRequest .ResponseFormat .jsonObject ())
268+ .build ());
257269
258270 String generationTextFromStream = Objects
259271 .requireNonNull (this .streamingChatModel .stream (prompt ).collectList ().block ())
@@ -281,7 +293,7 @@ void functionCallTest() {
281293 List <Message > messages = new ArrayList <>(List .of (userMessage ));
282294
283295 var promptOptions = ZhiPuAiChatOptions .builder ()
284- .model (ZhiPuAiApi .ChatModel .GLM_4 .getValue ())
296+ .model (ZhiPuAiApi .ChatModel .GLM_4_Flash .getValue ())
285297 .toolCallbacks (List .of (FunctionToolCallback .builder ("getCurrentWeather" , new MockWeatherService ())
286298 .description ("Get the weather in location" )
287299 .inputType (MockWeatherService .Request .class )
@@ -306,7 +318,7 @@ void streamFunctionCallTest() {
306318 List <Message > messages = new ArrayList <>(List .of (userMessage ));
307319
308320 var promptOptions = ZhiPuAiChatOptions .builder ()
309- .model (ZhiPuAiApi .ChatModel .GLM_4 .getValue ())
321+ .model (ZhiPuAiApi .ChatModel .GLM_4_Flash .getValue ())
310322 .toolCallbacks (List .of (FunctionToolCallback .builder ("getCurrentWeather" , new MockWeatherService ())
311323 .description ("Get the weather in location" )
312324 .inputType (MockWeatherService .Request .class )
@@ -332,8 +344,7 @@ void streamFunctionCallTest() {
332344 @ ParameterizedTest (name = "{0} : {displayName} " )
333345 @ ValueSource (strings = { "glm-4.5-flash" })
334346 void enabledThinkingTest (String modelName ) {
335- UserMessage userMessage = new UserMessage (
336- "Are there an infinite number of prime numbers such that n mod 4 == 3?" );
347+ UserMessage userMessage = new UserMessage ("9.11 and 9.8, which is greater?" );
337348
338349 var promptOptions = ZhiPuAiChatOptions .builder ()
339350 .model (modelName )
@@ -344,14 +355,16 @@ void enabledThinkingTest(String modelName) {
344355 ChatResponse response = this .chatModel .call (new Prompt (List .of (userMessage ), promptOptions ));
345356 logger .info ("Response: {}" , response );
346357
347- for ( Generation generation : response .getResults ()) {
348- AssistantMessage message = generation .getOutput ();
358+ Generation generation = response .getResult ();
359+ AssistantMessage message = generation .getOutput ();
349360
350- assertThat (message ).isInstanceOf (ZhiPuAiAssistantMessage .class );
361+ assertThat (message ).isInstanceOf (ZhiPuAiAssistantMessage .class );
351362
352- assertThat (message .getText ()).isNotBlank ();
353- assertThat (((ZhiPuAiAssistantMessage ) message ).getReasoningContent ()).isNotBlank ();
354- }
363+ assertThat (message .getText ()).isNotBlank ();
364+ assertThat (((ZhiPuAiAssistantMessage ) message ).getReasoningContent ()).isNotBlank ();
365+
366+ ZhiPuAiApi .Usage nativeUsage = (ZhiPuAiApi .Usage ) response .getMetadata ().getUsage ().getNativeUsage ();
367+ assertThat (nativeUsage .promptTokensDetails ()).isNotNull ();
355368 }
356369
357370 @ ParameterizedTest (name = "{0} : {displayName} " )
@@ -382,8 +395,7 @@ void disabledThinkingTest(String modelName) {
382395 @ ParameterizedTest (name = "{0} : {displayName} " )
383396 @ ValueSource (strings = { "glm-4.5-flash" })
384397 void streamAndEnableThinkingTest (String modelName ) {
385- UserMessage userMessage = new UserMessage (
386- "Are there an infinite number of prime numbers such that n mod 4 == 3?" );
398+ UserMessage userMessage = new UserMessage ("9.11 and 9.8, which is greater?" );
387399
388400 var promptOptions = ZhiPuAiChatOptions .builder ()
389401 .model (modelName )
@@ -408,6 +420,7 @@ void streamAndEnableThinkingTest(String modelName) {
408420 }
409421 return message .getText ();
410422 })
423+ .filter (StringUtils ::hasText )
411424 .collect (Collectors .joining ());
412425
413426 logger .info ("reasoningContent: {}" , reasoningContent );
@@ -420,7 +433,7 @@ void streamAndEnableThinkingTest(String modelName) {
420433 }
421434
422435 @ ParameterizedTest (name = "{0} : {displayName} " )
423- @ ValueSource (strings = { "glm-4v" })
436+ @ ValueSource (strings = { "glm-4v-flash " })
424437 void multiModalityEmbeddedImage (String modelName ) throws IOException {
425438
426439 var imageData = new ClassPathResource ("/test.png" );
@@ -461,7 +474,7 @@ void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IO
461474 }
462475
463476 @ ParameterizedTest (name = "{0} : {displayName} " )
464- @ ValueSource (strings = { "glm-4v" , "glm-4.1v-thinking-flash" })
477+ @ ValueSource (strings = { "glm-4v-flash " , "glm-4.1v-thinking-flash" })
465478 void multiModalityImageUrl (String modelName ) throws IOException {
466479
467480 var userMessage = UserMessage .builder ()
@@ -505,7 +518,7 @@ void reasonerMultiModalityImageUrl(String modelName) throws IOException {
505518 }
506519
507520 @ ParameterizedTest (name = "{0} : {displayName} " )
508- @ ValueSource (strings = { "glm-4v" })
521+ @ ValueSource (strings = { "glm-4v-flash " })
509522 void streamingMultiModalityImageUrl (String modelName ) throws IOException {
510523
511524 var userMessage = UserMessage .builder ()
0 commit comments