DiSMEC++
pybind.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include <utility>
7 
8 #include "python/binding.h"
9 #include "pybind11/pybind11.h"
10 #include "pybind11/eigen.h"
11 #include "pybind11/stl.h"
12 #include "data/data.h"
13 #include "model/model.h"
14 
15 #include "io/xmc.h"
16 #include "io/model-io.h"
17 #include "io/prediction.h"
18 
19 #include "training/weighting.h"
20 #include "training/training.h"
21 #include "training/initializer.h"
22 
23 #include "parallel/runner.h"
24 
25 namespace py = pybind11;
26 using namespace dismec;
27 
30 
31 #define PY_PROPERTY(type, function) \
32 def_property(#function, [](const type& pds){ return pds->function(); } , nullptr)
33 
34 void register_dataset(pybind11::module_& m);
35 void register_training(pybind11::module_& m);
36 
37 PYBIND11_MODULE(pydismec, m)
38 {
41 
42 
43  // predictions
44  m.def("load_predictions", [](const std::string& file_name) {
46  });
47 
48  // model
49  py::class_<PyModel>(m, "Model")
50  .PY_PROPERTY(PyModel, num_labels)
51  .PY_PROPERTY(PyModel, num_features)
52  .PY_PROPERTY(PyModel, num_weights)
53  .PY_PROPERTY(PyModel, has_sparse_weights)
54  .PY_PROPERTY(PyModel, is_partial_model)
55  .def_property("labels_begin", [](const PyModel& pds) { return pds.access().labels_begin().to_index(); } , nullptr)
56  .def_property("labels_end", [](const PyModel& pds){ return pds.access().labels_end().to_index(); } , nullptr)
57  .def("get_weights_for_label", [](const PyModel& model, long label){
58  DenseRealVector target(model.access().num_features());
59  model.access().get_weights_for_label(label_id_t{label}, target);
60  return target;
61  })
62  .def("set_weights_for_label", [](PyModel& model, long label, const DenseRealVector& dense_weights){
63  model.access().set_weights_for_label(label_id_t{label}, model::Model::WeightVectorIn{dense_weights});
64  })
65  .def("predict_scores", [](const PyModel& model, const Eigen::Ref<const types::DenseColMajor<real_t>>& instances) {
66  PredictionMatrix target(instances.rows(), model.access().num_weights());
67  model.access().predict_scores(model::Model::FeatureMatrixIn{instances}, target);
68  return target;
69  })
70  .def("predict_scores", [](const PyModel& model, const SparseFeatures& instances) {
71  PredictionMatrix target(instances.rows(), model.access().num_weights());
72  model.access().predict_scores(model::Model::FeatureMatrixIn(instances), target);
73  return target;
74  })
75  ;
76 
77  m.def("load_model", [](const std::string& file_name) -> PyModel {
78  return io::load_model(file_name);
79  }, py::arg("file_name"), py::call_guard<py::gil_scoped_release>());
80 
81  py::class_<PySaver>(m, "ModelSaver")
82  .def(py::init([](std::string_view path, std::string_view format, int precision, double culling, bool load_partial) {
83  io::SaveOption options;
84  options.Precision = precision;
85  options.Culling = culling;
86  options.Format = io::model::parse_weights_format(format);
87  return io::PartialModelSaver(path, options, load_partial);
88  }), py::arg("path"), py::arg("format"), py::arg("precision") = 6, py::arg("culling")=0.0,
89  py::arg("load_partial")=false)
90  .PY_PROPERTY(PySaver, num_labels)
91  .def("add_model", [](PySaver& saver, const PyModel& model, std::optional<std::string> target_file) {
92  auto saved = saver.access().add_model(model.ptr(), std::move(target_file));
93  py::dict result_dict;
94  io::model::WeightFileEntry entry = saved.get();
95  result_dict["first"] = entry.First.to_index();
96  result_dict["count"] = entry.Count;
97  result_dict["file"] = entry.FileName;
98  result_dict["format"] = to_string(entry.Format);
99  saver.access().update_meta_file();
100  return result_dict;
101  }, py::arg("model"), py::arg("target_path") = std::nullopt, py::call_guard<py::gil_scoped_release>())
102  .def("add_meta", [](PySaver& saver, py::dict data) {
104  label_id_t(data["first"].cast<long>()),
105  data["count"].cast<long>(),
106  data["file"].cast<std::string>(),
107  io::model::parse_weights_format(data["format"].cast<std::string>())};
108  saver.access().insert_sub_file(entry);
109  saver.access().update_meta_file();
110  })
111  .def("get_missing_weights", [](PySaver& saver) {
112  auto interval = saver.access().get_missing_weights();
113  return std::make_pair(interval.first.to_index(), interval.second.to_index());
114  })
115  .def("any_weight_vector_for_interval", [](PySaver& saver, int begin, int end) {
116  return saver.access().any_weight_vector_for_interval(label_id_t{begin}, label_id_t{end});
117  })
118  ;
119 }
Utility class used to wrap all objects we provide to python.
Definition: binding.h:32
T & access()
Definition: binding.h:48
const std::shared_ptr< T > & ptr() const
Definition: binding.h:62
Manage saving a model consisting of multiple partial models.
Definition: model-io.h:236
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
constexpr double precision(const ConfusionMatrixBase< T > &matrix)
std::shared_ptr< Model > load_model(path source)
Definition: model-io.cpp:334
WeightFormat parse_weights_format(std::string_view name)
Gets the eighs.
Definition: model-io.cpp:84
const char * to_string(WeightFormat format)
Definition: model-io.cpp:89
std::pair< IndexMatrix, PredictionMatrix > read_sparse_prediction(std::istream &source)
Reads sparse predictions as saved by save_sparse_predictions().
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50
types::DenseRowMajor< real_t > PredictionMatrix
Dense matrix in Row Major format used for predictions.
Definition: matrix_types.h:75
#define PY_PROPERTY(type, function)
Definition: pybind.cpp:31
PYBIND11_MODULE(pydismec, m)
Definition: pybind.cpp:37
void register_training(pybind11::module_ &m)
Definition: py_train.cpp:123
void register_dataset(pybind11::module_ &m)
Definition: py_data.cpp:55
WeightFormat Format
Format in which the weights will be saved.
Definition: model-io.h:115
double Culling
If saving in sparse mode, threshold below which weights will be omitted.
Definition: model-io.h:113
int Precision
Precision with which the labels will be saved.
Definition: model-io.h:112
Collect the data about a weight file.
Definition: model-io.h:139