Skip to content

Commit

Permalink
Simplify PlotCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
Levi-Armstrong committed Jul 1, 2022
1 parent 08e709c commit 71fbdd5
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 23 deletions.
3 changes: 2 additions & 1 deletion trajopt/include/trajopt/plot_callback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Returns a callback function suitable for an Optimizer.
This callback will plot the trajectory (with translucent copies of the robot) as
well as all of the Cost and Constraint functions with plot methods
*/
sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter);
sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter);

/**
* @brief Returns a callback suitable for an optimizer but does not require the problem
* @param plotter
Expand Down
15 changes: 8 additions & 7 deletions trajopt/src/plot_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ void PlotCosts(const tesseract_visualization::Visualization::Ptr& plotter,
plotter->waitForInput();
}

sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter)
sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter)
{
return [&prob, plotter](sco::OptProb*, sco::OptResults& results) {
auto state_solver = prob.GetEnv()->getStateSolver();
return [plotter](sco::OptProb* prob, sco::OptResults& results) {
auto& trajopt_prob = dynamic_cast<TrajOptProb&>(*prob);
auto state_solver = trajopt_prob.GetEnv()->getStateSolver();
PlotCosts(plotter,
*state_solver,
prob.GetKin()->getJointNames(),
std::ref(prob.getCosts()),
prob.getConstraints(),
std::ref(prob.GetVars()),
trajopt_prob.GetKin()->getJointNames(),
std::ref(trajopt_prob.getCosts()),
trajopt_prob.getConstraints(),
std::ref(trajopt_prob.GetVars()),
results);
};
}
Expand Down
2 changes: 1 addition & 1 deletion trajopt/src/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ TrajOptResult::Ptr OptimizeProblem(const TrajOptProb::Ptr& prob,
param.improve_ratio_threshold = .2;
param.initial_merit_error_coeff = 20;
if (plotter)
opt.addCallback(PlotCallback(*prob, plotter));
opt.addCallback(PlotCallback(plotter));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();
return std::make_shared<TrajOptResult>(opt.results(), *prob);
Expand Down
4 changes: 2 additions & 2 deletions trajopt/test/cast_cost_attached_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ TEST_F(CastAttachedTest, LinkWithGeom) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down Expand Up @@ -176,7 +176,7 @@ TEST_F(CastAttachedTest, LinkWithoutGeom) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down
2 changes: 1 addition & 1 deletion trajopt/test/cast_cost_octomap_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TEST_F(CastOctomapTest, boxes) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down
2 changes: 1 addition & 1 deletion trajopt/test/cast_cost_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TEST_F(CastTest, boxes) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down
2 changes: 1 addition & 1 deletion trajopt/test/cast_cost_world_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TEST_F(CastWorldTest, boxes) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down
16 changes: 8 additions & 8 deletions trajopt/test/joint_costs_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TEST_F(CostsTest, equality_jointPos) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -214,7 +214,7 @@ TEST_F(CostsTest, inequality_jointPos) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -309,7 +309,7 @@ TEST_F(CostsTest, equality_jointVel) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -416,7 +416,7 @@ TEST_F(CostsTest, inequality_jointVel) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -514,7 +514,7 @@ TEST_F(CostsTest, equality_jointVel_time) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -627,7 +627,7 @@ TEST_F(CostsTest, inequality_jointVel_time) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -722,7 +722,7 @@ TEST_F(CostsTest, equality_jointAcc) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down Expand Up @@ -831,7 +831,7 @@ TEST_F(CostsTest, inequality_jointAcc) // NOLINT
sco::BasicTrustRegionSQP opt(prob);
if (plotting)
{
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
}

opt.initialize(trajToDblVec(prob->GetInitTraj()));
Expand Down
2 changes: 1 addition & 1 deletion trajopt/test/simple_collision_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TEST_F(SimpleCollisionTest, spheres) // NOLINT

sco::BasicTrustRegionSQP opt(prob);
if (plotting)
opt.addCallback(PlotCallback(*prob, plotter_));
opt.addCallback(PlotCallback(plotter_));
opt.initialize(trajToDblVec(prob->GetInitTraj()));
opt.optimize();

Expand Down

0 comments on commit 71fbdd5

Please sign in to comment.