6 #ifndef DISMEC_HYPERPARAMS_H
7 #define DISMEC_HYPERPARAMS_H
11 #include <unordered_map>
12 #include <type_traits>
116 template<
class U,
class S>
123 static_assert(std::is_base_of_v<HyperParameterBase, S>,
"HyperParameterBase is not base class of member pointer");
125 if(
dynamic_cast<S*
>(
this) ==
nullptr) {
126 throw std::logic_error(
"Cannot cast this pointer to class `S`");
133 return dynamic_cast<const S*
>(
self)->*pointer;
136 dynamic_cast<S*
>(
self)->*pointer = value;
151 template<
class U,
class S>
154 static_assert(std::is_base_of_v<HyperParameterBase, S>,
"T is not base class of member pointer");
156 if(
dynamic_cast<S*
>(
this) ==
nullptr) {
157 throw std::logic_error(
"Cannot cast this pointer to class `S`");
161 return (
dynamic_cast<const S*
>(
self)->*getter)();
164 (
dynamic_cast<S*
>(
self)->*setter)(value);
178 template<
class T,
class S>
180 static_assert(std::is_base_of_v<HyperParameterBase, S>,
"S is not base class of member pointer");
181 static_assert(std::is_base_of_v<HyperParameterBase, T>,
"T is not base class of member pointer");
183 if(
dynamic_cast<S*
>(
this) ==
nullptr) {
184 throw std::logic_error(
"Cannot cast this pointer to class `S`");
195 using value_t =
typename std::decay_t<decltype(hp_data)>::ValueType;
199 auto self_as_s =
dynamic_cast<const S*
>(
self);
200 return inner_get(&(self_as_s->*
object));
202 auto setter_ = [object, inner_set=hp_data.Setter](
HyperParameterBase*
self,
const value_t& value) {
203 auto self_as_s =
dynamic_cast<S*
>(
self);
204 inner_set(&(self_as_s->*
object), value);
227 auto result =
m_HyperParameters.insert( std::make_pair(std::move(name), std::move(data)) );
229 throw std::logic_error(
"Trying to re-register hyper-parameter " + result.first->first);
246 void set(
const std::string& name,
long value);
247 void set(
const std::string& name,
double value);
257 std::unordered_map<std::string, hyper_param_t>
m_Values;
Base class for all objects that have adjustable hyper-parameters.
HyperParameterBase()=default
virtual ~HyperParameterBase()=default
HyperParameterBase(HyperParameterBase &&)=default
std::unordered_map< std::string, hyper_param_ptr_t > m_HyperParameters
HyperParameterBase & operator=(const HyperParameterBase &)=default
HyperParameterBase & operator=(HyperParameterBase &&)=default
std::variant< HyperParamData< long >, HyperParamData< double > > hyper_param_ptr_t
void declare_hyper_parameter(std::string name, U S::*pointer)
hyper_param_t get_hyper_parameter(const std::string &name) const
void set_hyper_parameter(const std::string &name, long value)
void declare_hyper_parameter(std::string name, U(S::*getter)() const, void(S::*setter)(U))
Declares an constrained hyper-parameter with explicit getter and setter function.
std::variant< long, double > hyper_param_t
std::vector< std::string > get_hyper_parameter_names() const
Returns a vector that lists all hyper parameter names.
void declare_sub_object(const std::string &name, T S::*object)
Declares a sub-object that also contains hyper-parameters.
HyperParameterBase(const HyperParameterBase &)=default
void declare_hyper_parameter(std::string name, HyperParamData< D > data)
This class represents a set of hyper-parameters.
std::unordered_map< std::string, hyper_param_t > m_Values
void set(const std::string &name, long value)
Sets a hyper-parameter with the given name and value.
hyper_param_t get(const std::string &name) const
Gets the hyper-parameter with the given name, or throws if it does not exist.
void apply(HyperParameterBase &target) const
HyperParameterBase::hyper_param_t hyper_param_t
auto visit(F &&f, Variants &&... variants)
Main namespace in which all types, classes, and functions are defined.
This structure collects the Getter and Setter functions. This is what we store in the variant.
std::function< void(HyperParameterBase *, D)> Setter
std::function< D(const HyperParameterBase *)> Getter