From 71fbdd53e72da8cc48996e614263a68978c25c4b Mon Sep 17 00:00:00 2001 From: Levi Armstrong Date: Tue, 28 Jun 2022 23:07:49 -0500 Subject: [PATCH] Simplify PlotCallback --- trajopt/include/trajopt/plot_callback.hpp | 3 ++- trajopt/src/plot_callback.cpp | 15 ++++++++------- trajopt/src/problem_description.cpp | 2 +- trajopt/test/cast_cost_attached_unit.cpp | 4 ++-- trajopt/test/cast_cost_octomap_unit.cpp | 2 +- trajopt/test/cast_cost_unit.cpp | 2 +- trajopt/test/cast_cost_world_unit.cpp | 2 +- trajopt/test/joint_costs_unit.cpp | 16 ++++++++-------- trajopt/test/simple_collision_unit.cpp | 2 +- 9 files changed, 25 insertions(+), 23 deletions(-) diff --git a/trajopt/include/trajopt/plot_callback.hpp b/trajopt/include/trajopt/plot_callback.hpp index db695da2..4a12bd8f 100644 --- a/trajopt/include/trajopt/plot_callback.hpp +++ b/trajopt/include/trajopt/plot_callback.hpp @@ -12,7 +12,8 @@ Returns a callback function suitable for an Optimizer. This callback will plot the trajectory (with translucent copies of the robot) as well as all of the Cost and Constraint functions with plot methods */ -sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter); +sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter); + /** * @brief Returns a callback suitable for an optimizer but does not require the problem * @param plotter diff --git a/trajopt/src/plot_callback.cpp b/trajopt/src/plot_callback.cpp index b0e4d167..61ad47cd 100644 --- a/trajopt/src/plot_callback.cpp +++ b/trajopt/src/plot_callback.cpp @@ -47,16 +47,17 @@ void PlotCosts(const tesseract_visualization::Visualization::Ptr& plotter, plotter->waitForInput(); } -sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter) +sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter) { - return [&prob, plotter](sco::OptProb*, sco::OptResults& results) { - auto state_solver = prob.GetEnv()->getStateSolver(); + return [plotter](sco::OptProb* prob, sco::OptResults& results) { + auto& trajopt_prob = dynamic_cast(*prob); + auto state_solver = trajopt_prob.GetEnv()->getStateSolver(); PlotCosts(plotter, *state_solver, - prob.GetKin()->getJointNames(), - std::ref(prob.getCosts()), - prob.getConstraints(), - std::ref(prob.GetVars()), + trajopt_prob.GetKin()->getJointNames(), + std::ref(trajopt_prob.getCosts()), + trajopt_prob.getConstraints(), + std::ref(trajopt_prob.GetVars()), results); }; } diff --git a/trajopt/src/problem_description.cpp b/trajopt/src/problem_description.cpp index 9426bcee..e6e1b022 100644 --- a/trajopt/src/problem_description.cpp +++ b/trajopt/src/problem_description.cpp @@ -394,7 +394,7 @@ TrajOptResult::Ptr OptimizeProblem(const TrajOptProb::Ptr& prob, param.improve_ratio_threshold = .2; param.initial_merit_error_coeff = 20; if (plotter) - opt.addCallback(PlotCallback(*prob, plotter)); + opt.addCallback(PlotCallback(plotter)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); return std::make_shared(opt.results(), *prob); diff --git a/trajopt/test/cast_cost_attached_unit.cpp b/trajopt/test/cast_cost_attached_unit.cpp index dfa53b6c..66dbcd99 100644 --- a/trajopt/test/cast_cost_attached_unit.cpp +++ b/trajopt/test/cast_cost_attached_unit.cpp @@ -126,7 +126,7 @@ TEST_F(CastAttachedTest, LinkWithGeom) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); @@ -176,7 +176,7 @@ TEST_F(CastAttachedTest, LinkWithoutGeom) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); diff --git a/trajopt/test/cast_cost_octomap_unit.cpp b/trajopt/test/cast_cost_octomap_unit.cpp index 62db2d03..7b830a2e 100644 --- a/trajopt/test/cast_cost_octomap_unit.cpp +++ b/trajopt/test/cast_cost_octomap_unit.cpp @@ -125,7 +125,7 @@ TEST_F(CastOctomapTest, boxes) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); diff --git a/trajopt/test/cast_cost_unit.cpp b/trajopt/test/cast_cost_unit.cpp index 5ad93c74..51f85ce6 100644 --- a/trajopt/test/cast_cost_unit.cpp +++ b/trajopt/test/cast_cost_unit.cpp @@ -86,7 +86,7 @@ TEST_F(CastTest, boxes) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); diff --git a/trajopt/test/cast_cost_world_unit.cpp b/trajopt/test/cast_cost_world_unit.cpp index c240e604..142c15ce 100644 --- a/trajopt/test/cast_cost_world_unit.cpp +++ b/trajopt/test/cast_cost_world_unit.cpp @@ -109,7 +109,7 @@ TEST_F(CastWorldTest, boxes) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize(); diff --git a/trajopt/test/joint_costs_unit.cpp b/trajopt/test/joint_costs_unit.cpp index 2390936a..fd5a8350 100644 --- a/trajopt/test/joint_costs_unit.cpp +++ b/trajopt/test/joint_costs_unit.cpp @@ -107,7 +107,7 @@ TEST_F(CostsTest, equality_jointPos) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -214,7 +214,7 @@ TEST_F(CostsTest, inequality_jointPos) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -309,7 +309,7 @@ TEST_F(CostsTest, equality_jointVel) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -416,7 +416,7 @@ TEST_F(CostsTest, inequality_jointVel) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -514,7 +514,7 @@ TEST_F(CostsTest, equality_jointVel_time) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -627,7 +627,7 @@ TEST_F(CostsTest, inequality_jointVel_time) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -722,7 +722,7 @@ TEST_F(CostsTest, equality_jointAcc) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); @@ -831,7 +831,7 @@ TEST_F(CostsTest, inequality_jointAcc) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) { - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); } opt.initialize(trajToDblVec(prob->GetInitTraj())); diff --git a/trajopt/test/simple_collision_unit.cpp b/trajopt/test/simple_collision_unit.cpp index 739745c4..708239ac 100644 --- a/trajopt/test/simple_collision_unit.cpp +++ b/trajopt/test/simple_collision_unit.cpp @@ -86,7 +86,7 @@ TEST_F(SimpleCollisionTest, spheres) // NOLINT sco::BasicTrustRegionSQP opt(prob); if (plotting) - opt.addCallback(PlotCallback(*prob, plotter_)); + opt.addCallback(PlotCallback(plotter_)); opt.initialize(trajToDblVec(prob->GetInitTraj())); opt.optimize();