Skip to content
27 changes: 23 additions & 4 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,29 @@ def _init_chat_model_helper(

return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
try:
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_huggingface.llms import HuggingFacePipeline
except ImportError as e:
import_error_msg = "Please install langchain-huggingface to use HuggingFace models."
raise ImportError(import_error_msg) from e

# The 'task' kwarg is required by from_model_id but not the base constructor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is only passed to HuggingFacePipeline.from_model_id here, no?

# We pop it from kwargs to avoid the Pydantic 'extra_forbidden' error.
task = kwargs.pop("task", None)
if not task:
task_error_msg = "The 'task' keyword argument is required for HuggingFace models."
raise ValueError(task_error_msg)

# Initialize the base LLM pipeline with the model and arguments
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task=task,
**kwargs, # Pass remaining kwargs like `device`
)

return ChatHuggingFace(model_id=model, **kwargs)
# Pass the initialized LLM to the chat wrapper
return ChatHuggingFace(llm=llm)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
Expand Down Expand Up @@ -957,4 +976,4 @@ def with_structured_output(
schema: Union[dict, type[BaseModel]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)
return self.__getattr__("with_structured_output")(schema, **kwargs)
14 changes: 14 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langchain_huggingface.chat_models import ChatHuggingFace
from pydantic import SecretStr

from langchain.chat_models.base import __all__, init_chat_model
Expand Down Expand Up @@ -289,3 +290,16 @@ def test_configurable_with_default() -> None:
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)


def test_init_chat_model_huggingface() -> None:
"""Test that init_chat_model works with huggingface."""
model_name = "google-bert/bert-base-uncased"

llm = init_chat_model(
model=model_name,
model_provider="huggingface",
task="text-generation",
)
assert isinstance(llm, ChatHuggingFace)
assert llm.llm.model_id == model_name