14 result = current.Getter(
this);
21 std::get<HyperParamData<long>>(hp).Setter(
this, value);
26 std::get<HyperParamData<double>>(hp).Setter(
this, value);
30 std::vector<std::string> result{};
33 result.push_back(hp.first);
47 return m_Values.at(name);
67 void set_b(
long v) { indirect_hp = v; }
68 long get_b()
const {
return indirect_hp; }
71 declare_hyper_parameter(
"a", &TestObject::direct_hp);
72 declare_hyper_parameter(
"b", &TestObject::get_b, &TestObject::set_b);
80 declare_sub_object(
"so", &NestedTestObject::sub);
94 SUBCASE(
"get and set") {
95 object.set_hyper_parameter(
"a", 1.0);
96 CHECK(
object.direct_hp == 1.0);
97 CHECK(std::get<double>(
object.get_hyper_parameter(
"a")) == 1.0);
99 object.set_hyper_parameter(
"b", 5l);
100 CHECK(
object.indirect_hp == 5);
101 CHECK(std::get<long>(
object.get_hyper_parameter(
"b")) == 5);
104 SUBCASE(
"type mismatch") {
105 CHECK_THROWS(
object.set_hyper_parameter(
"a", 3l));
106 CHECK_THROWS(
object.set_hyper_parameter(
"b", 3.5));
109 SUBCASE(
"name mismatch") {
110 CHECK_THROWS(
object.set_hyper_parameter(
"wrong", 5l));
111 CHECK_THROWS(
object.set_hyper_parameter(
"wrong", 2.0));
112 CHECK_THROWS(
object.get_hyper_parameter(
"wrong"));
115 SUBCASE(
"list hps") {
116 auto hp_names =
object.get_hyper_parameter_names();
117 REQUIRE(hp_names.size() == 2);
118 CHECK(((hp_names[0] ==
"a" && hp_names[1] ==
"b") || (hp_names[0] ==
"b" && hp_names[1] ==
"a")));
127 NestedTestObject object;
128 CHECK_THROWS(
object.set_hyper_parameter(
"a", 1.0));
130 object.set_hyper_parameter(
"so.a", 1.0);
131 CHECK_THROWS(
object.set_hyper_parameter(
"so.a", 5l));
132 CHECK(
object.sub.direct_hp == 1.0);
133 CHECK(std::get<double>(
object.get_hyper_parameter(
"so.a")) == 1.0);
135 object.set_hyper_parameter(
"so.b", 5l);
136 CHECK_THROWS(
object.set_hyper_parameter(
"so.b", 1.0));
137 CHECK(
object.sub.indirect_hp == 5);
138 CHECK(std::get<long>(
object.get_hyper_parameter(
"so.b")) == 5);
147 InvalidRegistration() {
148 declare_hyper_parameter(
"a", &TestObject::direct_hp);
153 CHECK_THROWS(InvalidRegistration{});
165 CHECK_THROWS(hps.
get(
"test"));
169 CHECK(std::get<long>(hps.
get(
"b")) == 10);
174 CHECK(target.indirect_hp == 10);
178 CHECK_THROWS(hps.
apply(target));
183 CHECK(target.direct_hp == 0.5);
187 CHECK_THROWS(hps.
apply(target));
196 hps.
set(
"so.a", 1.0);
198 NestedTestObject object;
200 CHECK(
object.sub.direct_hp == 1.0);
201 CHECK(
object.sub.indirect_hp == 5);
Base class for all objects that have adjustable hyper-parameters.
std::unordered_map< std::string, hyper_param_ptr_t > m_HyperParameters
std::variant< HyperParamData< long >, HyperParamData< double > > hyper_param_ptr_t
hyper_param_t get_hyper_parameter(const std::string &name) const
void set_hyper_parameter(const std::string &name, long value)
std::variant< long, double > hyper_param_t
std::vector< std::string > get_hyper_parameter_names() const
Returns a vector that lists all hyper parameter names.
This class represents a set of hyper-parameters.
std::unordered_map< std::string, hyper_param_t > m_Values
void set(const std::string &name, long value)
Sets a hyper-parameter with the given name and value.
hyper_param_t get(const std::string &name) const
Gets the hyper-parameter with the given name, or throws if it does not exist.
void apply(HyperParameterBase &target) const
HyperParameterBase::hyper_param_t hyper_param_t
TEST_CASE("HyperParameterBase")
auto visit(F &&f, Variants &&... variants)
Main namespace in which all types, classes, and functions are defined.