diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml deleted file mode 100644 index 9fcae12..0000000 --- a/.github/workflows/release-publish.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: publish-and-release - -on: - create: - tags: - - 'v*' - push: - branches: [ "master", "dev" ] - -jobs: - publish-to-pypi: - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@master - - - name: Set up Python 3.10 - uses: actions/setup-python@v1 - with: - python-version: '3.10.13' - - - name: Install build - run: >- - pip install -r requirements.txt && - python setup.py sdist - - - name: Publish distribution 📦 to Test PyPI - if: startsWith(github.ref, 'refs/dev') - uses: pypa/gh-action-pypi-publish@master - with: - password: ${{ secrets.TEST_PYPI_API_TOKEN }} - repository_url: https://test.pypi.org/legacy/ - skip_existing: true - - - name: Publish distribution 📦 to PyPI - if: startsWith(github.ref, 'refs/tags/') - uses: pypa/gh-action-pypi-publish@master - with: - password: ${{ secrets.PYPI_API_TOKEN }} - - github-release: - name: github-release - needs: - - publish-to-pypi - runs-on: ubuntu-latest - permissions: - contents: write - id-token: write - steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: dist/ - - name: Sign the dists with Sigstore - uses: sigstore/gh-action-sigstore-python@v1.2.3 - with: - inputs: >- - ./dist/*.tar.gz - ./dist/*.whl - - name: Create GitHub Release - env: - GITHUB_TOKEN: ${{ github.token }} - run: >- - gh release create - '${{ github.ref_name }}' - --repo '${{ github.repository }}' - --notes "" - - name: Upload artifact signatures to GitHub Release - env: - GITHUB_TOKEN: ${{ github.token }} - run: >- - gh release upload - '${{ github.ref_name }}' dist/** - --repo '${{ github.repository }}' \ No newline at end of file diff --git a/.gitignore b/.gitignore index fd2b633..6af4b0d 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,9 @@ target/ .DS_Store .venv + +# test +test/.pytest_cache/ +test/log/ + +**/allure-report diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..95f48ae --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + args: ["-l", "120"] + language_version: python3 + + - repo: https://github.com/myint/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + args: ["--in-place", "--remove-all-unused-imports"] + exclude: "__init__.py$" + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: "v4.0.0-alpha.8" # Use the ref you want to point at + hooks: + - id: prettier + args: ["--write"] diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..6356387 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +prune test diff --git a/README.md b/README.md index d00ba7f..2ccd825 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # TaskingAI-client + The TaskingAI Python client for creating and managing AI-driven applications. For more information, see the docs at [TaskingAI Documentation](https://docs.tasking.ai/) @@ -6,6 +7,7 @@ For more information, see the docs at [TaskingAI Documentation](https://docs.tas ## Installation Install the latest released version using pip: + ```shell pip3 install taskingai ``` @@ -15,10 +17,13 @@ pip3 install taskingai Here's how you can quickly start building and managing AI-driven applications using the TaskingAI client. ### Assistants + Explore the ease of creating and customizing your own AI assistants with TaskingAI to enhance user interactions. + ```python import taskingai from taskingai.assistant import * +from taskingai.assistant.memory import AssistantNaiveMemory # Initialize your API key if you haven't already set it in the environment taskingai.init(api_key="YOUR_API_KEY") @@ -26,8 +31,7 @@ taskingai.init(api_key="YOUR_API_KEY") # Create an assistant assistant = create_assistant( model_id="YOUR_MODEL_ID", - name="My Assistant", - description="An assistant that understands numbers.", + memory=AssistantNaiveMemory(), system_prompt_template=["You are a professional assistant."], ) print(f"Assistant created: {assistant.id}") @@ -49,7 +53,9 @@ print("Assistant deleted successfully.") ``` ### Retrieval + Leverage TaskingAI's retrieval capabilities to store, manage, and extract information, making your applications smarter and more responsive. + ```python import taskingai from taskingai.retrieval import * @@ -62,9 +68,10 @@ collection = create_collection( print(f"Collection created: {collection.id}") # Add a record to the collection -record = create_text_record( +record = create_record( collection_id=collection.id, - text="Example text for machine learning." + content="Example text for machine learning.", + text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20), ) print(f"Record added to collection: {record.id}") @@ -88,7 +95,9 @@ print("Collection deleted.") ``` ### Tools + Utilize TaskingAI's tools to create actions that enable your assistant to interact with external APIs and services, enriching the user experience. + ```python import taskingai from taskingai.tool import * @@ -121,8 +130,3 @@ print("Action deleted.") ## Contributing We welcome contributions of all kinds. Please read our [Contributing Guidelines](./CONTRIBUTING.md) for more information on how to get started. - -## Security - -For any security concerns or issues, please reach out to us directly at support@tasking.ai. - diff --git a/examples/assistant/chat_with_assistant.ipynb b/examples/assistant/chat_with_assistant.ipynb index 13d9335..ba4e54b 100644 --- a/examples/assistant/chat_with_assistant.ipynb +++ b/examples/assistant/chat_with_assistant.ipynb @@ -91,7 +91,7 @@ " }\n", "}\n", "actions: List[Action] = taskingai.tool.bulk_create_actions(\n", - " schema=NUMBERS_API_SCHEMA,\n", + " openapi_schema=NUMBERS_API_SCHEMA,\n", " authentication=ActionAuthentication(\n", " type=ActionAuthenticationType.NONE,\n", " )\n", @@ -250,6 +250,7 @@ "messages = taskingai.assistant.list_messages(\n", " assistant_id=assistant.assistant_id,\n", " chat_id=chat.chat_id,\n", + " order=\"asc\"\n", ")\n", "for message in messages:\n", " print(f\"{message.role}: {message.content.text}\")" @@ -257,7 +258,7 @@ "metadata": { "collapsed": false }, - "id": "34bae46ba56752bb" + "id": "e94e3adb0d15373b" }, { "cell_type": "code", @@ -278,19 +279,11 @@ "cell_type": "code", "execution_count": null, "outputs": [], - "source": [ - "# list messages\n", - "messages = taskingai.assistant.list_messages(\n", - " assistant_id=assistant.assistant_id,\n", - " chat_id=chat.chat_id,\n", - ")\n", - "for message in messages:\n", - " print(f\"{message.role}: {message.content.text}\")" - ], + "source": [], "metadata": { "collapsed": false }, - "id": "e94e3adb0d15373b" + "id": "9cfed1128acbdd30" } ], "metadata": { diff --git a/examples/crud/retrieval_crud.ipynb b/examples/crud/retrieval_crud.ipynb index 1762313..68450cd 100644 --- a/examples/crud/retrieval_crud.ipynb +++ b/examples/crud/retrieval_crud.ipynb @@ -10,7 +10,6 @@ "outputs": [], "source": [ "import taskingai\n", - "import time\n", "# Load TaskingAI API Key from environment variable" ] }, @@ -29,7 +28,7 @@ "execution_count": null, "outputs": [], "source": [ - "from taskingai.retrieval import Collection, Record\n", + "from taskingai.retrieval import Collection, Record, TokenTextSplitter\n", "\n", "# choose an available text_embedding model from your project\n", "embedding_model_id = \"YOUR_MODEL_ID\"" @@ -49,6 +48,19 @@ }, "id": "a6874f1ff8ec5a9c" }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "collections = taskingai.retrieval.list_collections()\n", + "print(collections)" + ], + "metadata": { + "collapsed": false + }, + "id": "81ec82280d5c8c64" + }, { "cell_type": "code", "execution_count": null, @@ -146,14 +158,6 @@ }, "id": "1b7688a3cf40c241" }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "collapsed": false - }, - "id": "dbc3aafe16758b4c" - }, { "cell_type": "code", "execution_count": null, @@ -161,10 +165,7 @@ "source": [ "# create a new collection\n", "collection: Collection = create_collection()\n", - "print(collection)\n", - "\n", - "# wait for the collection creation to finish\n", - "time.sleep(3)" + "print(collection)" ], "metadata": { "collapsed": false @@ -177,9 +178,10 @@ "outputs": [], "source": [ "# create a new text record\n", - "record: Record = taskingai.retrieval.create_text_record(\n", + "record: Record = taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", - " text=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\"\n", + " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", + " text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20)\n", ")\n", "print(f\"created record: {record.record_id} for collection: {collection.collection_id}\\n\")" ], diff --git a/examples/crud/tool_crud.ipynb b/examples/crud/tool_crud.ipynb index f865ff2..865b426 100644 --- a/examples/crud/tool_crud.ipynb +++ b/examples/crud/tool_crud.ipynb @@ -99,7 +99,7 @@ " }\n", "}\n", "actions: List[Action] = taskingai.tool.bulk_create_actions(\n", - " schema=NUMBERS_API_SCHEMA,\n", + " openapi_schema=NUMBERS_API_SCHEMA,\n", " authentication=ActionAuthentication(\n", " type=ActionAuthenticationType.NONE,\n", " )\n", @@ -139,7 +139,7 @@ "NUMBERS_API_SCHEMA[\"paths\"][\"/{number}\"][\"get\"][\"summary\"] = \"Get fun fact about a number)\"\n", "action: Action = taskingai.tool.update_action(\n", " action_id=action_id,\n", - " schema=NUMBERS_API_SCHEMA\n", + " openapi_schema=NUMBERS_API_SCHEMA\n", ")\n", "\n", "print(f\"updated action: {action}\\n\")" @@ -197,14 +197,6 @@ "collapsed": false }, "id": "5a1a36d15055918f" - }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "collapsed": false - }, - "id": "b1736bf2e80c2dd6" } ], "metadata": { diff --git a/examples/inference/chat_completion.ipynb b/examples/inference/chat_completion.ipynb index 35ef316..4253aa5 100644 --- a/examples/inference/chat_completion.ipynb +++ b/examples/inference/chat_completion.ipynb @@ -56,8 +56,8 @@ "chat_completion_result = taskingai.inference.chat_completion(\n", " model_id=model_id,\n", " messages=[\n", - " SystemMessage(\"You are a professional assistant.\"),\n", - " UserMessage(\"Hi\"),\n", + " SystemMessage(\"You are an assistant specialized in productivity and time management strategies.\"),\n", + " UserMessage(\"I'm struggling with managing my time effectively. Do you have any tips?\"),\n", " ]\n", ")\n", "chat_completion_result" @@ -82,15 +82,16 @@ "execution_count": null, "outputs": [], "source": [ + "# Multiple Round Interaction with Specific System Prompt\n", "chat_completion_result = taskingai.inference.chat_completion(\n", " model_id=model_id,\n", " messages=[\n", - " SystemMessage(\"You are a professional assistant.\"),\n", - " UserMessage(\"Hi\"),\n", - " AssistantMessage(\"Hello! How can I assist you today?\"),\n", - " UserMessage(\"Can you tell me a joke?\"),\n", - " AssistantMessage(\"Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!\"),\n", - " UserMessage(\"That's funny. Can you tell me another one?\"),\n", + " SystemMessage(\"You are an assistant with extensive knowledge in literature and book recommendations.\"),\n", + " UserMessage(\"Hello, I'm looking for book recommendations.\"),\n", + " AssistantMessage(\"Certainly! What genre are you interested in?\"),\n", + " UserMessage(\"I love science fiction.\"),\n", + " AssistantMessage(\"Great choice! I recommend 'Dune' by Frank Herbert for its rich world-building and 'Neuromancer' by William Gibson for its cyberpunk influence.\"),\n", + " UserMessage(\"Thanks! Can you also suggest something in non-fiction?\"),\n", " ]\n", ")\n", "chat_completion_result" @@ -109,12 +110,12 @@ "chat_completion_result = taskingai.inference.chat_completion(\n", " model_id=model_id,\n", " messages=[\n", - " SystemMessage(\"You are a professional assistant.\"),\n", - " UserMessage(\"Hi\"),\n", - " AssistantMessage(\"Hello! How can I assist you today?\"),\n", - " UserMessage(\"Can you tell me a joke?\"),\n", - " AssistantMessage(\"Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!\"),\n", - " UserMessage(\"That's funny. Can you tell me another one?\"),\n", + " SystemMessage(\"You are an assistant with extensive knowledge in literature and book recommendations.\"),\n", + " UserMessage(\"Hello, I'm looking for book recommendations.\"),\n", + " AssistantMessage(\"Certainly! What genre are you interested in?\"),\n", + " UserMessage(\"I love science fiction.\"),\n", + " AssistantMessage(\"Great choice! I recommend 'Dune' by Frank Herbert for its rich world-building and 'Neuromancer' by William Gibson for its cyberpunk influence.\"),\n", + " UserMessage(\"Thanks! Can you also suggest something in non-fiction?\"),\n", " ],\n", " configs={\n", " \"max_tokens\": 5\n", @@ -198,6 +199,7 @@ " return a + b\n", "\n", "arguments = function_call_assistant_message.function_calls[0].arguments\n", + "function_id = function_call_assistant_message.function_calls[0].id\n", "function_call_result = plus_a_and_b(**arguments)\n", "print(f\"function_call_result = {function_call_result}\")" ], @@ -217,7 +219,7 @@ " messages=[\n", " UserMessage(\"What is the result of 112 plus 22?\"),\n", " function_call_assistant_message,\n", - " FunctionMessage(id=\"FUNCTION_ID\", content=str(function_call_result))\n", + " FunctionMessage(id=function_id, content=str(function_call_result))\n", " ],\n", " functions=[function]\n", ")\n", @@ -263,16 +265,6 @@ "collapsed": false }, "id": "4f3290f87635152a" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "e98109211cacfbb1" } ], "metadata": { diff --git a/examples/inference/text_embedding.ipynb b/examples/inference/text_embedding.ipynb index fbea8a1..1943800 100644 --- a/examples/inference/text_embedding.ipynb +++ b/examples/inference/text_embedding.ipynb @@ -28,8 +28,6 @@ "execution_count": null, "outputs": [], "source": [ - "from taskingai.inference import *\n", - "import json\n", "# choose an available text embedding model from your project\n", "model_id = \"YOUR_MODEL_ID\"" ], diff --git a/examples/retrieval/semantic_search.ipynb b/examples/retrieval/semantic_search.ipynb index 4f33607..3a98a3c 100644 --- a/examples/retrieval/semantic_search.ipynb +++ b/examples/retrieval/semantic_search.ipynb @@ -11,7 +11,8 @@ "source": [ "import taskingai\n", "# Load TaskingAI API Key from environment variable\n", - "from taskingai.retrieval import Collection, Record, CollectionConfig" + "from taskingai.retrieval import Collection\n", + "from taskingai.retrieval.text_splitter import TokenTextSplitter" ] }, { @@ -57,10 +58,6 @@ " collection: Collection = taskingai.retrieval.create_collection(\n", " embedding_model_id=embedding_model_id,\n", " capacity=1000, # maximum text chunks can be stored\n", - " configs=CollectionConfig(\n", - " chunk_size=100, # maximum tokens of each chunk\n", - " chunk_overlap=0, # token overlap between chunks\n", - " )\n", " )\n", " return collection\n", "\n", @@ -72,15 +69,34 @@ }, "id": "7c7d4e2cc2f2f494" }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Check collection status. \n", + "# Only when status is \"READY\" can you insert records and query chunks.\n", + "collection = taskingai.retrieval.get_collection(collection_id=collection.collection_id)\n", + "print(f\"collection status: {collection.status}\")" + ], + "metadata": { + "collapsed": false + }, + "id": "eb5dee18aa83c5e4" + }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# create record 1 (machine learning)\n", - "taskingai.retrieval.create_text_record(\n", + "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", - " text=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\"\n", + " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", + " text_splitter=TokenTextSplitter(\n", + " chunk_size=100, # maximum tokens of each chunk\n", + " chunk_overlap=10, # token overlap between chunks\n", + " ),\n", ")" ], "metadata": { @@ -94,9 +110,13 @@ "outputs": [], "source": [ "# create record 2 (Michael Jordan)\n", - "taskingai.retrieval.create_text_record(\n", + "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", - " text=\"Michael Jordan, often referred to by his initials MJ, is considered one of the greatest players in the history of the National Basketball Association (NBA). He was known for his scoring ability, defensive prowess, competitiveness, and clutch performances. Born on February 17, 1963, Jordan played 15 seasons in the NBA, primarily with the Chicago Bulls, but also with the Washington Wizards. His professional career spanned two decades from 1984 to 2003, during which he won numerous awards and set multiple records. Here are some key highlights of his career: - Scoring: Jordan won the NBA scoring title a record 10 times. He also has the highest career scoring average in NBA history, both in the regular season (30.12 points per game) and in the playoffs (33.45 points per game). - Championships: He led the Chicago Bulls to six NBA championships and was named Finals MVP in all six of those Finals (1991-1993, 1996-1998). - MVP Awards: Jordan was named the NBA's Most Valuable Player (MVP) five times (1988, 1991, 1992, 1996, 1998). - Defensive Ability: He was named to the NBA All-Defensive First Team nine times and won the NBA Defensive Player of the Year award in 1988. - Olympics: Jordan also won two Olympic gold medals with the U.S. basketball team, in 1984 and 1992. - Retirements and Comebacks: Jordan retired twice during his career. His first retirement came in 1993, after which he briefly played minor league baseball. He returned to the NBA in 1995. He retired a second time in 1999, only to return again in 2001, this time with the Washington Wizards. He played two seasons for the Wizards before retiring for good in 2003. After his playing career, Jordan became a team owner and executive. As of my knowledge cutoff in September 2021, he is the majority owner of the Charlotte Hornets. Off the court, Jordan is known for his lucrative endorsement deals, particularly with Nike. The Air Jordan line of sneakers is one of the most popular and enduring in the world. His influence also extends to the realms of film and fashion, and he is recognized globally as a cultural icon. In 2000, he was inducted into the Basketball Hall of Fame.\"\n", + " content=\"Michael Jordan, often referred to by his initials MJ, is considered one of the greatest players in the history of the National Basketball Association (NBA). He was known for his scoring ability, defensive prowess, competitiveness, and clutch performances. Born on February 17, 1963, Jordan played 15 seasons in the NBA, primarily with the Chicago Bulls, but also with the Washington Wizards. His professional career spanned two decades from 1984 to 2003, during which he won numerous awards and set multiple records. Here are some key highlights of his career: - Scoring: Jordan won the NBA scoring title a record 10 times. He also has the highest career scoring average in NBA history, both in the regular season (30.12 points per game) and in the playoffs (33.45 points per game). - Championships: He led the Chicago Bulls to six NBA championships and was named Finals MVP in all six of those Finals (1991-1993, 1996-1998). - MVP Awards: Jordan was named the NBA's Most Valuable Player (MVP) five times (1988, 1991, 1992, 1996, 1998). - Defensive Ability: He was named to the NBA All-Defensive First Team nine times and won the NBA Defensive Player of the Year award in 1988. - Olympics: Jordan also won two Olympic gold medals with the U.S. basketball team, in 1984 and 1992. - Retirements and Comebacks: Jordan retired twice during his career. His first retirement came in 1993, after which he briefly played minor league baseball. He returned to the NBA in 1995. He retired a second time in 1999, only to return again in 2001, this time with the Washington Wizards. He played two seasons for the Wizards before retiring for good in 2003. After his playing career, Jordan became a team owner and executive. As of my knowledge cutoff in September 2021, he is the majority owner of the Charlotte Hornets. Off the court, Jordan is known for his lucrative endorsement deals, particularly with Nike. The Air Jordan line of sneakers is one of the most popular and enduring in the world. His influence also extends to the realms of film and fashion, and he is recognized globally as a cultural icon. In 2000, he was inducted into the Basketball Hall of Fame.\",\n", + " text_splitter=TokenTextSplitter(\n", + " chunk_size=100,\n", + " chunk_overlap=10,\n", + " ),\n", ")" ], "metadata": { @@ -110,9 +130,13 @@ "outputs": [], "source": [ "# create record 3 (Granite)\n", - "taskingai.retrieval.create_text_record(\n", + "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", - " text=\"Granite is a type of coarse-grained igneous rock composed primarily of quartz and feldspar, among other minerals. The term \\\"granitic\\\" means granite-like and is applied to granite and a group of intrusive igneous rocks. Description of Granite * Type: Igneous rock * Grain size: Coarse-grained * Composition: Mainly quartz, feldspar, and micas with minor amounts of amphibole minerals * Color: Typically appears in shades of white, pink, or gray, depending on their mineralogy * Crystalline Structure: Yes, due to slow cooling of magma beneath Earth's surface * Density: Approximately 2.63 to 2.75 g/cm³ * Hardness: 6-7 on the Mohs hardness scale Formation Process Granite is formed from the slow cooling of magma that is rich in silica and aluminum, deep beneath the earth's surface. Over time, the magma cools slowly, allowing large crystals to form and resulting in the coarse-grained texture that is characteristic of granite. Uses Granite is known for its durability and aesthetic appeal, making it a popular choice for construction and architectural applications. It's often used for countertops, flooring, monuments, and building materials. In addition, due to its hardness and toughness, it is used for cobblestones and in other paving applications. Geographical Distribution Granite is found worldwide, with significant deposits in regions such as the United States (especially in New Hampshire, which is also known as \\\"The Granite State\\\"), Canada, Brazil, Norway, India, and China. Varieties There are many varieties of granite, based on differences in color and mineral composition. Some examples include Bianco Romano, Black Galaxy, Blue Pearl, Santa Cecilia, and Ubatuba. Each variety has unique patterns, colors, and mineral compositions.\"\n", + " content=\"Granite is a type of coarse-grained igneous rock composed primarily of quartz and feldspar, among other minerals. The term \\\"granitic\\\" means granite-like and is applied to granite and a group of intrusive igneous rocks. Description of Granite * Type: Igneous rock * Grain size: Coarse-grained * Composition: Mainly quartz, feldspar, and micas with minor amounts of amphibole minerals * Color: Typically appears in shades of white, pink, or gray, depending on their mineralogy * Crystalline Structure: Yes, due to slow cooling of magma beneath Earth's surface * Density: Approximately 2.63 to 2.75 g/cm³ * Hardness: 6-7 on the Mohs hardness scale Formation Process Granite is formed from the slow cooling of magma that is rich in silica and aluminum, deep beneath the earth's surface. Over time, the magma cools slowly, allowing large crystals to form and resulting in the coarse-grained texture that is characteristic of granite. Uses Granite is known for its durability and aesthetic appeal, making it a popular choice for construction and architectural applications. It's often used for countertops, flooring, monuments, and building materials. In addition, due to its hardness and toughness, it is used for cobblestones and in other paving applications. Geographical Distribution Granite is found worldwide, with significant deposits in regions such as the United States (especially in New Hampshire, which is also known as \\\"The Granite State\\\"), Canada, Brazil, Norway, India, and China. Varieties There are many varieties of granite, based on differences in color and mineral composition. Some examples include Bianco Romano, Black Galaxy, Blue Pearl, Santa Cecilia, and Ubatuba. Each variety has unique patterns, colors, and mineral compositions.\",\n", + " text_splitter=TokenTextSplitter(\n", + " chunk_size=100,\n", + " chunk_overlap=10,\n", + " ),\n", ")" ], "metadata": { @@ -130,21 +154,6 @@ }, "id": "7538cb91a6439106" }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# Check collection status. \n", - "# Only when status is \"READY\" can you query chunks.\n", - "collection = taskingai.retrieval.get_collection(collection_id=collection.collection_id)\n", - "print(f\"collection status: {collection.status}\")" - ], - "metadata": { - "collapsed": false - }, - "id": "eb5dee18aa83c5e4" - }, { "cell_type": "code", "execution_count": null, @@ -154,7 +163,7 @@ "# Only when status is \"READY\", the record chunks can appear in query results.\n", "records = taskingai.retrieval.list_records(collection_id=collection.collection_id)\n", "for record in records:\n", - " content = record.content[\"text\"][:20]\n", + " content = record.content[:20]\n", " print(f\"record {record.record_id} content ({content}...) status: {record.status}\")" ], "metadata": { @@ -173,7 +182,7 @@ " query_text=\"Basketball\",\n", " top_k=2\n", ")\n", - "chunks" + "print(chunks)" ], "metadata": { "collapsed": false @@ -191,7 +200,7 @@ " query_text=\"geology\",\n", " top_k=2\n", ")\n", - "chunks" + "print(chunks)" ], "metadata": { "collapsed": false @@ -209,7 +218,7 @@ " query_text=\"what is machine learning\",\n", " top_k=2\n", ")\n", - "chunks" + "print(chunks)" ], "metadata": { "collapsed": false diff --git a/setup.py b/setup.py index 3efae3c..8c47448 100644 --- a/setup.py +++ b/setup.py @@ -1,22 +1,8 @@ -# coding: utf-8 - -""" - TaskingAI - - OpenAPI spec version: 0.1.0 -""" - -from setuptools import setup, find_packages # noqa: H301 +from setuptools import setup, find_packages import taskingai NAME = "taskingai" VERSION = taskingai.__version__ -# To install the library, run the following -# -# python setup.py install -# -# prerequisite: setuptools -# http://pypi.python.org/pypi/setuptools REQUIRES = [ "certifi>=14.05.14", @@ -27,17 +13,19 @@ "pydantic>=2.5.0", ] +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + setup( name=NAME, version=VERSION, description="TaskingAI", author_email="support@tasking.ai", - url="http://www.tasking.ai", + url="https://www.tasking.ai", keywords=["TaskingAI", "LLM", "AI"], install_requires=REQUIRES, - packages=find_packages(), + packages=find_packages(exclude=["test", "test.*"]), include_package_data=True, - long_description="""\ - No description provided - """ + long_description=long_description, + long_description_content_type="text/markdown", ) diff --git a/taskingai/_version.py b/taskingai/_version.py index 7f2e148..cf38cb1 100644 --- a/taskingai/_version.py +++ b/taskingai/_version.py @@ -1,3 +1,2 @@ __title__ = "taskingai" -__version__ = "0.1.0" - +__version__ = "0.1.3" diff --git a/taskingai/client/api/retrieval_api.py b/taskingai/client/api/retrieval_api.py index 20f4311..cd799ee 100644 --- a/taskingai/client/api/retrieval_api.py +++ b/taskingai/client/api/retrieval_api.py @@ -17,7 +17,6 @@ class RetrievalApi(object): - def __init__(self, api_client=None): if api_client is None: api_client = SyncApiClient() @@ -38,8 +37,8 @@ def create_collection(self, body, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.create_collection_with_http_info(body, **kwargs) # noqa: E501 else: (data) = self.create_collection_with_http_info(body, **kwargs) # noqa: E501 @@ -61,24 +60,20 @@ def create_collection_with_http_info(self, body, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['body'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method create_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method create_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `create_collection`") # noqa: E501 collection_formats = {} @@ -93,33 +88,35 @@ def create_collection_with_http_info(self, body, **kwargs): # noqa: E501 local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections', 'POST', + "/v1/collections", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionCreateResponse', # noqa: E501 + response_type="CollectionCreateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def create_record(self, body, collection_id, **kwargs): # noqa: E501 """Create record # noqa: E501 @@ -127,7 +124,7 @@ def create_record(self, body, collection_id, **kwargs): # noqa: E501 Create a new record in a collection. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_text_record(body, collection_id, async_req=True) + >>> thread = api.create_record(body, collection_id, async_req=True) >>> result = thread.get() :param async_req bool @@ -137,8 +134,8 @@ def create_record(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.create_record_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = self.create_record_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -161,35 +158,32 @@ def create_record_with_http_info(self, body, collection_id, **kwargs): # noqa: returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method create_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method create_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `create_record`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `create_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `create_record`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -199,33 +193,35 @@ def create_record_with_http_info(self, body, collection_id, **kwargs): # noqa: local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/records', 'POST', + "/v1/collections/{collection_id}/records", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordCreateResponse', # noqa: E501 + response_type="RecordCreateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def delete_collection(self, collection_id, **kwargs): # noqa: E501 """Delete collection # noqa: E501 @@ -242,8 +238,8 @@ def delete_collection(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.delete_collection_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = self.delete_collection_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -265,31 +261,29 @@ def delete_collection_with_http_info(self, collection_id, **kwargs): # noqa: E5 returns the request thread. """ - all_params = ['collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method delete_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method delete_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `delete_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `delete_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -300,26 +294,27 @@ def delete_collection_with_http_info(self, collection_id, **kwargs): # noqa: E5 body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}', 'DELETE', + "/v1/collections/{collection_id}", + "DELETE", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='DeleteCollectionResponse', # noqa: E501 + response_type="DeleteCollectionResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def delete_record(self, collection_id, record_id, **kwargs): # noqa: E501 """Delete record # noqa: E501 @@ -337,8 +332,8 @@ def delete_record(self, collection_id, record_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.delete_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 else: (data) = self.delete_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 @@ -361,37 +356,34 @@ def delete_record_with_http_info(self, collection_id, record_id, **kwargs): # n returns the request thread. """ - all_params = ['collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method delete_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method delete_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `delete_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `delete_record`" + ) # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `delete_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -402,26 +394,27 @@ def delete_record_with_http_info(self, collection_id, record_id, **kwargs): # n body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'DELETE', + "/v1/collections/{collection_id}/records/{record_id}", + "DELETE", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordDeleteResponse', # noqa: E501 + response_type="RecordDeleteResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def get_collection(self, collection_id, **kwargs): # noqa: E501 """Get collection # noqa: E501 @@ -438,8 +431,8 @@ def get_collection(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.get_collection_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = self.get_collection_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -461,31 +454,29 @@ def get_collection_with_http_info(self, collection_id, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method get_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method get_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `get_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `get_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -496,26 +487,27 @@ def get_collection_with_http_info(self, collection_id, **kwargs): # noqa: E501 body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}', 'GET', + "/v1/collections/{collection_id}", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionGetResponse', # noqa: E501 + response_type="CollectionGetResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def get_record(self, collection_id, record_id, **kwargs): # noqa: E501 """Get record # noqa: E501 @@ -533,8 +525,8 @@ def get_record(self, collection_id, record_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.get_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 else: (data) = self.get_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 @@ -557,37 +549,32 @@ def get_record_with_http_info(self, collection_id, record_id, **kwargs): # noqa returns the request thread. """ - all_params = ['collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method get_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method get_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `get_record`") # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `get_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -598,26 +585,27 @@ def get_record_with_http_info(self, collection_id, record_id, **kwargs): # noqa body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'GET', + "/v1/collections/{collection_id}/records/{record_id}", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordGetResponse', # noqa: E501 + response_type="RecordGetResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def list_collections(self, **kwargs): # noqa: E501 """List collections # noqa: E501 @@ -637,8 +625,8 @@ def list_collections(self, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.list_collections_with_http_info(**kwargs) # noqa: E501 else: (data) = self.list_collections_with_http_info(**kwargs) # noqa: E501 @@ -663,35 +651,32 @@ def list_collections_with_http_info(self, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['limit', 'order', 'after', 'before'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["limit", "order", "after", "before"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method list_collections" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method list_collections" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] collection_formats = {} path_params = {} query_params = [] - if 'limit' in params: - query_params.append(('limit', params['limit'])) # noqa: E501 - if 'order' in params: - query_params.append(('order', params['order'])) # noqa: E501 - if 'after' in params: - query_params.append(('after', params['after'])) # noqa: E501 - if 'before' in params: - query_params.append(('before', params['before'])) # noqa: E501 + if "limit" in params: + query_params.append(("limit", params["limit"])) # noqa: E501 + if "order" in params: + query_params.append(("order", params["order"])) # noqa: E501 + if "after" in params: + query_params.append(("after", params["after"])) # noqa: E501 + if "before" in params: + query_params.append(("before", params["before"])) # noqa: E501 header_params = {} @@ -700,26 +685,27 @@ def list_collections_with_http_info(self, **kwargs): # noqa: E501 body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections', 'GET', + "/v1/collections", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionListResponse', # noqa: E501 + response_type="CollectionListResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def list_records(self, collection_id, **kwargs): # noqa: E501 """List records # noqa: E501 @@ -740,8 +726,8 @@ def list_records(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.list_records_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = self.list_records_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -767,41 +753,37 @@ def list_records_with_http_info(self, collection_id, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['collection_id', 'limit', 'order', 'after', 'before'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "limit", "order", "after", "before"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method list_records" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method list_records" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `list_records`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] - if 'limit' in params: - query_params.append(('limit', params['limit'])) # noqa: E501 - if 'order' in params: - query_params.append(('order', params['order'])) # noqa: E501 - if 'after' in params: - query_params.append(('after', params['after'])) # noqa: E501 - if 'before' in params: - query_params.append(('before', params['before'])) # noqa: E501 + if "limit" in params: + query_params.append(("limit", params["limit"])) # noqa: E501 + if "order" in params: + query_params.append(("order", params["order"])) # noqa: E501 + if "after" in params: + query_params.append(("after", params["after"])) # noqa: E501 + if "before" in params: + query_params.append(("before", params["before"])) # noqa: E501 header_params = {} @@ -810,26 +792,27 @@ def list_records_with_http_info(self, collection_id, **kwargs): # noqa: E501 body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/records', 'GET', + "/v1/collections/{collection_id}/records", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordListResponse', # noqa: E501 + response_type="RecordListResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def query_chunks(self, body, collection_id, **kwargs): # noqa: E501 """Query chunks # noqa: E501 @@ -847,8 +830,8 @@ def query_chunks(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.query_chunks_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = self.query_chunks_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -871,35 +854,30 @@ def query_chunks_with_http_info(self, body, collection_id, **kwargs): # noqa: E returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method query_chunks" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method query_chunks" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `query_chunks`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `query_chunks`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -909,33 +887,35 @@ def query_chunks_with_http_info(self, body, collection_id, **kwargs): # noqa: E local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/chunks/query', 'POST', + "/v1/collections/{collection_id}/chunks/query", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='ChunkQueryResponse', # noqa: E501 + response_type="ChunkQueryResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def update_collection(self, body, collection_id, **kwargs): # noqa: E501 """Update collection # noqa: E501 @@ -953,8 +933,8 @@ def update_collection(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.update_collection_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = self.update_collection_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -977,35 +957,32 @@ def update_collection_with_http_info(self, body, collection_id, **kwargs): # no returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method update_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method update_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `update_collection`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `update_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `update_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -1015,33 +992,35 @@ def update_collection_with_http_info(self, body, collection_id, **kwargs): # no local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}', 'POST', + "/v1/collections/{collection_id}", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionUpdateResponse', # noqa: E501 + response_type="CollectionUpdateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) def update_record(self, body, collection_id, record_id, **kwargs): # noqa: E501 """Update record # noqa: E501 @@ -1060,8 +1039,8 @@ def update_record(self, body, collection_id, record_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return self.update_record_with_http_info(body, collection_id, record_id, **kwargs) # noqa: E501 else: (data) = self.update_record_with_http_info(body, collection_id, record_id, **kwargs) # noqa: E501 @@ -1085,41 +1064,37 @@ def update_record_with_http_info(self, body, collection_id, record_id, **kwargs) returns the request thread. """ - all_params = ['body', 'collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method update_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method update_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `update_record`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `update_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `update_record`" + ) # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `update_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -1129,30 +1104,32 @@ def update_record_with_http_info(self, body, collection_id, record_id, **kwargs) local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'POST', + "/v1/collections/{collection_id}/records/{record_id}", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordUpdateResponse', # noqa: E501 + response_type="RecordUpdateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) diff --git a/taskingai/client/api/retrieval_async_api.py b/taskingai/client/api/retrieval_async_api.py index 3e2e6d9..29ab676 100644 --- a/taskingai/client/api/retrieval_async_api.py +++ b/taskingai/client/api/retrieval_async_api.py @@ -17,7 +17,6 @@ class AsyncRetrievalApi(object): - def __init__(self, api_client=None): if api_client is None: api_client = AsyncApiClient() @@ -38,8 +37,8 @@ async def create_collection(self, body, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.create_collection_with_http_info(body, **kwargs) # noqa: E501 else: (data) = await self.create_collection_with_http_info(body, **kwargs) # noqa: E501 @@ -61,24 +60,20 @@ async def create_collection_with_http_info(self, body, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['body'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method create_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method create_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `create_collection`") # noqa: E501 collection_formats = {} @@ -93,33 +88,35 @@ async def create_collection_with_http_info(self, body, **kwargs): # noqa: E501 local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections', 'POST', + "/v1/collections", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionCreateResponse', # noqa: E501 + response_type="CollectionCreateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def create_record(self, body, collection_id, **kwargs): # noqa: E501 """Create record # noqa: E501 @@ -127,7 +124,7 @@ async def create_record(self, body, collection_id, **kwargs): # noqa: E501 Create a new record in a collection. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_text_record(body, collection_id, async_req=True) + >>> thread = api.create_record(body, collection_id, async_req=True) >>> result = thread.get() :param async_req bool @@ -137,8 +134,8 @@ async def create_record(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.create_record_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = await self.create_record_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -161,35 +158,32 @@ async def create_record_with_http_info(self, body, collection_id, **kwargs): # returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method create_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method create_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `create_record`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `create_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `create_record`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -199,33 +193,35 @@ async def create_record_with_http_info(self, body, collection_id, **kwargs): # local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/records', 'POST', + "/v1/collections/{collection_id}/records", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordCreateResponse', # noqa: E501 + response_type="RecordCreateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def delete_collection(self, collection_id, **kwargs): # noqa: E501 """Delete collection # noqa: E501 @@ -242,8 +238,8 @@ async def delete_collection(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.delete_collection_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = await self.delete_collection_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -265,31 +261,29 @@ async def delete_collection_with_http_info(self, collection_id, **kwargs): # no returns the request thread. """ - all_params = ['collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method delete_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method delete_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `delete_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `delete_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -300,26 +294,27 @@ async def delete_collection_with_http_info(self, collection_id, **kwargs): # no body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}', 'DELETE', + "/v1/collections/{collection_id}", + "DELETE", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='DeleteCollectionResponse', # noqa: E501 + response_type="DeleteCollectionResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def delete_record(self, collection_id, record_id, **kwargs): # noqa: E501 """Delete record # noqa: E501 @@ -337,8 +332,8 @@ async def delete_record(self, collection_id, record_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.delete_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 else: (data) = await self.delete_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 @@ -361,37 +356,34 @@ async def delete_record_with_http_info(self, collection_id, record_id, **kwargs) returns the request thread. """ - all_params = ['collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method delete_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method delete_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `delete_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `delete_record`" + ) # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `delete_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -402,26 +394,27 @@ async def delete_record_with_http_info(self, collection_id, record_id, **kwargs) body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'DELETE', + "/v1/collections/{collection_id}/records/{record_id}", + "DELETE", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordDeleteResponse', # noqa: E501 + response_type="RecordDeleteResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def get_collection(self, collection_id, **kwargs): # noqa: E501 """Get collection # noqa: E501 @@ -438,8 +431,8 @@ async def get_collection(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.get_collection_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = await self.get_collection_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -461,31 +454,29 @@ async def get_collection_with_http_info(self, collection_id, **kwargs): # noqa: returns the request thread. """ - all_params = ['collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method get_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method get_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `get_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `get_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -496,26 +487,27 @@ async def get_collection_with_http_info(self, collection_id, **kwargs): # noqa: body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}', 'GET', + "/v1/collections/{collection_id}", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionGetResponse', # noqa: E501 + response_type="CollectionGetResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def get_record(self, collection_id, record_id, **kwargs): # noqa: E501 """Get record # noqa: E501 @@ -533,8 +525,8 @@ async def get_record(self, collection_id, record_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.get_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 else: (data) = await self.get_record_with_http_info(collection_id, record_id, **kwargs) # noqa: E501 @@ -557,37 +549,32 @@ async def get_record_with_http_info(self, collection_id, record_id, **kwargs): returns the request thread. """ - all_params = ['collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method get_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method get_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `get_record`") # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `get_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -598,26 +585,27 @@ async def get_record_with_http_info(self, collection_id, record_id, **kwargs): body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'GET', + "/v1/collections/{collection_id}/records/{record_id}", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordGetResponse', # noqa: E501 + response_type="RecordGetResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def list_collections(self, **kwargs): # noqa: E501 """List collections # noqa: E501 @@ -637,8 +625,8 @@ async def list_collections(self, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.list_collections_with_http_info(**kwargs) # noqa: E501 else: (data) = await self.list_collections_with_http_info(**kwargs) # noqa: E501 @@ -663,35 +651,32 @@ async def list_collections_with_http_info(self, **kwargs): # noqa: E501 returns the request thread. """ - all_params = ['limit', 'order', 'after', 'before'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["limit", "order", "after", "before"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method list_collections" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method list_collections" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] collection_formats = {} path_params = {} query_params = [] - if 'limit' in params: - query_params.append(('limit', params['limit'])) # noqa: E501 - if 'order' in params: - query_params.append(('order', params['order'])) # noqa: E501 - if 'after' in params: - query_params.append(('after', params['after'])) # noqa: E501 - if 'before' in params: - query_params.append(('before', params['before'])) # noqa: E501 + if "limit" in params: + query_params.append(("limit", params["limit"])) # noqa: E501 + if "order" in params: + query_params.append(("order", params["order"])) # noqa: E501 + if "after" in params: + query_params.append(("after", params["after"])) # noqa: E501 + if "before" in params: + query_params.append(("before", params["before"])) # noqa: E501 header_params = {} @@ -700,26 +685,27 @@ async def list_collections_with_http_info(self, **kwargs): # noqa: E501 body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections', 'GET', + "/v1/collections", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionListResponse', # noqa: E501 + response_type="CollectionListResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def list_records(self, collection_id, **kwargs): # noqa: E501 """List records # noqa: E501 @@ -740,8 +726,8 @@ async def list_records(self, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.list_records_with_http_info(collection_id, **kwargs) # noqa: E501 else: (data) = await self.list_records_with_http_info(collection_id, **kwargs) # noqa: E501 @@ -767,41 +753,37 @@ async def list_records_with_http_info(self, collection_id, **kwargs): # noqa: E returns the request thread. """ - all_params = ['collection_id', 'limit', 'order', 'after', 'before'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["collection_id", "limit", "order", "after", "before"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method list_records" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method list_records" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `list_records`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] - if 'limit' in params: - query_params.append(('limit', params['limit'])) # noqa: E501 - if 'order' in params: - query_params.append(('order', params['order'])) # noqa: E501 - if 'after' in params: - query_params.append(('after', params['after'])) # noqa: E501 - if 'before' in params: - query_params.append(('before', params['before'])) # noqa: E501 + if "limit" in params: + query_params.append(("limit", params["limit"])) # noqa: E501 + if "order" in params: + query_params.append(("order", params["order"])) # noqa: E501 + if "after" in params: + query_params.append(("after", params["after"])) # noqa: E501 + if "before" in params: + query_params.append(("before", params["before"])) # noqa: E501 header_params = {} @@ -810,26 +792,27 @@ async def list_records_with_http_info(self, collection_id, **kwargs): # noqa: E body_params = None # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/records', 'GET', + "/v1/collections/{collection_id}/records", + "GET", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordListResponse', # noqa: E501 + response_type="RecordListResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def query_chunks(self, body, collection_id, **kwargs): # noqa: E501 """Query chunks # noqa: E501 @@ -847,8 +830,8 @@ async def query_chunks(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.query_chunks_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = await self.query_chunks_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -871,35 +854,30 @@ async def query_chunks_with_http_info(self, body, collection_id, **kwargs): # n returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method query_chunks" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method query_chunks" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `query_chunks`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): + if "collection_id" not in params or params["collection_id"] is None: raise ValueError("Missing the required parameter `collection_id` when calling `query_chunks`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -909,33 +887,35 @@ async def query_chunks_with_http_info(self, body, collection_id, **kwargs): # n local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/chunks/query', 'POST', + "/v1/collections/{collection_id}/chunks/query", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='ChunkQueryResponse', # noqa: E501 + response_type="ChunkQueryResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def update_collection(self, body, collection_id, **kwargs): # noqa: E501 """Update collection # noqa: E501 @@ -953,8 +933,8 @@ async def update_collection(self, body, collection_id, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.update_collection_with_http_info(body, collection_id, **kwargs) # noqa: E501 else: (data) = await self.update_collection_with_http_info(body, collection_id, **kwargs) # noqa: E501 @@ -977,35 +957,32 @@ async def update_collection_with_http_info(self, body, collection_id, **kwargs): returns the request thread. """ - all_params = ['body', 'collection_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method update_collection" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method update_collection" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `update_collection`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `update_collection`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `update_collection`" + ) # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 query_params = [] @@ -1015,33 +992,35 @@ async def update_collection_with_http_info(self, body, collection_id, **kwargs): local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}', 'POST', + "/v1/collections/{collection_id}", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='CollectionUpdateResponse', # noqa: E501 + response_type="CollectionUpdateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) async def update_record(self, body, collection_id, record_id, **kwargs): # noqa: E501 """Update record # noqa: E501 @@ -1060,8 +1039,8 @@ async def update_record(self, body, collection_id, record_id, **kwargs): # noqa If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): return await self.update_record_with_http_info(body, collection_id, record_id, **kwargs) # noqa: E501 else: (data) = await self.update_record_with_http_info(body, collection_id, record_id, **kwargs) # noqa: E501 @@ -1085,41 +1064,37 @@ async def update_record_with_http_info(self, body, collection_id, record_id, **k returns the request thread. """ - all_params = ['body', 'collection_id', 'record_id'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body", "collection_id", "record_id"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: - raise TypeError( - "Got an unexpected keyword argument '%s'" - " to method update_record" % key - ) + raise TypeError("Got an unexpected keyword argument '%s'" " to method update_record" % key) params[key] = val - del params['kwargs'] + del params["kwargs"] # verify the required parameter 'body' is set - if ('body' not in params or - params['body'] is None): + if "body" not in params or params["body"] is None: raise ValueError("Missing the required parameter `body` when calling `update_record`") # noqa: E501 # verify the required parameter 'collection_id' is set - if ('collection_id' not in params or - params['collection_id'] is None): - raise ValueError("Missing the required parameter `collection_id` when calling `update_record`") # noqa: E501 + if "collection_id" not in params or params["collection_id"] is None: + raise ValueError( + "Missing the required parameter `collection_id` when calling `update_record`" + ) # noqa: E501 # verify the required parameter 'record_id' is set - if ('record_id' not in params or - params['record_id'] is None): + if "record_id" not in params or params["record_id"] is None: raise ValueError("Missing the required parameter `record_id` when calling `update_record`") # noqa: E501 collection_formats = {} path_params = {} - if 'collection_id' in params: - path_params['collection_id'] = params['collection_id'] # noqa: E501 - if 'record_id' in params: - path_params['record_id'] = params['record_id'] # noqa: E501 + if "collection_id" in params: + path_params["collection_id"] = params["collection_id"] # noqa: E501 + if "record_id" in params: + path_params["record_id"] = params["record_id"] # noqa: E501 query_params = [] @@ -1129,30 +1104,32 @@ async def update_record_with_http_info(self, body, collection_id, record_id, **k local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['application/json']) # noqa: E501 + header_params["Content-Type"] = self.api_client.select_header_content_type( # noqa: E501 + ["application/json"] + ) # noqa: E501 # Authentication setting auth_settings = [] # noqa: E501 return await self.api_client.call_api( - '/v1/collections/{collection_id}/records/{record_id}', 'POST', + "/v1/collections/{collection_id}/records/{record_id}", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='RecordUpdateResponse', # noqa: E501 + response_type="RecordUpdateResponse", # noqa: E501 auth_settings=auth_settings, - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) diff --git a/taskingai/client/models/__init__.py b/taskingai/client/models/__init__.py index 0f381f4..6c7e63f 100644 --- a/taskingai/client/models/__init__.py +++ b/taskingai/client/models/__init__.py @@ -22,6 +22,7 @@ from taskingai.client.models.entity.retrieval.collection import * from taskingai.client.models.entity.retrieval.record import * from taskingai.client.models.entity.retrieval.chunk import * +from taskingai.client.models.entity.retrieval.text_splitter import * from taskingai.client.models.entity.inference.text_embedding import * from taskingai.client.models.entity.inference.chat_completion import * from taskingai.client.models.rest.action_bulk_create_request import ActionBulkCreateRequest diff --git a/taskingai/client/models/entity/assistant/message.py b/taskingai/client/models/entity/assistant/message.py index 4b72746..e90d1cd 100644 --- a/taskingai/client/models/entity/assistant/message.py +++ b/taskingai/client/models/entity/assistant/message.py @@ -1,14 +1,8 @@ from typing import Optional, Dict -from pydantic import Field, validator from .._base import TaskingaiBaseModel from enum import Enum -__all__ = [ - "Message", - "MessageContent", - "MessageRole", - "MessageChunk" -] +__all__ = ["Message", "MessageContent", "MessageRole", "MessageChunk"] class MessageRole(str, Enum): @@ -17,7 +11,7 @@ class MessageRole(str, Enum): class MessageContent(TaskingaiBaseModel): - text: Optional[str] + text: str class Message(TaskingaiBaseModel): diff --git a/taskingai/client/models/entity/retrieval/chunk.py b/taskingai/client/models/entity/retrieval/chunk.py index 40a9dcb..60767d6 100644 --- a/taskingai/client/models/entity/retrieval/chunk.py +++ b/taskingai/client/models/entity/retrieval/chunk.py @@ -4,10 +4,11 @@ "Chunk", ] + class Chunk(TaskingaiBaseModel): object: str chunk_id: str record_id: str collection_id: str - text: str + content: str score: float diff --git a/taskingai/client/models/entity/retrieval/collection.py b/taskingai/client/models/entity/retrieval/collection.py index dc961e9..712fc9b 100644 --- a/taskingai/client/models/entity/retrieval/collection.py +++ b/taskingai/client/models/entity/retrieval/collection.py @@ -1,17 +1,10 @@ -from typing import Dict, Optional -from pydantic import Field +from typing import Dict from .._base import TaskingaiBaseModel __all__ = [ - "CollectionConfig", "Collection", ] -class CollectionConfig(TaskingaiBaseModel): - chunk_size: int = Field(200, ge=100, le=500) - chunk_overlap: int = Field(0, ge=0, le=100) - metric: str = Field("cosine") - class Collection(TaskingaiBaseModel): object: str @@ -21,7 +14,6 @@ class Collection(TaskingaiBaseModel): capacity: int num_records: int num_chunks: int - configs: CollectionConfig embedding_model_id: str metadata: Dict[str, str] created_timestamp: int diff --git a/taskingai/client/models/entity/retrieval/record.py b/taskingai/client/models/entity/retrieval/record.py index 68e4a25..754c782 100644 --- a/taskingai/client/models/entity/retrieval/record.py +++ b/taskingai/client/models/entity/retrieval/record.py @@ -1,18 +1,18 @@ -from typing import Dict, Any +from typing import Dict from .._base import TaskingaiBaseModel -from pydantic import Field __all__ = [ "Record", ] + class Record(TaskingaiBaseModel): object: str record_id: str collection_id: str type: str num_chunks: int - content: Dict[str, Any] + content: str metadata: Dict[str, str] created_timestamp: int status: str diff --git a/taskingai/client/models/entity/retrieval/text_splitter.py b/taskingai/client/models/entity/retrieval/text_splitter.py new file mode 100644 index 0000000..7ae6847 --- /dev/null +++ b/taskingai/client/models/entity/retrieval/text_splitter.py @@ -0,0 +1,53 @@ +from enum import Enum +from .._base import TaskingaiBaseModel +from pydantic import Field, model_validator +from typing import Optional, Dict, Any + +__all__ = [ + "TextSplitter", + "TextSplitterType", + "TokenTextSplitter", + "build_text_splitter", +] + + +class TextSplitterType(str, Enum): + """TextSplitterType enum.""" + + TOKEN = "token" + + +class TextSplitter(TaskingaiBaseModel): + type: TextSplitterType = Field(...) + + +class TokenTextSplitter(TextSplitter): + type: TextSplitterType = Field(TextSplitterType.TOKEN, Literal=TextSplitterType.TOKEN) + chunk_size: int = Field(...) + chunk_overlap: int = Field(...) + + # check chunk_overlap <= chunk_size/2 + @model_validator(mode="after") + def validate_chunk_overlap(cls, data: Any): + if data.chunk_overlap > data.chunk_size / 2: + raise ValueError("chunk_overlap must be less than or equal to chunk_size/2") + return data + + +def build_text_splitter(data: Dict) -> Optional[TextSplitter]: + if not isinstance(data, Dict): + raise ValueError("Text splitter input data must be a valid dictionary") + + splitter_type = data.get("type") + if splitter_type is None: + return None + + # Depending on the type of splitter, initialize the appropriate splitter instance + if splitter_type == TextSplitterType.TOKEN.value: + chunk_size = data.get("chunk_size") + chunk_overlap = data.get("chunk_overlap") + return TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + else: + # If the splitter_type is unknown, return None + return None diff --git a/taskingai/client/models/entity/tool/action.py b/taskingai/client/models/entity/tool/action.py index ab042a4..e7a1e24 100644 --- a/taskingai/client/models/entity/tool/action.py +++ b/taskingai/client/models/entity/tool/action.py @@ -1,7 +1,6 @@ from typing import Dict, Any, Optional from enum import Enum from .._base import TaskingaiBaseModel -from pydantic import Field __all__ = [ @@ -29,10 +28,9 @@ class Action(TaskingaiBaseModel): action_id: str name: str description: str - schema: Dict[str, Any] = Field(..., alias='schema') + openapi_schema: Dict[str, Any] authentication: ActionAuthentication created_timestamp: int class Config: allow_population_by_field_name = True - diff --git a/taskingai/client/models/rest/action_bulk_create_request.py b/taskingai/client/models/rest/action_bulk_create_request.py index d9c998d..fdae1c7 100644 --- a/taskingai/client/models/rest/action_bulk_create_request.py +++ b/taskingai/client/models/rest/action_bulk_create_request.py @@ -11,11 +11,13 @@ import six + class ActionBulkCreateRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -23,49 +25,43 @@ class ActionBulkCreateRequest(object): attribute_map (dict): The key is attribute name and the value is json key in definition. """ - swagger_types = { - 'schema': 'object', - 'authentication': 'object' - } + swagger_types = {"openapi_schema": "object", "authentication": "object"} - attribute_map = { - 'schema': 'schema', - 'authentication': 'authentication' - } + attribute_map = {"openapi_schema": "openapi_schema", "authentication": "authentication"} - def __init__(self, schema=None, authentication=None): # noqa: E501 + def __init__(self, openapi_schema=None, authentication=None): # noqa: E501 """ActionBulkCreateRequest - a model defined in Swagger""" # noqa: E501 - self._schema = None + self._openapi_schema = None self._authentication = None self.discriminator = None - self.schema = schema + self.openapi_schema = openapi_schema if authentication is not None: self.authentication = authentication @property - def schema(self): - """Gets the schema of this ActionBulkCreateRequest. # noqa: E501 + def openapi_schema(self): + """Gets the openapi_schema of this ActionBulkCreateRequest. # noqa: E501 - The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the schema, the service will create multiple actions whose schema only has exactly one path and one method # noqa: E501 + The action openapi_schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method # noqa: E501 - :return: The schema of this ActionBulkCreateRequest. # noqa: E501 + :return: The openapi_schema of this ActionBulkCreateRequest. # noqa: E501 :rtype: object """ - return self._schema + return self._openapi_schema - @schema.setter - def schema(self, schema): - """Sets the schema of this ActionBulkCreateRequest. + @openapi_schema.setter + def openapi_schema(self, openapi_schema): + """Sets the openapi_schema of this ActionBulkCreateRequest. - The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the schema, the service will create multiple actions whose schema only has exactly one path and one method # noqa: E501 + The action openapi_schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method # noqa: E501 - :param schema: The schema of this ActionBulkCreateRequest. # noqa: E501 + :param openapi_schema: The openapi_schema of this ActionBulkCreateRequest. # noqa: E501 :type: object """ - if schema is None: - raise ValueError("Invalid value for `schema`, must not be `None`") # noqa: E501 + if openapi_schema is None: + raise ValueError("Invalid value for `openapi_schema`, must not be `None`") # noqa: E501 - self._schema = schema + self._openapi_schema = openapi_schema @property def authentication(self): @@ -97,18 +93,16 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) if hasattr(item[1], "to_dict") else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(ActionBulkCreateRequest, dict): diff --git a/taskingai/client/models/rest/action_update_request.py b/taskingai/client/models/rest/action_update_request.py index 5970549..afe8fb0 100644 --- a/taskingai/client/models/rest/action_update_request.py +++ b/taskingai/client/models/rest/action_update_request.py @@ -11,11 +11,13 @@ import six + class ActionUpdateRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -23,48 +25,42 @@ class ActionUpdateRequest(object): attribute_map (dict): The key is attribute name and the value is json key in definition. """ - swagger_types = { - 'schema': 'object', - 'authentication': 'object' - } + swagger_types = {"openapi_schema": "object", "authentication": "object"} - attribute_map = { - 'schema': 'schema', - 'authentication': 'authentication' - } + attribute_map = {"openapi_schema": "openapi_schema", "authentication": "authentication"} - def __init__(self, schema=None, authentication=None): # noqa: E501 + def __init__(self, openapi_schema=None, authentication=None): # noqa: E501 """ActionUpdateRequest - a model defined in Swagger""" # noqa: E501 - self._schema = None + self._openapi_schema = None self._authentication = None self.discriminator = None - if schema is not None: - self.schema = schema + if openapi_schema is not None: + self.openapi_schema = openapi_schema if authentication is not None: self.authentication = authentication @property - def schema(self): - """Gets the schema of this ActionUpdateRequest. # noqa: E501 + def openapi_schema(self): + """Gets the openapi_schema of this ActionUpdateRequest. # noqa: E501 - The action schema, which is compliant with the OpenAPI Specification. # noqa: E501 + The action openapi_schema, which is compliant with the OpenAPI Specification. # noqa: E501 - :return: The schema of this ActionUpdateRequest. # noqa: E501 + :return: The openapi_schema of this ActionUpdateRequest. # noqa: E501 :rtype: object """ - return self._schema + return self._openapi_schema - @schema.setter - def schema(self, schema): - """Sets the schema of this ActionUpdateRequest. + @openapi_schema.setter + def openapi_schema(self, openapi_schema): + """Sets the openapi_schema of this ActionUpdateRequest. - The action schema, which is compliant with the OpenAPI Specification. # noqa: E501 + The action openapi_schema, which is compliant with the OpenAPI Specification. # noqa: E501 - :param schema: The schema of this ActionUpdateRequest. # noqa: E501 + :param openapi_schema: The openapi_schema of this ActionUpdateRequest. # noqa: E501 :type: object """ - self._schema = schema + self._openapi_schema = openapi_schema @property def authentication(self): @@ -96,18 +92,16 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) if hasattr(item[1], "to_dict") else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(ActionUpdateRequest, dict): diff --git a/taskingai/client/models/rest/collection_create_request.py b/taskingai/client/models/rest/collection_create_request.py index 73aa8dd..3c3d206 100644 --- a/taskingai/client/models/rest/collection_create_request.py +++ b/taskingai/client/models/rest/collection_create_request.py @@ -11,11 +11,13 @@ import six + class CollectionCreateRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -24,30 +26,29 @@ class CollectionCreateRequest(object): and the value is json key in definition. """ swagger_types = { - 'capacity': 'object', - 'embedding_model_id': 'object', - 'name': 'object', - 'description': 'object', - 'configs': 'object', - 'metadata': 'object' + "capacity": "object", + "embedding_model_id": "object", + "name": "object", + "description": "object", + "metadata": "object", } attribute_map = { - 'capacity': 'capacity', - 'embedding_model_id': 'embedding_model_id', - 'name': 'name', - 'description': 'description', - 'configs': 'configs', - 'metadata': 'metadata' + "capacity": "capacity", + "embedding_model_id": "embedding_model_id", + "name": "name", + "description": "description", + "metadata": "metadata", } - def __init__(self, capacity=None, embedding_model_id=None, name=None, description=None, configs=None, metadata=None): # noqa: E501 + def __init__( + self, capacity=None, embedding_model_id=None, name=None, description=None, metadata=None + ): # noqa: E501 """CollectionCreateRequest - a model defined in Swagger""" # noqa: E501 self._capacity = None self._embedding_model_id = None self._name = None self._description = None - self._configs = None self._metadata = None self.discriminator = None if capacity is not None: @@ -57,8 +58,6 @@ def __init__(self, capacity=None, embedding_model_id=None, name=None, descriptio self.name = name if description is not None: self.description = description - if configs is not None: - self.configs = configs if metadata is not None: self.metadata = metadata @@ -156,29 +155,6 @@ def description(self, description): self._description = description - @property - def configs(self): - """Gets the configs of this CollectionCreateRequest. # noqa: E501 - - The collection configs indicating how the collection stores and indexes the text chunks. It cannot change after creation. # noqa: E501 - - :return: The configs of this CollectionCreateRequest. # noqa: E501 - :rtype: object - """ - return self._configs - - @configs.setter - def configs(self, configs): - """Sets the configs of this CollectionCreateRequest. - - The collection configs indicating how the collection stores and indexes the text chunks. It cannot change after creation. # noqa: E501 - - :param configs: The configs of this CollectionCreateRequest. # noqa: E501 - :type: object - """ - - self._configs = configs - @property def metadata(self): """Gets the metadata of this CollectionCreateRequest. # noqa: E501 @@ -209,18 +185,16 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) if hasattr(item[1], "to_dict") else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(CollectionCreateRequest, dict): diff --git a/taskingai/client/models/rest/record_create_request.py b/taskingai/client/models/rest/record_create_request.py index abd2429..a0a7477 100644 --- a/taskingai/client/models/rest/record_create_request.py +++ b/taskingai/client/models/rest/record_create_request.py @@ -11,11 +11,13 @@ import six + class RecordCreateRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -23,27 +25,20 @@ class RecordCreateRequest(object): attribute_map (dict): The key is attribute name and the value is json key in definition. """ - swagger_types = { - 'type': 'object', - 'text': 'object', - 'metadata': 'object' - } - - attribute_map = { - 'type': 'type', - 'text': 'text', - 'metadata': 'metadata' - } - - def __init__(self, type=None, text=None, metadata=None): # noqa: E501 + swagger_types = {"type": "object", "content": "object", "text_splitter": "object", "metadata": "object"} + + attribute_map = {"type": "type", "content": "content", "text_splitter": "text_splitter", "metadata": "metadata"} + + def __init__(self, type=None, content=None, text_splitter=None, metadata=None): # noqa: E501 """RecordCreateRequest - a model defined in Swagger""" # noqa: E501 self._type = None - self._text = None + self._content = None + self._text_splitter = None self._metadata = None self.discriminator = None self.type = type - if text is not None: - self.text = text + self.content = content + self.text_splitter = text_splitter if metadata is not None: self.metadata = metadata @@ -51,7 +46,7 @@ def __init__(self, type=None, text=None, metadata=None): # noqa: E501 def type(self): """Gets the type of this RecordCreateRequest. # noqa: E501 - The record type, which can be `text` or `file`. # noqa: E501 + The record type. :return: The type of this RecordCreateRequest. # noqa: E501 :rtype: object @@ -62,7 +57,7 @@ def type(self): def type(self, type): """Sets the type of this RecordCreateRequest. - The record type, which can be `text` or `file`. # noqa: E501 + The record type. :param type: The type of this RecordCreateRequest. # noqa: E501 :type: object @@ -73,27 +68,50 @@ def type(self, type): self._type = type @property - def text(self): - """Gets the text of this RecordCreateRequest. # noqa: E501 + def content(self): + """Gets the content of this RecordCreateRequest. # noqa: E501 + + The record content. + + :return: The content of this RecordCreateRequest. # noqa: E501 + :rtype: object + """ + return self._content + + @content.setter + def content(self, content): + """Sets the content of this RecordCreateRequest. + + The record content. + + :param content: The content of this RecordCreateRequest. # noqa: E501 + :type: object + """ + + self._content = content + + @property + def text_splitter(self): + """Gets the text_splitter of this RecordCreateRequest. # noqa: E501 - The record text. It's only valid when the record type is `text`. # noqa: E501 + The text splitter to split records into chunks. - :return: The text of this RecordCreateRequest. # noqa: E501 + :return: The text_splitter of this RecordCreateRequest. # noqa: E501 :rtype: object """ - return self._text + return self._text_splitter - @text.setter - def text(self, text): - """Sets the text of this RecordCreateRequest. + @text_splitter.setter + def text_splitter(self, text_splitter): + """Sets the text_splitter of this RecordCreateRequest. - The record text. It's only valid when the record type is `text`. # noqa: E501 + The text splitter to split records into chunks. - :param text: The text of this RecordCreateRequest. # noqa: E501 + :param text_splitter: The text_splitter of this RecordCreateRequest. # noqa: E501 :type: object """ - self._text = text + self._text_splitter = text_splitter @property def metadata(self): @@ -125,18 +143,16 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) if hasattr(item[1], "to_dict") else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(RecordCreateRequest, dict): diff --git a/taskingai/client/models/rest/record_update_request.py b/taskingai/client/models/rest/record_update_request.py index 7de4c05..1fa12b7 100644 --- a/taskingai/client/models/rest/record_update_request.py +++ b/taskingai/client/models/rest/record_update_request.py @@ -11,11 +11,13 @@ import six + class RecordUpdateRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -23,21 +25,97 @@ class RecordUpdateRequest(object): attribute_map (dict): The key is attribute name and the value is json key in definition. """ - swagger_types = { - 'metadata': 'object' - } + swagger_types = {"type": "object", "content": "object", "text_splitter": "object", "metadata": "object"} - attribute_map = { - 'metadata': 'metadata' - } + attribute_map = {"type": "type", "content": "content", "text_splitter": "text_splitter", "metadata": "metadata"} - def __init__(self, metadata=None): # noqa: E501 + def __init__(self, type=None, content=None, text_splitter=None, metadata=None): # noqa: E501 """RecordUpdateRequest - a model defined in Swagger""" # noqa: E501 + self._type = None + self._content = None + self._text_splitter = None self._metadata = None self.discriminator = None + if type is not None: + self.type = type + if content is not None: + self.content = content + if text_splitter is not None: + self.text_splitter = text_splitter if metadata is not None: self.metadata = metadata + @property + def type(self): + """Gets the type of this RecordCreateRequest. # noqa: E501 + + The record type. + + :return: The type of this RecordCreateRequest. # noqa: E501 + :rtype: object + """ + return self._type + + @type.setter + def type(self, type): + """Sets the type of this RecordCreateRequest. + + The record type. + + :param type: The type of this RecordCreateRequest. # noqa: E501 + :type: object + """ + if type is None: + raise ValueError("Invalid value for `type`, must not be `None`") # noqa: E501 + + self._type = type + + @property + def content(self): + """Gets the content of this RecordCreateRequest. # noqa: E501 + + The record content. + + :return: The content of this RecordCreateRequest. # noqa: E501 + :rtype: object + """ + return self._content + + @content.setter + def content(self, content): + """Sets the content of this RecordCreateRequest. + + The record content. + + :param content: The content of this RecordCreateRequest. # noqa: E501 + :type: object + """ + + self._content = content + + @property + def text_splitter(self): + """Gets the text_splitter of this RecordCreateRequest. # noqa: E501 + + The text splitter to split records into chunks. + + :return: The text_splitter of this RecordCreateRequest. # noqa: E501 + :rtype: object + """ + return self._text_splitter + + @text_splitter.setter + def text_splitter(self, text_splitter): + """Sets the text_splitter of this RecordCreateRequest. + + The text splitter to split records into chunks. + + :param text_splitter: The text_splitter of this RecordCreateRequest. # noqa: E501 + :type: object + """ + + self._text_splitter = text_splitter + @property def metadata(self): """Gets the metadata of this RecordUpdateRequest. # noqa: E501 @@ -66,18 +144,16 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) if hasattr(item[1], "to_dict") else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(RecordUpdateRequest, dict): diff --git a/taskingai/retrieval/__init__.py b/taskingai/retrieval/__init__.py index 60d53a8..685d834 100644 --- a/taskingai/retrieval/__init__.py +++ b/taskingai/retrieval/__init__.py @@ -1,3 +1,4 @@ from .chunk import * from .collection import * from .record import * +from .text_splitter import * diff --git a/taskingai/retrieval/collection.py b/taskingai/retrieval/collection.py index c8a78fe..230cc65 100644 --- a/taskingai/retrieval/collection.py +++ b/taskingai/retrieval/collection.py @@ -1,14 +1,18 @@ -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict from taskingai.client.utils import get_api_instance, ModuleType -from taskingai.client.models import Collection, CollectionConfig -from taskingai.client.models import CollectionCreateRequest, CollectionCreateResponse, \ - CollectionUpdateRequest, CollectionUpdateResponse, \ - CollectionGetResponse, CollectionListResponse +from taskingai.client.models import Collection +from taskingai.client.models import ( + CollectionCreateRequest, + CollectionCreateResponse, + CollectionUpdateRequest, + CollectionUpdateResponse, + CollectionGetResponse, + CollectionListResponse, +) __all__ = [ "Collection", - "CollectionConfig", "get_collection", "list_collections", "create_collection", @@ -23,10 +27,10 @@ def list_collections( - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Collection]: """ List collections. @@ -56,10 +60,10 @@ def list_collections( async def a_list_collections( - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Collection]: """ List collections. @@ -115,12 +119,11 @@ async def a_get_collection(collection_id: str) -> Collection: def create_collection( - embedding_model_id: str, - capacity: int = 1000, - name: Optional[str] = None, - description: Optional[str] = None, - configs: Optional[CollectionConfig] = None, - metadata: Optional[Dict[str, str]] = None, + embedding_model_id: str, + capacity: int = 1000, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Collection: """ Create a collection. @@ -129,7 +132,6 @@ def create_collection( :param capacity: The maximum number of embeddings that can be stored in the collection. :param name: The name of the collection. :param description: The description of the collection. - :param configs: The collection configurations. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created collection object. """ @@ -141,7 +143,6 @@ def create_collection( capacity=capacity, name=name, description=description, - configs=configs, metadata=metadata, ) response: CollectionCreateResponse = api_instance.create_collection(body=body) @@ -150,12 +151,11 @@ def create_collection( async def a_create_collection( - embedding_model_id: str, - capacity: int = 1000, - name: Optional[str] = None, - description: Optional[str] = None, - configs: Optional[CollectionConfig] = None, - metadata: Optional[Dict[str, str]] = None, + embedding_model_id: str, + capacity: int = 1000, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Collection: """ Create a collection in async mode. @@ -164,7 +164,6 @@ async def a_create_collection( :param capacity: The maximum number of embeddings that can be stored in the collection. :param name: The name of the collection. :param description: The description of the collection. - :param configs: The collection configurations. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created collection object. """ @@ -176,7 +175,6 @@ async def a_create_collection( capacity=capacity, name=name, description=description, - configs=configs, metadata=metadata, ) response: CollectionCreateResponse = await api_instance.create_collection(body=body) @@ -185,10 +183,10 @@ async def a_create_collection( def update_collection( - collection_id: str, - name: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + collection_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Collection: """ Update a collection. @@ -198,26 +196,23 @@ def update_collection( :param authentication: The collection API authentication. :return: The updated collection object. """ - #todo: verify at least one parameter is not None + # todo: verify at least one parameter is not None api_instance = get_api_instance(ModuleType.RETRIEVAL) body = CollectionUpdateRequest( name=name, description=description, metadata=metadata, ) - response: CollectionUpdateResponse = api_instance.update_collection( - collection_id=collection_id, - body=body - ) + response: CollectionUpdateResponse = api_instance.update_collection(collection_id=collection_id, body=body) collection: Collection = Collection(**response.data) return collection async def a_update_collection( - collection_id: str, - name: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + collection_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Collection: """ Update a collection in async mode. @@ -233,10 +228,7 @@ async def a_update_collection( description=description, metadata=metadata, ) - response: CollectionUpdateResponse = await api_instance.update_collection( - collection_id=collection_id, - body=body - ) + response: CollectionUpdateResponse = await api_instance.update_collection(collection_id=collection_id, body=body) collection: Collection = Collection(**response.data) return collection @@ -261,4 +253,3 @@ async def a_delete_collection(collection_id: str) -> None: api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) await api_instance.delete_collection(collection_id=collection_id) - diff --git a/taskingai/retrieval/record.py b/taskingai/retrieval/record.py index 0a395e9..2982ae7 100644 --- a/taskingai/retrieval/record.py +++ b/taskingai/retrieval/record.py @@ -1,32 +1,37 @@ from typing import Optional, List, Dict from taskingai.client.utils import get_api_instance, ModuleType -from taskingai.client.models import Record -from taskingai.client.models import RecordCreateRequest, RecordCreateResponse, \ - RecordUpdateRequest, RecordUpdateResponse, \ - RecordGetResponse, RecordListResponse +from taskingai.client.models import Record, TextSplitter +from taskingai.client.models import ( + RecordCreateRequest, + RecordCreateResponse, + RecordUpdateRequest, + RecordUpdateResponse, + RecordGetResponse, + RecordListResponse, +) __all__ = [ "Record", "get_record", "list_records", - "create_text_record", + "create_record", "update_record", "delete_record", "a_get_record", "a_list_records", - "a_create_text_record", + "a_create_record", "a_update_record", "a_delete_record", ] def list_records( - collection_id: str, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + collection_id: str, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Record]: """ List records. @@ -50,20 +55,17 @@ def list_records( "before": before, } params = {k: v for k, v in params.items() if v is not None} - response: RecordListResponse = api_instance.list_records( - collection_id=collection_id, - **params - ) + response: RecordListResponse = api_instance.list_records(collection_id=collection_id, **params) records: List[Record] = [Record(**item) for item in response.data] return records async def a_list_records( - collection_id: str, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + collection_id: str, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Record]: """ List records in async mode. @@ -87,15 +89,11 @@ async def a_list_records( "before": before, } params = {k: v for k, v in params.items() if v is not None} - response: RecordListResponse = await api_instance.list_records( - collection_id=collection_id, - **params - ) + response: RecordListResponse = await api_instance.list_records(collection_id=collection_id, **params) records: List[Record] = [Record(**item) for item in response.data] return records - def get_record(collection_id: str, record_id: str) -> Record: """ Get a record. @@ -130,17 +128,18 @@ async def a_get_record(collection_id: str, record_id: str) -> Record: return record -def create_text_record( +def create_record( collection_id: str, - # todo: support file - text: str, + content: str, + text_splitter: TextSplitter, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Create a record. :param collection_id: The ID of the collection. - :param text: The text content of the record. + :param content: The content of the record. + :param text_splitter: The text splitter to split records into chunks. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created record object. """ @@ -148,28 +147,27 @@ def create_text_record( api_instance = get_api_instance(ModuleType.RETRIEVAL) body = RecordCreateRequest( type="text", - text=text, + content=content, + text_splitter=text_splitter, metadata=metadata, ) - response: RecordCreateResponse = api_instance.create_record( - collection_id=collection_id, - body=body - ) + response: RecordCreateResponse = api_instance.create_record(collection_id=collection_id, body=body) record: Record = Record(**response.data) return record -async def a_create_text_record( +async def a_create_record( collection_id: str, - # todo: support file - text: str, + content: str, + text_splitter: TextSplitter, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Create a record in async mode. :param collection_id: The ID of the collection. - :param text: The text content of the record. + :param content: The content of the record. + :param text_splitter: The text splitter to split records into chunks. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created record object. """ @@ -177,13 +175,11 @@ async def a_create_text_record( api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) body = RecordCreateRequest( type="text", - text=text, + content=content, + text_splitter=text_splitter, metadata=metadata, ) - response: RecordCreateResponse = await api_instance.create_record( - collection_id=collection_id, - body=body - ) + response: RecordCreateResponse = await api_instance.create_record(collection_id=collection_id, body=body) record: Record = Record(**response.data) return record @@ -191,6 +187,8 @@ async def a_create_text_record( def update_record( collection_id: str, record_id: str, + content: Optional[str] = None, + text_splitter: Optional[TextSplitter] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ @@ -198,18 +196,25 @@ def update_record( :param collection_id: The ID of the collection. :param record_id: The ID of the record. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. + :param content: The content of the record. + :param text_splitter: The text splitter to split records into chunks. + :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less + than 64 and value's length is less than 512. :return: The collection object. """ api_instance = get_api_instance(ModuleType.RETRIEVAL) + type = None + if content and text_splitter: + type = "text" body = RecordUpdateRequest( + type=type, + content=content, + text_splitter=text_splitter, metadata=metadata, ) response: RecordUpdateResponse = api_instance.update_record( - collection_id=collection_id, - record_id=record_id, - body=body + collection_id=collection_id, record_id=record_id, body=body ) record: Record = Record(**response.data) return record @@ -218,6 +223,8 @@ def update_record( async def a_update_record( collection_id: str, record_id: str, + content: Optional[str] = None, + text_splitter: Optional[TextSplitter] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ @@ -225,18 +232,25 @@ async def a_update_record( :param collection_id: The ID of the collection. :param record_id: The ID of the record. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. + :param content: The content of the record. + :param text_splitter: The text splitter to split records into chunks. + :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less + than 64 and value's length is less than 512. :return: The collection object. """ api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) + type = None + if content and text_splitter: + type = "text" body = RecordUpdateRequest( + type=type, + content=content, + text_splitter=text_splitter, metadata=metadata, ) response: RecordUpdateResponse = await api_instance.update_record( - collection_id=collection_id, - record_id=record_id, - body=body + collection_id=collection_id, record_id=record_id, body=body ) record: Record = Record(**response.data) return record @@ -270,7 +284,3 @@ async def a_delete_record( api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) await api_instance.delete_record(collection_id=collection_id, record_id=record_id) - - - - diff --git a/taskingai/retrieval/text_splitter.py b/taskingai/retrieval/text_splitter.py new file mode 100644 index 0000000..588c60e --- /dev/null +++ b/taskingai/retrieval/text_splitter.py @@ -0,0 +1,7 @@ +from taskingai.client.models import TextSplitter, TextSplitterType, TokenTextSplitter + +__all__ = [ + "TextSplitter", + "TextSplitterType", + "TokenTextSplitter", +] diff --git a/taskingai/tool/action.py b/taskingai/tool/action.py index 09c4635..c6b5994 100644 --- a/taskingai/tool/action.py +++ b/taskingai/tool/action.py @@ -10,7 +10,7 @@ ActionGetResponse, ActionListResponse, ActionRunRequest, - ActionRunResponse + ActionRunResponse, ) __all__ = [ @@ -33,10 +33,10 @@ def list_actions( - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Action]: """ List actions. @@ -65,10 +65,10 @@ def list_actions( async def a_list_actions( - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, + order: str = "desc", + limit: int = 20, + after: Optional[str] = None, + before: Optional[str] = None, ) -> List[Action]: """ List actions in async mode. @@ -96,7 +96,6 @@ async def a_list_actions( return actions - def get_action(action_id: str) -> Action: """ Get an action. @@ -122,26 +121,24 @@ async def a_get_action(action_id: str) -> Action: action: Action = Action(**response.data) return action -def bulk_create_actions( - schema: Dict, - authentication: Optional[ActionAuthentication] = None -) -> List[Action]: + +def bulk_create_actions(openapi_schema: Dict, authentication: Optional[ActionAuthentication] = None) -> List[Action]: """ Create actions from an OpenAPI schema. - :param schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the schema, the service will create multiple actions whose schema only has exactly one path and one method + :param openapi_schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method :param authentication: The action API authentication. :return: The created action object. """ - # todo verify schema + # todo verify openapi_schema api_instance = get_api_instance(ModuleType.TOOL) if authentication is None: authentication = ActionAuthentication( type=ActionAuthenticationType.NONE, ) body = ActionBulkCreateRequest( - schema=schema, + openapi_schema=openapi_schema, authentication=authentication, ) response: ActionBulkCreateResponse = api_instance.bulk_create_action(body=body) @@ -150,25 +147,24 @@ def bulk_create_actions( async def a_bulk_create_actions( - schema: Dict, - authentication: Optional[ActionAuthentication] = None + openapi_schema: Dict, authentication: Optional[ActionAuthentication] = None ) -> List[Action]: """ Create actions from an OpenAPI schema in async mode. - :param schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the schema, the service will create multiple actions whose schema only has exactly one path and one method + :param openapi_schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method :param authentication: The action API authentication. :return: The created action object. """ - # todo verify schema + # todo verify openapi_schema api_instance = get_api_instance(ModuleType.TOOL, async_client=True) if authentication is None: authentication = ActionAuthentication( type=ActionAuthenticationType.NONE, ) body = ActionBulkCreateRequest( - schema=schema, + openapi_schema=openapi_schema, authentication=authentication, ) response: ActionBulkCreateResponse = await api_instance.bulk_create_action(body=body) @@ -178,53 +174,46 @@ async def a_bulk_create_actions( def update_action( action_id: str, - schema: Optional[Dict] = None, + openapi_schema: Optional[Dict] = None, authentication: Optional[ActionAuthentication] = None, ) -> Action: """ Update an action. :param action_id: The ID of the action. - :param schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. + :param openapi_schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. :param authentication: The action API authentication. :return: The updated action object. """ - #todo: verify schema api_instance = get_api_instance(ModuleType.TOOL) body = ActionUpdateRequest( - schema=schema, + openapi_schema=openapi_schema, authentication=authentication, ) - response: ActionUpdateResponse = api_instance.update_action( - action_id=action_id, - body=body - ) + response: ActionUpdateResponse = api_instance.update_action(action_id=action_id, body=body) action: Action = Action(**response.data) return action async def a_update_action( action_id: str, - schema: Optional[Dict] = None, + openapi_schema: Optional[Dict] = None, authentication: Optional[ActionAuthentication] = None, ) -> Action: """ Update an action in async mode. :param action_id: The ID of the action. - :param schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. + :param openapi_schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. :param authentication: The action API authentication. :return: The updated action object. """ api_instance = get_api_instance(ModuleType.TOOL, async_client=True) body = ActionUpdateRequest( - schema=schema, + openapi_schema=openapi_schema, authentication=authentication, ) - response: ActionUpdateResponse = await api_instance.update_action( - action_id=action_id, - body=body - ) + response: ActionUpdateResponse = await api_instance.update_action(action_id=action_id, body=body) action: Action = Action(**response.data) return action @@ -267,10 +256,7 @@ def run_action( body = ActionRunRequest( parameters=parameters, ) - response: ActionRunResponse = api_instance.run_action( - action_id=action_id, - body=body - ) + response: ActionRunResponse = api_instance.run_action(action_id=action_id, body=body) result = response.data return result @@ -291,10 +277,6 @@ async def a_run_action( body = ActionRunRequest( parameters=parameters, ) - response: ActionRunResponse = await api_instance.run_action( - action_id=action_id, - body=body - ) + response: ActionRunResponse = await api_instance.run_action(action_id=action_id, body=body) result = response.data return result - diff --git a/test/testcase/test_async/test_async_assistant.py b/test/testcase/test_async/test_async_assistant.py index c7bf613..7557fe4 100644 --- a/test/testcase/test_async/test_async_assistant.py +++ b/test/testcase/test_async/test_async_assistant.py @@ -1,9 +1,8 @@ -import asyncio import pytest from taskingai.assistant import * from taskingai.assistant.memory import AssistantNaiveMemory -from test.config import chat_completion_model_id, sleep_time, embedding_model_id +from test.config import chat_completion_model_id from test.common.read_data import data from test.common.logger import logger from test.common.utils import list_to_dict @@ -11,44 +10,59 @@ from test.testcase.test_async.base import Base from taskingai.tool import * from taskingai.retrieval import * +import re assistant_data = data.load_yaml("test_assistant_data.yml") @pytest.mark.test_async class TestAssistant(Base): - - assistant_list = ['assistant_id', 'created_timestamp', 'description', 'metadata', 'model_id', 'name', 'object', - 'retrievals', 'system_prompt_template', 'tools', "memory"] + assistant_list = [ + "assistant_id", + "created_timestamp", + "description", + "metadata", + "model_id", + "name", + "object", + "retrievals", + "system_prompt_template", + "tools", + "memory", + ] assistant_keys = set(assistant_list) @pytest.mark.parametrize("a_create_assistant_data", assistant_data["test_success_create_assistant"]) @pytest.mark.run(order=18) @pytest.mark.asyncio async def test_a_create_assistant(self, a_create_assistant_data): - # Create an assistant. assistant_dict = list_to_dict(a_create_assistant_data) assistant_dict.update({"model_id": chat_completion_model_id}) - if ("retrievals" in assistant_dict.keys() and len(assistant_dict["retrievals"]) > 0 and - assistant_dict["retrievals"][0]["type"] == "collection"): + if ( + "retrievals" in assistant_dict.keys() + and len(assistant_dict["retrievals"]) > 0 + and assistant_dict["retrievals"][0]["type"] == "collection" + ): assistant_dict["retrievals"][0]["id"] = Base.collection_id - if ("tools" in assistant_dict.keys() and len(assistant_dict["tools"]) > 0 and assistant_dict["tools"][0]["type"] - == "action"): - logger.info(f'a_create_assistant_action_id:{Base.action_id}') + if ( + "tools" in assistant_dict.keys() + and len(assistant_dict["tools"]) > 0 + and assistant_dict["tools"][0]["type"] == "action" + ): + logger.info(f"a_create_assistant_action_id:{Base.action_id}") assistant_dict["tools"][0]["id"] = Base.action_id assistant_dict.update({"memory": AssistantNaiveMemory()}) res = await a_create_assistant(**assistant_dict) res_dict = res.to_dict() - logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{res_dict}, except_dict:{assistant_dict}") pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant(res_dict, assistant_dict) @pytest.mark.run(order=19) @pytest.mark.asyncio async def test_a_list_assistants(self): - # List assistants. nums_limit = 1 @@ -73,7 +87,6 @@ async def test_a_list_assistants(self): @pytest.mark.run(order=20) @pytest.mark.asyncio async def test_a_get_assistant(self, a_assistant_id): - # Get an assistant. if not Base.assistant_id: @@ -85,7 +98,6 @@ async def test_a_get_assistant(self, a_assistant_id): @pytest.mark.run(order=21) @pytest.mark.asyncio async def test_a_update_assistant(self): - # Update an assistant. name = "openai" @@ -99,7 +111,6 @@ async def test_a_update_assistant(self): @pytest.mark.run(order=33) @pytest.mark.asyncio async def test_a_delete_assistant(self): - # List assistants. assistants = await a_list_assistants(limit=100) @@ -122,108 +133,106 @@ async def test_a_delete_assistant(self): @pytest.mark.test_async class TestChat(Base): + chat_list = ["assistant_id", "chat_id", "created_timestamp", "metadata", "object"] + chat_keys = set(chat_list) - chat_list = ['assistant_id', 'chat_id', 'created_timestamp', 'metadata', 'object'] - chat_keys = set(chat_list) - - @pytest.mark.run(order=22) - @pytest.mark.asyncio - async def test_a_create_chat(self): - - for x in range(2): - - # Create a chat. + @pytest.mark.run(order=22) + @pytest.mark.asyncio + async def test_a_create_chat(self): + for x in range(2): + # Create a chat. - res = await a_create_chat(assistant_id=self.assistant_id) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) + res = await a_create_chat(assistant_id=self.assistant_id) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) - @pytest.mark.run(order=23) - @pytest.mark.asyncio - async def test_a_list_chats(self): + @pytest.mark.run(order=23) + @pytest.mark.asyncio + async def test_a_list_chats(self): + # List chats. - # List chats. + nums_limit = 1 + res = await a_list_chats(limit=nums_limit, assistant_id=self.assistant_id) + pytest.assume(len(res) == nums_limit) - nums_limit = 1 - res = await a_list_chats(limit=nums_limit, assistant_id=self.assistant_id) - pytest.assume(len(res) == nums_limit) + after_id = res[-1].chat_id + after_res = await a_list_chats(limit=nums_limit, after=after_id, assistant_id=self.assistant_id) + pytest.assume(len(after_res) == nums_limit) - after_id = res[-1].chat_id - after_res = await a_list_chats(limit=nums_limit, after=after_id, assistant_id=self.assistant_id) - pytest.assume(len(after_res) == nums_limit) + twice_nums_list = await a_list_chats(limit=nums_limit * 2, assistant_id=self.assistant_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - twice_nums_list = await a_list_chats(limit=nums_limit * 2, assistant_id=self.assistant_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + before_id = after_res[0].chat_id + before_res = await a_list_chats(limit=nums_limit, before=before_id, assistant_id=self.assistant_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) - before_id = after_res[0].chat_id - before_res = await a_list_chats(limit=nums_limit, before=before_id, assistant_id=self.assistant_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=24) + @pytest.mark.asyncio + async def test_a_get_chat(self, a_chat_id): + # Get a chat. - @pytest.mark.run(order=24) - @pytest.mark.asyncio - async def test_a_get_chat(self, a_chat_id): + if not Base.chat_id: + Base.assistant_id, Base.chat_id = await a_chat_id + res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) - # Get a chat. + @pytest.mark.run(order=25) + @pytest.mark.asyncio + async def test_a_update_chat(self): + # Update a chat. - if not Base.chat_id: - Base.assistant_id, Base.chat_id = await a_chat_id - res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) + metadata = {"test": "test"} + res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["metadata"] == metadata) - @pytest.mark.run(order=25) - @pytest.mark.asyncio - async def test_a_update_chat(self): + @pytest.mark.run(order=32) + @pytest.mark.asyncio + async def test_a_delete_chat(self): + # List chats. - # Update a chat. + chats = await a_list_chats(assistant_id=self.assistant_id) + old_nums = len(chats) + for index, chat in enumerate(chats): + chat_id = chat.chat_id - metadata = {"test": "test"} - res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) - pytest.assume(res_dict["metadata"] == metadata) + # Delete a chat. - @pytest.mark.run(order=32) - @pytest.mark.asyncio - async def test_a_delete_chat(self): + await a_delete_chat(assistant_id=self.assistant_id, chat_id=str(chat_id)) # List chats. - chats = await a_list_chats(assistant_id=self.assistant_id) - old_nums = len(chats) - for index, chat in enumerate(chats): - chat_id = chat.chat_id - - # Delete a chat. - - await a_delete_chat(assistant_id=self.assistant_id, chat_id=str(chat_id)) - - # List chats. - - new_chats = await a_list_chats(assistant_id=self.assistant_id) - chat_ids = [i.chat_id for i in new_chats] - pytest.assume(chat_id not in chat_ids) - new_nums = len(new_chats) - pytest.assume(new_nums == old_nums - 1 - index) + new_chats = await a_list_chats(assistant_id=self.assistant_id) + chat_ids = [i.chat_id for i in new_chats] + pytest.assume(chat_id not in chat_ids) + new_nums = len(new_chats) + pytest.assume(new_nums == old_nums - 1 - index) @pytest.mark.test_async class TestMessage(Base): - - message_list = ['object', 'assistant_id', 'chat_id', 'message_id', 'role', 'content', 'metadata', - 'created_timestamp'] + message_list = [ + "object", + "assistant_id", + "chat_id", + "message_id", + "role", + "content", + "metadata", + "created_timestamp", + ] message_keys = set(message_list) @pytest.mark.run(order=26) @pytest.mark.asyncio async def test_a_create_message(self): - for x in range(2): - # Create a user message. text = f"hello test{x}" @@ -237,7 +246,6 @@ async def test_a_create_message(self): @pytest.mark.run(order=27) @pytest.mark.asyncio async def test_a_list_messages(self): - # List messages. nums_limit = 1 @@ -245,19 +253,22 @@ async def test_a_list_messages(self): pytest.assume(len(res) == nums_limit) after_id = res[-1].message_id - after_res = await a_list_messages(limit=nums_limit, after=after_id, assistant_id=self.assistant_id, - chat_id=self.chat_id) + after_res = await a_list_messages( + limit=nums_limit, after=after_id, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(after_res) == nums_limit) - twice_nums_list = await a_list_messages(limit=nums_limit * 2, assistant_id=self.assistant_id, - chat_id=self.chat_id) + twice_nums_list = await a_list_messages( + limit=nums_limit * 2, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(twice_nums_list) == nums_limit * 2) pytest.assume(after_res[-1] == twice_nums_list[-1]) pytest.assume(after_res[0] == twice_nums_list[nums_limit]) before_id = after_res[0].message_id - before_res = await a_list_messages(limit=nums_limit, before=before_id, assistant_id=self.assistant_id, - chat_id=self.chat_id) + before_res = await a_list_messages( + limit=nums_limit, before=before_id, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(before_res) == nums_limit) pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @@ -265,7 +276,6 @@ async def test_a_list_messages(self): @pytest.mark.run(order=28) @pytest.mark.asyncio async def test_a_get_message(self, a_message_id): - # Get a message. if not Base.message_id: @@ -277,12 +287,12 @@ async def test_a_get_message(self, a_message_id): @pytest.mark.run(order=29) @pytest.mark.asyncio async def test_a_update_message(self): - # Update a message. metadata = {"test": "test"} - res = await a_update_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id, - metadata=metadata) + res = await a_update_message( + assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id, metadata=metadata + ) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["metadata"] == metadata) @@ -290,11 +300,9 @@ async def test_a_update_message(self): @pytest.mark.run(order=30) @pytest.mark.asyncio async def test_a_generate_message(self): - # Generate an assistant message. - res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, - system_prompt_variables={}) + res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, system_prompt_variables={}) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["role"] == "assistant") @@ -302,38 +310,37 @@ async def test_a_generate_message(self): @pytest.mark.run(order=30) @pytest.mark.asyncio async def test_a_generate_message_by_stream(self): - # create chat chat_res = await a_create_chat(assistant_id=self.assistant_id) chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') + logger.info(f"chat_id:{chat_id}") # create user message user_message = await a_create_message( assistant_id=self.assistant_id, chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", + text="count from 1 to 10 and separate numbers by comma.", ) # Generate an assistant message by stream. - stream_res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}, stream=True) - except_list = [i + 1 for i in range(100)] - real_list = [] - real_str = '' + stream_res = await a_generate_message( + assistant_id=self.assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True + ) + except_list = [i + 1 for i in range(10)] + real_str = "" async for item in stream_res: if isinstance(item, MessageChunk): - logger.info(f"MessageChunk: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) + logger.debug(f"MessageChunk: {item.delta}") real_str += item.delta elif isinstance(item, Message): - logger.info(f"Message: {item.message_id}") + logger.debug(f"Message: {item.message_id}") pytest.assume(item.content is not None) - logger.info(f"Message: {real_str}") - logger.info(f"except_list: {except_list} real_list: {real_list}") + + real_list = [int(num) for num in re.findall(r"\b\d+\b", real_str)] + logger.debug(f"Message: {real_str}") + logger.debug(f"except_list: {except_list} real_list: {real_list}") pytest.assume(set(except_list) == set(real_list)) diff --git a/test/testcase/test_async/test_async_inference.py b/test/testcase/test_async/test_async_inference.py index 6f5cd66..249c4fd 100644 --- a/test/testcase/test_async/test_async_inference.py +++ b/test/testcase/test_async/test_async_inference.py @@ -3,15 +3,14 @@ from taskingai.inference import * from test.config import embedding_model_id, chat_completion_model_id from test.common.logger import logger +import re @pytest.mark.test_async class TestChatCompletion: - @pytest.mark.run(order=4) @pytest.mark.asyncio async def test_a_chat_completion(self): - # normal chat completion. normal_res = await a_chat_completion( @@ -19,7 +18,7 @@ async def test_a_chat_completion(self): messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), - ] + ], ) pytest.assume(normal_res.finish_reason == "stop") pytest.assume(normal_res.message.content) @@ -35,10 +34,7 @@ async def test_a_chat_completion(self): UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), - ] + ], ) pytest.assume(multi_round_res.finish_reason == "stop") @@ -55,13 +51,8 @@ async def test_a_chat_completion(self): UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), ], - configs={ - "max_tokens": 10 - } + configs={"max_tokens": 10}, ) pytest.assume(max_tokens_res.finish_reason == "length") pytest.assume(max_tokens_res.message.content) @@ -70,33 +61,34 @@ async def test_a_chat_completion(self): # chat completion with stream. - stream_res = await a_chat_completion(model_id=chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("count from 1 to 50 and separate numbers by comma."), - ], - stream=True - ) - except_list = [i + 1 for i in range(50)] - real_list = [] + stream_res = await a_chat_completion( + model_id=chat_completion_model_id, + messages=[ + SystemMessage("You are a professional assistant."), + UserMessage("count from 1 to 10 and separate numbers by comma."), + ], + stream=True, + ) + except_list = [i + 1 for i in range(10)] + real_str = "" async for item in stream_res: if isinstance(item, ChatCompletionChunk): logger.info(f"Message: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) + real_str += item.delta + elif isinstance(item, ChatCompletion): logger.info(f"Message: {item.finish_reason}") pytest.assume(item.finish_reason == "stop") + + real_list = [int(num) for num in re.findall(r"\b\d+\b", real_str)] pytest.assume(set(except_list) == set(real_list)) @pytest.mark.test_async class TestTextEmbedding: - @pytest.mark.run(order=0) @pytest.mark.asyncio async def test_a_text_embedding(self): - # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." diff --git a/test/testcase/test_async/test_async_retrieval.py b/test/testcase/test_async/test_async_retrieval.py index ab52f19..045b0d0 100644 --- a/test/testcase/test_async/test_async_retrieval.py +++ b/test/testcase/test_async/test_async_retrieval.py @@ -1,46 +1,51 @@ -import asyncio import pytest -from taskingai.retrieval import a_list_collections, a_create_collection, a_get_collection, a_update_collection, a_delete_collection, a_list_records, a_create_text_record, a_get_record, a_update_record, a_delete_record, a_query_chunks -from test.config import embedding_model_id, sleep_time +from taskingai.retrieval import * +from test.config import embedding_model_id from test.common.logger import logger from test.testcase.test_async.base import Base @pytest.mark.test_async class TestCollection(Base): - - collection_list = ['object', 'collection_id', 'name', 'description', 'num_records', 'num_chunks', 'capacity', - 'embedding_model_id', 'metadata', 'configs', 'created_timestamp', "status"] + collection_list = [ + "object", + "collection_id", + "name", + "description", + "num_records", + "num_chunks", + "capacity", + "embedding_model_id", + "metadata", + "created_timestamp", + "status", + ] collection_keys = set(collection_list) - collection_configs = ["metric", "chunk_size", "chunk_overlap"] - collection_configs_keys = set(collection_configs) @pytest.mark.run(order=9) @pytest.mark.asyncio async def test_a_create_collection(self): - for x in range(2): - # Create a collection. name = f"test{x}" description = "just for test" - res = await a_create_collection(name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000) + res = await a_create_collection( + name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000 + ) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["embedding_model_id"] == embedding_model_id) pytest.assume(res_dict["capacity"] == 1000) - pytest.assume(res_dict["status"] == "creating") + pytest.assume((res_dict["status"] == "ready") or (res_dict["status"] == "creating")) @pytest.mark.run(order=10) @pytest.mark.asyncio async def test_a_list_collections(self): - # List collections. nums_limit = 1 @@ -62,7 +67,6 @@ async def test_a_list_collections(self): @pytest.mark.run(order=11) @pytest.mark.asyncio async def test_a_get_collection(self, a_collection_id): - # Get a collection. if not Base.collection_id: @@ -70,13 +74,11 @@ async def test_a_get_collection(self, a_collection_id): res = await a_get_collection(collection_id=self.collection_id) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(res_dict["status"] == "ready" or "creating") + pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating") @pytest.mark.run(order=12) @pytest.mark.asyncio async def test_a_update_collection(self): - # Update a collection. name = "openai" @@ -84,7 +86,6 @@ async def test_a_update_collection(self): res = await a_update_collection(collection_id=self.collection_id, name=name, description=description) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["status"] == "ready") @@ -93,48 +94,59 @@ async def test_a_update_collection(self): @pytest.mark.asyncio async def test_a_delete_collection(self): # List collections. - old_res = await a_list_collections(order="desc", limit=100, after=None, before=None) + old_res = await a_list_collections(order="desc", limit=100, after=None, before=None) + old_nums = len(old_res) for index, collection in enumerate(old_res): collection_id = collection.collection_id - # Delete a collection. + # Delete a collection await a_delete_collection(collection_id=collection_id) - new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None) - # List collections. + new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None) + + # List collections collection_ids = [c.collection_id for c in new_collections] pytest.assume(collection_id not in collection_ids) + new_nums = len(new_collections) + pytest.assume(new_nums == old_nums - 1 - index) + @pytest.mark.test_async class TestRecord(Base): - - record_list = ['record_id', 'collection_id', 'num_chunks', 'content', 'metadata', 'type', 'object', - 'created_timestamp', 'status'] + record_list = [ + "record_id", + "collection_id", + "num_chunks", + "content", + "metadata", + "type", + "object", + "created_timestamp", + "status", + ] record_keys = set(record_list) - record_content = ["text"] - record_content_keys = set(record_content) - + @pytest.mark.run(order=13) @pytest.mark.asyncio - async def test_a_create_text_record(self): - + async def test_a_create_record(self): for x in range(2): - # Create a text record. text = f"{x}Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - res = await a_create_text_record(collection_id=self.collection_id, text=text) + res = await a_create_record( + collection_id=self.collection_id, + content=text, + text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20), + ) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["content"]["text"] == text) - pytest.assume(res_dict["status"] == "creating") + pytest.assume(res_dict["content"] == text) + pytest.assume((res_dict["status"] == "creating") or (res_dict["status"] == "ready")) @pytest.mark.run(order=14) @pytest.mark.asyncio async def test_a_list_records(self, a_record_id): - # List records. if not Base.record_id: @@ -161,39 +173,37 @@ async def test_a_list_records(self, a_record_id): @pytest.mark.run(order=15) @pytest.mark.asyncio async def test_a_get_record(self): - # Get a record. - res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id) - logger.info(f'a_get_record:{res}') - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["status"] == "ready" or "creating") + records = await a_list_records(collection_id=self.collection_id) + for record in records: + record_id = record.record_id + res = await a_get_record(collection_id=self.collection_id, record_id=record_id) + logger.info(f"a_get_record:{res}") + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.record_keys) + pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating") @pytest.mark.run(order=16) @pytest.mark.asyncio async def test_a_update_record(self): - # Update a record. metadata = {"test": "test"} res = await a_update_record(collection_id=self.collection_id, record_id=self.record_id, metadata=metadata) - logger.info(f'a_update_record:{res}') + logger.info(f"a_update_record:{res}") res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) pytest.assume(res_dict["metadata"] == metadata) - @pytest.mark.run(order=34) @pytest.mark.asyncio async def test_a_delete_record(self): - # List records. - records = await a_list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, - before=None) + records = await a_list_records( + collection_id=self.collection_id, order="desc", limit=20, after=None, before=None + ) old_nums = len(records) for index, record in enumerate(records): record_id = record.record_id @@ -204,8 +214,9 @@ async def test_a_delete_record(self): # List records. - new_records = await a_list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, - before=None) + new_records = await a_list_records( + collection_id=self.collection_id, order="desc", limit=20, after=None, before=None + ) record_ids = [record.record_id for record in new_records] pytest.assume(record_id not in record_ids) new_nums = len(new_records) @@ -214,14 +225,12 @@ async def test_a_delete_record(self): @pytest.mark.test_async class TestChunk(Base): - - chunk_list = ['chunk_id', 'collection_id', 'record_id', 'object', 'text', 'score'] + chunk_list = ["chunk_id", "collection_id", "record_id", "object", "content", "score"] chunk_keys = set(chunk_list) @pytest.mark.run(order=17) @pytest.mark.asyncio async def test_a_query_chunks(self): - # Query chunks. query_text = "Machine learning" @@ -231,6 +240,5 @@ async def test_a_query_chunks(self): for chunk in res: chunk_dict = chunk.to_dict() pytest.assume(chunk_dict.keys() == self.chunk_keys) - pytest.assume(query_text in chunk_dict["text"]) + pytest.assume(query_text in chunk_dict["content"]) pytest.assume(chunk_dict["score"] >= 0) - diff --git a/test/testcase/test_async/test_async_tool.py b/test/testcase/test_async/test_async_tool.py index 497410c..604d428 100644 --- a/test/testcase/test_async/test_async_tool.py +++ b/test/testcase/test_async/test_async_tool.py @@ -1,31 +1,39 @@ import pytest -import asyncio -from taskingai.tool import a_bulk_create_actions, a_get_action, a_update_action, a_delete_action, a_run_action, a_list_actions +from taskingai.tool import ( + a_bulk_create_actions, + a_get_action, + a_update_action, + a_delete_action, + a_run_action, + a_list_actions, +) from test.common.logger import logger from test.testcase.test_async.base import Base -from test.config import sleep_time @pytest.mark.test_async class TestAction(Base): - - action_list = ['object', 'action_id', 'name', 'description', 'authentication', 'schema', 'created_timestamp'] + action_list = [ + "object", + "action_id", + "name", + "description", + "authentication", + "openapi_schema", + "created_timestamp", + ] action_keys = set(action_list) - action_schema = ['openapi', 'info', 'servers', 'paths', 'components', 'security'] + action_schema = ["openapi", "info", "servers", "paths", "components", "security"] action_schema_keys = set(action_schema) - schema = { + openapi_schema = { "openapi": "3.1.0", "info": { "title": "Get weather data", "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" + "version": "v1.0.0", }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], + "servers": [{"url": "https://weather.example.com"}], "paths": { "/location": { "get": { @@ -37,76 +45,71 @@ class TestAction(Base): "in": "query", "description": "The city and state to retrieve the weather for", "required": True, - "schema": { - "type": "string" - } + "schema": {"type": "string"}, } ], - "deprecated": False + "deprecated": False, }, "post": { "description": "UPDATE temperature for a specific location", "operationId": "UpdateCurrentWeather", "requestBody": { - "required": True, - "content":{ - "application/json":{ - "schema":{ - "$ref":"#/componeents/schemas/ActionCreateRequest" - } - } - } + "required": True, + "content": { + "application/json": {"schema": {"$ref": "#/componeents/schemas/ActionCreateRequest"}} + }, }, - "deprecated": False - } + "deprecated": False, + }, } }, - "components": { - "schemas": {} - }, - "security": [] - } + "components": {"schemas": {}}, + "security": [], + } @pytest.mark.run(order=4) @pytest.mark.asyncio async def test_a_bulk_create_actions(self): - # Create an action. - res = await a_bulk_create_actions(schema=self.schema) + res = await a_bulk_create_actions(openapi_schema=self.openapi_schema) for action in res: action_dict = action.to_dict() logger.info(action_dict) pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) + pytest.assume(action_dict["openapi_schema"].keys() == self.action_schema_keys) - for key in action_dict["schema"].keys(): + for key in action_dict["openapi_schema"].keys(): if key == "paths": - if action_dict["schema"][key]["/location"] == "get": - pytest.assume(action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": - pytest.assume(action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"]["post"]) + if action_dict["openapi_schema"][key]["/location"] == "get": + pytest.assume( + action_dict["openapi_schema"][key]["/location"]["get"] + == self.openapi_schema["paths"]["/location"]["get"] + ) + elif action_dict["openapi_schema"][key]["/location"] == "post": + pytest.assume( + action_dict["openapi_schema"][key]["/location"]["post"] + == self.openapi_schema["paths"]["/location"]["post"] + ) else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) + pytest.assume(action_dict["openapi_schema"][key] == self.openapi_schema[key]) @pytest.mark.run(order=5) @pytest.mark.asyncio async def test_a_run_action(self, a_action_id): - # Run an action. if not Base.action_id: Base.action_id = await a_action_id - parameters = {"location": "beijing"} + parameters = {"location": "tokyo"} res = await a_run_action(action_id=self.action_id, parameters=parameters) - logger.info(f'async run action{res}') - pytest.assume(res['status'] == 400) + logger.info(f"async run action{res}") + pytest.assume(res["status"] != 200) pytest.assume(res["error"]) @pytest.mark.run(order=6) @pytest.mark.asyncio async def test_a_list_actions(self): - # List actions. nums_limit = 1 @@ -125,76 +128,64 @@ async def test_a_list_actions(self): before_id = after_res[0].action_id before_res = await a_list_actions(limit=nums_limit, before=before_id) pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit-1]) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @pytest.mark.run(order=7) @pytest.mark.asyncio async def test_a_get_action(self): - # Get an action. res = await a_get_action(action_id=self.action_id) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.action_keys) - pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) + pytest.assume(res_dict["openapi_schema"].keys() == self.action_schema_keys) @pytest.mark.run(order=39) @pytest.mark.asyncio async def test_a_update_action(self): - # Update an action. update_schema = { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location by get method", - "operationId": "GetCurrentWeather", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - } - - } - }, - "components": { - "schemas": {} - }, - "security": [] - } - - res = await a_update_action(action_id=self.action_id, schema=update_schema) + "openapi": "3.1.0", + "info": { + "title": "Get weather data", + "description": "Retrieves current weather data for a location.", + "version": "v1.0.0", + }, + "servers": [{"url": "https://weather.example.com"}], + "paths": { + "/location": { + "get": { + "description": "Get temperature for a specific location by get method", + "operationId": "GetCurrentWeather", + "parameters": [ + { + "name": "location", + "in": "query", + "description": "The city and state to retrieve the weather for", + "required": True, + "schema": {"type": "string"}, + } + ], + "deprecated": False, + } + } + }, + "components": {"schemas": {}}, + "security": [], + } + + res = await a_update_action(action_id=self.action_id, openapi_schema=update_schema) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.action_keys) - pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) - pytest.assume(res_dict["schema"] == update_schema) + pytest.assume(res_dict["openapi_schema"].keys() == self.action_schema_keys) + pytest.assume(res_dict["openapi_schema"] == update_schema) @pytest.mark.run(order=40) @pytest.mark.asyncio async def test_a_delete_action(self): - # List actions. actions = await a_list_actions(limit=100) @@ -212,7 +203,3 @@ async def test_a_delete_action(self): pytest.assume(action_id not in action_ids) new_nums = len(new_actions) pytest.assume(new_nums == old_nums - 1 - index) - - - - diff --git a/test/testcase/test_sync/test_sync_assistant.py b/test/testcase/test_sync/test_sync_assistant.py index 26de66d..854617a 100644 --- a/test/testcase/test_sync/test_sync_assistant.py +++ b/test/testcase/test_sync/test_sync_assistant.py @@ -1,15 +1,15 @@ import pytest -import time from taskingai.assistant import * from taskingai.retrieval import * from taskingai.tool import * from taskingai.assistant.memory import AssistantNaiveMemory -from test.config import chat_completion_model_id, embedding_model_id, sleep_time +from test.config import chat_completion_model_id from test.common.read_data import data from test.common.logger import logger from test.common.utils import list_to_dict from test.common.utils import assume_assistant +import re assistant_data = data.load_yaml("test_assistant_data.yml") @@ -17,32 +17,49 @@ @pytest.mark.test_sync class TestAssistant: - - assistant_list = ['assistant_id', 'created_timestamp', 'description', 'metadata', 'model_id', 'name', 'object', 'retrievals', 'system_prompt_template', 'tools',"memory"] + assistant_list = [ + "assistant_id", + "created_timestamp", + "description", + "metadata", + "model_id", + "name", + "object", + "retrievals", + "system_prompt_template", + "tools", + "memory", + ] assistant_keys = set(assistant_list) @pytest.mark.parametrize("create_assistant_data", assistant_data["test_success_create_assistant"]) @pytest.mark.run(order=18) def test_create_assistant(self, collection_id, action_id, create_assistant_data): - # Create an assistant. assistant_dict = list_to_dict(create_assistant_data) assistant_dict.update({"model_id": chat_completion_model_id}) - if "retrievals" in assistant_dict.keys() and len(assistant_dict["retrievals"]) > 0 and assistant_dict["retrievals"][0]["type"] == "collection": + if ( + "retrievals" in assistant_dict.keys() + and len(assistant_dict["retrievals"]) > 0 + and assistant_dict["retrievals"][0]["type"] == "collection" + ): assistant_dict["retrievals"][0]["id"] = collection_id - if "tools" in assistant_dict.keys() and len(assistant_dict["tools"]) > 0 and assistant_dict["tools"][0]["type"] == "action": + if ( + "tools" in assistant_dict.keys() + and len(assistant_dict["tools"]) > 0 + and assistant_dict["tools"][0]["type"] == "action" + ): assistant_dict["tools"][0]["id"] = action_id assistant_dict.update({"memory": AssistantNaiveMemory()}) res = create_assistant(**assistant_dict) res_dict = res.to_dict() - logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{res_dict}, except_dict:{assistant_dict}") pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant(res_dict, assistant_dict) @pytest.mark.run(order=19) def test_list_assistants(self): - # List assistants. nums_limit = 1 @@ -66,7 +83,6 @@ def test_list_assistants(self): @pytest.mark.run(order=20) def test_get_assistant(self, assistant_id): - # Get an assistant. res = get_assistant(assistant_id=assistant_id) @@ -75,7 +91,6 @@ def test_get_assistant(self, assistant_id): @pytest.mark.run(order=21) def test_update_assistant(self, assistant_id): - # Update an assistant. name = "openai" @@ -88,7 +103,6 @@ def test_update_assistant(self, assistant_id): @pytest.mark.run(order=33) def test_delete_assistant(self): - # List assistants. assistants = list_assistants(limit=100) @@ -111,99 +125,98 @@ def test_delete_assistant(self): @pytest.mark.test_sync class TestChat: + chat_list = ["assistant_id", "chat_id", "created_timestamp", "metadata", "object"] + chat_keys = set(chat_list) - chat_list = ['assistant_id', 'chat_id', 'created_timestamp', 'metadata', 'object'] - chat_keys = set(chat_list) - - @pytest.mark.run(order=22) - def test_create_chat(self, assistant_id): - - for x in range(2): - - # Create a chat. + @pytest.mark.run(order=22) + def test_create_chat(self, assistant_id): + for x in range(2): + # Create a chat. - res = create_chat(assistant_id=assistant_id) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) + res = create_chat(assistant_id=assistant_id) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) - @pytest.mark.run(order=23) - def test_list_chats(self, assistant_id): + @pytest.mark.run(order=23) + def test_list_chats(self, assistant_id): + # List chats. - # List chats. + nums_limit = 1 + res = list_chats(limit=nums_limit, assistant_id=assistant_id) + pytest.assume(len(res) == nums_limit) - nums_limit = 1 - res = list_chats(limit=nums_limit, assistant_id=assistant_id) - pytest.assume(len(res) == nums_limit) + after_id = res[-1].chat_id + after_res = list_chats(limit=nums_limit, after=after_id, assistant_id=assistant_id) + pytest.assume(len(after_res) == nums_limit) - after_id = res[-1].chat_id - after_res = list_chats(limit=nums_limit, after=after_id, assistant_id=assistant_id) - pytest.assume(len(after_res) == nums_limit) + twice_nums_list = list_chats(limit=nums_limit * 2, assistant_id=assistant_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - twice_nums_list = list_chats(limit=nums_limit * 2, assistant_id=assistant_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + before_id = after_res[0].chat_id + before_res = list_chats(limit=nums_limit, before=before_id, assistant_id=assistant_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) - before_id = after_res[0].chat_id - before_res = list_chats(limit=nums_limit, before=before_id, assistant_id=assistant_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=24) + def test_get_chat(self, assistant_id, chat_id): + # Get a chat. - @pytest.mark.run(order=24) - def test_get_chat(self, assistant_id, chat_id): + res = get_chat(assistant_id=assistant_id, chat_id=chat_id) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) - # Get a chat. + @pytest.mark.run(order=25) + def test_update_chat(self, assistant_id, chat_id): + # Update a chat. - res = get_chat(assistant_id=assistant_id, chat_id=chat_id) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) + metadata = {"test": "test"} + res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata) + res_dict = res.to_dict() + pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["metadata"] == metadata) - @pytest.mark.run(order=25) - def test_update_chat(self, assistant_id, chat_id): + @pytest.mark.run(order=32) + def test_delete_chat(self, assistant_id): + # List chats. - # Update a chat. + chats = list_chats(assistant_id=assistant_id) + old_nums = len(chats) + for index, chat in enumerate(chats): + chat_id = chat.chat_id - metadata = {"test": "test"} - res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata) - res_dict = res.to_dict() - pytest.assume(res_dict.keys() == self.chat_keys) - pytest.assume(res_dict["metadata"] == metadata) + # Delete a chat. - @pytest.mark.run(order=32) - def test_delete_chat(self, assistant_id): + delete_chat(assistant_id=assistant_id, chat_id=str(chat_id)) # List chats. - chats = list_chats(assistant_id=assistant_id) - old_nums = len(chats) - for index, chat in enumerate(chats): - chat_id = chat.chat_id - - # Delete a chat. - - delete_chat(assistant_id=assistant_id, chat_id=str(chat_id)) - - # List chats. - - new_chats = list_chats(assistant_id=assistant_id) - chat_ids = [i.chat_id for i in new_chats] - pytest.assume(chat_id not in chat_ids) - new_nums = len(new_chats) - pytest.assume(new_nums == old_nums - 1 - index) + new_chats = list_chats(assistant_id=assistant_id) + chat_ids = [i.chat_id for i in new_chats] + pytest.assume(chat_id not in chat_ids) + new_nums = len(new_chats) + pytest.assume(new_nums == old_nums - 1 - index) @pytest.mark.test_sync class TestMessage: - - message_list = ['object', 'assistant_id', 'chat_id', 'message_id', 'role', 'content', 'metadata', 'created_timestamp'] + message_list = [ + "object", + "assistant_id", + "chat_id", + "message_id", + "role", + "content", + "metadata", + "created_timestamp", + ] message_keys = set(message_list) @pytest.mark.run(order=26) def test_create_message(self, assistant_id, chat_id): - for x in range(2): - # Create a user message. text = "hello" @@ -216,31 +229,26 @@ def test_create_message(self, assistant_id, chat_id): @pytest.mark.run(order=27) def test_list_messages(self, assistant_id, chat_id): - # List messages. nums_limit = 1 res = list_messages(limit=nums_limit, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(res) == nums_limit) after_id = res[-1].message_id - after_res = list_messages(limit=nums_limit, after=after_id, assistant_id=assistant_id, - chat_id=chat_id) + after_res = list_messages(limit=nums_limit, after=after_id, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(after_res) == nums_limit) - twice_nums_list = list_messages(limit=nums_limit * 2, assistant_id=assistant_id, - chat_id=chat_id) + twice_nums_list = list_messages(limit=nums_limit * 2, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(twice_nums_list) == nums_limit * 2) pytest.assume(after_res[-1] == twice_nums_list[-1]) pytest.assume(after_res[0] == twice_nums_list[nums_limit]) before_id = after_res[0].message_id - before_res = list_messages(limit=nums_limit, before=before_id, assistant_id=assistant_id, - chat_id=chat_id) + before_res = list_messages(limit=nums_limit, before=before_id, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(before_res) == nums_limit) pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @pytest.mark.run(order=28) def test_get_message(self, assistant_id, chat_id, message_id): - # Get a message. res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=message_id) @@ -249,7 +257,6 @@ def test_get_message(self, assistant_id, chat_id, message_id): @pytest.mark.run(order=29) def test_update_message(self, assistant_id, chat_id, message_id): - # Update a message. metadata = {"test": "test"} @@ -260,7 +267,6 @@ def test_update_message(self, assistant_id, chat_id, message_id): @pytest.mark.run(order=30) def test_generate_message(self, assistant_id, chat_id): - # Generate an assistant message by no stream. res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}) @@ -270,7 +276,6 @@ def test_generate_message(self, assistant_id, chat_id): @pytest.mark.run(order=30) def test_generate_message_by_stream(self): - # List assistants. assistants = list_assistants() @@ -286,23 +291,26 @@ def test_generate_message_by_stream(self): # create user message user_message: Message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.") + assistant_id=assistant_id, chat_id=chat_id, text="count from 1 to 10 and separate numbers by comma." + ) # Generate an assistant message by stream. - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True) - except_list = [i + 1 for i in range(100)] - real_list = [] + stream_res = generate_message( + assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True + ) + except_list = [i + 1 for i in range(10)] + real_str = "" for item in stream_res: if isinstance(item, MessageChunk): - logger.info(f"MessageChunk: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) + logger.debug(f"MessageChunk: {item.delta}") + real_str += item.delta + elif isinstance(item, Message): - logger.info(f"Message: {item.message_id}") + logger.debug(f"Message: {item.message_id}") pytest.assume(item.content is not None) - logger.info(f"except_list: {except_list} real_list: {real_list}") - pytest.assume(set(except_list) == set(real_list)) + real_list = [int(num) for num in re.findall(r"\b\d+\b", real_str)] + logger.debug(f"Message: {real_str}") + logger.debug(f"except_list: {except_list} real_list: {real_list}") + pytest.assume(set(except_list) == set(real_list)) diff --git a/test/testcase/test_sync/test_sync_inference.py b/test/testcase/test_sync/test_sync_inference.py index 9b07021..55e42be 100644 --- a/test/testcase/test_sync/test_sync_inference.py +++ b/test/testcase/test_sync/test_sync_inference.py @@ -3,14 +3,13 @@ from taskingai.inference import * from test.config import embedding_model_id, chat_completion_model_id from test.common.logger import logger +import re @pytest.mark.test_sync class TestChatCompletion: - @pytest.mark.run(order=1) def test_chat_completion(self): - # normal chat completion. normal_res = chat_completion( @@ -18,7 +17,7 @@ def test_chat_completion(self): messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), - ] + ], ) pytest.assume(normal_res.finish_reason == "stop") pytest.assume(normal_res.message.content) @@ -34,10 +33,7 @@ def test_chat_completion(self): UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), - ] + ], ) pytest.assume(multi_round_res.finish_reason == "stop") @@ -54,13 +50,8 @@ def test_chat_completion(self): UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), ], - configs={ - "max_tokens": 10 - } + configs={"max_tokens": 10}, ) pytest.assume(max_tokens_res.finish_reason == "length") pytest.assume(max_tokens_res.message.content) @@ -69,33 +60,32 @@ def test_chat_completion(self): # chat completion with stream. - stream_res = chat_completion(model_id=chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("count from 1 to 50 and separate numbers by comma."), - ], - stream=True - ) - except_list = [i+1 for i in range(50)] - real_list = [] + stream_res = chat_completion( + model_id=chat_completion_model_id, + messages=[ + SystemMessage("You are a professional assistant."), + UserMessage("count from 1 to 10 and separate numbers by comma."), + ], + stream=True, + ) + except_list = [i + 1 for i in range(10)] + real_str = "" for item in stream_res: if isinstance(item, ChatCompletionChunk): logger.info(f"Message: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) + real_str += item.delta elif isinstance(item, ChatCompletion): logger.info(f"Message: {item.finish_reason}") pytest.assume(item.finish_reason == "stop") - logger.info(f"except_list: {except_list} real_list: {real_list}") + + real_list = [int(num) for num in re.findall(r"\b\d+\b", real_str)] pytest.assume(set(except_list) == set(real_list)) @pytest.mark.test_sync class TestTextEmbedding: - @pytest.mark.run(order=0) def test_text_embedding(self): - # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." diff --git a/test/testcase/test_sync/test_sync_retrieval.py b/test/testcase/test_sync/test_sync_retrieval.py index cc6035a..34fd3f6 100644 --- a/test/testcase/test_sync/test_sync_retrieval.py +++ b/test/testcase/test_sync/test_sync_retrieval.py @@ -1,42 +1,47 @@ -import time import pytest -from taskingai.retrieval import list_collections, create_collection, get_collection, update_collection, delete_collection, list_records, create_text_record, get_record, update_record, delete_record, query_chunks -from test.config import embedding_model_id, sleep_time +from taskingai.retrieval import * +from test.config import embedding_model_id from test.common.logger import logger @pytest.mark.test_sync class TestCollection: - - collection_list = ['object', 'collection_id', 'name', 'description', 'num_records', 'num_chunks', 'capacity', - 'embedding_model_id', 'metadata', 'configs', 'created_timestamp', "status"] + collection_list = [ + "object", + "collection_id", + "name", + "description", + "num_records", + "num_chunks", + "capacity", + "embedding_model_id", + "metadata", + "created_timestamp", + "status", + ] collection_keys = set(collection_list) - collection_configs = ["metric", "chunk_size", "chunk_overlap"] - collection_configs_keys = set(collection_configs) @pytest.mark.run(order=9) def test_create_collection(self): - # Create a collection. name = "test" description = "just for test" for x in range(2): - res = create_collection(name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000) + res = create_collection( + name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000 + ) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["embedding_model_id"] == embedding_model_id) pytest.assume(res_dict["capacity"] == 1000) - pytest.assume(res_dict["status"] == "creating") - + pytest.assume((res_dict["status"] == "ready") or (res_dict["status"] == "creating")) @pytest.mark.run(order=10) def test_list_collections(self): - # List collections. nums_limit = 1 @@ -57,18 +62,15 @@ def test_list_collections(self): @pytest.mark.run(order=11) def test_get_collection(self, collection_id): - # Get a collection. res = get_collection(collection_id=collection_id) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(res_dict["status"] == "ready") + pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating") @pytest.mark.run(order=12) def test_update_collection(self, collection_id): - # Update a collection. name = "openai" @@ -76,27 +78,24 @@ def test_update_collection(self, collection_id): res = update_collection(collection_id=collection_id, name=name, description=description) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.collection_keys) - pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["status"] == "ready") @pytest.mark.run(order=35) def test_delete_collection(self): - # List collections. - old_res = list_collections(order="desc", limit=100, after=None, before=None) + old_res = list_collections(order="desc", limit=100, after=None, before=None) old_nums = len(old_res) for index, collection in enumerate(old_res): collection_id = collection.collection_id # Delete a collection. - delete_collection(collection_id=collection_id) - new_collections = list_collections(order="desc", limit=100, after=None, before=None) + new_collections = list_collections(order="desc", limit=100, after=None, before=None) # List collections. @@ -108,30 +107,37 @@ def test_delete_collection(self): @pytest.mark.test_sync class TestRecord: - - record_list = ['record_id', 'collection_id', 'num_chunks', 'content', 'metadata', 'type', 'object', - 'created_timestamp', 'status'] + record_list = [ + "record_id", + "collection_id", + "num_chunks", + "content", + "metadata", + "type", + "object", + "created_timestamp", + "status", + ] record_keys = set(record_list) - record_content = ["text"] - record_content_keys = set(record_content) - - @pytest.mark.run(order=13) - def test_create_text_record(self, collection_id): + @pytest.mark.run(order=13) + def test_create_record(self, collection_id): # Create a text record. text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." for x in range(2): - res = create_text_record(collection_id=collection_id, text=text) + res = create_record( + collection_id=collection_id, + content=text, + text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20), + ) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["content"]["text"] == text) - pytest.assume(res_dict["status"] == "creating") + pytest.assume(res_dict["content"] == text) + pytest.assume((res_dict["status"] == "creating") or (res_dict["status"] == "ready")) @pytest.mark.run(order=14) def test_list_records(self, collection_id): - # List records. nums_limit = 1 @@ -155,39 +161,32 @@ def test_list_records(self, collection_id): @pytest.mark.run(order=15) def test_get_record(self, collection_id): - # list records records = list_records(collection_id=collection_id) for record in records: record_id = record.record_id res = get_record(collection_id=collection_id, record_id=record_id) - logger.info(f'get record response: {res}') + logger.info(f"get record response: {res}") res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["status"] == "creating" or "ready") + pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating") @pytest.mark.run(order=16) def test_update_record(self, collection_id, record_id): - # Update a record. metadata = {"test": "test"} res = update_record(collection_id=collection_id, record_id=record_id, metadata=metadata) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) - pytest.assume(res_dict["content"].keys() == self.record_content_keys) pytest.assume(res_dict["metadata"] == metadata) - @pytest.mark.run(order=34) def test_delete_record(self, collection_id): - # List records. - records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, - before=None) + records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) old_nums = len(records) for index, record in enumerate(records): record_id = record.record_id @@ -198,8 +197,7 @@ def test_delete_record(self, collection_id): # List records. - new_records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, - before=None) + new_records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) record_ids = [record.record_id for record in new_records] pytest.assume(record_id not in record_ids) new_nums = len(new_records) @@ -208,13 +206,11 @@ def test_delete_record(self, collection_id): @pytest.mark.test_sync class TestChunk: - - chunk_list = ['chunk_id', 'collection_id', 'record_id', 'object', 'text', 'score'] + chunk_list = ["chunk_id", "collection_id", "record_id", "object", "content", "score"] chunk_keys = set(chunk_list) @pytest.mark.run(order=17) def test_query_chunks(self, collection_id): - # Query chunks. query_text = "Machine learning" @@ -224,6 +220,5 @@ def test_query_chunks(self, collection_id): for chunk in res: chunk_dict = chunk.to_dict() pytest.assume(chunk_dict.keys() == self.chunk_keys) - pytest.assume(query_text in chunk_dict["text"]) + pytest.assume(query_text in chunk_dict["content"]) pytest.assume(chunk_dict["score"] >= 0) - diff --git a/test/testcase/test_sync/test_sync_tool.py b/test/testcase/test_sync/test_sync_tool.py index dce1db4..6550631 100644 --- a/test/testcase/test_sync/test_sync_tool.py +++ b/test/testcase/test_sync/test_sync_tool.py @@ -6,23 +6,26 @@ @pytest.mark.test_sync class TestAction: - - action_list = ['object', 'action_id', 'name', 'description', 'authentication', 'schema', 'created_timestamp'] + action_list = [ + "object", + "action_id", + "name", + "description", + "authentication", + "openapi_schema", + "created_timestamp", + ] action_keys = set(action_list) - action_schema = ['openapi', 'info', 'servers', 'paths', 'components', 'security'] + action_schema = ["openapi", "info", "servers", "paths", "components", "security"] action_schema_keys = set(action_schema) - schema = { + openapi_schema = { "openapi": "3.1.0", "info": { "title": "Get weather data", "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" + "version": "v1.0.0", }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], + "servers": [{"url": "https://weather.example.com"}], "paths": { "/location": { "get": { @@ -34,12 +37,10 @@ class TestAction: "in": "query", "description": "The city and state to retrieve the weather for", "required": True, - "schema": { - "type": "string" - } + "schema": {"type": "string"}, } ], - "deprecated": False + "deprecated": False, }, "post": { "description": "UPDATE temperature for a specific location", @@ -47,63 +48,55 @@ class TestAction: "requestBody": { "required": True, "content": { - "application/json": { - "schema": { - "$ref": "#/componeents/schemas/ActionCreateRequest" - } - } - } + "application/json": {"schema": {"$ref": "#/componeents/schemas/ActionCreateRequest"}} + }, }, - "deprecated": False - } + "deprecated": False, + }, } }, - "components": { - "schemas": {} - }, - "security": [] + "components": {"schemas": {}}, + "security": [], } @pytest.mark.run(order=4) def test_bulk_create_actions(self): - # Create an action. - res = bulk_create_actions(schema=self.schema) + res = bulk_create_actions(openapi_schema=self.openapi_schema) for action in res: action_dict = action.to_dict() logger.info(action_dict) pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) + pytest.assume(action_dict["openapi_schema"].keys() == self.action_schema_keys) - for key in action_dict["schema"].keys(): + for key in action_dict["openapi_schema"].keys(): if key == "paths": - if action_dict["schema"][key]["/location"] == "get": + if action_dict["openapi_schema"][key]["/location"] == "get": pytest.assume( - action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": + action_dict["openapi_schema"][key]["/location"]["get"] + == self.openapi_schema["paths"]["/location"]["get"] + ) + elif action_dict["openapi_schema"][key]["/location"] == "post": pytest.assume( - action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"][ - "post"]) + action_dict["openapi_schema"][key]["/location"]["post"] + == self.openapi_schema["paths"]["/location"]["post"] + ) else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) + pytest.assume(action_dict["openapi_schema"][key] == self.openapi_schema[key]) @pytest.mark.run(order=5) def test_run_action(self, action_id): - # Run an action. - parameters = { - "parameters": {"location": "tokyo"} - } + parameters = {"location": "tokyo"} res = run_action(action_id=action_id, parameters=parameters) - logger.info(f'async run action{res}') - pytest.assume(res['status'] == 400) + logger.info(f"async run action{res}") + pytest.assume(res["status"] != 200) pytest.assume(res["error"]) @pytest.mark.run(order=6) def test_list_actions(self): - # List actions. nums_limit = 1 @@ -131,18 +124,15 @@ def test_list_actions(self): @pytest.mark.run(order=7) def test_get_action(self, action_id): - # Get an action. res = get_action(action_id=action_id) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.action_keys) - logger.info(res_dict["schema"].keys()) - pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) + pytest.assume(res_dict["openapi_schema"].keys() == self.action_schema_keys) @pytest.mark.run(order=39) def test_update_action(self, action_id): - # Update an action. update_schema = { @@ -150,13 +140,9 @@ def test_update_action(self, action_id): "info": { "title": "Get weather data", "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" + "version": "v1.0.0", }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], + "servers": [{"url": "https://weather.example.com"}], "paths": { "/location": { "get": { @@ -168,31 +154,25 @@ def test_update_action(self, action_id): "in": "query", "description": "The city and state to retrieve the weather for", "required": True, - "schema": { - "type": "string" - } + "schema": {"type": "string"}, } ], - "deprecated": False + "deprecated": False, } - } }, - "components": { - "schemas": {} - }, - "security": [] + "components": {"schemas": {}}, + "security": [], } - res = update_action(action_id=action_id, schema=update_schema) + res = update_action(action_id=action_id, openapi_schema=update_schema) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.action_keys) - pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) - pytest.assume(res_dict["schema"] == update_schema) + pytest.assume(res_dict["openapi_schema"].keys() == self.action_schema_keys) + pytest.assume(res_dict["openapi_schema"] == update_schema) @pytest.mark.run(order=40) def test_delete_action(self): - # List actions. actions = list_actions(limit=100) @@ -209,7 +189,3 @@ def test_delete_action(self): pytest.assume(action_id not in action_ids) new_nums = len(new_actions) pytest.assume(new_nums == old_nums - 1 - index) - - - -