DiSMEC++
hyperparams.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_HYPERPARAMS_H
7 #define DISMEC_HYPERPARAMS_H
8 
9 #include <string>
10 #include <variant>
11 #include <unordered_map>
12 #include <type_traits>
13 #include <functional>
14 #include <stdexcept>
15 
16 namespace dismec
17 {
84  public:
85  using hyper_param_t = std::variant<long, double>;
86 
87  HyperParameterBase() = default;
88  virtual ~HyperParameterBase() = default;
93 
96  void set_hyper_parameter(const std::string& name, long value);
97  void set_hyper_parameter(const std::string& name, double value);
98 
101  [[nodiscard]] hyper_param_t get_hyper_parameter(const std::string& name) const;
102 
104  std::vector<std::string> get_hyper_parameter_names() const;
105 
106  protected:
116  template<class U, class S>
117  void declare_hyper_parameter(std::string name, U S::* pointer)
118  {
123  static_assert(std::is_base_of_v<HyperParameterBase, S>, "HyperParameterBase is not base class of member pointer");
124 
125  if(dynamic_cast<S*>(this) == nullptr) {
126  throw std::logic_error("Cannot cast this pointer to class `S`");
127  }
128 
132  auto getter = [pointer](const HyperParameterBase* self) {
133  return dynamic_cast<const S*>(self)->*pointer;
134  };
135  auto setter = [pointer](HyperParameterBase* self, const U& value) {
136  dynamic_cast<S*>(self)->*pointer = value;
137  };
138 
140  declare_hyper_parameter(std::move(name), HyperParamData<U>{setter, getter});
141  }
151  template<class U, class S>
152  void declare_hyper_parameter(std::string name, U(S::*getter)() const, void(S::*setter)(U))
153  {
154  static_assert(std::is_base_of_v<HyperParameterBase, S>, "T is not base class of member pointer");
155 
156  if(dynamic_cast<S*>(this) == nullptr) {
157  throw std::logic_error("Cannot cast this pointer to class `S`");
158  }
159 
160  auto getter_ = [getter](const HyperParameterBase* self) -> std::decay_t<U> {
161  return (dynamic_cast<const S*>(self)->*getter)();
162  };
163  auto setter_ = [setter](HyperParameterBase* self, const U& value) {
164  (dynamic_cast<S*>(self)->*setter)(value);
165  };
166 
167  declare_hyper_parameter(std::move(name), HyperParamData<U>{setter_, getter_});
168  }
169 
178  template<class T, class S>
179  void declare_sub_object(const std::string& name, T S::*object) {
180  static_assert(std::is_base_of_v<HyperParameterBase, S>, "S is not base class of member pointer");
181  static_assert(std::is_base_of_v<HyperParameterBase, T>, "T is not base class of member pointer");
182 
183  if(dynamic_cast<S*>(this) == nullptr) {
184  throw std::logic_error("Cannot cast this pointer to class `S`");
185  }
186 
187  // we need to get one instance of the sub-object, so we can iterate its hyper-parameters
188  HyperParameterBase& sub_object = dynamic_cast<S*>(this)->*object;
189 
190  for(const auto& hp : sub_object.m_HyperParameters)
191  {
192  std::visit([&](const auto& hp_data)
193  {
194  // get the type that is used in the inner object
195  using value_t = typename std::decay_t<decltype(hp_data)>::ValueType;
196  // make getter and setter functions that look up the actual inner objects in the
197  // submitted `self`, and then call the inner getters and setters
198  auto getter_ = [object, inner_get=hp_data.Getter](const HyperParameterBase* self) {
199  auto self_as_s = dynamic_cast<const S*>(self);
200  return inner_get(&(self_as_s->*object));
201  };
202  auto setter_ = [object, inner_set=hp_data.Setter](HyperParameterBase* self, const value_t& value) {
203  auto self_as_s = dynamic_cast<S*>(self);
204  inner_set(&(self_as_s->*object), value);
205  };
206  // register with . separated name
207  declare_hyper_parameter(name + "." + hp.first, HyperParamData<value_t>{setter_, getter_});
208  }, hp.second);
209  }
210  }
211 
212  private:
213 
215  template<class D>
216  struct HyperParamData {
217  std::function<void(HyperParameterBase*, D)> Setter;
218  std::function<D(const HyperParameterBase*)> Getter;
219  using ValueType = D;
220  };
221 
225  template<class D>
226  void declare_hyper_parameter(std::string name, HyperParamData<D> data) {
227  auto result = m_HyperParameters.insert( std::make_pair(std::move(name), std::move(data)) );
228  if(!result.second) {
229  throw std::logic_error("Trying to re-register hyper-parameter " + result.first->first);
230  }
231  }
232 
233  using hyper_param_ptr_t = std::variant<HyperParamData<long>, HyperParamData<double>>;
234  std::unordered_map<std::string, hyper_param_ptr_t> m_HyperParameters;
235  };
236 
237 
242  public:
244 
246  void set(const std::string& name, long value);
247  void set(const std::string& name, double value);
248 
250  [[nodiscard]] hyper_param_t get(const std::string& name) const;
251 
254  void apply(HyperParameterBase& target) const;
255 
256  private:
257  std::unordered_map<std::string, hyper_param_t> m_Values;
258  };
259 }
260 
261 #endif //DISMEC_HYPERPARAMS_H
Base class for all objects that have adjustable hyper-parameters.
Definition: hyperparams.h:83
virtual ~HyperParameterBase()=default
HyperParameterBase(HyperParameterBase &&)=default
std::unordered_map< std::string, hyper_param_ptr_t > m_HyperParameters
Definition: hyperparams.h:234
HyperParameterBase & operator=(const HyperParameterBase &)=default
HyperParameterBase & operator=(HyperParameterBase &&)=default
std::variant< HyperParamData< long >, HyperParamData< double > > hyper_param_ptr_t
Definition: hyperparams.h:233
void declare_hyper_parameter(std::string name, U S::*pointer)
Definition: hyperparams.h:117
hyper_param_t get_hyper_parameter(const std::string &name) const
Definition: hyperparams.cpp:10
void set_hyper_parameter(const std::string &name, long value)
Definition: hyperparams.cpp:19
void declare_hyper_parameter(std::string name, U(S::*getter)() const, void(S::*setter)(U))
Declares an constrained hyper-parameter with explicit getter and setter function.
Definition: hyperparams.h:152
std::variant< long, double > hyper_param_t
Definition: hyperparams.h:85
std::vector< std::string > get_hyper_parameter_names() const
Returns a vector that lists all hyper parameter names.
Definition: hyperparams.cpp:29
void declare_sub_object(const std::string &name, T S::*object)
Declares a sub-object that also contains hyper-parameters.
Definition: hyperparams.h:179
HyperParameterBase(const HyperParameterBase &)=default
void declare_hyper_parameter(std::string name, HyperParamData< D > data)
Definition: hyperparams.h:226
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
std::unordered_map< std::string, hyper_param_t > m_Values
Definition: hyperparams.h:257
void set(const std::string &name, long value)
Sets a hyper-parameter with the given name and value.
Definition: hyperparams.cpp:38
hyper_param_t get(const std::string &name) const
Gets the hyper-parameter with the given name, or throws if it does not exist.
Definition: hyperparams.cpp:46
void apply(HyperParameterBase &target) const
Definition: hyperparams.cpp:50
HyperParameterBase::hyper_param_t hyper_param_t
Definition: hyperparams.h:243
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
This structure collects the Getter and Setter functions. This is what we store in the variant.
Definition: hyperparams.h:216
std::function< void(HyperParameterBase *, D)> Setter
Definition: hyperparams.h:217
std::function< D(const HyperParameterBase *)> Getter
Definition: hyperparams.h:218