Skip to content

Commit

Permalink
Merge pull request #69 from oberbichler/feature/armijo
Browse files Browse the repository at this point in the history
Add Armijo
  • Loading branch information
oberbichler authored Jun 2, 2020
2 parents 6cc1021 + 3d95668 commit 0e60404
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 4 deletions.
75 changes: 75 additions & 0 deletions include/eqlib/Armijo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include "Define.h"
#include "Log.h"
#include "Problem.h"
#include "Settings.h"
#include "Timer.h"

namespace eqlib {

class Armijo
{
private: // types
using Type = eqlib::Armijo;

private: // members
Pointer<Problem> m_problem;
double m_c;
double m_rho;
Vector m_x_init;
Vector m_x;

public: // constructor
Armijo(Pointer<Problem> problem) : m_problem(problem), m_c(0.2), m_rho(0.9), m_x_init(problem->nb_variables()), m_x(problem->nb_variables())
{
}

public: // methods
double search(Vector search_direction, double alpha_init, bool reset)
{
double alpha = alpha_init;

m_x_init = m_problem->x();
const double f_init = m_problem->f();
const double cache = m_c * m_problem->df().dot(search_direction);

m_x = m_x_init + alpha * search_direction;

m_problem->set_x(m_x);
m_problem->compute(0);
double f = m_problem->f();

while ((f - f_init) > (alpha * cache)) {
alpha *= m_rho;
m_x = m_x_init + alpha * search_direction;

m_problem->set_x(m_x);
m_problem->compute(0);
f = m_problem->f();
}

if (reset) {
m_problem->set_f(f_init);
}

m_problem->set_x(m_x_init);

return alpha;
}

public: // python
template <typename TModule>
static void register_python(TModule& m)
{
namespace py = pybind11;
using namespace pybind11::literals;

py::class_<Type>(m, "Armijo")
.def(py::init<Pointer<eqlib::Problem>>(), "problem"_a)
// methods
.def("search", &Type::search, py::call_guard<py::gil_scoped_release>(), "search_direction"_a, "alpha_init"_a = true, "reset"_a = true);
}
}; // class Armijo

} // namespace eqlib
4 changes: 3 additions & 1 deletion include/eqlib/Problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class Problem {

Timer timer;

m_data.set_zero();
m_data.set_zero<TOrder>();

if constexpr (TParallel) {
ProblemData l_data(m_data);
Expand Down Expand Up @@ -1306,6 +1306,8 @@ class Problem {
.def_property_readonly("values", py::overload_cast<>(&Type::values))
.def_property_readonly("equation_bounds", &Type::equation_bounds)
.def_property_readonly("variable_bounds", &Type::variable_bounds)
.def_property_readonly("nb_elements_f", &Type::nb_elements_f)
.def_property_readonly("nb_elements_g", &Type::nb_elements_g)
// properties
.def_property("linear_solver", &Type::linear_solver, &Type::set_linear_solver)
.def_property("f", &Type::f, &Type::set_f)
Expand Down
19 changes: 16 additions & 3 deletions include/eqlib/ProblemData.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,23 @@ class ProblemData {
return m_assemble_time;
}

template <index TOrder>
void set_zero()
{
m_values.setZero();
m_buffer.setZero();
if constexpr(TOrder == 0) {
m_values(0) = 0;
}
if constexpr(TOrder == 1) {
m_values.head(1 + m_m + m_n).setZero();
}
if constexpr(TOrder == 2) {
m_values.setZero();
}

if constexpr(TOrder > 0) {
m_buffer.setZero();
}

m_computation_time = 0.0;
m_assemble_time = 0.0;
}
Expand Down Expand Up @@ -91,7 +104,7 @@ class ProblemData {

m_buffer.resize(std::max(index{1}, max_element_m) * max_element_n + std::max(index{1}, max_element_m) * max_element_n * max_element_n);

set_zero();
set_zero<2>();
}

double& f() noexcept
Expand Down
4 changes: 4 additions & 0 deletions src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include <eqlib/Armijo.h>
#include <eqlib/Constraint.h>
#include <eqlib/Equation.h>
#include <eqlib/LambdaConstraint.h>
Expand Down Expand Up @@ -46,6 +47,9 @@ PYBIND11_MODULE(eqlib, m)

// --- core

// Armijo
eqlib::Armijo::register_python(m);

// Equation
eqlib::Equation::register_python(m);

Expand Down
8 changes: 8 additions & 0 deletions tests/test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_init(problem):
assert_equal(len(problem.hm_values), 5)


def test_nb_elements_f(problem):
assert_equal(problem.nb_elements_f, 2)


def test_nb_elements_g(problem):
assert_equal(problem.nb_elements_g, 0)


def test_compute(problem):
problem.compute()

Expand Down

0 comments on commit 0e60404

Please sign in to comment.