@@ -80,34 +80,34 @@ def create_table_if_not_exists(self) -> None:
8080 extend_existing = True ,
8181 )
8282 with self .engine .connect () as conn :
83- # Create the table
84- Base .metadata .create_all (conn )
85-
86- # Check if the index exists
87- index_name = f"{ self .collection_name } _embedding_idx"
88- index_query = text (
89- f"""
90- SELECT 1
91- FROM pg_indexes
92- WHERE indexname = '{ index_name } ';
93- """
94- )
95- result = conn .execute (index_query ).scalar ()
83+ with conn .begin ():
84+ # Create the table
85+ Base .metadata .create_all (conn )
9686
97- # Create the index if it doesn't exist
98- if not result :
99- index_statement = text (
87+ # Check if the index exists
88+ index_name = f" { self . collection_name } _embedding_idx"
89+ index_query = text (
10090 f"""
101- CREATE INDEX { index_name }
102- ON { self .collection_name } USING ann(embedding)
103- WITH (
104- "dim" = { self .embedding_dimension } ,
105- "hnsw_m" = 100
106- );
91+ SELECT 1
92+ FROM pg_indexes
93+ WHERE indexname = '{ index_name } ';
10794 """
10895 )
109- conn .execute (index_statement )
110- conn .commit ()
96+ result = conn .execute (index_query ).scalar ()
97+
98+ # Create the index if it doesn't exist
99+ if not result :
100+ index_statement = text (
101+ f"""
102+ CREATE INDEX { index_name }
103+ ON { self .collection_name } USING ann(embedding)
104+ WITH (
105+ "dim" = { self .embedding_dimension } ,
106+ "hnsw_m" = 100
107+ );
108+ """
109+ )
110+ conn .execute (index_statement )
111111
112112 def create_collection (self ) -> None :
113113 if self .pre_delete_collection :
@@ -118,8 +118,8 @@ def delete_collection(self) -> None:
118118 self .logger .debug ("Trying to delete collection" )
119119 drop_statement = text (f"DROP TABLE IF EXISTS { self .collection_name } ;" )
120120 with self .engine .connect () as conn :
121- conn .execute ( drop_statement )
122- conn .commit ( )
121+ with conn .begin ():
122+ conn .execute ( drop_statement )
123123
124124 def add_texts (
125125 self ,
@@ -160,30 +160,28 @@ def add_texts(
160160
161161 chunks_table_data = []
162162 with self .engine .connect () as conn :
163- for document , metadata , chunk_id , embedding in zip (
164- texts , metadatas , ids , embeddings
165- ):
166- chunks_table_data .append (
167- {
168- "id" : chunk_id ,
169- "embedding" : embedding ,
170- "document" : document ,
171- "metadata" : metadata ,
172- }
173- )
174-
175- # Execute the batch insert when the batch size is reached
176- if len (chunks_table_data ) == batch_size :
163+ with conn .begin ():
164+ for document , metadata , chunk_id , embedding in zip (
165+ texts , metadatas , ids , embeddings
166+ ):
167+ chunks_table_data .append (
168+ {
169+ "id" : chunk_id ,
170+ "embedding" : embedding ,
171+ "document" : document ,
172+ "metadata" : metadata ,
173+ }
174+ )
175+
176+ # Execute the batch insert when the batch size is reached
177+ if len (chunks_table_data ) == batch_size :
178+ conn .execute (insert (chunks_table ).values (chunks_table_data ))
179+ # Clear the chunks_table_data list for the next batch
180+ chunks_table_data .clear ()
181+
182+ # Insert any remaining records that didn't make up a full batch
183+ if chunks_table_data :
177184 conn .execute (insert (chunks_table ).values (chunks_table_data ))
178- # Clear the chunks_table_data list for the next batch
179- chunks_table_data .clear ()
180-
181- # Insert any remaining records that didn't make up a full batch
182- if chunks_table_data :
183- conn .execute (insert (chunks_table ).values (chunks_table_data ))
184-
185- # Commit the transaction only once after all records have been inserted
186- conn .commit ()
187185
188186 return ids
189187
@@ -333,9 +331,9 @@ def from_texts(
333331 ) -> AnalyticDB :
334332 """
335333 Return VectorStore initialized from texts and embeddings.
336- Postgres connection string is required
334+ Postgres Connection string is required
337335 Either pass it as a parameter
338- or set the PGVECTOR_CONNECTION_STRING environment variable.
336+ or set the PG_CONNECTION_STRING environment variable.
339337 """
340338
341339 connection_string = cls .get_connection_string (kwargs )
@@ -363,7 +361,7 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
363361 raise ValueError (
364362 "Postgres connection string is required"
365363 "Either pass it as a parameter"
366- "or set the PGVECTOR_CONNECTION_STRING environment variable."
364+ "or set the PG_CONNECTION_STRING environment variable."
367365 )
368366
369367 return connection_string
@@ -381,9 +379,9 @@ def from_documents(
381379 ) -> AnalyticDB :
382380 """
383381 Return VectorStore initialized from documents and embeddings.
384- Postgres connection string is required
382+ Postgres Connection string is required
385383 Either pass it as a parameter
386- or set the PGVECTOR_CONNECTION_STRING environment variable.
384+ or set the PG_CONNECTION_STRING environment variable.
387385 """
388386
389387 texts = [d .page_content for d in documents ]
0 commit comments