Skip to content

Fix/module error with openai package #1102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
pull_request:
branches:
- "release/*"
- "patch/*"
- "main"

jobs:
Expand Down
28 changes: 15 additions & 13 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel
from langtest.utils.custom_types.sample import NERSample
from langtest.tasks import TaskManager
from ..utils.lib_manager import try_import_lib
from ..errors import Errors


Expand Down Expand Up @@ -358,6 +357,9 @@ def __init__(
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except ModuleNotFoundError:
raise ImportError(Errors.E097())

except Exception as e_msg:
raise Errors.E095(e=e_msg)

Expand Down Expand Up @@ -606,19 +608,19 @@ def __generate_templates(
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)
"""This method is used to generate extra templates from a given template."""
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

params = model_config.copy() if model_config else {}
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)
if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
else:
return generate_templates_openai(template, num_extra_templates)
6 changes: 3 additions & 3 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict):
class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
azure_ad_token_provider = (None,)
organization: Union[str, None] = (None,)


Expand Down Expand Up @@ -76,6 +74,7 @@ def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""

import openai

if "provider" in model_config:
Expand Down Expand Up @@ -139,6 +138,7 @@ def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""

import openai

if "provider" in model_config:
Expand Down
1 change: 1 addition & 0 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes):
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {msg}")
E097 = ("Failed to load openai. Please install it using `pip install openai`")


class ColumnNameError(Exception):
Expand Down
Loading