diff --git a/cli/src/errors.rs b/cli/src/errors.rs index 9af1b5b8be..055b51271e 100644 --- a/cli/src/errors.rs +++ b/cli/src/errors.rs @@ -1,6 +1,5 @@ #![allow(deprecated)] -use tract_core::prelude::*; use tract_core::ndarray; use crate::model::Model; @@ -25,7 +24,7 @@ error_chain! { } errors { - ModelBuilding(partial: Box, inner: TractError) { + ModelBuilding(partial: Box) { } } } diff --git a/cli/src/main.rs b/cli/src/main.rs index decd4b0b8c..a9a37099bf 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -434,23 +434,27 @@ fn handle(matches: clap::ArgMatches, probe: Option<&Probe>) -> CliResult<()> { return Ok(()); } - let params = match Parameters::from_clap(&matches, probe) { + let builder_result = Parameters::from_clap(&matches, probe); + let params = match builder_result { Ok(params) => params, - Err(CliError(CliErrorKind::ModelBuilding(mut broken_model, e), _)) => { - let annotations = crate::annotations::Annotations::from_model(&*broken_model)?; - let display_params = if let ("dump", Some(sm)) = matches.subcommand() { - display_params_from_clap(&matches, &sm)? - } else { - crate::display_params::DisplayParams::default() - }; - - if broken_model.output_outlets().len() == 0 { - broken_model.auto_outputs()?; + Err(e) => { + if let CliError(CliErrorKind::ModelBuilding(ref broken_model), _) = e { + let mut broken_model:Box = tract_hir::tract_core::dyn_clone::clone(broken_model); + let annotations = + crate::annotations::Annotations::from_model(broken_model.as_ref())?; + let display_params = if let ("dump", Some(sm)) = matches.subcommand() { + display_params_from_clap(&matches, &sm)? + } else { + crate::display_params::DisplayParams::default() + }; + + if broken_model.output_outlets().len() == 0 { + broken_model.auto_outputs()?; + } + terminal::render(broken_model.as_ref(), &annotations, &display_params)?; } - terminal::render(&*broken_model, &annotations, &display_params)?; Err(e)? } - Err(e) => Err(e)?, }; let mut need_optimisations = false; diff --git a/cli/src/model.rs b/cli/src/model.rs index 385a8758e9..c291838131 100644 --- a/cli/src/model.rs +++ b/cli/src/model.rs @@ -91,6 +91,7 @@ pub trait Model: downcast_rs::Downcast + std::fmt::Debug + dyn_clone::DynClone + } downcast_rs::impl_downcast!(Model); +dyn_clone::clone_trait_object!(Model); impl Model for Graph where diff --git a/cli/src/params.rs b/cli/src/params.rs index 139c224b54..fed35bcb27 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -3,7 +3,6 @@ use std::str::FromStr; use tract_itertools::Itertools; use tract_core::internal::*; -use tract_onnx::prelude::*; use tract_core::model::TypedModel; use tract_hir::internal::*; #[cfg(feature = "tf")] @@ -126,19 +125,17 @@ impl Parameters { if need_graph { ( SomeGraphDef::Nnef(proto_model.clone()), - Box::new( - nnef.translate(&proto_model) - .map_err(|e| CliErrorKind::ModelBuilding(Box::new(e.0), e.1))?, - ), + Box::new(nnef.translate(&proto_model).map_err(|(g, e)| { + CliError::from(e).chain_err(|| CliErrorKind::ModelBuilding(Box::new(g))) + })?), Option::::None, ) } else { ( SomeGraphDef::NoGraphDef, - Box::new( - nnef.translate(&proto_model) - .map_err(|e| CliErrorKind::ModelBuilding(Box::new(e.0), e.1))?, - ), + Box::new(nnef.translate(&proto_model).map_err(|(g, e)| { + CliError::from(e).chain_err(|| CliErrorKind::ModelBuilding(Box::new(g))) + })?), Option::::None, ) } @@ -453,7 +450,7 @@ impl Parameters { info!(concat!("Running '", $name, "'")); let mut last_model: Option> = if keep_last { Some(Box::new(from.as_ref().clone())) } else { None }; - let block: &dyn Fn(_) -> TractResult<_> = &$block; + let block: &dyn Fn(_) -> CliResult<_> = &$block; let owned_model = Arc::try_unwrap(from).unwrap_or_else(|from| from.as_ref().clone()); match block(owned_model) { @@ -462,11 +459,10 @@ impl Parameters { } Err(e) => { if let Some(last_model) = last_model.take() { - return Err( - CliErrorKind::ModelBuilding(last_model, e.into()).into() - ); + return Err(CliError::from(e) + .chain_err(|| CliErrorKind::ModelBuilding(last_model))); } else { - Err(e)? + return Err(e); } } } @@ -487,33 +483,39 @@ impl Parameters { }; stage!("analyse", inference_model -> inference_model, - |mut m:InferenceModel| { m.analyse(matches.is_present("analyse_fail_fast"))?; TractResult::Ok(m) }); + |mut m:InferenceModel| { + let result = m.analyse(matches.is_present("analyse_fail_fast")); + match result { + Ok(_) => Ok(m), + Err(e) => Err( + CliError::from(e).chain_err(|| CliErrorKind::ModelBuilding(Box::new(m)))) + }}); if let Some(ext) = tf_model_extensions { #[cfg(feature = "tf")] - stage!("tf-preproc", inference_model -> inference_model, |m:InferenceModel| ext.preproc(m)); + stage!("tf-preproc", inference_model -> inference_model, |m:InferenceModel| Ok(ext.preproc(m)?)); } - stage!("incorporate", inference_model -> inference_model, |m:InferenceModel| { m.incorporate()}); - stage!("type", inference_model -> typed_model, |m:InferenceModel| m.into_typed()); - stage!("declutter", typed_model -> typed_model, |m:TypedModel| m.declutter()); + stage!("incorporate", inference_model -> inference_model, |m:InferenceModel| { Ok(m.incorporate()?)}); + stage!("type", inference_model -> typed_model, |m:InferenceModel| Ok(m.into_typed()?)); + stage!("declutter", typed_model -> typed_model, |m:TypedModel| Ok(m.declutter()?)); if let Some(dim) = concretize_stream_dim { - stage!("concretize-stream-dim", typed_model -> typed_model, |m:TypedModel| m.concretize_stream_dim(dim) ); - stage!("concretize-stream-dim-declutter", typed_model -> typed_model, |m:TypedModel| m.declutter()); + stage!("concretize-stream-dim", typed_model -> typed_model, |m:TypedModel| Ok(m.concretize_stream_dim(dim)?) ); + stage!("concretize-stream-dim-declutter", typed_model -> typed_model, |m:TypedModel| Ok(m.declutter()?)); } else if let Some(pulse) = pulse { - stage!("pulse", typed_model -> pulsed_model, |m:TypedModel| ::tract_core::pulse::PulsedModel::new(&m, pulse)); - stage!("pulse-to-type", pulsed_model -> typed_model, |m:PulsedModel| m.into_typed()); - stage!("pulse-declutter", typed_model -> typed_model, |m:TypedModel| m.declutter()); + stage!("pulse", typed_model -> pulsed_model, |m:TypedModel| Ok(::tract_core::pulse::PulsedModel::new(&m, pulse)?)); + stage!("pulse-to-type", pulsed_model -> typed_model, |m:PulsedModel| Ok(m.into_typed()?)); + stage!("pulse-declutter", typed_model -> typed_model, |m:TypedModel| Ok(m.declutter()?)); } if nnef_cycle { stage!("nnef-cycle", typed_model -> typed_model, |m:TypedModel| { let nnef = super::nnef(&matches); let mut vec = vec!(); nnef.write(&m, &mut vec)?; - nnef.model_for_read(&mut &*vec) + Ok(nnef.model_for_read(&mut &*vec)?) }); - stage!("nnef-declutter", typed_model -> typed_model, |m:TypedModel| m.declutter()); + stage!("nnef-declutter", typed_model -> typed_model, |m:TypedModel| Ok(m.declutter()?)); } stage!("before-optimize", typed_model -> typed_model, |m:TypedModel| Ok(m)); - stage!("optimize", typed_model -> typed_model, |m:TypedModel| m.optimize()); + stage!("optimize", typed_model -> typed_model, |m:TypedModel| Ok(m.optimize()?)); Ok((typed_model.clone().unwrap(), typed_model, pulsed_model, reference_model)) } diff --git a/hir/src/ops/scan.rs b/hir/src/ops/scan.rs index fb05c811b9..6aed453dc4 100644 --- a/hir/src/ops/scan.rs +++ b/hir/src/ops/scan.rs @@ -300,7 +300,7 @@ impl InferenceOp for InferenceScan { if self .body .analyse(false) - .map_err(|e| format!("analysing inner model: {}\n{:#?}", e, self.body))? + .map_err(|e| format!("analysing inner model: {}", e))? { changed = true; }