Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import io.micrometer.observation.ObservationRegistry;

Expand Down Expand Up @@ -178,8 +175,8 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingReq
}

return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(),
DurationParser.parse(requestOptions.getKeepAlive()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subclass DurationParser can also be removed.

OllamaEmbeddingOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
requestOptions.getKeepAlive(), OllamaEmbeddingOptions.filterNonSupportedFields(requestOptions.toMap()),
requestOptions.getTruncate());
}

/**
Expand All @@ -200,37 +197,6 @@ public void setObservationConvention(EmbeddingModelObservationConvention observa
this.observationConvention = observationConvention;
}

public static class DurationParser {

private static final Pattern PATTERN = Pattern.compile("(-?\\d+)(ms|s|m|h)");

public static Duration parse(String input) {

if (!StringUtils.hasText(input)) {
return null;
}

Matcher matcher = PATTERN.matcher(input);

if (matcher.matches()) {
long value = Long.parseLong(matcher.group(1));
String unit = matcher.group(2);

return switch (unit) {
case "ms" -> Duration.ofMillis(value);
case "s" -> Duration.ofSeconds(value);
case "m" -> Duration.ofMinutes(value);
case "h" -> Duration.ofHours(value);
default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
};
}
else {
throw new IllegalArgumentException("Invalid duration format: " + input);
}
}

}

public static final class Builder {

private OllamaApi ollamaApi;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ public Duration getEvalDuration() {
public record EmbeddingsRequest(
@JsonProperty("model") String model,
@JsonProperty("input") List<String> input,
@JsonProperty("keep_alive") Duration keepAlive,
@JsonProperty("keep_alive") String keepAlive,
@JsonProperty("options") Map<String, Object> options,
@JsonProperty("truncate") Boolean truncate) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -46,7 +45,7 @@
* @since 1.0.0
*/
@ExtendWith(MockitoExtension.class)
public class OllamaEmbeddingModelTests {
class OllamaEmbeddingModelTests {

@Mock
OllamaApi ollamaApi;
Expand All @@ -55,7 +54,7 @@ public class OllamaEmbeddingModelTests {
ArgumentCaptor<EmbeddingsRequest> embeddingsRequestCaptor;

@Test
public void options() {
void options() {

given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME",
Expand Down Expand Up @@ -109,7 +108,7 @@ public void options() {
}

@Test
public void singleInputEmbedding() {
void singleInputEmbedding() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("TEST_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f }), 10L, 5L, 1));

Expand All @@ -131,7 +130,7 @@ public void singleInputEmbedding() {
}

@Test
public void embeddingWithNullOptions() {
void embeddingWithNullOptions() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("NULL_OPTIONS_MODEL", List.of(new float[] { 0.5f }), 5L, 2L, 1));

Expand All @@ -150,7 +149,7 @@ public void embeddingWithNullOptions() {
}

@Test
public void embeddingWithMultipleLargeInputs() {
void embeddingWithMultipleLargeInputs() {
List<String> largeInputs = List.of(
"This is a very long text input that might be used for document embedding scenarios",
"Another substantial piece of text content that could represent a paragraph or section",
Expand Down Expand Up @@ -179,7 +178,7 @@ public void embeddingWithMultipleLargeInputs() {
}

@Test
public void embeddingWithCustomKeepAliveFormats() {
void embeddingWithCustomKeepAliveFormats() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("KEEPALIVE_MODEL", List.of(new float[] { 1.0f }), 5L, 2L, 1));

Expand All @@ -192,17 +191,17 @@ public void embeddingWithCustomKeepAliveFormats() {
var secondsOptions = OllamaEmbeddingOptions.builder().model("KEEPALIVE_MODEL").keepAlive("300s").build();

embeddingModel.call(new EmbeddingRequest(List.of("Keep alive seconds"), secondsOptions));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofSeconds(300));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo("300s");

// Test with hours format
var hoursOptions = OllamaEmbeddingOptions.builder().model("KEEPALIVE_MODEL").keepAlive("2h").build();

embeddingModel.call(new EmbeddingRequest(List.of("Keep alive hours"), hoursOptions));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofHours(2));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo("2h");
}

@Test
public void embeddingResponseMetadata() {
void embeddingResponseMetadata() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("METADATA_MODEL", List.of(new float[] { 0.1f, 0.2f }), 100L, 50L, 25));

Expand All @@ -220,7 +219,7 @@ public void embeddingResponseMetadata() {
}

@Test
public void embeddingWithZeroLengthVectors() {
void embeddingWithZeroLengthVectors() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("ZERO_MODEL", List.of(new float[] {}), 0L, 0L, 1));

Expand All @@ -237,7 +236,7 @@ public void embeddingWithZeroLengthVectors() {
}

@Test
public void builderValidation() {
void builderValidation() {
// Test that builder requires ollamaApi
assertThatThrownBy(() -> OllamaEmbeddingModel.builder().build()).isInstanceOf(IllegalArgumentException.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand All @@ -35,12 +34,12 @@
* @author Thomas Vitale
* @author Jonghoon Park
*/
public class OllamaEmbeddingRequestTests {
class OllamaEmbeddingRequestTests {

private OllamaEmbeddingModel embeddingModel;

@BeforeEach
public void setUp() {
void setUp() {
this.embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(
Expand All @@ -49,7 +48,7 @@ public void setUp() {
}

@Test
public void ollamaEmbeddingRequestDefaultOptions() {
void ollamaEmbeddingRequestDefaultOptions() {
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

Expand All @@ -58,7 +57,7 @@ public void ollamaEmbeddingRequestDefaultOptions() {
}

@Test
public void ollamaEmbeddingRequestRequestOptions() {
void ollamaEmbeddingRequestRequestOptions() {
var promptOptions = OllamaEmbeddingOptions.builder()//
.model("PROMPT_MODEL")//
.build();
Expand All @@ -72,18 +71,18 @@ public void ollamaEmbeddingRequestRequestOptions() {
}

@Test
public void ollamaEmbeddingRequestWithNegativeKeepAlive() {
void ollamaEmbeddingRequestWithNegativeKeepAlive() {
var promptOptions = OllamaEmbeddingOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build();

var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
assertThat(ollamaRequest.keepAlive()).isEqualTo("-1m");
}

@Test
public void ollamaEmbeddingRequestWithEmptyInput() {
void ollamaEmbeddingRequestWithEmptyInput() {
var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(Collections.emptyList(), null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
Expand All @@ -93,7 +92,7 @@ public void ollamaEmbeddingRequestWithEmptyInput() {
}

@Test
public void ollamaEmbeddingRequestWithMultipleInputs() {
void ollamaEmbeddingRequestWithMultipleInputs() {
List<String> inputs = Arrays.asList("Hello", "World", "How are you?");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(inputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
Expand All @@ -103,7 +102,7 @@ public void ollamaEmbeddingRequestWithMultipleInputs() {
}

@Test
public void ollamaEmbeddingRequestOptionsOverrideDefaults() {
void ollamaEmbeddingRequestOptionsOverrideDefaults() {
var requestOptions = OllamaEmbeddingOptions.builder().model("OVERRIDE_MODEL").build();

var embeddingRequest = this.embeddingModel
Expand All @@ -115,24 +114,24 @@ public void ollamaEmbeddingRequestOptionsOverrideDefaults() {
}

@Test
public void ollamaEmbeddingRequestWithDifferentKeepAliveFormats() {
void ollamaEmbeddingRequestWithDifferentKeepAliveFormats() {
// Test seconds format
var optionsSeconds = OllamaEmbeddingOptions.builder().keepAlive("30s").build();
var requestSeconds = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsSeconds));
var ollamaRequestSeconds = this.embeddingModel.ollamaEmbeddingRequest(requestSeconds);
assertThat(ollamaRequestSeconds.keepAlive()).isEqualTo(Duration.ofSeconds(30));
assertThat(ollamaRequestSeconds.keepAlive()).isEqualTo("30s");

// Test hours format
var optionsHours = OllamaEmbeddingOptions.builder().keepAlive("2h").build();
var requestHours = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsHours));
var ollamaRequestHours = this.embeddingModel.ollamaEmbeddingRequest(requestHours);
assertThat(ollamaRequestHours.keepAlive()).isEqualTo(Duration.ofHours(2));
assertThat(ollamaRequestHours.keepAlive()).isEqualTo("2h");
}

@Test
public void ollamaEmbeddingRequestWithMinimalDefaults() {
void ollamaEmbeddingRequestWithMinimalDefaults() {
// Create model with minimal defaults
var minimalModel = OllamaEmbeddingModel.builder()
.ollamaApi(OllamaApi.builder().build())
Expand All @@ -151,7 +150,7 @@ public void ollamaEmbeddingRequestWithMinimalDefaults() {
}

@Test
public void ollamaEmbeddingRequestPreservesInputOrder() {
void ollamaEmbeddingRequestPreservesInputOrder() {
List<String> orderedInputs = Arrays.asList("First", "Second", "Third", "Fourth");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(orderedInputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
Expand All @@ -160,7 +159,7 @@ public void ollamaEmbeddingRequestPreservesInputOrder() {
}

@Test
public void ollamaEmbeddingRequestWithWhitespaceInputs() {
void ollamaEmbeddingRequestWithWhitespaceInputs() {
List<String> inputs = Arrays.asList("", " ", "\t\n", "normal text", " spaced ");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(inputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
Expand All @@ -170,7 +169,7 @@ public void ollamaEmbeddingRequestWithWhitespaceInputs() {
}

@Test
public void ollamaEmbeddingRequestWithNullInput() {
void ollamaEmbeddingRequestWithNullInput() {
// Test behavior when input list contains null values
List<String> inputsWithNull = Arrays.asList("Hello", null, "World");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(inputsWithNull, null));
Expand All @@ -181,7 +180,7 @@ public void ollamaEmbeddingRequestWithNullInput() {
}

@Test
public void ollamaEmbeddingRequestPartialOptionsOverride() {
void ollamaEmbeddingRequestPartialOptionsOverride() {
// Test that only specified options are overridden, others remain default
var requestOptions = OllamaEmbeddingOptions.builder()
.model("PARTIAL_OVERRIDE_MODEL")
Expand All @@ -199,7 +198,7 @@ public void ollamaEmbeddingRequestPartialOptionsOverride() {
}

@Test
public void ollamaEmbeddingRequestWithEmptyStringInput() {
void ollamaEmbeddingRequestWithEmptyStringInput() {
// Test with list containing only empty string
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(List.of(""), null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
Expand Down