DiSMEC++
test.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
7 #include "doctest.h"
8 /*
9 #include "solver/minimizer.h"
10 #include "solver/cg.h"
11 #include "hash_vector.h"
12 #include "objectives/l2_reg_sq_hinge.h"
13 
14 
15 TEST_CASE("eurlex at w=0") {
16  auto problem = read_liblinear_dataset("test_data/liblin_eurlex_train.txt");
17  auto objective = Regularized_SquaredHingeSVC(problem.get_features(), problem.get_labels(0));
18 
19  Eigen::VectorXd w = Eigen::VectorXd::Zero(objective.get_num_variables());
20  Eigen::VectorXd reference_gradient = Eigen::VectorXd::Zero(objective.get_num_variables());
21  Eigen::VectorXd reference_m = Eigen::VectorXd::Zero(objective.get_num_variables());
22  {
23  std::fstream file_g("test_data/g.txt", std::fstream::in);
24  REQUIRE(file_g.is_open());
25  for (int i = 0; i < objective.get_num_variables(); ++i) {
26  file_g >> reference_gradient.coeffRef(i);
27  }
28  REQUIRE(file_g.good());
29 
30  std::fstream file_m("test_data/m.txt", std::fstream::in);
31  REQUIRE(file_m.is_open());
32  for (int i = 0; i < objective.get_num_variables(); ++i) {
33  file_m >> reference_m.coeffRef(i);
34  }
35  REQUIRE(file_m.good());
36 
37  std::fstream weight_file("test_data/w.txt", std::fstream::in);
38  REQUIRE(weight_file.is_open());
39  for (int i = 0; i < objective.get_num_variables(); ++i) {
40  weight_file >> w.coeffRef(i);
41  }
42  REQUIRE(weight_file.good());
43  }
44 
45  SUBCASE("gradient") {
46  auto g = objective.gradient(HashVector{w});
47  for (int i = 0; i < objective.get_num_variables(); ++i) {
48  CHECK(reference_gradient.coeff(i) == doctest::Approx(g[i]).epsilon(1e-10));
49  }
50  }
51 
52  SUBCASE("hessian times direction") {
53  Eigen::VectorXd d(objective.get_num_variables());
54  std::fstream file_d("test_data/d.txt", std::fstream::in);
55  REQUIRE(file_d.is_open());
56  for (int i = 0; i < objective.get_num_variables(); ++i) {
57  file_d >> d.coeffRef(i);
58  }
59  REQUIRE(file_d.good());
60 
61  auto Hd = objective.hessian_times_direction(HashVector{w}, d);
62 
63  std::cout << "d " << d[0] << "\n";
64 
65  std::fstream reference("test_data/Hd.txt", std::fstream::in);
66  for (int i = 0; i < objective.get_num_variables(); ++i) {
67  double ground_truth;
68  reference >> ground_truth;
69  CHECK(ground_truth == doctest::Approx(Hd[i]).epsilon(1e-14));
70  }
71  }
72 
73  SUBCASE("preconditioning") {
74  CGMinimizer cg(objective.get_num_variables());
75  Eigen::VectorXd M = objective.get_diag_preconditioner(HashVector{w});
76  for (int i = 0; i < objective.get_num_variables(); ++i) {
77  CHECK(reference_m.coeff(i) == doctest::Approx(M[i]).epsilon(1e-14));
78  }
79  }
80 
81  SUBCASE("cg") {
82  CGMinimizer cg(objective.get_num_variables());
83  // regularize the preconditioner: M = aI + (1-a)M
84  Eigen::VectorXd M = (1 - 0.01) + (reference_m * 0.01).array();
85 
86  cg.minimize([&](const Eigen::VectorXd& d, Eigen::Ref<Eigen::VectorXd> o) {
87  o = objective.hessian_times_direction(HashVector{w}, d);
88  }, reference_gradient, M);
89 
90  std::fstream reference("test_data/s.txt", std::fstream::in);
91  for (int i = 0; i < objective.get_num_variables(); ++i) {
92  double ground_truth;
93  reference >> ground_truth;
94  REQUIRE(ground_truth == doctest::Approx(cg.get_solution()[i]));
95  }
96  }
97 }
98 
99 TEST_CASE("heart_scale") {
100  auto problem = read_liblinear_dataset("test_data/heart_scale");
101 
102  auto objective = Regularized_SquaredHingeSVC(problem.get_features(), problem.get_labels(0));
103  auto minimizer = NewtonWithLineSearch();
104 
105  Eigen::VectorXd w = Eigen::VectorXd::Zero(problem.num_features());
106  auto res = minimizer.minimize(objective, w);
107  REQUIRE(res.Outcome == MinimizerStatus::SUCCESS);
108 
109 
110  std::array<double, 13> liblinear_results = {
111  0.10478334183618995,
112  0.2310677614517834,
113  0.42808388718499868,
114  0.26794147077324199,
115  -0.01253218890573427,
116  -0.16433603637371605,
117  0.12597353419494489,
118  -0.26426183971657191,
119  0.12586162713774524,
120  0.070475901031940236,
121  0.16194342067257222,
122  0.44104006346498714,
123  0.25983806499730233,
124  };
125 
126  for(int i = 0; i < 13; ++i) {
127  CHECK(w[i] == doctest::Approx(liblinear_results[i]));
128  }
129 }
130 */