9 #include "spdlog/spdlog.h"
21 const std::function<
double(
double)>& projected_objective,
double gTs,
double f_init)
const
25 double f_old = f_init;
26 double f_new = f_init;
28 for(
int num_linesearch=0; num_linesearch <
m_MaxSteps; ++num_linesearch)
30 f_new = projected_objective(step);
31 if (f_new - f_old <=
m_Eta * step * gTs) {
32 return {f_new, step, num_linesearch};
37 spdlog::warn(
"Line search failed. Final step size: {:.3}, df = {:.3}",
44 throw std::invalid_argument(
"step size must be positive");
51 throw std::invalid_argument(
"max num steps must be positive");
58 throw std::invalid_argument(
"alpha must be positive");
61 throw std::invalid_argument(
"alpha must be less than 1");
68 throw std::invalid_argument(
"eta must be positive");
71 throw std::invalid_argument(
"eta must be less than 1");
81 CHECK(searcher.get_alpha() == 0.4);
82 searcher.set_initial_step(1.8);
83 CHECK(searcher.get_initial_step() == 1.8);
84 searcher.set_max_steps(5);
85 CHECK(searcher.get_max_steps() == 5);
86 searcher.set_eta(0.8);
87 CHECK(searcher.get_eta() == 0.8);
89 CHECK_THROWS(searcher.set_alpha(-0.1));
90 CHECK_THROWS(searcher.set_alpha(0.0));
91 CHECK_THROWS(searcher.set_alpha(1.0));
92 CHECK_THROWS(searcher.set_eta(0.0));
93 CHECK_THROWS(searcher.set_eta(-.1));
94 CHECK_THROWS(searcher.set_eta(1.0));
95 CHECK_THROWS(searcher.set_initial_step(0.0));
96 CHECK_THROWS(searcher.set_initial_step(-.1));
97 CHECK_THROWS(searcher.set_max_steps(0));
98 CHECK_THROWS(searcher.set_max_steps(-1));
104 auto quad_objective = [](
double x_0,
double d) {
105 auto fun = [=](
double a) {
107 return (a + x_0) * (a + x_0);
116 SUBCASE(
"x^2 wrong direction") {
117 auto objective = quad_objective(1.0, 1.0);
119 CHECK(result.StepSize == 0.0);
120 CHECK(result.Value == 1.0);
125 SUBCASE(
"x^2 right direction") {
126 auto objective = quad_objective(1.0, -1.0);
128 CHECK(result.StepSize == 1.0);
129 CHECK(result.Value == 0.0);
134 SUBCASE(
"x^2 right direction too large") {
135 auto objective = quad_objective(1.0, -8.0);
137 CHECK(result.StepSize == 1.0/8.0);
138 CHECK(result.Value == 0.0);
void declare_hyper_parameter(std::string name, U S::*pointer)
Backtracking line search using the armijo rule.
sLineSearchResult search(const std::function< double(double)> &projected_objective, double gTs, double f_init) const
double get_initial_step() const
void set_eta(double e)
sets the eta parameter. Throws std::invalid_argument if e is not in (0, 1)
void set_alpha(double a)
sets the alpha parameter. Throws std::invalid_argument if a is not in (0, 1)
void set_max_steps(long n)
sets the eta parameter. Throws std::invalid_argument if n is not positive
long get_max_steps() const
void set_initial_step(double s)
sets the initial step multiplied. Throws std::invalid_argument if s is not positive.
TEST_CASE("test_get_set")
Result of a Line Search operation.