Skip to content

Commit 656ae28

Browse files
authored
Expose model's partial shape and related types (#98)
* Expose port element type and shape * Expose model's partial shape * Clean up casts
1 parent cfec608 commit 656ae28

File tree

7 files changed

+398
-3
lines changed

7 files changed

+398
-3
lines changed

crates/openvino/src/dimension.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use openvino_sys::{ov_dimension_is_dynamic, ov_dimension_t};
2+
3+
/// See [`Dimension`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__dimension__c__api.html).
4+
#[derive(Copy, Clone, Debug)]
5+
#[repr(transparent)]
6+
pub struct Dimension {
7+
instance: ov_dimension_t,
8+
}
9+
10+
impl PartialEq for Dimension {
11+
fn eq(&self, other: &Self) -> bool {
12+
self.instance.min == other.instance.min && self.instance.max == other.instance.max
13+
}
14+
}
15+
16+
impl Eq for Dimension {}
17+
18+
impl Dimension {
19+
/// Get the pointer to the underlying OpenVINO dimension.
20+
#[allow(dead_code)]
21+
pub(crate) fn instance(&self) -> ov_dimension_t {
22+
self.instance
23+
}
24+
25+
/// Create a new dimension object from `ov_dimension_t`.
26+
#[allow(dead_code)]
27+
pub(crate) fn new_from_instance(instance: ov_dimension_t) -> Self {
28+
Self { instance }
29+
}
30+
31+
/// Creates a new Dimension with minimum and maximum values.
32+
pub fn new(min: i64, max: i64) -> Self {
33+
let instance = ov_dimension_t { min, max };
34+
Self { instance }
35+
}
36+
37+
/// Returns the minimum value.
38+
pub fn get_min(&self) -> i64 {
39+
self.instance.min
40+
}
41+
42+
/// Returns the maximum value.
43+
pub fn get_max(&self) -> i64 {
44+
self.instance.max
45+
}
46+
47+
/// Returns `true` if the dimension is dynamic.
48+
pub fn is_dynamic(&self) -> bool {
49+
unsafe { ov_dimension_is_dynamic(self.instance) }
50+
}
51+
}
52+
53+
#[cfg(test)]
54+
mod tests {
55+
use crate::LoadingError;
56+
57+
use super::Dimension;
58+
59+
#[test]
60+
fn test_static() {
61+
openvino_sys::library::load()
62+
.map_err(LoadingError::SystemFailure)
63+
.unwrap();
64+
65+
let dim = Dimension::new(1, 1);
66+
assert!(!dim.is_dynamic());
67+
}
68+
69+
#[test]
70+
fn test_dynamic() {
71+
openvino_sys::library::load()
72+
.map_err(LoadingError::SystemFailure)
73+
.unwrap();
74+
75+
let dim = Dimension::new(1, 2);
76+
assert!(dim.is_dynamic());
77+
}
78+
}

crates/openvino/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,29 @@
2626
)]
2727

2828
mod core;
29+
mod dimension;
2930
mod element_type;
3031
mod error;
3132
mod layout;
3233
mod model;
3334
mod node;
35+
mod partial_shape;
3436
pub mod prepostprocess;
37+
mod rank;
3538
mod request;
3639
mod shape;
3740
mod tensor;
3841
mod util;
3942

4043
pub use crate::core::Core;
44+
pub use dimension::Dimension;
4145
pub use element_type::ElementType;
4246
pub use error::{InferenceError, LoadingError, SetupError};
4347
pub use layout::Layout;
4448
pub use model::{CompiledModel, Model};
4549
pub use node::Node;
50+
pub use partial_shape::PartialShape;
51+
pub use rank::Rank;
4652
pub use request::InferRequest;
4753
pub use shape::Shape;
4854
pub use tensor::Tensor;

crates/openvino/src/model.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{drop_using_function, try_unsafe, util::Result};
99
use openvino_sys::{
1010
ov_compiled_model_create_infer_request, ov_compiled_model_free, ov_compiled_model_t,
1111
ov_model_const_input_by_index, ov_model_const_output_by_index, ov_model_free,
12-
ov_model_inputs_size, ov_model_outputs_size, ov_model_t,
12+
ov_model_inputs_size, ov_model_is_dynamic, ov_model_outputs_size, ov_model_t,
1313
};
1414

1515
/// See [`Model`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__model__c__api.html).
@@ -78,6 +78,11 @@ impl Model {
7878
))?;
7979
Ok(Node::new(node))
8080
}
81+
82+
/// Returns `true` if the model contains dynamic shapes.
83+
pub fn is_dynamic(&self) -> bool {
84+
unsafe { ov_model_is_dynamic(self.instance) }
85+
}
8186
}
8287

8388
/// See [`CompiledModel`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__compiled__model__c__api.html).

crates/openvino/src/node.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
use crate::{try_unsafe, util::Result};
2-
use openvino_sys::{ov_output_const_port_t, ov_port_get_any_name};
1+
use crate::{try_unsafe, util::Result, ElementType, PartialShape, Shape};
2+
use openvino_sys::{
3+
ov_const_port_get_shape, ov_output_const_port_t, ov_partial_shape_t, ov_port_get_any_name,
4+
ov_port_get_element_type, ov_port_get_partial_shape, ov_rank_t, ov_shape_t,
5+
};
6+
37
use std::ffi::CStr;
48

59
/// See [`Node`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__node__c__api.html).
@@ -25,4 +29,40 @@ impl Node {
2529
.into_owned();
2630
Ok(rust_name)
2731
}
32+
33+
/// Get the data type of elements of the port.
34+
pub fn get_element_type(&self) -> Result<u32> {
35+
let mut element_type = ElementType::Undefined as u32;
36+
try_unsafe!(ov_port_get_element_type(
37+
self.instance,
38+
std::ptr::addr_of_mut!(element_type),
39+
))?;
40+
Ok(element_type)
41+
}
42+
43+
/// Get the shape of the port.
44+
pub fn get_shape(&self) -> Result<Shape> {
45+
let mut instance = ov_shape_t {
46+
rank: 0,
47+
dims: std::ptr::null_mut(),
48+
};
49+
try_unsafe!(ov_const_port_get_shape(
50+
self.instance,
51+
std::ptr::addr_of_mut!(instance),
52+
))?;
53+
Ok(Shape::new_from_instance(instance))
54+
}
55+
56+
/// Get the partial shape of the port.
57+
pub fn get_partial_shape(&self) -> Result<PartialShape> {
58+
let mut instance = ov_partial_shape_t {
59+
rank: ov_rank_t { min: 0, max: 0 },
60+
dims: std::ptr::null_mut(),
61+
};
62+
try_unsafe!(ov_port_get_partial_shape(
63+
self.instance,
64+
std::ptr::addr_of_mut!(instance),
65+
))?;
66+
Ok(PartialShape::new_from_instance(instance))
67+
}
2868
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
use crate::{dimension::Dimension, try_unsafe, util::Result, Rank};
2+
use openvino_sys::{
3+
ov_dimension_t, ov_partial_shape_create, ov_partial_shape_create_dynamic,
4+
ov_partial_shape_create_static, ov_partial_shape_free, ov_partial_shape_is_dynamic,
5+
ov_partial_shape_t, ov_rank_t,
6+
};
7+
8+
use std::convert::TryInto;
9+
10+
/// See [`PartialShape`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__partial__shape__c__api.html).
11+
pub struct PartialShape {
12+
instance: ov_partial_shape_t,
13+
}
14+
15+
impl Drop for PartialShape {
16+
/// Drops the `PartialShape` instance and frees the associated memory.
17+
fn drop(&mut self) {
18+
unsafe { ov_partial_shape_free(std::ptr::addr_of_mut!(self.instance)) }
19+
}
20+
}
21+
22+
impl PartialShape {
23+
/// Get the pointer to the underlying OpenVINO partial shape.
24+
#[allow(dead_code)]
25+
pub(crate) fn instance(&self) -> ov_partial_shape_t {
26+
self.instance
27+
}
28+
29+
/// Create a new partial shape object from `ov_partial_shape_t`.
30+
pub(crate) fn new_from_instance(instance: ov_partial_shape_t) -> Self {
31+
Self { instance }
32+
}
33+
34+
/// Creates a new `PartialShape` instance with a static rank and dynamic dimensions.
35+
pub fn new(rank: i64, dimensions: &[Dimension]) -> Result<Self> {
36+
let mut partial_shape = ov_partial_shape_t {
37+
rank: ov_rank_t { min: 0, max: 0 },
38+
dims: std::ptr::null_mut(),
39+
};
40+
try_unsafe!(ov_partial_shape_create(
41+
rank,
42+
dimensions.as_ptr().cast::<ov_dimension_t>(),
43+
std::ptr::addr_of_mut!(partial_shape)
44+
))?;
45+
Ok(Self {
46+
instance: partial_shape,
47+
})
48+
}
49+
50+
/// Creates a new `PartialShape` instance with a dynamic rank and dynamic dimensions.
51+
pub fn new_dynamic(rank: Rank, dimensions: &[Dimension]) -> Result<Self> {
52+
let mut partial_shape = ov_partial_shape_t {
53+
rank: ov_rank_t { min: 0, max: 0 },
54+
dims: std::ptr::null_mut(),
55+
};
56+
try_unsafe!(ov_partial_shape_create_dynamic(
57+
rank.instance(),
58+
dimensions.as_ptr().cast::<ov_dimension_t>(),
59+
std::ptr::addr_of_mut!(partial_shape)
60+
))?;
61+
Ok(Self {
62+
instance: partial_shape,
63+
})
64+
}
65+
66+
/// Creates a new `PartialShape` instance with a static rank and static dimensions.
67+
pub fn new_static(rank: i64, dimensions: &[i64]) -> Result<Self> {
68+
let mut partial_shape = ov_partial_shape_t {
69+
rank: ov_rank_t { min: 0, max: 0 },
70+
dims: std::ptr::null_mut(),
71+
};
72+
try_unsafe!(ov_partial_shape_create_static(
73+
rank,
74+
dimensions.as_ptr(),
75+
std::ptr::addr_of_mut!(partial_shape)
76+
))?;
77+
Ok(Self {
78+
instance: partial_shape,
79+
})
80+
}
81+
82+
/// Returns the rank of the partial shape.
83+
pub fn get_rank(&self) -> Rank {
84+
let rank = self.instance.rank;
85+
Rank::new_from_instance(rank)
86+
}
87+
88+
/// Returns the dimensions of the partial shape.
89+
pub fn get_dimensions(&self) -> &[Dimension] {
90+
if self.instance.dims.is_null() {
91+
&[]
92+
} else {
93+
unsafe {
94+
std::slice::from_raw_parts(
95+
self.instance.dims.cast::<Dimension>(),
96+
self.instance.rank.max.try_into().unwrap(),
97+
)
98+
}
99+
}
100+
}
101+
102+
/// Returns `true` if the partial shape is dynamic.
103+
pub fn is_dynamic(&self) -> bool {
104+
unsafe { ov_partial_shape_is_dynamic(self.instance) }
105+
}
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
use crate::LoadingError;
111+
112+
use super::*;
113+
114+
#[test]
115+
fn test_new_partial_shape() {
116+
openvino_sys::library::load()
117+
.map_err(LoadingError::SystemFailure)
118+
.unwrap();
119+
120+
let dimensions = vec![
121+
Dimension::new(0, 1),
122+
Dimension::new(1, 2),
123+
Dimension::new(2, 3),
124+
Dimension::new(3, 4),
125+
];
126+
127+
let shape = PartialShape::new(4, &dimensions).unwrap();
128+
assert_eq!(shape.get_rank().get_min(), 4);
129+
assert_eq!(shape.get_rank().get_max(), 4);
130+
assert!(shape.is_dynamic());
131+
}
132+
133+
#[test]
134+
fn test_new_dynamic_partial_shape() {
135+
openvino_sys::library::load()
136+
.map_err(LoadingError::SystemFailure)
137+
.unwrap();
138+
139+
let dimensions = vec![Dimension::new(1, 1), Dimension::new(2, 2)];
140+
141+
let shape = PartialShape::new_dynamic(Rank::new(0, 2), &dimensions).unwrap();
142+
assert!(shape.is_dynamic());
143+
}
144+
145+
#[test]
146+
fn test_new_static_partial_shape() {
147+
openvino_sys::library::load()
148+
.map_err(LoadingError::SystemFailure)
149+
.unwrap();
150+
151+
let dimensions = vec![1, 2];
152+
153+
let shape = PartialShape::new_static(2, &dimensions).unwrap();
154+
assert!(!shape.is_dynamic());
155+
}
156+
157+
#[test]
158+
fn test_get_dimensions() {
159+
openvino_sys::library::load()
160+
.map_err(LoadingError::SystemFailure)
161+
.unwrap();
162+
163+
let dimensions = vec![
164+
Dimension::new(0, 1),
165+
Dimension::new(1, 2),
166+
Dimension::new(2, 3),
167+
Dimension::new(3, 4),
168+
];
169+
170+
let shape = PartialShape::new(4, &dimensions).unwrap();
171+
172+
let dims = shape.get_dimensions();
173+
174+
assert_eq!(dims, &dimensions);
175+
}
176+
}

0 commit comments

Comments
 (0)