Skip to content

Commit ca278e1

Browse files
committed
New RAG example
Signed-off-by: Ed Snible <[email protected]>
1 parent 4d051d2 commit ca278e1

File tree

9 files changed

+269
-3
lines changed

9 files changed

+269
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ pdl-live/package-lock.json
151151
*_result.yaml
152152
*_trace.json
153153

154+
# Demo files
155+
pdl-rag-demo.db
156+
154157
# Built docs
155158
_site
156159

examples/rag/README.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,30 @@
1-
This example requires you to install:
1+
This example uses [Ollama](../../tutorial/#using-ollama-models). Fetch the models used in this example with
2+
3+
```bash
4+
ollama pull mxbai-embed-large
5+
ollama pull granite-code:8b
26
```
3-
pip install scikit-learn
4-
```
7+
8+
This example requires you to install pypdf, langchain, langchain-community, and milvus.
9+
10+
```bash
11+
pip install pypdf milvus langchain langchain-community
12+
```
13+
14+
To run the demo, first load a PDF document into the vector database:
15+
16+
```bash
17+
pdl examples/rag/pdf_index.pdl
18+
```
19+
20+
After the data has loaded, the program prints "Success!"
21+
22+
Next, query the vector database for relevant text and use that text in a query to an LLM:
23+
24+
```bash
25+
pdl examples/rag/pdf_query.pdl
26+
```
27+
28+
This PDL program computes a data structure containing all questions and answers. It is printed at the end.
29+
30+
To cleanup, run `rm pdl-rag-demo.db`.

examples/rag/pdf_index.pdl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Load PDF document into vector database
2+
3+
description: Load document into vector database
4+
text:
5+
- include: rag_library1.pdl
6+
- call: ${ pdf_parse }
7+
args:
8+
filename: "docs/assets/pdl_quick_reference.pdf"
9+
chunk_size: 400
10+
chunk_overlap: 100
11+
def: input_data
12+
contribute: []
13+
- call: ${ rag_index }
14+
args:
15+
inp: ${ input_data }
16+
encoder_model: "ollama/mxbai-embed-large"
17+
embed_dimension: 1024
18+
database_name: "./pdl-rag-demo.db"
19+
collection_name: "pdl_rag_collection"
20+
contribute: []
21+
- "Success!"

examples/rag/pdf_query.pdl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Query vector database for relevant passages; use passages to query LLM.
2+
3+
defs:
4+
QUESTIONS:
5+
data: [
6+
"Does PDL have a contribute keyword?",
7+
"Is Brooklyn the capital of New York?"
8+
]
9+
CONCLUSIONS:
10+
lang: python
11+
code: "result = {}"
12+
text:
13+
- include: rag_library1.pdl
14+
- lastOf:
15+
- for:
16+
question: ${ QUESTIONS }
17+
repeat:
18+
array:
19+
# Define MATCHING_PASSAGES as the text retrieved from the vector DB
20+
- def: MATCHING_PASSAGES
21+
call: ${ rag_retrieve }
22+
args:
23+
# I am passing the client in implicitly. NOT WHAT I WANT
24+
inp: ${ question }
25+
encoder_model: "ollama/mxbai-embed-large"
26+
limit: 3
27+
collection_name: "pdl_rag_collection"
28+
database_name: "./pdl-rag-demo.db"
29+
# - lang: python
30+
# code: |
31+
# print(f"MATCHING_PASSAGES='{MATCHING_PASSAGES}'")
32+
# result = None
33+
- model: ollama/granite-code:8b
34+
def: CONCLUSION
35+
input: >
36+
Here is some information:
37+
${ MATCHING_PASSAGES }
38+
Question: ${ question }
39+
Answer:
40+
parameters:
41+
# I couldn't get this working
42+
stop_sequences: ['Yes', 'No']
43+
temperature: 0
44+
- lang: python
45+
code: |
46+
# split()[0] needed because of stop_sequences not working
47+
# print(f"CONCLUSION={CONCLUSION}")
48+
CONCLUSIONS[question] = CONCLUSION.split()[0]
49+
result = "dummy"
50+
contribute: []
51+
- text:
52+
- "${ CONCLUSIONS | tojson }\n"

examples/rag/rag.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from langchain.text_splitter import RecursiveCharacterTextSplitter
2+
from langchain_community.document_loaders import PyPDFLoader
3+
from litellm import embedding
4+
from pymilvus import MilvusClient
5+
from pymilvus.exceptions import MilvusException
6+
7+
8+
def parse(filename: str, chunk_size: int, chunk_overlap: int) -> list[str]:
9+
loader = PyPDFLoader(filename)
10+
11+
docs = loader.load()
12+
# 'docs' will be a list[langchain_core.documents.base.Document],
13+
# one entry per page. We don't want to return this, because PDL only
14+
# wants types that work in JSON schemas.
15+
16+
text_splitter = RecursiveCharacterTextSplitter(
17+
chunk_size=chunk_size,
18+
chunk_overlap=chunk_overlap,
19+
length_function=len,
20+
is_separator_regex=False,
21+
)
22+
23+
split_docs = text_splitter.split_documents(docs)
24+
25+
# Note that this throws away the metadata.
26+
return [doc.page_content for doc in split_docs]
27+
28+
29+
def rag_index(
30+
inp: list[str],
31+
encoder_model: str,
32+
embed_dimension: int,
33+
database_name: str,
34+
collection_name: str,
35+
):
36+
37+
# Have LiteLLM embed the passages
38+
response = embedding(
39+
model=encoder_model,
40+
input=inp,
41+
)
42+
43+
client = MilvusClient(
44+
database_name
45+
) # Use URL if talking to remote Milvus (non-Lite)
46+
47+
if client.has_collection(collection_name=collection_name):
48+
client.drop_collection(collection_name=collection_name)
49+
client.create_collection(
50+
collection_name=collection_name, dimension=embed_dimension, overwrite=True
51+
)
52+
53+
mid = 0 # There is also an auto-id feature in Milvus, which we are not using
54+
for text in inp:
55+
vector = response.data[id]["embedding"] # type: ignore
56+
client.insert(
57+
collection_name=collection_name,
58+
data=[
59+
{
60+
"id": mid,
61+
"text": text,
62+
"vector": vector,
63+
# We SHOULD set "source" and "url" based on the metadata we threw away in parse()
64+
}
65+
],
66+
)
67+
mid = mid + 1
68+
69+
return True
70+
71+
72+
# Global cache of database clients.
73+
# (We do this so the PDL programmer doesn't need to explicitly maintain the client connection)
74+
DATABASE_CLIENTS: dict[str, MilvusClient] = {}
75+
76+
77+
def get_or_create_client(database_name: str):
78+
if database_name in DATABASE_CLIENTS:
79+
return DATABASE_CLIENTS[database_name]
80+
81+
client = MilvusClient(
82+
database_name
83+
) # Use URL if talking to remote Milvus (non-Lite)
84+
DATABASE_CLIENTS[database_name] = client
85+
return client
86+
87+
88+
# Search vector database collection for input.
89+
# The output is 'limit' vectors, as strings, concatenated together
90+
def rag_retrieve(
91+
inp: str, encoder_model: str, limit: int, database_name: str, collection_name: str
92+
) -> str:
93+
# Embed the question as a vector
94+
try:
95+
response = embedding(
96+
model=encoder_model,
97+
input=[inp],
98+
)
99+
except MilvusException: # This is usually a APIConnectionError
100+
# Retry because of https://github.com/BerriAI/litellm/issues/7667
101+
response = embedding(
102+
model=encoder_model,
103+
input=[inp],
104+
)
105+
106+
data = response.data[0]["embedding"]
107+
108+
milvus_client = get_or_create_client(database_name)
109+
search_res = milvus_client.search(
110+
collection_name=collection_name,
111+
data=[data],
112+
limit=limit, # Return top n results
113+
search_params={"metric_type": "COSINE", "params": {}},
114+
output_fields=["text"], # Return the text field
115+
)
116+
117+
# Note that this throws away document metadata (if any)
118+
return "\n".join([res["entity"]["text"] for res in search_res[0]])

examples/rag/rag_library1.pdl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# This module can be included from a PDL program to bring in Python functions.
2+
3+
description: RAG library for PDL
4+
text:
5+
- def: pdf_parse
6+
function:
7+
filename: str
8+
chunk_size: int
9+
chunk_overlap: int
10+
return:
11+
lang: python
12+
code: |
13+
from examples.rag.rag import parse
14+
result = parse(filename, chunk_size, chunk_overlap)
15+
- def: rag_index
16+
function:
17+
inp: list # This is a list[str], but PDL doesn't allow that type
18+
encoder_model: str
19+
embed_dimension: int
20+
database_name: str # optional, could also be URL?
21+
collection_name: str
22+
return:
23+
lang: python
24+
code: |
25+
from examples.rag.rag import rag_index
26+
result = rag_index(inp, encoder_model, embed_dimension, database_name, collection_name)
27+
- def: rag_retrieve
28+
function:
29+
inp: str
30+
encoder_model: str
31+
limit: int
32+
collection_name: str
33+
database_name: str # optional, could also be URL?
34+
return:
35+
lang: python
36+
code: |
37+
from examples.rag.rag import rag_retrieve
38+
result = rag_retrieve(inp, encoder_model, limit, database_name, collection_name)

examples/tfidf_rag/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
This example requires you to install:
2+
```
3+
pip install scikit-learn
4+
```
File renamed without changes.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ dev = [
3939
"pydantic~=2.9"
4040
]
4141
examples = [
42+
"pymilvus~=2.5",
43+
"langchain~=0.3",
44+
"langchain-community~=0.3",
45+
"pypdf~==5.2",
4246
"wikipedia~=1.0",
4347
"textdistance~=4.0",
4448
"datasets>2,<4",

0 commit comments

Comments
 (0)