Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.
Open
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ function(get_linux_lsb_release_information)
set(LSB_RELEASE_VERSION "${LSB_RELEASE_VERSION}" PARENT_SCOPE)
endfunction()


set(OV_VERSION_SHORT "2024.4")
set(OV_VERSION "2024.4.0.16579.c3152d32c9c_x86_64")
set(OV_STORAGE_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages")
set(OV_NIGHTLY_COMMIT "2024.3.0-15502-66093834e38")
set(OV_NIGHTLY_COMMIT "2024.4.0-16039-620d2a20c8c")

if (WIN32)
if(NOT OV_LIBRARY_URL)
Expand Down
3 changes: 1 addition & 2 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

FetchContent_Declare(
intel_npu_acceleration_library
GIT_REPOSITORY "https://github.com/intel/intel-npu-acceleration-library"
GIT_TAG "main"
SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../"
)
FetchContent_MakeAvailable(intel_npu_acceleration_library)

Expand Down
33 changes: 15 additions & 18 deletions examples/cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace intel_npu_acceleration_library;
#include <iostream>

int main() {
const size_t batch = 128, inC = 256, outC = 512, N = 100000;
const size_t batch = 128, inC = 256, outC = 512, N = 10000;

std::cout << "Create a ModelFactory" << std::endl;
auto factory = std::make_shared<ModelFactory>("NPU");
Expand All @@ -28,19 +28,19 @@ int main() {
factory->compile();

// Save OV model
std::cout << "Saving model to matmul.xml" << std::endl;
factory->saveModel("matmul.xml");
// std::cout << "Saving model to matmul.xml" << std::endl;
// factory->saveModel("matmul.xml");

// Here you can create float16 buffers and run inference by using
half_ptr input_buffer = new uint16_t[batch * inC];
half_ptr weights_buffer = new uint16_t[outC * inC];
half_ptr bias_buffer = new uint16_t[outC];
half_ptr output_buffer = new uint16_t[batch * outC];
std::cout << "Creating a remote tensor" << std::endl;
auto input_buffer = factory->createRemoteInputTensor(0);
auto weights_buffer = factory->createRemoteInputTensor(1);
auto bias_buffer = factory->createRemoteInputTensor(2);
auto output_buffer = factory->createRemoteOutputTensor(0);

memset(input_buffer, 0, batch * inC * sizeof(uint16_t));
memset(weights_buffer, 0, outC * inC * sizeof(uint16_t));
memset(output_buffer, 0, batch * outC * sizeof(uint16_t));
memset(bias_buffer, 0, outC * sizeof(uint16_t));
std::memset(input_buffer.get(), 0, input_buffer.get_byte_size());
std::memset(weights_buffer.get(), 0, weights_buffer.get_byte_size());
std::memset(bias_buffer.get(), 0, bias_buffer.get_byte_size());
std::memset(output_buffer.get(), 0, output_buffer.get_byte_size());

factory->setInputTensor(input_buffer, 0);
factory->setInputTensor(weights_buffer, 1);
Expand All @@ -49,13 +49,10 @@ int main() {

// Run inference
std::cout << "Run inference on " << N << " workloads" << std::endl;
for (auto idx = 0; idx < N; idx++)
for (auto idx = 0; idx < N; idx++) {
factory->run();
std::cout << "Inference done" << std::endl;
}

delete[] input_buffer;
delete[] weights_buffer;
delete[] bias_buffer;
delete[] output_buffer;
std::cout << "Inference done" << std::endl;
return 0;
}
7 changes: 7 additions & 0 deletions include/intel_npu_acceleration_library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "openvino/opsets/opset7.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp"
#include "openvino/runtime/intel_npu/properties.hpp"

#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
Expand All @@ -23,6 +24,12 @@

namespace intel_npu_acceleration_library {

/**
* @brief OpenVINO core object
*
*/
ov::Core core;

static constexpr ov::Property<std::string> npu_compiler_type{"NPU_COMPILER_TYPE"};
static constexpr ov::Property<std::string> npu_parameters{"NPU_COMPILATION_MODE_PARAMS"};

Expand Down
73 changes: 65 additions & 8 deletions include/intel_npu_acceleration_library/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,10 @@
#include <vector>
#include "intel_npu_acceleration_library/common.h"
#include "intel_npu_acceleration_library/parameters.h"
#include "intel_npu_acceleration_library/tensor.h"

namespace intel_npu_acceleration_library {

/**
* @brief OpenVINO core object
*
*/
static ov::Core core;

/**
* @brief Create a remote tensor
*
Expand Down Expand Up @@ -95,8 +90,6 @@ class OVInferenceModel {
compiled_model = core.compile_model(model, device);
// Create inference request
infer_request = compiled_model.create_infer_request();
// First inference
infer_request.infer();
}

/**
Expand Down Expand Up @@ -126,6 +119,14 @@ class OVInferenceModel {
wt_thread.join();
}

/**
* @brief Get the remote context
*
*/
auto get_context() {
return core.get_default_context(device).as<ov::intel_npu::level_zero::ZeroContext>();
}

/**
* @brief Save the model to a local path
*
Expand Down Expand Up @@ -167,6 +168,42 @@ class OVInferenceModel {
}
}

/**
* @brief Create a Remote Tensor object
*
* @param type element type
* @param shape element shape
* @param tensor_type element tensor type: INPUT, OUTPUT, BIND
* @return auto
*/
auto createRemoteTensor(const ov::element::Type type, const ov::Shape& shape,
const ov::intel_npu::TensorType tensor_type) {
ov::intel_npu::level_zero::ZeroContext context = get_context();
return context.create_l0_host_tensor(type, shape, tensor_type);
}

/**
* @brief Create a Remote Tensor object
*
* @param idx index of the input tensor
* @return auto
*/
auto createRemoteInputTensor(size_t idx) {
auto tensor = infer_request.get_input_tensor(idx);
return createRemoteTensor(tensor.get_element_type(), tensor.get_shape(), ov::intel_npu::TensorType::INPUT);
}

/**
* @brief Create a Remote Tensor object
*
* @param idx index of the output tensor
* @return auto
*/
auto createRemoteOutputTensor(size_t idx) {
auto tensor = infer_request.get_output_tensor(idx);
return createRemoteTensor(tensor.get_element_type(), tensor.get_shape(), ov::intel_npu::TensorType::OUTPUT);
}

/**
* @brief Get model input tensor
*
Expand Down Expand Up @@ -201,6 +238,16 @@ class OVInferenceModel {
infer_request.set_input_tensor(idx, X);
}

/**
* @brief Set the input activations
*
* @param _X reference to a zero buffer tensor
* @param idx input tensor index
*/
void setInputTensor(ov::intel_npu::level_zero::ZeroBufferTensor& _X, size_t idx) {
infer_request.set_input_tensor(idx, _X);
}

/**
* @brief Set the output activations
*
Expand All @@ -213,6 +260,16 @@ class OVInferenceModel {
infer_request.set_output_tensor(idx, X);
}

/**
* @brief Set the output activations
*
* @param _X reference to a zero buffer tensor
* @param idx output tensor index
*/
void setOutputTensor(ov::intel_npu::level_zero::ZeroBufferTensor& _X, size_t idx) {
infer_request.set_output_tensor(idx, _X);
}

/**
* @brief Set the input and output activations
*
Expand Down
52 changes: 52 additions & 0 deletions include/intel_npu_acceleration_library/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//
// Copyright © 2024 Intel Corporation
// SPDX-License-Identifier: Apache 2.0
//

#include "intel_npu_acceleration_library/common.h"

namespace intel_npu_acceleration_library {

/**
* @brief Class representing a NPU tensor
*
*/
class Tensor {
private:
ov::intel_npu::level_zero::ZeroBufferTensor _remote_tensor;
void* data_ptr;

public:
/**
* @brief Construct a new Tensor object
*
* @param dtype tensor datatype
* @param shape tensor shape
* @param data pointer to tensor data
* @param tensor_type tensor type. Choices between INPUT, OUTPUT, BINDED
* @param device target device for the tensor
*/
Tensor(ov::element::Type_t dtype, ov::Shape shape, void* data,
ov::intel_npu::TensorType tensor_type = ov::intel_npu::TensorType::INPUT, std::string device = "NPU") {
if (!_isNPUAvailable(core)) {
// Cannot create NPU remote tensor... use the same pointer as before
data_ptr = data;
} else {
auto context = core.get_default_context(device).as<ov::intel_npu::level_zero::ZeroContext>();
_remote_tensor = context.create_l0_host_tensor(dtype, shape, tensor_type);
data_ptr = _remote_tensor.get();
std::memcpy(data_ptr, data, _remote_tensor.get_byte_size());
}
}

/**
* @brief Get the data pointer
*
* @return void*
*/
void* data() {
return data_ptr;
}
};

} // namespace intel_npu_acceleration_library
9 changes: 9 additions & 0 deletions intel_npu_acceleration_library/backend/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def init_common(lib: ctypes.CDLL):

lib.compressToI4.argtypes = [c_i8_array, c_u8_array, ctypes.c_int]

# Remote tensors
lib.to_npu.argtypes = [ctypes.c_int, c_u32_array, ctypes.c_char_p, ctypes.c_void_p]
lib.to_npu.restype = handler

lib.remote_tensor_data.argtypes = [handler]
lib.remote_tensor_data.restype = ctypes.c_void_p

lib.del_remote_tensor.argtypes = [handler]


def init_network_factory(lib: ctypes.CDLL):
"""Initialize Netowrk factory bindings.
Expand Down
28 changes: 2 additions & 26 deletions intel_npu_acceleration_library/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from intel_npu_acceleration_library.backend.ops import get_supported_ops
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
from intel_npu_acceleration_library.backend.tensor import Tensor
from intel_npu_acceleration_library.dtypes import int4, bfloat16
from intel_npu_acceleration_library.dtypes import get_backend_dtype
from typing import Optional, Tuple, Any, Union, Sequence, TypeVar, Callable, cast, List
from functools import partial
import numpy.typing as npt
Expand Down Expand Up @@ -115,34 +115,10 @@ def get_backend_dtype(self, dtype) -> ctypes.c_char_p:
Args:
dtype: numpy dtype

Raises:
RuntimeError: Unsupported datatype

Returns:
ctypes.c_char_p: string representation of the dtype
"""
if dtype in [np.int8, torch.int8]:
str_dtype = "int8"
elif dtype == np.uint8 or dtype == int4:
# u8 represents packed i4 dtypes
str_dtype = "int4"
elif dtype in [np.int16, torch.int16]:
str_dtype = "int16"
elif dtype in [np.int32, torch.int32]:
str_dtype = "int32"
elif dtype in [np.int64, torch.int64]:
str_dtype = "int64"
elif dtype in [np.float16, torch.float16]:
str_dtype = "float16"
elif dtype in [np.float32, torch.float32]:
str_dtype = "float32"
elif dtype in [np.float64, torch.float64]:
str_dtype = "float64"
elif dtype in [bfloat16, torch.bfloat16]:
str_dtype = "bfloat16"
else:
raise RuntimeError(f"DType is not supported {dtype}")
return ctypes.c_char_p(str_dtype.encode())
return get_backend_dtype(dtype)

@return_tensor
def parameter(
Expand Down
Loading
Loading