9 #include "pybind11/pybind11.h"
10 #include "pybind11/eigen.h"
11 #include "pybind11/stl.h"
25 namespace py = pybind11;
31 #define PY_PROPERTY(type, function) \
32 def_property(#function, [](const type& pds){ return pds->function(); } , nullptr)
44 m.def(
"load_predictions", [](
const std::string& file_name) {
49 py::class_<PyModel>(m,
"Model")
50 .PY_PROPERTY(
PyModel, num_labels)
51 .PY_PROPERTY(
PyModel, num_features)
52 .PY_PROPERTY(
PyModel, num_weights)
53 .PY_PROPERTY(
PyModel, has_sparse_weights)
54 .PY_PROPERTY(
PyModel, is_partial_model)
55 .def_property(
"labels_begin", [](
const PyModel& pds) {
return pds.
access().labels_begin().to_index(); } ,
nullptr)
56 .def_property(
"labels_end", [](
const PyModel& pds){
return pds.
access().labels_end().to_index(); } ,
nullptr)
57 .def(
"get_weights_for_label", [](
const PyModel& model,
long label){
65 .def(
"predict_scores", [](
const PyModel& model,
const Eigen::Ref<
const types::DenseColMajor<real_t>>& instances) {
77 m.def(
"load_model", [](
const std::string& file_name) ->
PyModel {
79 }, py::arg(
"file_name"), py::call_guard<py::gil_scoped_release>());
81 py::class_<PySaver>(m,
"ModelSaver")
82 .def(py::init([](std::string_view path, std::string_view format,
int precision,
double culling,
bool load_partial) {
88 }), py::arg(
"path"), py::arg(
"format"), py::arg(
"precision") = 6, py::arg(
"culling")=0.0,
89 py::arg(
"load_partial")=
false)
91 .def(
"add_model", [](
PySaver& saver,
const PyModel& model, std::optional<std::string> target_file) {
92 auto saved = saver.
access().add_model(model.
ptr(), std::move(target_file));
96 result_dict[
"count"] = entry.
Count;
97 result_dict[
"file"] = entry.
FileName;
99 saver.
access().update_meta_file();
101 }, py::arg(
"model"), py::arg(
"target_path") = std::nullopt, py::call_guard<py::gil_scoped_release>())
102 .def(
"add_meta", [](
PySaver& saver, py::dict data) {
105 data[
"count"].cast<
long>(),
106 data[
"file"].cast<std::string>(),
108 saver.
access().insert_sub_file(entry);
109 saver.
access().update_meta_file();
111 .def(
"get_missing_weights", [](
PySaver& saver) {
112 auto interval = saver.
access().get_missing_weights();
113 return std::make_pair(interval.first.to_index(), interval.second.to_index());
115 .def(
"any_weight_vector_for_interval", [](
PySaver& saver,
int begin,
int end) {
Utility class used to wrap all objects we provide to python.
const std::shared_ptr< T > & ptr() const
Manage saving a model consisting of multiple partial models.
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
constexpr double precision(const ConfusionMatrixBase< T > &matrix)
std::shared_ptr< Model > load_model(path source)
WeightFormat parse_weights_format(std::string_view name)
Gets the eighs.
const char * to_string(WeightFormat format)
std::pair< IndexMatrix, PredictionMatrix > read_sparse_prediction(std::istream &source)
Reads sparse predictions as saved by save_sparse_predictions().
Main namespace in which all types, classes, and functions are defined.
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
types::DenseRowMajor< real_t > PredictionMatrix
Dense matrix in Row Major format used for predictions.
#define PY_PROPERTY(type, function)
PYBIND11_MODULE(pydismec, m)
void register_training(pybind11::module_ &m)
void register_dataset(pybind11::module_ &m)
WeightFormat Format
Format in which the weights will be saved.
double Culling
If saving in sparse mode, threshold below which weights will be omitted.
int Precision
Precision with which the labels will be saved.
Collect the data about a weight file.