DiSMEC++
line_search.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "line_search.h"
7 
8 #include "utils/hash_vector.h"
9 #include "spdlog/spdlog.h"
10 
11 using namespace dismec::solvers;
12 
18 }
19 
21  const std::function<double(double)>& projected_objective, double gTs, double f_init) const
22 {
23  double step = m_StepSize;
24 
25  double f_old = f_init;
26  double f_new = f_init;
27 
28  for(int num_linesearch=0; num_linesearch < m_MaxSteps; ++num_linesearch)
29  {
30  f_new = projected_objective(step);
31  if (f_new - f_old <= m_Eta * step * gTs) {
32  return {f_new, step, num_linesearch};
33  }
34  step *= m_Alpha;
35  }
36 
37  spdlog::warn("Line search failed. Final step size: {:.3}, df = {:.3}",
38  step, f_old - f_new);
39  return {f_old, 0.0, (int)m_MaxSteps};
40 }
41 
43  if(s <= 0) {
44  throw std::invalid_argument("step size must be positive");
45  }
46  m_StepSize = s;
47 }
48 
50  if(n <= 0) {
51  throw std::invalid_argument("max num steps must be positive");
52  }
53  m_MaxSteps = n;
54 }
55 
57  if(a <= 0) {
58  throw std::invalid_argument("alpha must be positive");
59  }
60  if (a >= 1) {
61  throw std::invalid_argument("alpha must be less than 1");
62  }
63  m_Alpha = a;
64 }
65 
67  if(e <= 0) {
68  throw std::invalid_argument("eta must be positive");
69  }
70  if (e >= 1) {
71  throw std::invalid_argument("eta must be less than 1");
72  }
73  m_Eta = e;
74 }
75 
76 #include "doctest.h"
77 
78 TEST_CASE("test_get_set") {
79  BacktrackingLineSearch searcher{};
80  searcher.set_alpha(0.4);
81  CHECK(searcher.get_alpha() == 0.4);
82  searcher.set_initial_step(1.8);
83  CHECK(searcher.get_initial_step() == 1.8);
84  searcher.set_max_steps(5);
85  CHECK(searcher.get_max_steps() == 5);
86  searcher.set_eta(0.8);
87  CHECK(searcher.get_eta() == 0.8);
88 
89  CHECK_THROWS(searcher.set_alpha(-0.1));
90  CHECK_THROWS(searcher.set_alpha(0.0));
91  CHECK_THROWS(searcher.set_alpha(1.0));
92  CHECK_THROWS(searcher.set_eta(0.0));
93  CHECK_THROWS(searcher.set_eta(-.1));
94  CHECK_THROWS(searcher.set_eta(1.0));
95  CHECK_THROWS(searcher.set_initial_step(0.0));
96  CHECK_THROWS(searcher.set_initial_step(-.1));
97  CHECK_THROWS(searcher.set_max_steps(0));
98  CHECK_THROWS(searcher.set_max_steps(-1));
99 }
100 
101 
102 TEST_CASE("backtracking line search") {
103 
104  auto quad_objective = [](double x_0, double d) {
105  auto fun = [=](double a) {
106  a *= d;
107  return (a + x_0) * (a + x_0);
108  };
109 
110  return fun;
111  };
112 
113  BacktrackingLineSearch searcher{};
114 
115  // looking in the wrong direction. Search fails
116  SUBCASE("x^2 wrong direction") {
117  auto objective = quad_objective(1.0, 1.0);
118  auto result = searcher.search(objective, 2.0, objective(0));
119  CHECK(result.StepSize == 0.0);
120  CHECK(result.Value == 1.0);
121  }
122 
123  // looking in the right direction -- first step fulfills our condition
124  // note that inverting the direction also means inverting the gradient g := <df, s>
125  SUBCASE("x^2 right direction") {
126  auto objective = quad_objective(1.0, -1.0);
127  auto result = searcher.search(objective, -2.0, objective(0));
128  CHECK(result.StepSize == 1.0);
129  CHECK(result.Value == 0.0);
130  }
131 
132 
133  // looking in the right direction -- we have to backtrack until we reach the minimum
134  SUBCASE("x^2 right direction too large") {
135  auto objective = quad_objective(1.0, -8.0);
136  auto result = searcher.search(objective, -2.0*8, objective(0));
137  CHECK(result.StepSize == 1.0/8.0);
138  CHECK(result.Value == 0.0);
139  }
140 }
void declare_hyper_parameter(std::string name, U S::*pointer)
Definition: hyperparams.h:117
Backtracking line search using the armijo rule.
Definition: line_search.h:35
sLineSearchResult search(const std::function< double(double)> &projected_objective, double gTs, double f_init) const
Definition: line_search.cpp:20
void set_eta(double e)
sets the eta parameter. Throws std::invalid_argument if e is not in (0, 1)
Definition: line_search.cpp:66
void set_alpha(double a)
sets the alpha parameter. Throws std::invalid_argument if a is not in (0, 1)
Definition: line_search.cpp:56
void set_max_steps(long n)
sets the eta parameter. Throws std::invalid_argument if n is not positive
Definition: line_search.cpp:49
void set_initial_step(double s)
sets the initial step multiplied. Throws std::invalid_argument if s is not positive.
Definition: line_search.cpp:42
TEST_CASE("test_get_set")
Definition: line_search.cpp:78
Result of a Line Search operation.
Definition: line_search.h:17