Skip to content

Commit 07fc7f8

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add a default image prefiller implementation (#13310)
Summary: As titled. I need to create an interface `IModule` for `Module` class to override, to make it test-able. Reviewed By: jackzhxng Differential Revision: D80063769
1 parent 8e208ad commit 07fc7f8

File tree

9 files changed

+501
-11
lines changed

9 files changed

+501
-11
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/runner/constants.h>
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
12+
namespace executorch::extension::llm {
13+
14+
class ET_EXPERIMENTAL MultimodalDecoderRunner
15+
: public executorch::extension::llm::TextDecoderRunner {
16+
public:
17+
explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager)
18+
: TextDecoderRunner(module, io_manager) {}
19+
20+
/**
21+
* Step the LLM Decoder with the given tokens and start position.
22+
* @param tokens The tokens to the LLM.
23+
* @param start_pos The start position of the tokens.
24+
* @return The logits tensor.
25+
*/
26+
inline executorch::runtime::Result<executorch::aten::Tensor> step(
27+
executorch::extension::TensorPtr& tokens,
28+
int64_t start_pos) override {
29+
// run token embedding
30+
auto token_embedding_outputs =
31+
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
32+
33+
// Return the logits tensor
34+
return decode(token_embedding_outputs[0], start_pos);
35+
}
36+
37+
/**
38+
* Decode the embeddings to logits.
39+
* @param embeddings The embeddings tensor.
40+
* @param start_pos The start position of the embeddings.
41+
* @return The logits tensor.
42+
*/
43+
inline executorch::runtime::Result<executorch::aten::Tensor> decode(
44+
const runtime::EValue& embeddings,
45+
int64_t start_pos) {
46+
auto start_pos_tensor = ::executorch::extension::from_blob(
47+
&start_pos, {1}, executorch::aten::ScalarType::Long);
48+
// run text model
49+
auto outputs_res = ET_UNWRAP(
50+
module_->execute(kTextModelMethod, {start_pos_tensor, embeddings}));
51+
52+
ET_CHECK_MSG(
53+
outputs_res.size() == 1,
54+
"More then one output returned from executing LLM.");
55+
ET_CHECK_MSG(
56+
outputs_res[0].isTensor(),
57+
"Non Tensor Output returned from executing LLM");
58+
59+
// Return the logits tensor
60+
return outputs_res[0].toTensor();
61+
}
62+
63+
/**
64+
* Load the Module for text decode purpose.
65+
* @return The error code.
66+
*/
67+
inline executorch::runtime::Error load() override {
68+
if (is_method_loaded()) {
69+
return executorch::runtime::Error::Ok;
70+
}
71+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod));
72+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
73+
return executorch::runtime::Error::Ok;
74+
}
75+
76+
/**
77+
* Check if the required methods in the Module is loaded.
78+
* @return True if the Module is loaded, false otherwise.
79+
*/
80+
inline bool is_method_loaded() override {
81+
executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
82+
module_->method_names();
83+
if (methods_res.error() != executorch::runtime::Error::Ok) {
84+
ET_CHECK_MSG(false, "Failed to get method names");
85+
}
86+
std::unordered_set<std::string> methods = methods_res.get();
87+
bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() &&
88+
methods.find(kTextModelMethod) != methods.end();
89+
if (!methods_exist) {
90+
for (const auto& method : methods) {
91+
ET_LOG(Error, "Method: %s", method.c_str());
92+
}
93+
ET_CHECK_MSG(
94+
methods_exist,
95+
"Missing required methods (%s, %s) in the model",
96+
kTokenEmbeddingMethod,
97+
kTextModelMethod);
98+
}
99+
bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) &&
100+
module_->is_method_loaded(kTextModelMethod);
101+
return methods_loaded;
102+
}
103+
};
104+
105+
} // namespace executorch::extension::llm
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A generic multimodal input class that can hold either image or text data.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/runner/image.h>
14+
#include <executorch/runtime/platform/compiler.h>
15+
#include <string>
16+
#include <variant>
17+
18+
namespace executorch {
19+
namespace extension {
20+
namespace llm {
21+
22+
/**
23+
* A generic class to hold either image or text data for multimodal inputs.
24+
* This allows the generate() API to take a std::vector of these objects
25+
* instead of separate image and text parameters.
26+
*/
27+
class ET_EXPERIMENTAL MultimodalInput {
28+
public:
29+
enum class Type { TEXT, IMAGE };
30+
31+
// Constructors
32+
explicit MultimodalInput(const std::string& text) : data_(text) {}
33+
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
34+
explicit MultimodalInput(const Image& image) : data_(image) {}
35+
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
36+
37+
// Copy constructor and assignment
38+
MultimodalInput(const MultimodalInput& other) = default;
39+
MultimodalInput& operator=(const MultimodalInput& other) = default;
40+
41+
// Move constructor and assignment
42+
MultimodalInput(MultimodalInput&& other) noexcept = default;
43+
MultimodalInput& operator=(MultimodalInput&& other) noexcept = default;
44+
45+
// Destructor
46+
~MultimodalInput() = default;
47+
48+
/**
49+
* Check if this input contains text data.
50+
* @return true if this input contains text, false otherwise.
51+
*/
52+
bool is_text() const {
53+
return std::holds_alternative<std::string>(data_);
54+
}
55+
56+
/**
57+
* Check if this input contains image data.
58+
* @return true if this input contains an image, false otherwise.
59+
*/
60+
bool is_image() const {
61+
return std::holds_alternative<Image>(data_);
62+
}
63+
64+
/**
65+
* Get the type of data stored in this input.
66+
* @return Type::TEXT if text data, Type::IMAGE if image data.
67+
*/
68+
Type get_type() const {
69+
return is_text() ? Type::TEXT : Type::IMAGE;
70+
}
71+
72+
/**
73+
* Get the text data from this input.
74+
* @return Reference to the stored text string.
75+
* @throws std::bad_variant_access if this input doesn't contain text.
76+
*/
77+
const std::string& get_text() const {
78+
return std::get<std::string>(data_);
79+
}
80+
81+
/**
82+
* Get the text data from this input (mutable version).
83+
* @return Mutable reference to the stored text string.
84+
* @throws std::bad_variant_access if this input doesn't contain text.
85+
*/
86+
std::string& get_text() {
87+
return std::get<std::string>(data_);
88+
}
89+
90+
/**
91+
* Get the image data from this input.
92+
* @return Reference to the stored Image object.
93+
* @throws std::bad_variant_access if this input doesn't contain an image.
94+
*/
95+
const Image& get_image() const {
96+
return std::get<Image>(data_);
97+
}
98+
99+
/**
100+
* Get the image data from this input (mutable version).
101+
* @return Mutable reference to the stored Image object.
102+
* @throws std::bad_variant_access if this input doesn't contain an image.
103+
*/
104+
Image& get_image() {
105+
return std::get<Image>(data_);
106+
}
107+
108+
/**
109+
* Try to get the text data from this input safely.
110+
* @return Pointer to the text string if this input contains text, nullptr
111+
* otherwise.
112+
*/
113+
const std::string* try_get_text() const {
114+
return std::get_if<std::string>(&data_);
115+
}
116+
117+
/**
118+
* Try to get the text data from this input safely (mutable version).
119+
* @return Pointer to the text string if this input contains text, nullptr
120+
* otherwise.
121+
*/
122+
std::string* try_get_text() {
123+
return std::get_if<std::string>(&data_);
124+
}
125+
126+
/**
127+
* Try to get the image data from this input safely.
128+
* @return Pointer to the Image object if this input contains an image,
129+
* nullptr otherwise.
130+
*/
131+
const Image* try_get_image() const {
132+
return std::get_if<Image>(&data_);
133+
}
134+
135+
/**
136+
* Try to get the image data from this input safely (mutable version).
137+
* @return Pointer to the Image object if this input contains an image,
138+
* nullptr otherwise.
139+
*/
140+
Image* try_get_image() {
141+
return std::get_if<Image>(&data_);
142+
}
143+
144+
private:
145+
std::variant<std::string, Image> data_;
146+
};
147+
148+
// Convenience factory functions
149+
inline MultimodalInput make_text_input(const std::string& text) {
150+
return MultimodalInput(text);
151+
}
152+
153+
inline MultimodalInput make_text_input(std::string&& text) {
154+
return MultimodalInput(std::move(text));
155+
}
156+
157+
inline MultimodalInput make_image_input(const Image& image) {
158+
return MultimodalInput(image);
159+
}
160+
161+
inline MultimodalInput make_image_input(Image&& image) {
162+
return MultimodalInput(std::move(image));
163+
}
164+
165+
} // namespace llm
166+
} // namespace extension
167+
} // namespace executorch

0 commit comments

Comments
 (0)