Skip to content

Commit 8e5716e

Browse files
committed
New endpoint/function rerank with example and test
1 parent 0c0b4bf commit 8e5716e

File tree

5 files changed

+172
-60
lines changed

5 files changed

+172
-60
lines changed

pinecone-client/src/main/scala/io/cequence/pineconescala/JsonFormats.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import io.cequence.pineconescala.domain.response._
55
import io.cequence.pineconescala.domain.settings.{EmbeddingsInputType, EmbeddingsTruncate}
66
import io.cequence.pineconescala.domain.settings.EmbeddingsInputType.{Passage, Query}
77
import io.cequence.pineconescala.domain.{Metric, PVector, PodType, SparseVector, response}
8-
import io.cequence.wsclient.JsonUtil.enumFormat
8+
import io.cequence.wsclient.JsonUtil
9+
import io.cequence.wsclient.JsonUtil.{JsonOps, enumFormat, toJson}
910
import play.api.libs.json._
1011
import play.api.libs.functional.syntax._
1112

@@ -89,7 +90,8 @@ object JsonFormats {
8990
implicit lazy val embeddingUsageInfoReads: Reads[EmbeddingsUsageInfo] =
9091
Json.reads[EmbeddingsUsageInfo]
9192
implicit lazy val embeddingInfoReads: Reads[EmbeddingsInfo] = Json.reads[EmbeddingsInfo]
92-
implicit lazy val embeddingValuesReads: Reads[EmbeddingsValues] = Json.reads[EmbeddingsValues]
93+
implicit lazy val embeddingValuesReads: Reads[EmbeddingsValues] =
94+
Json.reads[EmbeddingsValues]
9395
implicit lazy val embeddingResponseReads: Reads[GenerateEmbeddingsResponse] =
9496
Json.reads[GenerateEmbeddingsResponse]
9597

@@ -176,4 +178,13 @@ object JsonFormats {
176178
implicit lazy val chatCompletionChoiceFormat: Format[Choice] = Json.format[Choice]
177179
implicit lazy val chatCompletionModelFormat: Format[ChatCompletionResponse] =
178180
Json.format[ChatCompletionResponse]
181+
182+
// rerank
183+
implicit lazy val rerankUsageFormat: Format[RerankUsage] = Json.format[RerankUsage]
184+
implicit lazy val rerankedDocumentFormat: Format[RerankedDocument] = {
185+
implicit lazy val stringAnyMapFormat: Format[Map[String, Any]] =
186+
JsonUtil.StringAnyMapFormat
187+
Json.format[RerankedDocument]
188+
}
189+
implicit lazy val rerankResponseFormat: Format[RerankResponse] = Json.format[RerankResponse]
179190
}

pinecone-client/src/main/scala/io/cequence/pineconescala/service/EndPoint.scala

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ object EndPoint {
2020
case object collections extends EndPoint
2121
case object databases extends EndPoint
2222
case object indexes extends EndPoint
23+
case object rerank extends EndPoint
2324
}
2425

2526
// TODO: rename to Param
@@ -63,50 +64,9 @@ object Tag {
6364
case object metadata extends Tag
6465
case object messages extends Tag
6566
case object file extends Tag
66-
67-
// TODO: move elsewhere
68-
def fromCreatePodBasedIndexSettings(
69-
name: String,
70-
dimension: Int,
71-
settings: CreatePodBasedIndexSettings
72-
): Seq[(Tag, Option[Any])] = {
73-
Seq(
74-
Tag.name -> Some(name),
75-
Tag.dimension -> Some(dimension),
76-
Tag.metric -> Some(settings.metric.toString),
77-
Tag.spec -> Some(
78-
Map(
79-
"pod" -> Map(
80-
Tag.pods.toString -> Some(settings.pods),
81-
Tag.replicas.toString -> Some(settings.replicas),
82-
Tag.pod_type.toString -> Some(settings.podType.toString),
83-
Tag.shards.toString -> Some(settings.shards),
84-
Tag.metadata_config.toString ->
85-
(if (settings.metadataConfig.nonEmpty) Some(settings.metadataConfig) else None),
86-
Tag.source_collection.toString -> settings.sourceCollection
87-
)
88-
)
89-
)
90-
)
91-
}
92-
93-
def fromCreateServerlessIndexSettings(
94-
name: String,
95-
dimension: Int,
96-
settings: CreateServerlessIndexSettings
97-
): Seq[(Tag, Option[Any])] = {
98-
Seq(
99-
Tag.name -> Some(name),
100-
Tag.dimension -> Some(dimension),
101-
Tag.metric -> Some(settings.metric.toString),
102-
Tag.spec -> Some(
103-
Map(
104-
"serverless" -> Map(
105-
Tag.cloud.toString -> settings.cloud.toString,
106-
Tag.region.toString -> settings.region.toString
107-
)
108-
)
109-
)
110-
)
111-
}
67+
case object query extends Tag
68+
case object documents extends Tag
69+
case object top_n extends Tag
70+
case object return_documents extends Tag
71+
case object rank_fields extends Tag
11272
}

pinecone-client/src/main/scala/io/cequence/pineconescala/service/PineconeIndexServiceImpl.scala

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import com.typesafe.config.{Config, ConfigFactory}
55
import io.cequence.pineconescala.JsonFormats._
66
import io.cequence.pineconescala.PineconeScalaClientException
77
import io.cequence.pineconescala.domain.response._
8-
import io.cequence.pineconescala.domain.settings.IndexSettings.{CreatePodBasedIndexSettings, CreateServerlessIndexSettings}
8+
import io.cequence.pineconescala.domain.settings.IndexSettings.{
9+
CreatePodBasedIndexSettings,
10+
CreateServerlessIndexSettings
11+
}
912
import io.cequence.pineconescala.domain.settings._
1013
import io.cequence.pineconescala.domain.PodType
1114
import io.cequence.wsclient.JsonUtil.JsonOps
@@ -59,12 +62,32 @@ private final class ServerlessIndexServiceImpl(
5962
indexesEndpoint,
6063
bodyParams = {
6164
jsonBodyParams(
62-
Tag.fromCreateServerlessIndexSettings(name, dimension, settings): _*
65+
fromCreateServerlessIndexSettings(name, dimension, settings): _*
6366
)
6467
},
6568
acceptableStatusCodes = Nil // don't parse response at all
6669
).map(handleCreateResponse)
6770

71+
private def fromCreateServerlessIndexSettings(
72+
name: String,
73+
dimension: Int,
74+
settings: CreateServerlessIndexSettings
75+
): Seq[(Tag, Option[Any])] = {
76+
Seq(
77+
Tag.name -> Some(name),
78+
Tag.dimension -> Some(dimension),
79+
Tag.metric -> Some(settings.metric.toString),
80+
Tag.spec -> Some(
81+
Map(
82+
"serverless" -> Map(
83+
Tag.cloud.toString -> settings.cloud.toString,
84+
Tag.region.toString -> settings.region.toString
85+
)
86+
)
87+
)
88+
)
89+
}
90+
6891
override def describeIndexResponse(json: JsValue): IndexInfo =
6992
json.asSafe[ServerlessIndexInfo]
7093

@@ -117,11 +140,36 @@ private final class PineconePodPineconeBasedImpl(
117140
execPOSTRich(
118141
indexesEndpoint,
119142
bodyParams = jsonBodyParams(
120-
Tag.fromCreatePodBasedIndexSettings(name, dimension, settings): _*
143+
fromCreatePodBasedIndexSettings(name, dimension, settings): _*
121144
),
122145
acceptableStatusCodes = Nil // don't parse response at all
123146
).map(handleCreateResponse)
124147

148+
private def fromCreatePodBasedIndexSettings(
149+
name: String,
150+
dimension: Int,
151+
settings: CreatePodBasedIndexSettings
152+
): Seq[(Tag, Option[Any])] = {
153+
Seq(
154+
Tag.name -> Some(name),
155+
Tag.dimension -> Some(dimension),
156+
Tag.metric -> Some(settings.metric.toString),
157+
Tag.spec -> Some(
158+
Map(
159+
"pod" -> Map(
160+
Tag.pods.toString -> Some(settings.pods),
161+
Tag.replicas.toString -> Some(settings.replicas),
162+
Tag.pod_type.toString -> Some(settings.podType.toString),
163+
Tag.shards.toString -> Some(settings.shards),
164+
Tag.metadata_config.toString ->
165+
(if (settings.metadataConfig.nonEmpty) Some(settings.metadataConfig) else None),
166+
Tag.source_collection.toString -> settings.sourceCollection
167+
)
168+
)
169+
)
170+
)
171+
}
172+
125173
override def configureIndex(
126174
indexName: String,
127175
replicas: Option[Int],
@@ -206,7 +254,7 @@ abstract class PineconeIndexServiceImpl[S <: IndexSettings](
206254
// we use play-ws backend
207255
override protected val engine: WSClientEngine = PlayWSClientEngine(
208256
coreUrl,
209-
requestContext = WsRequestContext(
257+
requestContext = WsRequestContext(
210258
authHeaders = Seq(("Api-Key", apiKey)),
211259
explTimeouts = explicitTimeouts
212260
)

pinecone-client/src/main/scala/io/cequence/pineconescala/service/PineconeInferenceServiceImpl.scala

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package io.cequence.pineconescala.service
22

33
import akka.stream.Materializer
4-
import com.typesafe.config.Config
5-
import io.cequence.pineconescala.domain.response.GenerateEmbeddingsResponse
6-
import io.cequence.pineconescala.domain.settings.GenerateEmbeddingsSettings
4+
import io.cequence.pineconescala.domain.response.{GenerateEmbeddingsResponse, RerankResponse}
5+
import io.cequence.pineconescala.domain.settings.{GenerateEmbeddingsSettings, RerankSettings}
76
import io.cequence.wsclient.ResponseImplicits._
87
import io.cequence.wsclient.service.ws.{PlayWSClientEngine, Timeouts}
98
import io.cequence.pineconescala.JsonFormats._
@@ -29,10 +28,10 @@ private class PineconeInferenceServiceImpl(
2928
// we use play-ws backend
3029
override protected val engine: WSClientEngine = PlayWSClientEngine(
3130
coreUrl = "https://api.pinecone.io/",
32-
requestContext = WsRequestContext(
31+
requestContext = WsRequestContext(
3332
authHeaders = Seq(
3433
("Api-Key", apiKey),
35-
("X-Pinecone-API-Version", "2024-07")
34+
("X-Pinecone-API-Version", "2024-10")
3635
),
3736
explTimeouts = explicitTimeouts
3837
)
@@ -69,6 +68,44 @@ private class PineconeInferenceServiceImpl(
6968
_.asSafeJson[GenerateEmbeddingsResponse]
7069
)
7170

71+
/**
72+
* Using a reranker to rerank a list of items for a query.
73+
*
74+
* @param query
75+
* The query to rerank documents against (required)
76+
* @param documents
77+
* The documents to rerank (required)
78+
* @param settings
79+
* @return
80+
*
81+
* @see
82+
* <a href="https://docs.pinecone.io/reference/api/2024-10/inference/rerank">Pinecone
83+
* Doc</a>
84+
*/
85+
override def rerank(
86+
query: String,
87+
documents: Seq[Map[String, Any]],
88+
settings: RerankSettings = DefaultSettings.Rerank
89+
): Future[RerankResponse] =
90+
execPOST(
91+
EndPoint.rerank,
92+
bodyParams = jsonBodyParams(
93+
Tag.query -> Some(query),
94+
Tag.documents -> Some(documents),
95+
Tag.model -> Some(settings.model),
96+
Tag.top_n -> settings.top_n,
97+
Tag.return_documents -> Some(settings.return_documents),
98+
Tag.rank_fields -> (
99+
if (settings.rank_fields.nonEmpty) Some(settings.rank_fields) else None
100+
),
101+
Tag.parameters -> (
102+
if (settings.parameters.nonEmpty) Some(settings.parameters) else None
103+
)
104+
)
105+
).map(
106+
_.asSafeJson[RerankResponse]
107+
)
108+
72109
override protected def handleErrorCodes(
73110
httpCode: Int,
74111
message: String

pinecone-client/src/test/scala/io/cequence/pineconescala/service/ServerlessPineconeInferenceServiceImplSpec.scala

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package io.cequence.pineconescala.service
33
import akka.actor.ActorSystem
44
import akka.stream.Materializer
55
import com.typesafe.config.{Config, ConfigFactory}
6+
import io.cequence.pineconescala.PineconeScalaClientException
7+
import io.cequence.pineconescala.domain.RerankModelId
8+
import io.cequence.pineconescala.domain.settings.RerankSettings
69
import org.scalatest.GivenWhenThen
710
import org.scalatest.matchers.must.Matchers
811
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
@@ -13,7 +16,9 @@ import scala.concurrent.ExecutionContext
1316
class ServerlessPineconeInferenceServiceImplSpec
1417
extends AsyncWordSpec
1518
with GivenWhenThen
16-
with ServerlessFixtures with Matchers with PineconeServiceConsts{
19+
with ServerlessFixtures
20+
with Matchers
21+
with PineconeServiceConsts {
1722

1823
implicit val ec: ExecutionContext = ExecutionContext.global
1924
implicit val materializer: Materializer = Materializer(ActorSystem())
@@ -28,15 +33,66 @@ class ServerlessPineconeInferenceServiceImplSpec
2833
"create embeddings should provide embeddings for input data" in {
2934
val service = inferenceServiceBuilder
3035
for {
31-
embeddings <- service.createEmbeddings(Seq("The quick brown fox jumped over the lazy dog"),
32-
settings = DefaultSettings.GenerateEmbeddings.withPassageInputType.withEndTruncate)
36+
embeddings <- service.createEmbeddings(
37+
Seq("The quick brown fox jumped over the lazy dog"),
38+
settings = DefaultSettings.GenerateEmbeddings.withPassageInputType.withEndTruncate
39+
)
3340
} yield {
3441
embeddings.data.size should be(1)
3542
embeddings.data(0).values should not be empty
3643
embeddings.usage.total_tokens should be(16)
3744
}
3845
}
3946

47+
"rerank documents" in {
48+
val service = inferenceServiceBuilder
49+
50+
val documents = Seq(
51+
Map(
52+
"id" -> "vec1",
53+
"my_field" -> "Apple is a popular fruit known for its sweetness and crisp texture."
54+
),
55+
Map(
56+
"id" -> "vec2",
57+
"my_field" -> "Many people enjoy eating apples as a healthy snack."
58+
),
59+
Map(
60+
"id" -> "vec3",
61+
"my_field" -> "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."
62+
),
63+
Map(
64+
"id" -> "vec4",
65+
"my_field" -> "An apple a day keeps the doctor away, as the saying goes."
66+
)
67+
)
68+
69+
for {
70+
rerankResponse <- service.rerank(
71+
query =
72+
"The tech company Apple is known for its innovative products like the iPhone.",
73+
documents = documents,
74+
settings = RerankSettings(
75+
model = RerankModelId.bge_reranker_v2_m3,
76+
top_n = Some(4),
77+
return_documents = true,
78+
rank_fields = Seq("my_field")
79+
)
80+
)
81+
} yield {
82+
rerankResponse.data.size should be(4)
83+
rerankResponse.usage.rerank_units should be(1)
84+
rerankResponse.data.map(_.index) should be(Seq(2, 0, 3, 1))
85+
86+
def docEq(origIndex: Int, responseIndex: Int) =
87+
rerankResponse.data(origIndex).document.getOrElse(
88+
throw new PineconeScalaClientException("Document missing")
89+
) should be(documents(responseIndex))
4090

91+
docEq(0, 2)
92+
docEq(1, 0)
93+
docEq(2, 3)
94+
docEq(3, 1)
95+
}
96+
}
4197
}
4298
}

0 commit comments

Comments
 (0)