Skip to content

Commit ec8247e

Browse files
authored
Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy version (#6736)
1 parent d84a3bc commit ec8247e

File tree

1 file changed

+52
-54
lines changed

1 file changed

+52
-54
lines changed

langchain/vectorstores/analyticdb.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -80,34 +80,34 @@ def create_table_if_not_exists(self) -> None:
8080
extend_existing=True,
8181
)
8282
with self.engine.connect() as conn:
83-
# Create the table
84-
Base.metadata.create_all(conn)
85-
86-
# Check if the index exists
87-
index_name = f"{self.collection_name}_embedding_idx"
88-
index_query = text(
89-
f"""
90-
SELECT 1
91-
FROM pg_indexes
92-
WHERE indexname = '{index_name}';
93-
"""
94-
)
95-
result = conn.execute(index_query).scalar()
83+
with conn.begin():
84+
# Create the table
85+
Base.metadata.create_all(conn)
9686

97-
# Create the index if it doesn't exist
98-
if not result:
99-
index_statement = text(
87+
# Check if the index exists
88+
index_name = f"{self.collection_name}_embedding_idx"
89+
index_query = text(
10090
f"""
101-
CREATE INDEX {index_name}
102-
ON {self.collection_name} USING ann(embedding)
103-
WITH (
104-
"dim" = {self.embedding_dimension},
105-
"hnsw_m" = 100
106-
);
91+
SELECT 1
92+
FROM pg_indexes
93+
WHERE indexname = '{index_name}';
10794
"""
10895
)
109-
conn.execute(index_statement)
110-
conn.commit()
96+
result = conn.execute(index_query).scalar()
97+
98+
# Create the index if it doesn't exist
99+
if not result:
100+
index_statement = text(
101+
f"""
102+
CREATE INDEX {index_name}
103+
ON {self.collection_name} USING ann(embedding)
104+
WITH (
105+
"dim" = {self.embedding_dimension},
106+
"hnsw_m" = 100
107+
);
108+
"""
109+
)
110+
conn.execute(index_statement)
111111

112112
def create_collection(self) -> None:
113113
if self.pre_delete_collection:
@@ -118,8 +118,8 @@ def delete_collection(self) -> None:
118118
self.logger.debug("Trying to delete collection")
119119
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
120120
with self.engine.connect() as conn:
121-
conn.execute(drop_statement)
122-
conn.commit()
121+
with conn.begin():
122+
conn.execute(drop_statement)
123123

124124
def add_texts(
125125
self,
@@ -160,30 +160,28 @@ def add_texts(
160160

161161
chunks_table_data = []
162162
with self.engine.connect() as conn:
163-
for document, metadata, chunk_id, embedding in zip(
164-
texts, metadatas, ids, embeddings
165-
):
166-
chunks_table_data.append(
167-
{
168-
"id": chunk_id,
169-
"embedding": embedding,
170-
"document": document,
171-
"metadata": metadata,
172-
}
173-
)
174-
175-
# Execute the batch insert when the batch size is reached
176-
if len(chunks_table_data) == batch_size:
163+
with conn.begin():
164+
for document, metadata, chunk_id, embedding in zip(
165+
texts, metadatas, ids, embeddings
166+
):
167+
chunks_table_data.append(
168+
{
169+
"id": chunk_id,
170+
"embedding": embedding,
171+
"document": document,
172+
"metadata": metadata,
173+
}
174+
)
175+
176+
# Execute the batch insert when the batch size is reached
177+
if len(chunks_table_data) == batch_size:
178+
conn.execute(insert(chunks_table).values(chunks_table_data))
179+
# Clear the chunks_table_data list for the next batch
180+
chunks_table_data.clear()
181+
182+
# Insert any remaining records that didn't make up a full batch
183+
if chunks_table_data:
177184
conn.execute(insert(chunks_table).values(chunks_table_data))
178-
# Clear the chunks_table_data list for the next batch
179-
chunks_table_data.clear()
180-
181-
# Insert any remaining records that didn't make up a full batch
182-
if chunks_table_data:
183-
conn.execute(insert(chunks_table).values(chunks_table_data))
184-
185-
# Commit the transaction only once after all records have been inserted
186-
conn.commit()
187185

188186
return ids
189187

@@ -333,9 +331,9 @@ def from_texts(
333331
) -> AnalyticDB:
334332
"""
335333
Return VectorStore initialized from texts and embeddings.
336-
Postgres connection string is required
334+
Postgres Connection string is required
337335
Either pass it as a parameter
338-
or set the PGVECTOR_CONNECTION_STRING environment variable.
336+
or set the PG_CONNECTION_STRING environment variable.
339337
"""
340338

341339
connection_string = cls.get_connection_string(kwargs)
@@ -363,7 +361,7 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
363361
raise ValueError(
364362
"Postgres connection string is required"
365363
"Either pass it as a parameter"
366-
"or set the PGVECTOR_CONNECTION_STRING environment variable."
364+
"or set the PG_CONNECTION_STRING environment variable."
367365
)
368366

369367
return connection_string
@@ -381,9 +379,9 @@ def from_documents(
381379
) -> AnalyticDB:
382380
"""
383381
Return VectorStore initialized from documents and embeddings.
384-
Postgres connection string is required
382+
Postgres Connection string is required
385383
Either pass it as a parameter
386-
or set the PGVECTOR_CONNECTION_STRING environment variable.
384+
or set the PG_CONNECTION_STRING environment variable.
387385
"""
388386

389387
texts = [d.page_content for d in documents]

0 commit comments

Comments
 (0)