Skip to content

Commit e446bcf

Browse files
committed
prototype
1 parent d8a00e6 commit e446bcf

File tree

7 files changed

+286
-136
lines changed

7 files changed

+286
-136
lines changed

examples/models/phi-3-mini/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ project(phi_3_mini_runner)
1818

1919
set(CMAKE_CXX_STANDARD 17)
2020
set(CMAKE_CXX_STANDARD_REQUIRED True)
21-
set(CMAKE_BUILD_TYPE Release)
21+
set(CMAKE_BUILD_TYPE Debug)
2222

2323
# Set options for executorch build.
2424
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import argparse
7+
8+
import torch.nn
9+
10+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
11+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
12+
from executorch.exir import ExecutorchBackendConfig, to_edge
13+
14+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
15+
16+
17+
class ExampleModel(torch.nn.Module):
18+
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(
23+
self,
24+
input_token: torch.LongTensor = None,
25+
input_pos: torch.LongTensor = None,
26+
kv_cache: torch.LongTensor = None,
27+
) -> torch.LongTensor:
28+
pos = input_pos[-1].item()
29+
torch._check_is_size(pos)
30+
torch._check(pos < kv_cache.shape[1])
31+
narrowed_kv_cache = kv_cache.narrow(1, pos, 1)
32+
narrowed_kv_cache.copy_(input_token)
33+
return narrowed_kv_cache
34+
35+
36+
def main() -> None:
37+
torch.manual_seed(0)
38+
with torch.no_grad():
39+
model = ExampleModel()
40+
example_inputs = (
41+
torch.tensor([[3]], dtype=torch.long),
42+
torch.tensor([0], dtype=torch.long),
43+
torch.tensor([[1, 2]], dtype=torch.long),
44+
)
45+
dynamic_shapes = {
46+
"input_token": {
47+
0: 1,
48+
1: 1,
49+
},
50+
"input_pos": {0: 1},
51+
"kv_cache": {1: torch.export.Dim("sequence_length", min=1, max=128)},
52+
}
53+
54+
model = torch.export.export(
55+
model, example_inputs, dynamic_shapes=dynamic_shapes
56+
)
57+
edge_manager = to_edge(model, compile_config=get_xnnpack_edge_compile_config())
58+
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
59+
et_program = edge_manager.to_executorch(
60+
config=ExecutorchBackendConfig(
61+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass()
62+
)
63+
)
64+
65+
with open("example.pte", "wb") as file:
66+
file.write(et_program.buffer)
67+
68+
69+
def main2():
70+
kv_cache = torch.zeros((1, 10), dtype=torch.long)
71+
model = ExampleModel()
72+
for i in range(10):
73+
print(
74+
model.forward(
75+
input_token=torch.tensor([[i + 1]]),
76+
input_pos=torch.tensor([i]),
77+
kv_cache=kv_cache,
78+
)
79+
)
80+
print(kv_cache)
81+
82+
83+
if __name__ == "__main__":
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument(
86+
"-e",
87+
"--export",
88+
default=False,
89+
action="store_true",
90+
help="Whether or not to export",
91+
)
92+
args = parser.parse_args()
93+
if args.export:
94+
main()
95+
else:
96+
main2()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import argparse
7+
from pprint import pprint
8+
9+
import torch.nn
10+
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
12+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
13+
from executorch.exir import to_edge, ExecutorchBackendConfig
14+
15+
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
16+
17+
18+
class ExampleModel(torch.nn.Module):
19+
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(
24+
self,
25+
x: torch.LongTensor,
26+
y: torch.LongTensor,
27+
):
28+
x.copy_(y)
29+
30+
31+
def main() -> None:
32+
torch.manual_seed(0)
33+
with torch.no_grad():
34+
model = ExampleModel()
35+
example_inputs = (
36+
torch.zeros((1, 10), dtype=torch.long),
37+
torch.ones((1, 10), dtype=torch.long)
38+
)
39+
40+
model = torch.export.export(
41+
model, example_inputs, strict=False
42+
)
43+
print(model)
44+
edge_manager = to_edge(model, compile_config=get_xnnpack_edge_compile_config())
45+
print("Graph:")
46+
print(edge_manager.exported_program().graph_module.graph)
47+
print("Graph signature:")
48+
pprint(edge_manager.exported_program().graph_signature)
49+
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
50+
et_program = edge_manager.to_executorch(
51+
config=ExecutorchBackendConfig(
52+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False)
53+
)
54+
)
55+
print("ExecuTorch program:")
56+
pprint(et_program.executorch_program)
57+
print("Graph:")
58+
print(et_program.exported_program().graph_module.graph)
59+
print("Graph signature:")
60+
pprint(et_program.exported_program().graph_signature)
61+
62+
63+
with open("example2.pte", "wb") as file:
64+
file.write(et_program.buffer)
65+
66+
67+
def main2():
68+
x = torch.zeros((1, 10), dtype=torch.long)
69+
y = torch.ones((1, 10), dtype=torch.long)
70+
model = ExampleModel()
71+
model.forward(x, y)
72+
print(f"x: {x}")
73+
print(f"y: {y}")
74+
75+
76+
if __name__ == "__main__":
77+
parser = argparse.ArgumentParser()
78+
parser.add_argument(
79+
"-e",
80+
"--export",
81+
default=True,
82+
action="store_true",
83+
help="Whether or not to export",
84+
)
85+
args = parser.parse_args()
86+
if args.export:
87+
main()
88+
else:
89+
main2()

examples/models/phi-3-mini/main.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,15 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <gflags/gflags.h>
10-
119
#include <executorch/examples/models/phi-3-mini/runner.h>
1210

13-
DEFINE_string(
14-
model_path,
15-
"phi-3-mini.pte",
16-
"File path for model serialized in flatbuffer format.");
17-
18-
DEFINE_string(tokenizer_path, "tokenizer.bin", "File path for tokenizer.");
19-
20-
DEFINE_string(prompt, "Tell me a story", "Prompt.");
21-
22-
DEFINE_double(
23-
temperature,
24-
0.8f,
25-
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
26-
27-
DEFINE_int32(
28-
seq_len,
29-
128,
30-
"Total number of tokens to generate (prompt + output).");
31-
3211
int main(int32_t argc, char** argv) {
33-
gflags::ParseCommandLineFlags(&argc, &argv, true);
34-
35-
const char* model_path = FLAGS_model_path.c_str();
36-
37-
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
38-
39-
const char* prompt = FLAGS_prompt.c_str();
40-
41-
double temperature = FLAGS_temperature;
42-
43-
int32_t seq_len = FLAGS_seq_len;
12+
const char* model_path =
13+
"/home/lunwenh/executorch/examples/models/phi-3-mini/example2.pte";
4414

45-
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
15+
::torch::executor::Runner runner(model_path);
4616

47-
runner.generate(prompt, seq_len);
17+
runner.test_example();
4818

4919
return 0;
5020
}

0 commit comments

Comments
 (0)