DiSMEC++
py_train.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "binding.h"
7 
8 #include "data/data.h"
9 
10 #include "training/weighting.h"
11 #include "training/training.h"
12 #include "training/initializer.h"
13 #include "training/postproc.h"
14 
15 #include "parallel/runner.h"
16 #include "objective/regularizers.h"
17 
18 #include "spdlog/fmt/fmt.h"
19 
20 using namespace dismec;
21 using PyWeighting = std::shared_ptr<WeightingScheme>;
22 
23 void register_regularizers(pybind11::module_& root) {
24  auto m = root.def_submodule("reg", "Regularizer configuration types");
25  py::class_<objective::SquaredNormConfig>(m, "SquaredNormConfig")
26  .def(py::init<real_t, bool>(),
27  py::kw_only(), py::arg("strength"), py::arg("ignore_bias") = true)
28  .def_readwrite("strength", &objective::SquaredNormConfig::Strength)
29  .def_readwrite("ignore_bias", &objective::SquaredNormConfig::IgnoreBias)
30  .def("__repr__",
31  [](const objective::SquaredNormConfig &a) {
32  return fmt::format("SquaredNormConfig(strength={}, ignore_bias={})", a.Strength, a.IgnoreBias ? "True" : "False");
33  }
34  );
35 
36  py::class_<objective::HuberConfig>(m, "HuberConfig")
37  .def(py::init<real_t, real_t, bool>(),
38  py::kw_only(), py::arg("strength"), py::arg("epsilon"), py::arg("ignore_bias") = true)
39  .def_readwrite("strength", &objective::HuberConfig::Strength)
40  .def_readwrite("epsilon", &objective::HuberConfig::Epsilon)
41  .def_readwrite("ignore_bias", &objective::HuberConfig::IgnoreBias)
42  .def("__repr__",
43  [](const objective::HuberConfig &a) {
44  return fmt::format("HuberConfig(strength={}, epsilon={}, ignore_bias={})", a.Strength, a.Epsilon, a.IgnoreBias ? "True" : "False");
45  }
46  );
47 
48  py::class_<objective::ElasticConfig>(m, "ElasticConfig")
49  .def(py::init<real_t, real_t, real_t, bool>(),
50  py::kw_only(), py::arg("strength"), py::arg("epsilon"), py::arg("interpolation"), py::arg("ignore_bias") = true)
51  .def_readwrite("strength", &objective::ElasticConfig::Strength)
52  .def_readwrite("epsilon", &objective::ElasticConfig::Epsilon)
53  .def_readwrite("interpolation", &objective::ElasticConfig::Interpolation)
54  .def_readwrite("ignore_bias", &objective::ElasticConfig::IgnoreBias)
55  .def("__repr__",
56  [](const objective::ElasticConfig &a) {
57  return fmt::format("ElasticConfig(strength={}, epsilon={}, interpolation={}, ignore_bias={})", a.Strength, a.Epsilon, a.Interpolation, a.IgnoreBias ? "True" : "False");
58  }
59  );
60 }
61 
62 
63 namespace {
64  auto get_positive_weight(const WeightingScheme& pds, long label) {
65  return pds.get_positive_weight(label_id_t{label});
66  }
67  auto get_negative_weight(const WeightingScheme& pds, long label) {
68  return pds.get_negative_weight(label_id_t{label});
69  }
70 
71  PyWeighting make_constant(double pos, double neg) {
72  return std::make_shared<ConstantWeighting>(pos, neg);
73  }
74  PyWeighting make_propensity(const DatasetBase& data, double a, double b) {
75  return std::make_shared<PropensityWeighting>(PropensityModel(&data, a, b));
76  }
78  return std::make_shared<CustomWeighting>(std::move(pos), std::move(neg));
79  }
80 }
81 
82 void register_weighting(pybind11::module_& m) {
83  py::class_<WeightingScheme, std::shared_ptr<WeightingScheme>>(m, "WeightingScheme")
84  .def("positive_weight", get_positive_weight, py::arg("label"))
85  .def("negative_weight", get_negative_weight, py::arg("label"))
86  .def_static("Constant", make_constant,
87  py::kw_only(), py::arg("positive") = 1.0, py::arg("negative") = 1.0)
88  .def_static("Propensity", make_propensity,
89  py::arg("dataset"),
90  py::kw_only(), py::arg("a") = 0.55, py::arg("b") = 1.5)
91  .def_static("Custom", make_custom,
92  py::kw_only(), py::arg("positive"), py::arg("negative"))
93  ;
94 }
95 
96 void register_init(pybind11::module_& root) {
97  auto m = root.def_submodule("init", "Initialization configuration types");
98  using namespace init;
99 
100  py::class_<WeightInitializationStrategy, std::shared_ptr<WeightInitializationStrategy>>(m, "Initializer");
101  m.def("zero", [](){
102  return create_zero_initializer();
103  });
104 
105  m.def("constant", [](const DenseRealVector& vec){
106  return create_constant_initializer(vec);
107  }, py::arg("vector"));
108 
109  m.def("feature_mean", [](std::shared_ptr<DatasetBase> dataset, real_t pos, real_t neg){
110  return create_feature_mean_initializer(dataset, pos, neg);
111  }, py::kw_only(), py::arg("data"), py::arg("positive_margin")=1, py::arg("negative_margin")=-2);
112 
113  m.def("multi_feature_mean", [](std::shared_ptr<DatasetBase> dataset, int max_pos, real_t pos, real_t neg){
114  return create_multi_pos_mean_strategy(dataset, max_pos, pos, neg);
115  }, py::kw_only(), py::arg("data"), py::arg("max_pos"), py::arg("positive_margin")=1, py::arg("negative_margin")=-2);
116 
117 
118  m.def("ova_primal", [](std::shared_ptr<DatasetBase> dataset, RegularizerSpec reg, LossType loss){
119  return create_ova_primal_initializer(dataset, reg, loss);
120  }, py::kw_only(), py::arg("data"), py::arg("reg"), py::arg("loss"));
121 }
122 
123 void register_training(pybind11::module_& m) {
126  register_init(m);
127 
128  py::class_<DismecTrainingConfig>(m, "TrainingConfig")
129  .def(py::init([](PyWeighting weighting, RegularizerSpec regularizer, std::shared_ptr<init::WeightInitializationStrategy> init, LossType loss, real_t culling) {
130  std::shared_ptr<postproc::PostProcessFactory> pf{};
131  bool sparse = false;
132  if(culling > 0) {
133  pf = postproc::create_culling(culling);
134  sparse = true;
135  }
136  return DismecTrainingConfig{std::move(weighting), std::move(init), std::move(pf), nullptr, sparse, regularizer, loss};
137  }), py::kw_only(), py::arg("weighting"), py::arg("regularizer"), py::arg("init"),
138  py::arg("loss"), py::arg("culling"))
139  .def_readwrite("regularizer", &DismecTrainingConfig::Regularizer)
140  .def_readwrite("sparse_model", &DismecTrainingConfig::Sparse)
141  .def_readwrite("weighting", &DismecTrainingConfig::Weighting)
142  .def_readwrite("loss", &DismecTrainingConfig::Loss);
143 
144  py::enum_<LossType>(m, "LossType")
145  .value("SquaredHinge", LossType::SQUARED_HINGE)
146  .value("Hinge", LossType::HINGE)
147  .value("Logistic", LossType::LOGISTIC)
148  .value("HuberHinge", LossType::HUBER_HINGE);
149 
150 
151  /*
152  std::shared_ptr<postproc::PostProcessFactory> PostProcessing;
153  std::shared_ptr<TrainingStatsGatherer> StatsGatherer;
154  */
155 
156 
157 
158  m.def("parallel_train", [](const PyDataSet& data, const py::dict& hyper_params,
159  const DismecTrainingConfig& config, long label_begin,
160  long label_end, long threads) -> py::dict
161  {
162  HyperParameters hps;
163  for (auto item : hyper_params)
164  {
165  if(pybind11::isinstance<pybind11::int_>(item.second)) {
166  hps.set(item.first.cast<std::string>(), item.second.cast<long>());
167  } else {
168  hps.set(item.first.cast<std::string>(), item.second.cast<double>());
169  }
170  }
171 
172  auto spec = create_dismec_training(data, hps, config);
173 
174  parallel::ParallelRunner runner(threads);
175  runner.set_logger(spdlog::default_logger());
176  // TODO give more detailled result
177  auto result = run_training(runner, spec, label_id_t{label_begin}, label_id_t{label_end});
178  py::dict rdict;
179  rdict["loss"] = result.TotalLoss;
180  rdict["grad"] = result.TotalGrad;
181  rdict["finished"] = result.IsFinished;
182  rdict["model"] = PyModel(std::move(result.Model));
183  return rdict;
184  }, py::arg("data"), py::arg("hyperparameters"), py::arg("spec"), py::arg("label_begin") = 0, py::arg("label_end") = -1,
185  py::arg("num_threads") = -1, py::call_guard<py::gil_scoped_release>());
186  // TODO check constness and lifetime of returns
187 }
std::shared_ptr< dismec::DatasetBase > PyDataSet
Definition: binding.h:81
PyWrapper< dismec::model::Model > PyModel
Definition: binding.h:83
std::shared_ptr< dismec::WeightingScheme > PyWeighting
Definition: binding.h:82
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
Base class for label-based weighting schemes.
Definition: weighting.h:32
virtual double get_positive_weight(label_id_t label_id) const =0
Gets the weight to use for all examples where the label label_id is present.
virtual double get_negative_weight(label_id_t label_id) const =0
Gets the weight to use for all examples where the label label_id is absent.
Strong typedef for an int to signify a label id.
Definition: types.h:20
void set_logger(std::shared_ptr< spdlog::logger > logger)
sets the logger object that is used for reporting. Set to nullptr for quiet mode.
Definition: runner.cpp:28
auto get_negative_weight(const WeightingScheme &pds, long label)
Definition: py_train.cpp:67
PyWeighting make_constant(double pos, double neg)
Definition: py_train.cpp:71
PyWeighting make_propensity(const DatasetBase &data, double a, double b)
Definition: py_train.cpp:74
auto get_positive_weight(const WeightingScheme &pds, long label)
Definition: py_train.cpp:64
PyWeighting make_custom(DenseRealVector pos, DenseRealVector neg)
Definition: py_train.cpp:77
std::shared_ptr< WeightInitializationStrategy > create_zero_initializer()
Creates an initialization strategy that initializes all weight vectors to zero.
Definition: zero.cpp:33
std::shared_ptr< WeightInitializationStrategy > create_feature_mean_initializer(std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
Definition: msi.cpp:90
std::shared_ptr< WeightInitializationStrategy > create_multi_pos_mean_strategy(std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
Definition: multi_pos.cpp:212
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
Definition: constant.cpp:56
std::shared_ptr< WeightInitializationStrategy > create_ova_primal_initializer(const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)
Definition: ova-primal.cpp:15
FactoryPtr create_culling(real_t eps)
Definition: postproc.cpp:54
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:122
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
Definition: spec.h:143
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
LossType
Definition: spec.h:129
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
Definition: dismec.cpp:157
float real_t
The default type for floating point values.
Definition: config.h:17
void register_regularizers(pybind11::module_ &root)
Definition: py_train.cpp:23
void register_weighting(pybind11::module_ &m)
Definition: py_train.cpp:82
void register_training(pybind11::module_ &m)
Definition: py_train.cpp:123
void register_init(pybind11::module_ &root)
Definition: py_train.cpp:96
RegularizerSpec Regularizer
Definition: spec.h:151
std::shared_ptr< WeightingScheme > Weighting
Definition: spec.h:146