diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index f5933e82e32..ef98f41bd23 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -39,7 +39,18 @@ list(TRANSFORM _extension_llm_runner__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(extension_llm_runner STATIC ${_extension_llm_runner__srcs}) -set(runner_deps executorch_core extension_module extension_tensor tokenizers) +set(runner_deps executorch_core extension_module extension_tensor + tokenizers::tokenizers +) + +# depend on arange_utils +if(NOT TARGET kernels_util_all_deps) + add_subdirectory( + ${EXECUTORCH_ROOT}/kernels/portable/cpu/util + ${CMAKE_CURRENT_BINARY_DIR}/kernels_util + ) +endif() +list(APPEND runner_deps kernels_util_all_deps) target_link_libraries(extension_llm_runner PUBLIC ${runner_deps}) set_target_properties( diff --git a/extension/llm/runner/multimodal_decoder_runner.h b/extension/llm/runner/multimodal_decoder_runner.h new file mode 100644 index 00000000000..2f3ab401e03 --- /dev/null +++ b/extension/llm/runner/multimodal_decoder_runner.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch::extension::llm { + +class ET_EXPERIMENTAL MultimodalDecoderRunner + : public executorch::extension::llm::TextDecoderRunner { + public: + explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager) + : TextDecoderRunner(module, io_manager) {} + + /** + * Step the LLM Decoder with the given tokens and start position. + * @param tokens The tokens to the LLM. + * @param start_pos The start position of the tokens. + * @return The logits tensor. + */ + inline executorch::runtime::Result step( + executorch::extension::TensorPtr& tokens, + int64_t start_pos) override { + // run token embedding + auto token_embedding_outputs = + ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens)); + + // Return the logits tensor + return decode(token_embedding_outputs[0], start_pos); + } + + /** + * Decode the embeddings to logits. + * @param embeddings The embeddings tensor. + * @param start_pos The start position of the embeddings. + * @return The logits tensor. + */ + inline executorch::runtime::Result decode( + const runtime::EValue& embeddings, + int64_t start_pos) { + auto start_pos_tensor = ::executorch::extension::from_blob( + &start_pos, {1}, executorch::aten::ScalarType::Long); + // run text model + auto outputs_res = ET_UNWRAP( + module_->execute(kTextModelMethod, {start_pos_tensor, embeddings})); + + ET_CHECK_MSG( + outputs_res.size() == 1, + "More then one output returned from executing LLM."); + ET_CHECK_MSG( + outputs_res[0].isTensor(), + "Non Tensor Output returned from executing LLM"); + + // Return the logits tensor + return outputs_res[0].toTensor(); + } + + /** + * Load the Module for text decode purpose. + * @return The error code. + */ + inline executorch::runtime::Error load() override { + if (is_method_loaded()) { + return executorch::runtime::Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); + return executorch::runtime::Error::Ok; + } + + /** + * Check if the required methods in the Module is loaded. + * @return True if the Module is loaded, false otherwise. + */ + inline bool is_method_loaded() override { + executorch::runtime::Result> methods_res = + module_->method_names(); + if (methods_res.error() != executorch::runtime::Error::Ok) { + ET_CHECK_MSG(false, "Failed to get method names"); + } + std::unordered_set methods = methods_res.get(); + bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() && + methods.find(kTextModelMethod) != methods.end(); + if (!methods_exist) { + for (const auto& method : methods) { + ET_LOG(Error, "Method: %s", method.c_str()); + } + ET_CHECK_MSG( + methods_exist, + "Missing required methods (%s, %s) in the model", + kTokenEmbeddingMethod, + kTextModelMethod); + } + bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) && + module_->is_method_loaded(kTextModelMethod); + return methods_loaded; + } +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/multimodal_input.h b/extension/llm/runner/multimodal_input.h new file mode 100644 index 00000000000..8633def75bf --- /dev/null +++ b/extension/llm/runner/multimodal_input.h @@ -0,0 +1,186 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated +// A generic multimodal input class that can hold either image or text data. + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * A generic class to hold either image or text data for multimodal inputs. + * This allows the generate() API to take a std::vector of these objects + * instead of separate image and text parameters. + */ +class ET_EXPERIMENTAL MultimodalInput { + public: + enum class Type { TEXT, IMAGE }; + + // Constructors + explicit MultimodalInput(const std::string& text) : data_(text) {} + explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} + explicit MultimodalInput(const Image& image) : data_(image) {} + explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} + + // Copy constructor and assignment + MultimodalInput(const MultimodalInput& other) = default; + MultimodalInput& operator=(const MultimodalInput& other) = default; + + // Move constructor and assignment + MultimodalInput(MultimodalInput&& other) noexcept = default; + MultimodalInput& operator=(MultimodalInput&& other) noexcept = default; + + // Destructor + ~MultimodalInput() = default; + + /** + * Check if this input contains text data. + * @return true if this input contains text, false otherwise. + */ + bool is_text() const noexcept { + return std::holds_alternative(data_); + } + + /** + * Check if this input contains image data. + * @return true if this input contains an image, false otherwise. + */ + bool is_image() const noexcept { + return std::holds_alternative(data_); + } + + /** + * Get the type of data stored in this input. + * @return Type::TEXT if text data, Type::IMAGE if image data. + */ + Type get_type() const noexcept { + return is_text() ? Type::TEXT : Type::IMAGE; + } + + /** + * Get the text data from this input. + * @return Reference to the stored text string. + * @throws std::bad_variant_access if this input doesn't contain text. + */ + const std::string& get_text() const& { + return std::get(data_); + } + + /** + * Get the text data from this input (mutable version). + * @return Mutable reference to the stored text string. + * @throws std::bad_variant_access if this input doesn't contain text. + */ + std::string& get_text() & { + return std::get(data_); + } + + /** + * Get the text data from this input (rvalue version). + * @return Rvalue reference to the stored text string for efficient moves. + * @throws std::bad_variant_access if this input doesn't contain text. + */ + std::string&& get_text() && { + return std::get(std::move(data_)); + } + + /** + * Get the image data from this input. + * @return Reference to the stored Image object. + * @throws std::bad_variant_access if this input doesn't contain an image. + */ + const Image& get_image() const& { + return std::get(data_); + } + + /** + * Get the image data from this input (mutable version). + * @return Mutable reference to the stored Image object. + * @throws std::bad_variant_access if this input doesn't contain an image. + */ + Image& get_image() & { + return std::get(data_); + } + + /** + * Get the image data from this input (rvalue version). + * @return Rvalue reference to the stored Image object for efficient moves. + * @throws std::bad_variant_access if this input doesn't contain an image. + */ + Image&& get_image() && { + return std::get(std::move(data_)); + } + + /** + * Try to get the text data from this input safely. + * @return Pointer to the text string if this input contains text, nullptr + * otherwise. + */ + const std::string* try_get_text() const noexcept { + return std::get_if(&data_); + } + + /** + * Try to get the text data from this input safely (mutable version). + * @return Pointer to the text string if this input contains text, nullptr + * otherwise. + */ + std::string* try_get_text() noexcept { + return std::get_if(&data_); + } + + /** + * Try to get the image data from this input safely. + * @return Pointer to the Image object if this input contains an image, + * nullptr otherwise. + */ + const Image* try_get_image() const noexcept { + return std::get_if(&data_); + } + + /** + * Try to get the image data from this input safely (mutable version). + * @return Pointer to the Image object if this input contains an image, + * nullptr otherwise. + */ + Image* try_get_image() noexcept { + return std::get_if(&data_); + } + + private: + std::variant data_; +}; + +// Convenience factory functions +inline MultimodalInput make_text_input(const std::string& text) noexcept { + return MultimodalInput(text); +} + +inline MultimodalInput make_text_input(std::string&& text) noexcept { + return MultimodalInput(std::move(text)); +} + +inline MultimodalInput make_image_input(const Image& image) noexcept { + return MultimodalInput(image); +} + +inline MultimodalInput make_image_input(Image&& image) noexcept { + return MultimodalInput(std::move(image)); +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp new file mode 100644 index 00000000000..7f69041551f --- /dev/null +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Generic encoder prefiller that handles multimodal inputs (text, image and +// audio (to be implemented)) to prefill the KV cache of a multimodal LLM. +// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +MultimodalPrefiller::MultimodalPrefiller( + Module* module, + MultimodalDecoderRunner* decoder_runner, + Tokenizer* tokenizer, + IOManager* io_manager) + : module_(module), + text_decoder_runner_(decoder_runner), + tokenizer_(tokenizer), + io_manager_(io_manager) {} + +/** + * Prefill an LLM Module with the given multimodal input. + * @param input The multimodal input (text, image or audio) to the multimodal + * LLM. + * @param start_pos The starting position in KV cache of the input in the LLM + * @return logits of the prefill. + */ +Result MultimodalPrefiller::prefill( + const MultimodalInput& input, + int64_t& start_pos) { + // Check if input is image + ::executorch::runtime::EValue encoder_output; + if (input.is_image()) { + Image image = input.get_image(); + auto image_tensor = executorch::extension::from_blob( + image.data.data(), + {3, image.height, image.width}, + ::executorch::aten::ScalarType::Byte); + + // Run image encoder + auto image_encoder_outputs = + ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); + + encoder_output = image_encoder_outputs[0]; + } else if (input.is_text()) { + // For text input, we don't need to run the image encoder. + // Instead, we run the text encoder to get the encoder output. + auto& text = input.get_text(); + std::vector tokens = + ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); + auto text_tensor = executorch::extension::from_blob( + tokens.data(), + {1, static_cast(tokens.size())}, + ::executorch::aten::ScalarType::Long); + + // Run token embedding + auto token_embedding_outputs = + ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor)); + + encoder_output = token_embedding_outputs[0]; + } else { + ET_LOG(Error, "Unsupported input type"); + // For all other input types (e.g., audio), return error + return ::executorch::runtime::Error::NotSupported; + } + + auto outputs_res = + ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos)); + + // Update the start_pos, which is only available inside this function. + // outputs_res can have only one logits. + start_pos += encoder_output.toTensor().size(1); + + return static_cast( + text_decoder_runner_->logits_to_token(outputs_res)); +} + +/** + * Load the Module for encoder prefill purpose. + * @return The error code. + */ +::executorch::runtime::Error MultimodalPrefiller::load() { + if (is_method_loaded()) { + return ::executorch::runtime::Error::Ok; + } + // token_embeddings and text_model have to show up in method names. + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); + + std::unordered_set methods = + ET_UNWRAP(module_->method_names(), "Failed to get method names"); + + // Load image_encoder method if exists. + if (methods.find(kImageEncoderMethod) != methods.end()) { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); + } + return ::executorch::runtime::Error::Ok; +} + +/** + * Check if the required methods in the Module is loaded. + * @return True if the Module is loaded, false otherwise. + */ +bool MultimodalPrefiller::is_method_loaded() { + ::executorch::runtime::Result> methods_res = + module_->method_names(); + if (!module_->is_method_loaded(kTokenEmbeddingMethod)) { + return false; + } + if (!module_->is_method_loaded(kTextModelMethod)) { + return false; + } + if (methods_res.error() != ::executorch::runtime::Error::Ok) { + ET_CHECK_MSG(false, "Failed to get method names"); + } + std::unordered_set methods = methods_res.get(); + if (methods.find(kImageEncoderMethod) != methods.end()) { + return module_->is_method_loaded(kImageEncoderMethod); + } + return true; +} + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/multimodal_prefiller.h b/extension/llm/runner/multimodal_prefiller.h new file mode 100644 index 00000000000..dbfa2ec7ca3 --- /dev/null +++ b/extension/llm/runner/multimodal_prefiller.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Generic encoder prefiller that handles multimodal inputs (image and audio) +// to prefill the KV cache of a multimodal LLM. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using runtime::Error; +using runtime::Result; +using tokenizers::Tokenizer; + +// Assuming kv cache and parallel prefill are enabled. +// This prefiller supports both image and audio inputs +class ET_EXPERIMENTAL MultimodalPrefiller { + public: + explicit MultimodalPrefiller( + Module* module, + MultimodalDecoderRunner* decoder_runner, + Tokenizer* tokenizer, + IOManager* io_manager); + + /** + * Prefill an LLM Module with the given multimodal input. + * @param input The multimodal input (image or audio) to the multimodal LLM. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @return The next token of the LLM Module after prefill. + */ + virtual Result prefill( + const MultimodalInput& input, + int64_t& start_pos); + + virtual Error load(); + virtual bool is_method_loaded(); + + virtual ~MultimodalPrefiller() = default; + + protected: + Module* module_; + MultimodalDecoderRunner* text_decoder_runner_; + Tokenizer* tokenizer_; + IOManager* io_manager_; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index d25b1f6696a..5bbb12ab5ab 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -90,13 +90,33 @@ def define_common_targets(): exported_deps = [ ":constants", "//executorch/extension/module:module" + aten_suffix, + "//executorch/extension/tensor:tensor" + aten_suffix, + "//executorch/extension/llm/sampler:sampler" + aten_suffix, ], ) runtime.cxx_library( - name = "runner_lib" + aten_suffix, + name = "multimodal_runner_lib" + aten_suffix, exported_headers = [ + "multimodal_input.h", "multimodal_runner.h", + "multimodal_prefiller.h", + "multimodal_decoder_runner.h", + ], + srcs = [ + "multimodal_prefiller.cpp", + ], + exported_deps = [ + ":text_decoder_runner" + aten_suffix, + ":text_prefiller" + aten_suffix, + ":image_prefiller" + aten_suffix, + ":text_token_generator" + aten_suffix, + ], + ) + + runtime.cxx_library( + name = "runner_lib" + aten_suffix, + exported_headers = [ "text_llm_runner.h", "llm_runner_helper.h", "constants.h", @@ -114,6 +134,7 @@ def define_common_targets(): exported_deps = [ ":image_prefiller" + aten_suffix, ":irunner", + ":multimodal_runner_lib" + aten_suffix, ":text_decoder_runner" + aten_suffix, ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 78dcb25bcc5..2aa18000831 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -17,10 +17,23 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) -set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp - test_text_prefiller.cpp test_text_decoder_runner.cpp +set(_test_srcs + test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp + test_text_decoder_runner.cpp test_multimodal_input.cpp ) +# Add LSan stub for Apple platforms +if(APPLE) + list(APPEND _test_srcs lsan_stub.cpp) +endif() + et_cxx_test( test_runner SOURCES ${_test_srcs} EXTRA_LIBS executorch extension_llm_runner ) + +# Override sanitizer to this issue: +# https://github.com/abseil/abseil-cpp/issues/841 Root issue: +# https://github.com/llvm/llvm-project/issues/16778 +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + target_link_options(test_runner PUBLIC --rtlib=compiler-rt) +endif() diff --git a/extension/llm/runner/test/lsan_stub.cpp b/extension/llm/runner/test/lsan_stub.cpp new file mode 100644 index 00000000000..4a8c3aa9b2c --- /dev/null +++ b/extension/llm/runner/test/lsan_stub.cpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// lsan_stub.cpp - Fix for macOS LSan linking issue +#if defined(__APPLE__) && defined(__arm64__) +extern "C" { +// Provide stub for LSan symbol that macOS doesn't implement +int __lsan_is_turned_off() { + return 1; +} +} +#endif \ No newline at end of file diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 8bc3d4cc100..3339b3b8584 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -36,3 +36,11 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + + runtime.cxx_test( + name = "test_multimodal_input", + srcs = ["test_multimodal_input.cpp"], + deps = [ + "//executorch/extension/llm/runner:multimodal_runner_lib", + ], + ) diff --git a/extension/llm/runner/test/test_multimodal_input.cpp b/extension/llm/runner/test/test_multimodal_input.cpp new file mode 100644 index 00000000000..5c6d4c1b8f4 --- /dev/null +++ b/extension/llm/runner/test/test_multimodal_input.cpp @@ -0,0 +1,432 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + +#include +#include + +using namespace ::testing; +using executorch::extension::llm::Image; +using executorch::extension::llm::make_image_input; +using executorch::extension::llm::make_text_input; +using executorch::extension::llm::MultimodalInput; + +class MultimodalInputTest : public Test { + protected: + std::string createTestText() { + return "Hello, world!"; + } + + std::string createTestTextLong() { + return "This is a longer test string with multiple words and punctuation."; + } + + Image createTestImage() { + Image img; + img.width = 224; + img.height = 224; + img.channels = 3; + img.data = std::vector(224 * 224 * 3, 128); // Fill with gray + return img; + } + + Image createTestImageSmall() { + Image img; + img.width = 32; + img.height = 32; + img.channels = 1; + img.data = std::vector(32 * 32, 255); // Fill with white + return img; + } +}; + +// Test text constructors +TEST_F(MultimodalInputTest, TextConstructorFromString) { + std::string text = createTestText(); + MultimodalInput input(text); + + EXPECT_TRUE(input.is_text()); + EXPECT_FALSE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::TEXT); + EXPECT_EQ(input.get_text(), text); +} + +TEST_F(MultimodalInputTest, TextConstructorFromRvalueString) { + std::string text = createTestText(); + std::string original_text = text; + MultimodalInput input(std::move(text)); + + EXPECT_TRUE(input.is_text()); + EXPECT_FALSE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::TEXT); + EXPECT_EQ(input.get_text(), original_text); +} + +// Test image constructors +TEST_F(MultimodalInputTest, ImageConstructorFromImage) { + Image img = createTestImage(); + MultimodalInput input(img); + + EXPECT_FALSE(input.is_text()); + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::IMAGE); + EXPECT_EQ(input.get_image().width, 224); + EXPECT_EQ(input.get_image().height, 224); + EXPECT_EQ(input.get_image().channels, 3); + EXPECT_EQ(input.get_image().data.size(), 224 * 224 * 3); +} + +TEST_F(MultimodalInputTest, ImageConstructorFromRvalueImage) { + Image img = createTestImage(); + int width = img.width; + int height = img.height; + int channels = img.channels; + size_t data_size = img.data.size(); + + MultimodalInput input(std::move(img)); + + EXPECT_FALSE(input.is_text()); + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::IMAGE); + EXPECT_EQ(input.get_image().width, width); + EXPECT_EQ(input.get_image().height, height); + EXPECT_EQ(input.get_image().channels, channels); + EXPECT_EQ(input.get_image().data.size(), data_size); +} + +// Test copy constructor and assignment +TEST_F(MultimodalInputTest, CopyConstructorText) { + std::string text = createTestText(); + MultimodalInput original(text); + MultimodalInput copy(original); + + EXPECT_TRUE(copy.is_text()); + EXPECT_EQ(copy.get_text(), text); + EXPECT_EQ(original.get_text(), text); // Original should be unchanged +} + +TEST_F(MultimodalInputTest, CopyAssignmentText) { + std::string text = createTestText(); + MultimodalInput original(text); + MultimodalInput copy(createTestImage()); // Start with different type + + copy = original; + + EXPECT_TRUE(copy.is_text()); + EXPECT_EQ(copy.get_text(), text); + EXPECT_EQ(original.get_text(), text); // Original should be unchanged +} + +TEST_F(MultimodalInputTest, CopyConstructorImage) { + Image img = createTestImage(); + MultimodalInput original(img); + MultimodalInput copy(original); + + EXPECT_TRUE(copy.is_image()); + EXPECT_EQ(copy.get_image().width, 224); + EXPECT_EQ(copy.get_image().height, 224); + EXPECT_EQ(copy.get_image().channels, 3); + EXPECT_EQ(original.get_image().width, 224); // Original should be unchanged +} + +TEST_F(MultimodalInputTest, CopyAssignmentImage) { + Image img = createTestImage(); + MultimodalInput original(img); + MultimodalInput copy(createTestText()); // Start with different type + + copy = original; + + EXPECT_TRUE(copy.is_image()); + EXPECT_EQ(copy.get_image().width, 224); + EXPECT_EQ(copy.get_image().height, 224); + EXPECT_EQ(copy.get_image().channels, 3); + EXPECT_EQ(original.get_image().width, 224); // Original should be unchanged +} + +// Test move constructor and assignment +TEST_F(MultimodalInputTest, MoveConstructorText) { + std::string text = createTestText(); + std::string original_text = text; + MultimodalInput original(std::move(text)); + MultimodalInput moved(std::move(original)); + + EXPECT_TRUE(moved.is_text()); + EXPECT_EQ(moved.get_text(), original_text); +} + +TEST_F(MultimodalInputTest, MoveAssignmentText) { + std::string text = createTestText(); + std::string original_text = text; + MultimodalInput original(std::move(text)); + MultimodalInput moved(createTestImage()); // Start with different type + + moved = std::move(original); + + EXPECT_TRUE(moved.is_text()); + EXPECT_EQ(moved.get_text(), original_text); +} + +TEST_F(MultimodalInputTest, MoveConstructorImage) { + Image img = createTestImage(); + int width = img.width; + int height = img.height; + int channels = img.channels; + MultimodalInput original(std::move(img)); + MultimodalInput moved(std::move(original)); + + EXPECT_TRUE(moved.is_image()); + EXPECT_EQ(moved.get_image().width, width); + EXPECT_EQ(moved.get_image().height, height); + EXPECT_EQ(moved.get_image().channels, channels); +} + +TEST_F(MultimodalInputTest, MoveAssignmentImage) { + Image img = createTestImage(); + int width = img.width; + int height = img.height; + int channels = img.channels; + MultimodalInput original(std::move(img)); + MultimodalInput moved(createTestText()); // Start with different type + + moved = std::move(original); + + EXPECT_TRUE(moved.is_image()); + EXPECT_EQ(moved.get_image().width, width); + EXPECT_EQ(moved.get_image().height, height); + EXPECT_EQ(moved.get_image().channels, channels); +} + +// Test getter methods with correct types +TEST_F(MultimodalInputTest, GetTextWithTextInput) { + std::string text = createTestText(); + MultimodalInput input(text); + + // Test const lvalue reference version + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.get_text(), text); + + // Test mutable lvalue reference version + std::string& mutable_text = input.get_text(); + mutable_text += " Modified"; + EXPECT_EQ(input.get_text(), text + " Modified"); + + // Test rvalue reference version + std::string moved_text = std::move(input).get_text(); + EXPECT_EQ(moved_text, text + " Modified"); +} + +TEST_F(MultimodalInputTest, GetImageWithImageInput) { + Image img = createTestImage(); + MultimodalInput input(img); + + // Test const lvalue reference version + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.get_image().width, 224); + + // Test mutable lvalue reference version + Image& mutable_image = input.get_image(); + mutable_image.width = 448; + EXPECT_EQ(input.get_image().width, 448); + + // Test rvalue reference version + Image moved_image = std::move(input).get_image(); + EXPECT_EQ(moved_image.width, 448); +} + +// Test getter methods with wrong types (should throw) +TEST_F(MultimodalInputTest, GetTextWithImageInputThrows) { + Image img = createTestImage(); + MultimodalInput input(img); + + EXPECT_THROW(input.get_text(), std::bad_variant_access); + EXPECT_THROW(std::move(input).get_text(), std::bad_variant_access); +} + +TEST_F(MultimodalInputTest, GetImageWithTextInputThrows) { + std::string text = createTestText(); + MultimodalInput input(text); + + EXPECT_THROW(input.get_image(), std::bad_variant_access); + EXPECT_THROW(std::move(input).get_image(), std::bad_variant_access); +} + +// Test safe getter methods (try_get_*) +TEST_F(MultimodalInputTest, TryGetTextWithTextInput) { + std::string text = createTestText(); + MultimodalInput input(text); + + // Test const version + const MultimodalInput& const_input = input; + const std::string* text_ptr = const_input.try_get_text(); + ASSERT_NE(text_ptr, nullptr); + EXPECT_EQ(*text_ptr, text); + + // Test mutable version + std::string* mutable_text_ptr = input.try_get_text(); + ASSERT_NE(mutable_text_ptr, nullptr); + EXPECT_EQ(*mutable_text_ptr, text); + + // Modify through pointer + *mutable_text_ptr += " Modified"; + EXPECT_EQ(input.get_text(), text + " Modified"); +} + +TEST_F(MultimodalInputTest, TryGetTextWithImageInput) { + Image img = createTestImage(); + MultimodalInput input(img); + + // Should return nullptr for wrong type + EXPECT_EQ(input.try_get_text(), nullptr); + + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.try_get_text(), nullptr); +} + +TEST_F(MultimodalInputTest, TryGetImageWithImageInput) { + Image img = createTestImage(); + MultimodalInput input(img); + + // Test const version + const MultimodalInput& const_input = input; + const Image* image_ptr = const_input.try_get_image(); + ASSERT_NE(image_ptr, nullptr); + EXPECT_EQ(image_ptr->width, 224); + EXPECT_EQ(image_ptr->height, 224); + EXPECT_EQ(image_ptr->channels, 3); + + // Test mutable version + Image* mutable_image_ptr = input.try_get_image(); + ASSERT_NE(mutable_image_ptr, nullptr); + EXPECT_EQ(mutable_image_ptr->width, 224); + + // Modify through pointer + mutable_image_ptr->width = 448; + EXPECT_EQ(input.get_image().width, 448); +} + +TEST_F(MultimodalInputTest, TryGetImageWithTextInput) { + std::string text = createTestText(); + MultimodalInput input(text); + + // Should return nullptr for wrong type + EXPECT_EQ(input.try_get_image(), nullptr); + + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.try_get_image(), nullptr); +} + +// Test convenience factory functions +TEST_F(MultimodalInputTest, MakeTextInputFromString) { + std::string text = createTestText(); + MultimodalInput input = make_text_input(text); + + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), text); +} + +TEST_F(MultimodalInputTest, MakeTextInputFromRvalueString) { + std::string text = createTestText(); + std::string original_text = text; + MultimodalInput input = make_text_input(std::move(text)); + + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), original_text); +} + +TEST_F(MultimodalInputTest, MakeImageInputFromImage) { + Image img = createTestImage(); + MultimodalInput input = make_image_input(img); + + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_image().width, 224); + EXPECT_EQ(input.get_image().height, 224); + EXPECT_EQ(input.get_image().channels, 3); +} + +TEST_F(MultimodalInputTest, MakeImageInputFromRvalueImage) { + Image img = createTestImage(); + int width = img.width; + int height = img.height; + int channels = img.channels; + MultimodalInput input = make_image_input(std::move(img)); + + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_image().width, width); + EXPECT_EQ(input.get_image().height, height); + EXPECT_EQ(input.get_image().channels, channels); +} + +// Test with different image sizes +TEST_F(MultimodalInputTest, DifferentImageSizes) { + Image small_img = createTestImageSmall(); + MultimodalInput input(small_img); + + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_image().width, 32); + EXPECT_EQ(input.get_image().height, 32); + EXPECT_EQ(input.get_image().channels, 1); + EXPECT_EQ(input.get_image().data.size(), 32 * 32); +} + +// Test with empty text +TEST_F(MultimodalInputTest, EmptyText) { + std::string empty_text = ""; + MultimodalInput input(empty_text); + + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), ""); + EXPECT_EQ(input.get_text().size(), 0); +} + +// Test with long text +TEST_F(MultimodalInputTest, LongText) { + std::string long_text = createTestTextLong(); + MultimodalInput input(long_text); + + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), long_text); + EXPECT_GT(input.get_text().size(), 50); +} + +// Test type consistency +TEST_F(MultimodalInputTest, TypeConsistency) { + std::string text = createTestText(); + Image img = createTestImage(); + + MultimodalInput text_input(text); + MultimodalInput image_input(img); + + // Text input should consistently report as text + EXPECT_TRUE(text_input.is_text()); + EXPECT_FALSE(text_input.is_image()); + EXPECT_EQ(text_input.get_type(), MultimodalInput::Type::TEXT); + + // Image input should consistently report as image + EXPECT_FALSE(image_input.is_text()); + EXPECT_TRUE(image_input.is_image()); + EXPECT_EQ(image_input.get_type(), MultimodalInput::Type::IMAGE); +} + +// Test assignment between different types +TEST_F(MultimodalInputTest, AssignmentBetweenTypes) { + std::string text = createTestText(); + Image img = createTestImage(); + + MultimodalInput input(text); + EXPECT_TRUE(input.is_text()); + + // Assign image to text input + input = MultimodalInput(img); + EXPECT_TRUE(input.is_image()); + EXPECT_EQ(input.get_image().width, 224); + + // Assign text back to image input + input = MultimodalInput(text); + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), text); +} diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index b5302faebf4..4e4a4670361 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -195,16 +195,20 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); // Set up expectations for the tokenizer encode method - EXPECT_CALL(*tokenizer, encode(_, _, _)) - .WillOnce(Return(::tokenizers::Result>( - std::vector{1, 2, 3}))); + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); // Set up expectations for the text prefiller - EXPECT_CALL(*text_prefiller, prefill(_, _)) - .WillOnce(Return(Result(4))); + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); // Set up expectations for load methods - EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); std::unique_ptr stats = std::make_unique(); @@ -256,15 +260,20 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); // Set up expectations for the tokenizer encode method - EXPECT_CALL(*tokenizer, encode(_, _, _)) - .WillOnce(Return(::tokenizers::Result>( - std::vector{1, 2, 3}))); + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); // Set up expectations for the text prefiller - EXPECT_CALL(*text_prefiller, prefill(_, _)) - .WillOnce(Return(Result(4))); + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); - EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + // Set up expectations for load methods + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); std::unique_ptr stats = std::make_unique(); @@ -334,12 +343,14 @@ TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); // Set up expectations for the tokenizer encode method - EXPECT_CALL(*tokenizer, encode(_, _, _)) - .WillOnce(Return(::tokenizers::Result>( - std::vector{1, 2, 3}))); + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); // Set up expectations for load methods - EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); std::unique_ptr stats = std::make_unique(); diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp index 2e02fc2a406..3c80f4b57af 100644 --- a/extension/llm/runner/test/test_text_prefiller.cpp +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -286,9 +286,10 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) { auto prefiller = createTextPrefiller(10, true, true); // Set up expectations for the text decoder runner - EXPECT_CALL(text_decoder_runner_, step(_, _)) - .Times(1) - .WillOnce(Return(Result(tensor))); + ON_CALL(text_decoder_runner_, step(_, _)) + .WillByDefault([&](executorch::extension::TensorPtr&, int64_t) { + return Result(tensor); + }); // Create prompt tokens std::vector prompt_tokens = {1, 2, 3}; diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index aa8ad0d4003..c0477e04569 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -352,6 +352,7 @@ EXTENSION_RUNNER_UTIL_SRCS = [ EXTENSION_LLM_RUNNER_SRCS = [ "extension/llm/runner/llm_runner_helper.cpp", + "extension/llm/runner/multimodal_prefiller.cpp", "extension/llm/runner/text_decoder_runner.cpp", "extension/llm/runner/text_llm_runner.cpp", "extension/llm/runner/text_prefiller.cpp", diff --git a/test/run_oss_cpp_tests.sh b/test/run_oss_cpp_tests.sh index 4b35324f22e..1648f2ba434 100755 --- a/test/run_oss_cpp_tests.sh +++ b/test/run_oss_cpp_tests.sh @@ -32,7 +32,6 @@ build_executorch() { if [ -x "$(command -v glslc)" ]; then BUILD_VULKAN="ON" fi - # -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ TODO(larryliu0820): Fix the name collision between Abseil and XNNPACK and turn this on. cmake . \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DEXECUTORCH_USE_CPP_CODE_COVERAGE=ON \ @@ -42,6 +41,8 @@ build_executorch() { -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \