DiSMEC++
dense_test.cpp
Go to the documentation of this file.
1 // Copyright (c) 2022, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "utils/macros.h"
7 // GCC 11 emits `maybe-uninitialized` from inside Eigen with one of the test cases.
8 DIAGNOSTIC_IGNORE_GCC("-Wmaybe-uninitialized")
9 
10 #include "doctest.h"
11 #include "dense.h"
12 #include "utils/eigen_generic.h"
13 
14 
15 using namespace dismec;
16 using namespace dismec::model;
17 
22 TEST_CASE("get dense weights errors")
23 {
24  auto test_mat = std::make_shared<DenseModel::WeightMatrix>(DenseModel::WeightMatrix::Zero(4, 3));
25  DenseModel model{test_mat};
26 
27  DenseRealVector target(10);
28  CHECK_THROWS(model.get_weights_for_label(label_id_t{0}, target));
29 
30  // OK, is the size matches we're good to go.
31  target = DenseRealVector::Ones(4);
32  REQUIRE_NOTHROW(model.get_weights_for_label(label_id_t{0}, target));
33 
34  // check error on wrong label
35  CHECK_THROWS(model.get_weights_for_label(label_id_t{-1}, target));
36  CHECK_THROWS(model.get_weights_for_label(label_id_t{3}, target));
37 }
38 
43 TEST_CASE("set dense weights errors") {
44  auto test_mat = std::make_shared<DenseModel::WeightMatrix>(DenseModel::WeightMatrix::Zero(4, 3));
45  DenseModel model{test_mat};
46 
47  DenseRealVector source(10);
48  CHECK_THROWS(model.set_weights_for_label(label_id_t{0}, Model::WeightVectorIn{source}));
49 
50  // OK, is the size matches we're good to go.
51  source = DenseRealVector::Ones(4);
52  REQUIRE_NOTHROW(model.set_weights_for_label(label_id_t{0}, Model::WeightVectorIn{source}));
53 
54  // check error on wrong label
55  CHECK_THROWS(model.set_weights_for_label(label_id_t{-1}, Model::WeightVectorIn{source}));
56  CHECK_THROWS(model.set_weights_for_label(label_id_t{3}, Model::WeightVectorIn{source}));
57 }
58 
61 TEST_CASE("get/set dense weights round-trip") {
62  auto test_mat = std::make_shared<DenseModel::WeightMatrix>(DenseModel::WeightMatrix::Zero(4, 3));
63  DenseModel model{test_mat};
64 
65  DenseRealVector source = DenseRealVector::Ones(4);
66  source.coeffRef(2) = 2.0;
67  model.set_weights_for_label(label_id_t{1}, DenseModel::WeightVectorIn{source});
68 
69  DenseRealVector target(4);
70  model.get_weights_for_label(label_id_t{1}, target);
71 
72  for(int i = 0; i < 4; ++i) {
73  CHECK(source[i] == target[i]);
74  }
75 }
76 
79 TEST_CASE("predict_scores checks") {
80  DenseModel model{4, 3};
81 
82  PredictionMatrix t1(3, 7);
83  PredictionMatrix t2(3, 6);
84  PredictionMatrix t3(4, 6);
85 
86  CHECK_THROWS(model.predict_scores(GenericInMatrix::DenseRowMajorRef(Eigen::MatrixXf(4, 6)), t1)); // mismatched rows
87  CHECK_THROWS(model.predict_scores(GenericInMatrix::DenseColMajorRef(Eigen::MatrixXf(3, 6)), t2)); // wrong number of features
88  CHECK_THROWS(model.predict_scores(GenericInMatrix::DenseColMajorRef(Eigen::MatrixXf(4, 6)), t3)); // wrong number of labels
89 }
90 
94 TEST_CASE("partial model") {
95  DenseModel full(4, 5);
96  CHECK_FALSE(full.is_partial_model());
97  CHECK(full.labels_begin() == label_id_t{0});
98  CHECK(full.labels_end() == label_id_t{5});
99  CHECK(full.num_labels() == 5);
100  CHECK(full.num_weights() == 5);
101 
102  DenseModel partial(4, PartialModelSpec{label_id_t{1}, 3, 5});
103  CHECK(partial.is_partial_model());
104  CHECK(partial.labels_begin() == label_id_t{1});
105  CHECK(partial.labels_end() == label_id_t{4});
106  CHECK(partial.num_labels() == 5);
107  CHECK(partial.num_weights() == 3);
108 }
109 
113 TEST_CASE("DenseModel ctor consistency") {
114  auto build = [](long first_label, long label_count, long total_labels) {
115  DenseModel model(4, PartialModelSpec{label_id_t{first_label}, label_count, total_labels});
116  };
117  // label range exceeding total number of labels
118  SUBCASE("invalid range") {
119  CHECK_THROWS(build(4, 3, 5));
120  CHECK_THROWS(build(1, 5, 5));
121  }
122 
123  // negative numbers or zero are invalid
124  SUBCASE("non positive") {
125  // first label can be positive
126  CHECK_THROWS(build(-1, 5, 5));
127  CHECK_THROWS(build(3, 0, 5));
128  CHECK_THROWS(build(3, -1, 5));
129  CHECK_THROWS(build(3, 2, 0));
130  CHECK_THROWS(build(3, 2, -1));
131 
132  CHECK_THROWS(DenseModel(0, 5));
133  CHECK_THROWS(DenseModel(-1, 5));
134  }
135 
136  SUBCASE("data mismatch") {
137  auto matrix = std::make_shared<DenseModel::WeightMatrix>(4, 3);
138  // claim four labels, but matrix only has 3
139  CHECK_THROWS(DenseModel(matrix, PartialModelSpec{label_id_t{0}, 4, 5}));
140  }
141 }
Strong typedef for an int to signify a label id.
Definition: types.h:20
Implementation of the Model class that stores the weights as a single, dense matrix.
Definition: dense.h:17
label_id_t labels_end() const noexcept
Definition: model.h:102
long num_labels() const noexcept
How many labels are in the underlying dataset.
Definition: model.h:78
bool is_partial_model() const
returns true if this instance only stores part of the weights of an entire model
Definition: model.cpp:37
long num_weights() const noexcept
How many weights vectors are in this model.
Definition: model.h:87
GenericInVector WeightVectorIn
Definition: model.h:67
label_id_t labels_begin() const noexcept
Definition: model.h:98
Eigen::Ref< DenseColMajor< T > > DenseColMajorRef
Eigen::Ref< DenseRowMajor< T > > DenseRowMajorRef
TEST_CASE("get dense weights errors")
Definition: dense_test.cpp:22
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
types::DenseRowMajor< real_t > PredictionMatrix
Dense matrix in Row Major format used for predictions.
Definition: matrix_types.h:75
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22