DiSMEC++
cg.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "cg.h"
7 #include <spdlog/spdlog.h>
8 
9 using namespace dismec;
11 
12 CGMinimizer::CGMinimizer(long num_vars) : m_Size(num_vars) {
13  m_A_times_d = DenseRealVector(num_vars);
14  m_S = DenseRealVector(num_vars);
15  m_Residual = DenseRealVector(num_vars);
16  m_Conjugate = DenseRealVector(num_vars);
17 
19 }
20 
22  long result = do_minimize(A, b, M);
23  return result;
24 }
25 
27  // in comments, we use z to denote Residual/M
28  m_S.setZero(); // assume x_0 = 0
29  m_Residual = -b; // note: We are solving Ax+b = 0, typically CG is used for Ax = b.
30  m_Conjugate = m_Residual.array() / M.array();
31 
32  real_t Q = 0;
33 
34  auto zT_dot_r = m_Conjugate.dot(m_Residual); // at this point: m_Conjugate == z
35  real_t gMinv_norm = std::sqrt(zT_dot_r); // = sqrt(-b^T / M b)
36  real_t cg_tol = std::min(m_Epsilon, std::sqrt(gMinv_norm));
37 
38  long max_cg_iter = std::max(m_Size, CG_MIN_ITER_BOUND);
39  for(long cg_iter = 1; cg_iter <= max_cg_iter; ++cg_iter) {
41  real_t dAd = m_Conjugate.dot(m_A_times_d);
42  if(dAd < 1e-16) {
43  return cg_iter;
44  }
45 
46  real_t alpha = zT_dot_r / dAd;
47  m_S += alpha * m_Conjugate;
48  m_Residual -= alpha * m_A_times_d;
49 
50  // Using quadratic approximation as CG stopping criterion
51  real_t newQ = -real_t{0.5}*(m_S.dot(m_Residual - b));
52  real_t Qdiff = newQ - Q;
53  if (newQ <= 0 && Qdiff <= 0)
54  {
55  if (cg_iter * Qdiff >= cg_tol * newQ) {
56  return cg_iter; // success
57  }
58  }
59  else
60  {
61  spdlog::warn("quadratic approximation > 0 or increasing in {}th CG iteration. Old Q: {}, New Q: {}",
62  cg_iter, Q, newQ);
63  return cg_iter;
64  }
65  Q = newQ;
66 
67  // z == r.array() / M.array()
68  real_t znewTrnew = (m_Residual.array() / M.array()).matrix().dot(m_Residual);
69  real_t beta = znewTrnew / zT_dot_r;
70  m_Conjugate = m_Conjugate * beta + (m_Residual.array() / M.array()).matrix();
71  zT_dot_r = znewTrnew;
72  }
73 
74  spdlog::warn("reached maximum number of CG steps ({}). Remaining error is {}", max_cg_iter, Q);
75 
76  return max_cg_iter;
77 }
78 
79 #include "doctest.h"
80 
81 TEST_CASE("conjugate gradient") {
82  const int TEST_SIZE = 5;
83  auto minimizer = CGMinimizer(TEST_SIZE);
84  minimizer.set_epsilon(0.001);
85  types::DenseColMajor<real_t> A = types::DenseColMajor<real_t>::Random(TEST_SIZE, TEST_SIZE);
86  A = (A*A.transpose()).eval(); // ensure symmetric, PSD matrix
87  DenseRealVector b = DenseRealVector::Random(TEST_SIZE);
88  DenseRealVector m = DenseRealVector::Ones(TEST_SIZE);
89 
90  minimizer.minimize([&](const DenseRealVector& d, Eigen::Ref<DenseRealVector> out){
91  out = A * d;
92  }, b, m);
93 
94  DenseRealVector solution = minimizer.get_solution();
95  DenseRealVector sol = A * solution + b;
96  CHECK(sol.norm() == doctest::Approx(0.0));
97 }
TEST_CASE("conjugate gradient")
Definition: cg.cpp:81
void declare_hyper_parameter(std::string name, U S::*pointer)
Definition: hyperparams.h:117
Approximately solve a linear equation Ax + b = 0.
Definition: cg.h:23
DenseRealVector m_A_times_d
Definition: cg.h:50
std::function< void(const DenseRealVector &, Eigen::Ref< DenseRealVector >)> MatrixVectorProductFn
Definition: cg.h:27
DenseRealVector m_Residual
r_k from the CG algorithm
Definition: cg.h:52
long minimize(const MatrixVectorProductFn &A, const DenseRealVector &b, const DenseRealVector &M)
Solves Ax+b=0. returns the number of iterations.
Definition: cg.cpp:21
void set_epsilon(double v)
Sets the value of the tolerance hyperparameter.
Definition: cg.h:41
DenseRealVector m_S
s from the CG algorithm
Definition: cg.h:51
double get_epsilon() const
Gets the value of the tolerance hyperparameter.
Definition: cg.h:38
DenseRealVector m_Conjugate
p_k from the CG algorithm
Definition: cg.h:53
long do_minimize(const MatrixVectorProductFn &A, const DenseRealVector &b, const DenseRealVector &M)
Definition: cg.cpp:26
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
constexpr const long CG_MIN_ITER_BOUND
Definition: config.h:22
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17