Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion extension/llm/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@ list(TRANSFORM _extension_llm_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")

add_library(extension_llm_runner STATIC ${_extension_llm_runner__srcs})

set(runner_deps executorch_core extension_module extension_tensor tokenizers)
set(runner_deps executorch_core extension_module extension_tensor
tokenizers::tokenizers
)

# depend on arange_utils
if(NOT TARGET kernels_util_all_deps)
add_subdirectory(
${EXECUTORCH_ROOT}/kernels/portable/cpu/util
${CMAKE_CURRENT_BINARY_DIR}/kernels_util
)
endif()
list(APPEND runner_deps kernels_util_all_deps)

target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
set_target_properties(
Expand Down
105 changes: 105 additions & 0 deletions extension/llm/runner/multimodal_decoder_runner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/runner/constants.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>

namespace executorch::extension::llm {

class ET_EXPERIMENTAL MultimodalDecoderRunner
: public executorch::extension::llm::TextDecoderRunner {
public:
explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager)
: TextDecoderRunner(module, io_manager) {}

/**
* Step the LLM Decoder with the given tokens and start position.
* @param tokens The tokens to the LLM.
* @param start_pos The start position of the tokens.
* @return The logits tensor.
*/
inline executorch::runtime::Result<executorch::aten::Tensor> step(
executorch::extension::TensorPtr& tokens,
int64_t start_pos) override {
// run token embedding
auto token_embedding_outputs =
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));

// Return the logits tensor
return decode(token_embedding_outputs[0], start_pos);
}

/**
* Decode the embeddings to logits.
* @param embeddings The embeddings tensor.
* @param start_pos The start position of the embeddings.
* @return The logits tensor.
*/
inline executorch::runtime::Result<executorch::aten::Tensor> decode(
const runtime::EValue& embeddings,
int64_t start_pos) {
auto start_pos_tensor = ::executorch::extension::from_blob(
&start_pos, {1}, executorch::aten::ScalarType::Long);
// run text model
auto outputs_res = ET_UNWRAP(
module_->execute(kTextModelMethod, {start_pos_tensor, embeddings}));

ET_CHECK_MSG(
outputs_res.size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res[0].toTensor();
}

/**
* Load the Module for text decode purpose.
* @return The error code.
*/
inline executorch::runtime::Error load() override {
if (is_method_loaded()) {
return executorch::runtime::Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod));
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
return executorch::runtime::Error::Ok;
}

/**
* Check if the required methods in the Module is loaded.
* @return True if the Module is loaded, false otherwise.
*/
inline bool is_method_loaded() override {
executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
module_->method_names();
if (methods_res.error() != executorch::runtime::Error::Ok) {
ET_CHECK_MSG(false, "Failed to get method names");
}
std::unordered_set<std::string> methods = methods_res.get();
bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() &&
methods.find(kTextModelMethod) != methods.end();
if (!methods_exist) {
for (const auto& method : methods) {
ET_LOG(Error, "Method: %s", method.c_str());
}
ET_CHECK_MSG(
methods_exist,
"Missing required methods (%s, %s) in the model",
kTokenEmbeddingMethod,
kTextModelMethod);
}
bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) &&
module_->is_method_loaded(kTextModelMethod);
return methods_loaded;
}
};

} // namespace executorch::extension::llm
186 changes: 186 additions & 0 deletions extension/llm/runner/multimodal_input.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
// A generic multimodal input class that can hold either image or text data.

#pragma once

#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/compiler.h>
#include <string>
#include <variant>

namespace executorch {
namespace extension {
namespace llm {

/**
* A generic class to hold either image or text data for multimodal inputs.
* This allows the generate() API to take a std::vector of these objects
* instead of separate image and text parameters.
*/
class ET_EXPERIMENTAL MultimodalInput {
public:
enum class Type { TEXT, IMAGE };

// Constructors
explicit MultimodalInput(const std::string& text) : data_(text) {}
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
explicit MultimodalInput(const Image& image) : data_(image) {}
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}

// Copy constructor and assignment
MultimodalInput(const MultimodalInput& other) = default;
MultimodalInput& operator=(const MultimodalInput& other) = default;

// Move constructor and assignment
MultimodalInput(MultimodalInput&& other) noexcept = default;
MultimodalInput& operator=(MultimodalInput&& other) noexcept = default;

// Destructor
~MultimodalInput() = default;

/**
* Check if this input contains text data.
* @return true if this input contains text, false otherwise.
*/
bool is_text() const noexcept {
return std::holds_alternative<std::string>(data_);
}

/**
* Check if this input contains image data.
* @return true if this input contains an image, false otherwise.
*/
bool is_image() const noexcept {
return std::holds_alternative<Image>(data_);
}

/**
* Get the type of data stored in this input.
* @return Type::TEXT if text data, Type::IMAGE if image data.
*/
Type get_type() const noexcept {
return is_text() ? Type::TEXT : Type::IMAGE;
}

/**
* Get the text data from this input.
* @return Reference to the stored text string.
* @throws std::bad_variant_access if this input doesn't contain text.
*/
const std::string& get_text() const& {
return std::get<std::string>(data_);
}

/**
* Get the text data from this input (mutable version).
* @return Mutable reference to the stored text string.
* @throws std::bad_variant_access if this input doesn't contain text.
*/
std::string& get_text() & {
return std::get<std::string>(data_);
}

/**
* Get the text data from this input (rvalue version).
* @return Rvalue reference to the stored text string for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain text.
*/
std::string&& get_text() && {
return std::get<std::string>(std::move(data_));
}

/**
* Get the image data from this input.
* @return Reference to the stored Image object.
* @throws std::bad_variant_access if this input doesn't contain an image.
*/
const Image& get_image() const& {
return std::get<Image>(data_);
}

/**
* Get the image data from this input (mutable version).
* @return Mutable reference to the stored Image object.
* @throws std::bad_variant_access if this input doesn't contain an image.
*/
Image& get_image() & {
return std::get<Image>(data_);
}

/**
* Get the image data from this input (rvalue version).
* @return Rvalue reference to the stored Image object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain an image.
*/
Image&& get_image() && {
return std::get<Image>(std::move(data_));
}

/**
* Try to get the text data from this input safely.
* @return Pointer to the text string if this input contains text, nullptr
* otherwise.
*/
const std::string* try_get_text() const noexcept {
return std::get_if<std::string>(&data_);
}

/**
* Try to get the text data from this input safely (mutable version).
* @return Pointer to the text string if this input contains text, nullptr
* otherwise.
*/
std::string* try_get_text() noexcept {
return std::get_if<std::string>(&data_);
}

/**
* Try to get the image data from this input safely.
* @return Pointer to the Image object if this input contains an image,
* nullptr otherwise.
*/
const Image* try_get_image() const noexcept {
return std::get_if<Image>(&data_);
}

/**
* Try to get the image data from this input safely (mutable version).
* @return Pointer to the Image object if this input contains an image,
* nullptr otherwise.
*/
Image* try_get_image() noexcept {
return std::get_if<Image>(&data_);
}

private:
std::variant<std::string, Image> data_;
};

// Convenience factory functions
inline MultimodalInput make_text_input(const std::string& text) noexcept {
return MultimodalInput(text);
}

inline MultimodalInput make_text_input(std::string&& text) noexcept {
return MultimodalInput(std::move(text));
}

inline MultimodalInput make_image_input(const Image& image) noexcept {
return MultimodalInput(image);
}

inline MultimodalInput make_image_input(Image&& image) noexcept {
return MultimodalInput(std::move(image));
}

} // namespace llm
} // namespace extension
} // namespace executorch
Loading
Loading