-
Notifications
You must be signed in to change notification settings - Fork 216
/
main.rs
35 lines (31 loc) · 1.18 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use tract_tensorflow::prelude::*;
fn main() -> TractResult<()> {
let model = tract_tensorflow::tensorflow()
// load the model
.model_for_path("mobilenet_v2_1.4_224_frozen.pb")?
// specify input type and shape
.with_input_fact(0, f32::fact([1, 224, 224, 3]).into())?
// optimize the model
.into_optimized()?
// make the model runnable and fix its inputs and outputs
.into_runnable()?;
// open image, resize it and make a Tensor out of it
let image = image::open("grace_hopper.jpg").unwrap().to_rgb8();
let resized =
image::imageops::resize(&image, 224, 224, ::image::imageops::FilterType::Triangle);
let image: Tensor = tract_ndarray::Array4::from_shape_fn((1, 224, 224, 3), |(_, y, x, c)| {
resized[(x as _, y as _)][c] as f32 / 255.0
})
.into();
// run the model on the input
let result = model.run(tvec!(image.into()))?;
// find and display the max value with its index
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(1..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
println!("result: {best:?}");
Ok(())
}