DiSMEC++
line_search.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_LINE_SEARCH_H
7 #define DISMEC_LINE_SEARCH_H
8 
9 #include <functional>
10 #include "utils/hyperparams.h"
11 
12 namespace dismec::solvers
13 {
18  double Value;
19  double StepSize;
20  int NumIters;
21  };
22 
36  public:
38 
39  // get and set the parameters
40  [[nodiscard]] double get_initial_step() const { return m_StepSize; }
42  void set_initial_step(double s);
43 
44  [[nodiscard]] double get_alpha() const { return m_Alpha; }
46  void set_alpha(double a);
47 
48  [[nodiscard]] double get_eta() const { return m_Eta; }
50  void set_eta(double e);
51 
52  [[nodiscard]] long get_max_steps() const { return m_MaxSteps; }
54  void set_max_steps(long n);
55 
62  sLineSearchResult search(const std::function<double(double)>& projected_objective, double gTs, double f_init) const;
63 
64  private:
65  double m_StepSize = 1.0;
66 
67  // scale factor for the step size
68  double m_Alpha = 0.5;
69  // required reduction
70  double m_Eta = 0.01;
71  // maximum number of steps to perform
72  long m_MaxSteps = 20;
73  };
74 }
75 #endif //DISMEC_LINE_SEARCH_H
Base class for all objects that have adjustable hyper-parameters.
Definition: hyperparams.h:83
Backtracking line search using the armijo rule.
Definition: line_search.h:35
sLineSearchResult search(const std::function< double(double)> &projected_objective, double gTs, double f_init) const
Definition: line_search.cpp:20
void set_eta(double e)
sets the eta parameter. Throws std::invalid_argument if e is not in (0, 1)
Definition: line_search.cpp:66
void set_alpha(double a)
sets the alpha parameter. Throws std::invalid_argument if a is not in (0, 1)
Definition: line_search.cpp:56
void set_max_steps(long n)
sets the eta parameter. Throws std::invalid_argument if n is not positive
Definition: line_search.cpp:49
void set_initial_step(double s)
sets the initial step multiplied. Throws std::invalid_argument if s is not positive.
Definition: line_search.cpp:42
Result of a Line Search operation.
Definition: line_search.h:17
double Value
The function value at the accepted position.
Definition: line_search.h:18
double StepSize
The step size used to reach that position.
Definition: line_search.h:19