18 #include "spdlog/fmt/fmt.h"
24 auto m = root.def_submodule(
"reg",
"Regularizer configuration types");
25 py::class_<objective::SquaredNormConfig>(m,
"SquaredNormConfig")
26 .def(py::init<real_t, bool>(),
27 py::kw_only(), py::arg(
"strength"), py::arg(
"ignore_bias") =
true)
32 return fmt::format(
"SquaredNormConfig(strength={}, ignore_bias={})", a.
Strength, a.
IgnoreBias ?
"True" :
"False");
36 py::class_<objective::HuberConfig>(m,
"HuberConfig")
37 .def(py::init<real_t, real_t, bool>(),
38 py::kw_only(), py::arg(
"strength"), py::arg(
"epsilon"), py::arg(
"ignore_bias") =
true)
44 return fmt::format(
"HuberConfig(strength={}, epsilon={}, ignore_bias={})", a.
Strength, a.Epsilon, a.
IgnoreBias ?
"True" :
"False");
48 py::class_<objective::ElasticConfig>(m,
"ElasticConfig")
49 .def(py::init<real_t, real_t, real_t, bool>(),
50 py::kw_only(), py::arg(
"strength"), py::arg(
"epsilon"), py::arg(
"interpolation"), py::arg(
"ignore_bias") =
true)
57 return fmt::format(
"ElasticConfig(strength={}, epsilon={}, interpolation={}, ignore_bias={})", a.
Strength, a.Epsilon, a.Interpolation, a.
IgnoreBias ?
"True" :
"False");
72 return std::make_shared<ConstantWeighting>(pos, neg);
75 return std::make_shared<PropensityWeighting>(
PropensityModel(&data, a, b));
78 return std::make_shared<CustomWeighting>(std::move(pos), std::move(neg));
83 py::class_<WeightingScheme, std::shared_ptr<WeightingScheme>>(m,
"WeightingScheme")
87 py::kw_only(), py::arg(
"positive") = 1.0, py::arg(
"negative") = 1.0)
90 py::kw_only(), py::arg(
"a") = 0.55, py::arg(
"b") = 1.5)
92 py::kw_only(), py::arg(
"positive"), py::arg(
"negative"))
97 auto m = root.def_submodule(
"init",
"Initialization configuration types");
100 py::class_<WeightInitializationStrategy, std::shared_ptr<WeightInitializationStrategy>>(m,
"Initializer");
107 }, py::arg(
"vector"));
109 m.def(
"feature_mean", [](std::shared_ptr<DatasetBase> dataset,
real_t pos,
real_t neg){
111 }, py::kw_only(), py::arg(
"data"), py::arg(
"positive_margin")=1, py::arg(
"negative_margin")=-2);
113 m.def(
"multi_feature_mean", [](std::shared_ptr<DatasetBase> dataset,
int max_pos,
real_t pos,
real_t neg){
115 }, py::kw_only(), py::arg(
"data"), py::arg(
"max_pos"), py::arg(
"positive_margin")=1, py::arg(
"negative_margin")=-2);
120 }, py::kw_only(), py::arg(
"data"), py::arg(
"reg"), py::arg(
"loss"));
128 py::class_<DismecTrainingConfig>(m,
"TrainingConfig")
130 std::shared_ptr<postproc::PostProcessFactory> pf{};
136 return DismecTrainingConfig{std::move(weighting), std::move(init), std::move(pf),
nullptr, sparse, regularizer, loss};
137 }), py::kw_only(), py::arg(
"weighting"), py::arg(
"regularizer"), py::arg(
"init"),
138 py::arg(
"loss"), py::arg(
"culling"))
144 py::enum_<LossType>(m,
"LossType")
158 m.def(
"parallel_train", [](
const PyDataSet& data,
const py::dict& hyper_params,
160 long label_end,
long threads) -> py::dict
163 for (
auto item : hyper_params)
165 if(pybind11::isinstance<pybind11::int_>(item.second)) {
166 hps.set(item.first.cast<std::string>(), item.second.cast<long>());
168 hps.set(item.first.cast<std::string>(), item.second.cast<double>());
180 rdict[
"grad"] = result.TotalGrad;
181 rdict[
"finished"] = result.IsFinished;
182 rdict[
"model"] =
PyModel(std::move(result.Model));
184 }, py::arg(
"data"), py::arg(
"hyperparameters"), py::arg(
"spec"), py::arg(
"label_begin") = 0, py::arg(
"label_end") = -1,
185 py::arg(
"num_threads") = -1, py::call_guard<py::gil_scoped_release>());
std::shared_ptr< dismec::DatasetBase > PyDataSet
PyWrapper< dismec::model::Model > PyModel
std::shared_ptr< dismec::WeightingScheme > PyWeighting
This class represents a set of hyper-parameters.
Base class for label-based weighting schemes.
virtual double get_positive_weight(label_id_t label_id) const =0
Gets the weight to use for all examples where the label label_id is present.
virtual double get_negative_weight(label_id_t label_id) const =0
Gets the weight to use for all examples where the label label_id is absent.
Strong typedef for an int to signify a label id.
void set_logger(std::shared_ptr< spdlog::logger > logger)
sets the logger object that is used for reporting. Set to nullptr for quiet mode.
auto get_negative_weight(const WeightingScheme &pds, long label)
PyWeighting make_constant(double pos, double neg)
PyWeighting make_propensity(const DatasetBase &data, double a, double b)
auto get_positive_weight(const WeightingScheme &pds, long label)
PyWeighting make_custom(DenseRealVector pos, DenseRealVector neg)
std::shared_ptr< WeightInitializationStrategy > create_zero_initializer()
Creates an initialization strategy that initializes all weight vectors to zero.
std::shared_ptr< WeightInitializationStrategy > create_feature_mean_initializer(std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
std::shared_ptr< WeightInitializationStrategy > create_multi_pos_mean_strategy(std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
std::shared_ptr< WeightInitializationStrategy > create_ova_primal_initializer(const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)
FactoryPtr create_culling(real_t eps)
Main namespace in which all types, classes, and functions are defined.
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
float real_t
The default type for floating point values.
void register_regularizers(pybind11::module_ &root)
void register_weighting(pybind11::module_ &m)
void register_training(pybind11::module_ &m)
void register_init(pybind11::module_ &root)
RegularizerSpec Regularizer
std::shared_ptr< WeightingScheme > Weighting