DiSMEC++
test.cpp
Go to the documentation of this file.
1
// Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2
// All rights reserved.
3
//
4
// SPDX-License-Identifier: MIT
5
6
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
7
#include "doctest.h"
8
/*
9
#include "solver/minimizer.h"
10
#include "solver/cg.h"
11
#include "hash_vector.h"
12
#include "objectives/l2_reg_sq_hinge.h"
13
14
15
TEST_CASE("eurlex at w=0") {
16
auto problem = read_liblinear_dataset("test_data/liblin_eurlex_train.txt");
17
auto objective = Regularized_SquaredHingeSVC(problem.get_features(), problem.get_labels(0));
18
19
Eigen::VectorXd w = Eigen::VectorXd::Zero(objective.get_num_variables());
20
Eigen::VectorXd reference_gradient = Eigen::VectorXd::Zero(objective.get_num_variables());
21
Eigen::VectorXd reference_m = Eigen::VectorXd::Zero(objective.get_num_variables());
22
{
23
std::fstream file_g("test_data/g.txt", std::fstream::in);
24
REQUIRE(file_g.is_open());
25
for (int i = 0; i < objective.get_num_variables(); ++i) {
26
file_g >> reference_gradient.coeffRef(i);
27
}
28
REQUIRE(file_g.good());
29
30
std::fstream file_m("test_data/m.txt", std::fstream::in);
31
REQUIRE(file_m.is_open());
32
for (int i = 0; i < objective.get_num_variables(); ++i) {
33
file_m >> reference_m.coeffRef(i);
34
}
35
REQUIRE(file_m.good());
36
37
std::fstream weight_file("test_data/w.txt", std::fstream::in);
38
REQUIRE(weight_file.is_open());
39
for (int i = 0; i < objective.get_num_variables(); ++i) {
40
weight_file >> w.coeffRef(i);
41
}
42
REQUIRE(weight_file.good());
43
}
44
45
SUBCASE("gradient") {
46
auto g = objective.gradient(HashVector{w});
47
for (int i = 0; i < objective.get_num_variables(); ++i) {
48
CHECK(reference_gradient.coeff(i) == doctest::Approx(g[i]).epsilon(1e-10));
49
}
50
}
51
52
SUBCASE("hessian times direction") {
53
Eigen::VectorXd d(objective.get_num_variables());
54
std::fstream file_d("test_data/d.txt", std::fstream::in);
55
REQUIRE(file_d.is_open());
56
for (int i = 0; i < objective.get_num_variables(); ++i) {
57
file_d >> d.coeffRef(i);
58
}
59
REQUIRE(file_d.good());
60
61
auto Hd = objective.hessian_times_direction(HashVector{w}, d);
62
63
std::cout << "d " << d[0] << "\n";
64
65
std::fstream reference("test_data/Hd.txt", std::fstream::in);
66
for (int i = 0; i < objective.get_num_variables(); ++i) {
67
double ground_truth;
68
reference >> ground_truth;
69
CHECK(ground_truth == doctest::Approx(Hd[i]).epsilon(1e-14));
70
}
71
}
72
73
SUBCASE("preconditioning") {
74
CGMinimizer cg(objective.get_num_variables());
75
Eigen::VectorXd M = objective.get_diag_preconditioner(HashVector{w});
76
for (int i = 0; i < objective.get_num_variables(); ++i) {
77
CHECK(reference_m.coeff(i) == doctest::Approx(M[i]).epsilon(1e-14));
78
}
79
}
80
81
SUBCASE("cg") {
82
CGMinimizer cg(objective.get_num_variables());
83
// regularize the preconditioner: M = aI + (1-a)M
84
Eigen::VectorXd M = (1 - 0.01) + (reference_m * 0.01).array();
85
86
cg.minimize([&](const Eigen::VectorXd& d, Eigen::Ref<Eigen::VectorXd> o) {
87
o = objective.hessian_times_direction(HashVector{w}, d);
88
}, reference_gradient, M);
89
90
std::fstream reference("test_data/s.txt", std::fstream::in);
91
for (int i = 0; i < objective.get_num_variables(); ++i) {
92
double ground_truth;
93
reference >> ground_truth;
94
REQUIRE(ground_truth == doctest::Approx(cg.get_solution()[i]));
95
}
96
}
97
}
98
99
TEST_CASE("heart_scale") {
100
auto problem = read_liblinear_dataset("test_data/heart_scale");
101
102
auto objective = Regularized_SquaredHingeSVC(problem.get_features(), problem.get_labels(0));
103
auto minimizer = NewtonWithLineSearch();
104
105
Eigen::VectorXd w = Eigen::VectorXd::Zero(problem.num_features());
106
auto res = minimizer.minimize(objective, w);
107
REQUIRE(res.Outcome == MinimizerStatus::SUCCESS);
108
109
110
std::array<double, 13> liblinear_results = {
111
0.10478334183618995,
112
0.2310677614517834,
113
0.42808388718499868,
114
0.26794147077324199,
115
-0.01253218890573427,
116
-0.16433603637371605,
117
0.12597353419494489,
118
-0.26426183971657191,
119
0.12586162713774524,
120
0.070475901031940236,
121
0.16194342067257222,
122
0.44104006346498714,
123
0.25983806499730233,
124
};
125
126
for(int i = 0; i < 13; ++i) {
127
CHECK(w[i] == doctest::Approx(liblinear_results[i]));
128
}
129
}
130
*/
src
test.cpp
Generated by
1.9.1