|
1 | | -"""Retriever model for faiss: https://github.com/facebookresearch/faiss. |
2 | | -Author: Jagane Sundar: https://github.com/jagane. |
3 | | -""" |
4 | | - |
5 | | -import logging |
6 | | -from typing import Optional, Union |
7 | | - |
8 | | -import numpy as np |
9 | | - |
10 | | -import dspy |
11 | | -from dspy.dsp.modules.sentence_vectorizer import SentenceTransformersVectorizer |
12 | | -from dspy.dsp.utils import dotdict |
13 | | - |
14 | | -try: |
15 | | - import faiss |
16 | | -except ImportError: |
17 | | - faiss = None |
18 | | - |
19 | | -if faiss is None: |
20 | | - raise ImportError( |
21 | | - """ |
22 | | - The faiss package is required. Install it using `pip install dspy-ai[faiss-cpu]` |
23 | | - """, |
24 | | - ) |
25 | | - |
26 | | - |
27 | | -logger = logging.getLogger(__name__) |
28 | | -class FaissRM(dspy.Retrieve): |
29 | | - """A retrieval module that uses an in-memory Faiss to return the top passages for a given query. |
30 | | -
|
31 | | - Args: |
32 | | - document_chunks: the input text chunks |
33 | | - vectorizer: an object that is a subclass of BaseSentenceVectorizer |
34 | | - k (int, optional): The number of top passages to retrieve. Defaults to 3. |
35 | | -
|
36 | | - Returns: |
37 | | - dspy.Prediction: An object containing the retrieved passages. |
38 | | -
|
39 | | - Examples: |
40 | | - Below is a code snippet that shows how to use this as the default retriver: |
41 | | - ```python |
42 | | - import dspy |
43 | | - from dspy.retrieve import faiss_rm |
44 | | -
|
45 | | - document_chunks = [ |
46 | | - "The superbowl this year was played between the San Francisco 49ers and the Kanasas City Chiefs", |
47 | | - "Pop corn is often served in a bowl", |
48 | | - "The Rice Bowl is a Chinese Restaurant located in the city of Tucson, Arizona", |
49 | | - "Mars is the fourth planet in the Solar System", |
50 | | - "An aquarium is a place where children can learn about marine life", |
51 | | - "The capital of the United States is Washington, D.C", |
52 | | - "Rock and Roll musicians are honored by being inducted in the Rock and Roll Hall of Fame", |
53 | | - "Music albums were published on Long Play Records in the 70s and 80s", |
54 | | - "Sichuan cuisine is a spicy cuisine from central China", |
55 | | - "The interest rates for mortgages is considered to be very high in 2024", |
56 | | - ] |
57 | | -
|
58 | | - frm = faiss_rm.FaissRM(document_chunks) |
59 | | - turbo = dspy.OpenAI(model="gpt-3.5-turbo") |
60 | | - dspy.settings.configure(lm=turbo, rm=frm) |
61 | | - print(frm(["I am in the mood for Chinese food"])) |
62 | | - ``` |
63 | | -
|
64 | | - Below is a code snippet that shows how to use this in the forward() function of a module |
65 | | - ```python |
66 | | - self.retrieve = FaissRM(k=num_passages) |
67 | | - ``` |
68 | | - """ |
69 | | - |
70 | | - def __init__(self, document_chunks, vectorizer=None, k: int = 3): |
71 | | - """Inits the faiss retriever. |
72 | | -
|
73 | | - Args: |
74 | | - document_chunks: a list of input strings. |
75 | | - vectorizer: an object that is a subclass of BaseTransformersVectorizer. |
76 | | - k: number of matches to return. |
77 | | - """ |
78 | | - if vectorizer: |
79 | | - self._vectorizer = vectorizer |
80 | | - else: |
81 | | - self._vectorizer = SentenceTransformersVectorizer() |
82 | | - embeddings = self._vectorizer(document_chunks) |
83 | | - xb = np.array(embeddings) |
84 | | - d = len(xb[0]) |
85 | | - logger.info(f"FaissRM: embedding size={d}") |
86 | | - if len(xb) < 100: |
87 | | - self._faiss_index = faiss.IndexFlatL2(d) |
88 | | - self._faiss_index.add(xb) |
89 | | - else: |
90 | | - # if we have at least 100 vectors, we use Voronoi cells |
91 | | - nlist = 100 |
92 | | - quantizer = faiss.IndexFlatL2(d) |
93 | | - self._faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist) |
94 | | - self._faiss_index.train(xb) |
95 | | - self._faiss_index.add(xb) |
96 | | - |
97 | | - logger.info(f"{self._faiss_index.ntotal} vectors in faiss index") |
98 | | - self._document_chunks = document_chunks # save the input document chunks |
99 | | - |
100 | | - super().__init__(k=k) |
101 | | - |
102 | | - def _dump_raw_results(self, queries, index_list, distance_list) -> None: |
103 | | - for i in range(len(queries)): |
104 | | - indices = index_list[i] |
105 | | - distances = distance_list[i] |
106 | | - logger.debug(f"Query: {queries[i]}") |
107 | | - for j in range(len(indices)): |
108 | | - logger.debug(f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}") |
109 | | - return |
110 | | - |
111 | | - def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None, **kwargs) -> dspy.Prediction: |
112 | | - """Search the faiss index for k or self.k top passages for query. |
113 | | -
|
114 | | - Args: |
115 | | - query_or_queries (Union[str, List[str]]): The query or queries to search for. |
116 | | -
|
117 | | - Returns: |
118 | | - dspy.Prediction: An object containing the retrieved passages. |
119 | | - """ |
120 | | - queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries |
121 | | - queries = [q for q in queries if q] # Filter empty queries |
122 | | - embeddings = self._vectorizer(queries) |
123 | | - emb_npa = np.array(embeddings) |
124 | | - # For single query, just look up the top k passages |
125 | | - if len(queries) == 1: |
126 | | - distance_list, index_list = self._faiss_index.search(emb_npa, k or self.k) |
127 | | - # self._dump_raw_results(queries, index_list, distance_list) |
128 | | - passages = [(self._document_chunks[ind], ind) for ind in index_list[0]] |
129 | | - return [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages] |
130 | | - |
131 | | - distance_list, index_list = self._faiss_index.search(emb_npa, (k or self.k) * 3, **kwargs) |
132 | | - # self._dump_raw_results(queries, index_list, distance_list) |
133 | | - passage_scores = {} |
134 | | - for emb in range(len(embeddings)): |
135 | | - indices = index_list[emb] # indices of neighbors for embeddings[emb] - this is an array of k*3 integers |
136 | | - distances = distance_list[ |
137 | | - emb |
138 | | - ] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers |
139 | | - for res in range((k or self.k) * 3): |
140 | | - neighbor = indices[res] |
141 | | - distance = distances[res] |
142 | | - if neighbor in passage_scores: |
143 | | - passage_scores[neighbor].append(distance) |
144 | | - else: |
145 | | - passage_scores[neighbor] = [distance] |
146 | | - # Note re. sorting: |
147 | | - # first degree sort: number of queries that got a hit with any particular document chunk. More |
148 | | - # is a better match. This is len(queries)-len(x[1]) |
149 | | - # second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match |
150 | | - sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[ |
151 | | - : k or self.k |
152 | | - ] |
153 | | - return [ |
154 | | - dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) |
155 | | - for passage_index, _ in sorted_passages |
156 | | - ] |
| 1 | +# """Retriever model for faiss: https://github.com/facebookresearch/faiss. |
| 2 | +# Author: Jagane Sundar: https://github.com/jagane. |
| 3 | +# """ |
| 4 | + |
| 5 | +# import logging |
| 6 | +# from typing import Optional, Union |
| 7 | + |
| 8 | +# import numpy as np |
| 9 | + |
| 10 | +# import dspy |
| 11 | +# from dspy.dsp.modules.sentence_vectorizer import SentenceTransformersVectorizer |
| 12 | +# from dspy.dsp.utils import dotdict |
| 13 | + |
| 14 | +# try: |
| 15 | +# import faiss |
| 16 | +# except ImportError: |
| 17 | +# faiss = None |
| 18 | + |
| 19 | +# if faiss is None: |
| 20 | +# raise ImportError( |
| 21 | +# """ |
| 22 | +# The faiss package is required. Install it using `pip install dspy-ai[faiss-cpu]` |
| 23 | +# """, |
| 24 | +# ) |
| 25 | + |
| 26 | + |
| 27 | +# logger = logging.getLogger(__name__) |
| 28 | +# class FaissRM(dspy.Retrieve): |
| 29 | +# """A retrieval module that uses an in-memory Faiss to return the top passages for a given query. |
| 30 | + |
| 31 | +# Args: |
| 32 | +# document_chunks: the input text chunks |
| 33 | +# vectorizer: an object that is a subclass of BaseSentenceVectorizer |
| 34 | +# k (int, optional): The number of top passages to retrieve. Defaults to 3. |
| 35 | + |
| 36 | +# Returns: |
| 37 | +# dspy.Prediction: An object containing the retrieved passages. |
| 38 | + |
| 39 | +# Examples: |
| 40 | +# Below is a code snippet that shows how to use this as the default retriver: |
| 41 | +# ```python |
| 42 | +# import dspy |
| 43 | +# from dspy.retrieve import faiss_rm |
| 44 | + |
| 45 | +# document_chunks = [ |
| 46 | +# "The superbowl this year was played between the San Francisco 49ers and the Kanasas City Chiefs", |
| 47 | +# "Pop corn is often served in a bowl", |
| 48 | +# "The Rice Bowl is a Chinese Restaurant located in the city of Tucson, Arizona", |
| 49 | +# "Mars is the fourth planet in the Solar System", |
| 50 | +# "An aquarium is a place where children can learn about marine life", |
| 51 | +# "The capital of the United States is Washington, D.C", |
| 52 | +# "Rock and Roll musicians are honored by being inducted in the Rock and Roll Hall of Fame", |
| 53 | +# "Music albums were published on Long Play Records in the 70s and 80s", |
| 54 | +# "Sichuan cuisine is a spicy cuisine from central China", |
| 55 | +# "The interest rates for mortgages is considered to be very high in 2024", |
| 56 | +# ] |
| 57 | + |
| 58 | +# frm = faiss_rm.FaissRM(document_chunks) |
| 59 | +# turbo = dspy.OpenAI(model="gpt-3.5-turbo") |
| 60 | +# dspy.settings.configure(lm=turbo, rm=frm) |
| 61 | +# print(frm(["I am in the mood for Chinese food"])) |
| 62 | +# ``` |
| 63 | + |
| 64 | +# Below is a code snippet that shows how to use this in the forward() function of a module |
| 65 | +# ```python |
| 66 | +# self.retrieve = FaissRM(k=num_passages) |
| 67 | +# ``` |
| 68 | +# """ |
| 69 | + |
| 70 | +# def __init__(self, document_chunks, vectorizer=None, k: int = 3): |
| 71 | +# """Inits the faiss retriever. |
| 72 | + |
| 73 | +# Args: |
| 74 | +# document_chunks: a list of input strings. |
| 75 | +# vectorizer: an object that is a subclass of BaseTransformersVectorizer. |
| 76 | +# k: number of matches to return. |
| 77 | +# """ |
| 78 | +# if vectorizer: |
| 79 | +# self._vectorizer = vectorizer |
| 80 | +# else: |
| 81 | +# self._vectorizer = SentenceTransformersVectorizer() |
| 82 | +# embeddings = self._vectorizer(document_chunks) |
| 83 | +# xb = np.array(embeddings) |
| 84 | +# d = len(xb[0]) |
| 85 | +# logger.info(f"FaissRM: embedding size={d}") |
| 86 | +# if len(xb) < 100: |
| 87 | +# self._faiss_index = faiss.IndexFlatL2(d) |
| 88 | +# self._faiss_index.add(xb) |
| 89 | +# else: |
| 90 | +# # if we have at least 100 vectors, we use Voronoi cells |
| 91 | +# nlist = 100 |
| 92 | +# quantizer = faiss.IndexFlatL2(d) |
| 93 | +# self._faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist) |
| 94 | +# self._faiss_index.train(xb) |
| 95 | +# self._faiss_index.add(xb) |
| 96 | + |
| 97 | +# logger.info(f"{self._faiss_index.ntotal} vectors in faiss index") |
| 98 | +# self._document_chunks = document_chunks # save the input document chunks |
| 99 | + |
| 100 | +# super().__init__(k=k) |
| 101 | + |
| 102 | +# def _dump_raw_results(self, queries, index_list, distance_list) -> None: |
| 103 | +# for i in range(len(queries)): |
| 104 | +# indices = index_list[i] |
| 105 | +# distances = distance_list[i] |
| 106 | +# logger.debug(f"Query: {queries[i]}") |
| 107 | +# for j in range(len(indices)): |
| 108 | +# logger.debug(f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}") |
| 109 | +# return |
| 110 | + |
| 111 | +# def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None, **kwargs) -> dspy.Prediction: |
| 112 | +# """Search the faiss index for k or self.k top passages for query. |
| 113 | + |
| 114 | +# Args: |
| 115 | +# query_or_queries (Union[str, List[str]]): The query or queries to search for. |
| 116 | + |
| 117 | +# Returns: |
| 118 | +# dspy.Prediction: An object containing the retrieved passages. |
| 119 | +# """ |
| 120 | +# queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries |
| 121 | +# queries = [q for q in queries if q] # Filter empty queries |
| 122 | +# embeddings = self._vectorizer(queries) |
| 123 | +# emb_npa = np.array(embeddings) |
| 124 | +# # For single query, just look up the top k passages |
| 125 | +# if len(queries) == 1: |
| 126 | +# distance_list, index_list = self._faiss_index.search(emb_npa, k or self.k) |
| 127 | +# # self._dump_raw_results(queries, index_list, distance_list) |
| 128 | +# passages = [(self._document_chunks[ind], ind) for ind in index_list[0]] |
| 129 | +# return [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages] |
| 130 | + |
| 131 | +# distance_list, index_list = self._faiss_index.search(emb_npa, (k or self.k) * 3, **kwargs) |
| 132 | +# # self._dump_raw_results(queries, index_list, distance_list) |
| 133 | +# passage_scores = {} |
| 134 | +# for emb in range(len(embeddings)): |
| 135 | +# indices = index_list[emb] # indices of neighbors for embeddings[emb] - this is an array of k*3 integers |
| 136 | +# distances = distance_list[ |
| 137 | +# emb |
| 138 | +# ] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers |
| 139 | +# for res in range((k or self.k) * 3): |
| 140 | +# neighbor = indices[res] |
| 141 | +# distance = distances[res] |
| 142 | +# if neighbor in passage_scores: |
| 143 | +# passage_scores[neighbor].append(distance) |
| 144 | +# else: |
| 145 | +# passage_scores[neighbor] = [distance] |
| 146 | +# # Note re. sorting: |
| 147 | +# # first degree sort: number of queries that got a hit with any particular document chunk. More |
| 148 | +# # is a better match. This is len(queries)-len(x[1]) |
| 149 | +# # second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match |
| 150 | +# sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[ |
| 151 | +# : k or self.k |
| 152 | +# ] |
| 153 | +# return [ |
| 154 | +# dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) |
| 155 | +# for passage_index, _ in sorted_passages |
| 156 | +# ] |
0 commit comments