Skip to content

Add KerasHub Support for Local Model Inference #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link

@divyashreepathihalli divyashreepathihalli commented Aug 6, 2025

Overview
This PR implements KerasHub integration into Google ADK, providing local model inference capabilities alongside the existing LiteLLM remote model support. This enables users to run AI agents completely offline without requiring external API calls.
Key Features

  • Local Model Support
  • 50+ models supported across multiple families (GPT-2, OPT, BLOOM, LLaMA, Gemma, BERT, T5, etc.)
  • No external API calls during inference - runs entirely on local hardware
  • Privacy-focused - all computation happens locally
  • Cost-effective - no per-token charges

Technical Implementation

  • New KerasLlm wrapper class that subclasses BaseLlm
  • Automatic model loading using KerasHub's CausalLM.from_preset()
  • Configurable sampling strategies (greedy, top_k, top_p, temperature)
  • Conversation flattening to text prompts for local models
  • Async generation with thread-based execution
  • Integration with ADK's LLMRegistry for string-based model resolution

User Experience

  • Consistent Interface: Same API as existing LiteLlm wrapper
  • String-based Resolution: Agent(model="keras/gpt2_base_en", ...)
  • Flexible Configuration: Multiple models and sampling strategies
  • Environment Support: Respects KERAS_BACKEND environment variable

Examples

from google.adk.agents import Agent
from google.adk.models.keras_llm import KerasLlm

agent = Agent(
    model=KerasLlm(model="gpt2_base_en"),
    name="local_agent",
    instruction="You are a helpful assistant."
)
agent = Agent(
    model="keras/gpt2_base_en",
    name="local_agent",
    instruction="You are a helpful assistant."
)
agent = Agent(
    model=KerasLlm(
        model="llama_2_7b_en",
        max_length=200,
        temperature=0.7,
        sampler="top_p",
        top_p=0.9
    ),
    name="high_quality_agent",
    instruction="You are a highly capable AI assistant."
)

Issue : #2391

@adk-bot
Copy link
Collaborator

adk-bot commented Aug 6, 2025

Response from ADK Triaging Agent

Hello @divyashreepathihalli, thank you for your contribution! This is an exciting new feature.

To help us review this PR, could you please address the following points from our contribution guidelines:

  • Single Commit: Please squash the multiple commits into a single one.
  • Associated Issue: For a feature of this size, an associated GitHub issue is required. If one doesn't exist, please create one and link it to this PR.
  • Testing Plan: Could you please add a testing plan section to your PR description to detail how you've tested these changes?

You can find more details in our CONTRIBUTING.md.

Thank you!

@adk-bot adk-bot added bot triaged [Bot] This issue is triaged by ADK bot models [Component] Issues related to model support labels Aug 6, 2025
@adk-bot adk-bot requested a review from genquan9 August 6, 2025 20:09
@zeroasterisk
Copy link

zeroasterisk commented Aug 7, 2025

LGTM from my perspective. Need official reviewer. After merge, will need to add to docs.

Also need to decide if we are porting to other languages or not - makes the feature matrix difficult to maintain parity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bot triaged [Bot] This issue is triaged by ADK bot models [Component] Issues related to model support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants