DiSMEC++
hyperparams.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "hyperparams.h"
7 
8 using namespace dismec;
9 
10 [[nodiscard]] auto HyperParameterBase::get_hyper_parameter(const std::string& name) const -> hyper_param_t {
11  const hyper_param_ptr_t& hp = m_HyperParameters.at(name);
12  hyper_param_t result;
13  std::visit([&](auto&& current) -> void {
14  result = current.Getter(this);
15  }, hp);
16  return result;
17 }
18 
19 void HyperParameterBase::set_hyper_parameter(const std::string& name, long value) {
20  hyper_param_ptr_t& hp = m_HyperParameters.at(name);
21  std::get<HyperParamData<long>>(hp).Setter(this, value);
22 }
23 
24 void HyperParameterBase::set_hyper_parameter(const std::string& name, double value) {
25  hyper_param_ptr_t& hp = m_HyperParameters.at(name);
26  std::get<HyperParamData<double>>(hp).Setter(this, value);
27 }
28 
29 std::vector<std::string> HyperParameterBase::get_hyper_parameter_names() const {
30  std::vector<std::string> result{};
31  result.reserve(m_HyperParameters.size());
32  for(const auto& hp : m_HyperParameters) {
33  result.push_back(hp.first);
34  }
35  return result;
36 }
37 
38 void HyperParameters::set(const std::string &name, long value) {
39  m_Values[name] = value;
40 }
41 
42 void HyperParameters::set(const std::string &name, double value) {
43  m_Values[name] = value;
44 }
45 
46 auto HyperParameters::get(const std::string& name) const -> hyper_param_t {
47  return m_Values.at(name);
48 }
49 
51  for(const auto& hp : m_Values) {
52  std::visit([&](auto&& value) {
53  target.set_hyper_parameter(hp.first, value);
54  }, hp.second);
55  }
56 }
57 
58 
59 #include "doctest.h"
60 
61 namespace {
64  struct TestObject : public HyperParameterBase {
65  double direct_hp = 0;
66  long indirect_hp = 0;
67  void set_b(long v) { indirect_hp = v; }
68  long get_b() const { return indirect_hp; }
69 
71  declare_hyper_parameter("a", &TestObject::direct_hp);
72  declare_hyper_parameter("b", &TestObject::get_b, &TestObject::set_b);
73  }
74  };
75 
80  declare_sub_object("so", &NestedTestObject::sub);
81  }
82 
84  };
85 }
86 
92 TEST_CASE("HyperParameterBase") {
93  TestObject object;
94  SUBCASE("get and set") {
95  object.set_hyper_parameter("a", 1.0);
96  CHECK(object.direct_hp == 1.0);
97  CHECK(std::get<double>(object.get_hyper_parameter("a")) == 1.0);
98 
99  object.set_hyper_parameter("b", 5l);
100  CHECK(object.indirect_hp == 5);
101  CHECK(std::get<long>(object.get_hyper_parameter("b")) == 5);
102  }
103 
104  SUBCASE("type mismatch") {
105  CHECK_THROWS(object.set_hyper_parameter("a", 3l));
106  CHECK_THROWS(object.set_hyper_parameter("b", 3.5));
107  }
108 
109  SUBCASE("name mismatch") {
110  CHECK_THROWS(object.set_hyper_parameter("wrong", 5l));
111  CHECK_THROWS(object.set_hyper_parameter("wrong", 2.0));
112  CHECK_THROWS(object.get_hyper_parameter("wrong"));
113  }
114 
115  SUBCASE("list hps") {
116  auto hp_names = object.get_hyper_parameter_names();
117  REQUIRE(hp_names.size() == 2);
118  CHECK(((hp_names[0] == "a" && hp_names[1] == "b") || (hp_names[0] == "b" && hp_names[1] == "a")));
119  }
120 }
121 
126 TEST_CASE("nested hyper parameter object") {
127  NestedTestObject object;
128  CHECK_THROWS(object.set_hyper_parameter("a", 1.0));
129 
130  object.set_hyper_parameter("so.a", 1.0);
131  CHECK_THROWS(object.set_hyper_parameter("so.a", 5l));
132  CHECK(object.sub.direct_hp == 1.0);
133  CHECK(std::get<double>(object.get_hyper_parameter("so.a")) == 1.0);
134 
135  object.set_hyper_parameter("so.b", 5l);
136  CHECK_THROWS(object.set_hyper_parameter("so.b", 1.0));
137  CHECK(object.sub.indirect_hp == 5);
138  CHECK(std::get<long>(object.get_hyper_parameter("so.b")) == 5);
139 }
140 
145 TEST_CASE("wrong subtype causes error") {
146  struct InvalidRegistration : public HyperParameterBase {
147  InvalidRegistration() {
148  declare_hyper_parameter("a", &TestObject::direct_hp);
149  // ^-- this is wrong, so an error should be thrown.
150  }
151  double direct_hp;
152  };
153  CHECK_THROWS(InvalidRegistration{});
154 }
155 
162 TEST_CASE("HyperParameters") {
163  HyperParameters hps;
164  // getting unknown parameter throws
165  CHECK_THROWS(hps.get("test"));
166 
167  // setting and getting round-trip
168  hps.set("b", 10l);
169  CHECK(std::get<long>(hps.get("b")) == 10);
170 
171  // applying valid hyper-parameters to target. Works even for partial HPs
172  TestObject target;
173  hps.apply(target);
174  CHECK(target.indirect_hp == 10);
175 
176  // applying invalid type hps
177  hps.set("a", 10l);
178  CHECK_THROWS(hps.apply(target));
179 
180  // updating type in HP collection works
181  hps.set("a", 0.5);
182  hps.apply(target);
183  CHECK(target.direct_hp == 0.5);
184 
185  // applying breaks with additional hps
186  hps.set("c", 0.5);
187  CHECK_THROWS(hps.apply(target));
188 }
189 
194 TEST_CASE("HyperParameters nested") {
195  HyperParameters hps;
196  hps.set("so.a", 1.0);
197  hps.set("so.b", 5l);
198  NestedTestObject object;
199  hps.apply(object);
200  CHECK(object.sub.direct_hp == 1.0);
201  CHECK(object.sub.indirect_hp == 5);
202 }
Base class for all objects that have adjustable hyper-parameters.
Definition: hyperparams.h:83
std::unordered_map< std::string, hyper_param_ptr_t > m_HyperParameters
Definition: hyperparams.h:234
std::variant< HyperParamData< long >, HyperParamData< double > > hyper_param_ptr_t
Definition: hyperparams.h:233
hyper_param_t get_hyper_parameter(const std::string &name) const
Definition: hyperparams.cpp:10
void set_hyper_parameter(const std::string &name, long value)
Definition: hyperparams.cpp:19
std::variant< long, double > hyper_param_t
Definition: hyperparams.h:85
std::vector< std::string > get_hyper_parameter_names() const
Returns a vector that lists all hyper parameter names.
Definition: hyperparams.cpp:29
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
std::unordered_map< std::string, hyper_param_t > m_Values
Definition: hyperparams.h:257
void set(const std::string &name, long value)
Sets a hyper-parameter with the given name and value.
Definition: hyperparams.cpp:38
hyper_param_t get(const std::string &name) const
Gets the hyper-parameter with the given name, or throws if it does not exist.
Definition: hyperparams.cpp:46
void apply(HyperParameterBase &target) const
Definition: hyperparams.cpp:50
HyperParameterBase::hyper_param_t hyper_param_t
Definition: hyperparams.h:243
TEST_CASE("HyperParameterBase")
Definition: hyperparams.cpp:92
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15