DiSMEC++
py_data.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "binding.h"
7 
8 #include "data/data.h"
9 #include "io/xmc.h"
10 #include "io/slice.h"
11 
12 using namespace dismec;
13 using PyDataSet = std::shared_ptr<DatasetBase>;
14 
15 #define PY_PROPERTY(function) \
16 def_property(#function, [](const DatasetBase& pds){ return pds.function(); } , nullptr)
17 
18 namespace {
19  auto num_positives(const DatasetBase& ds, long label) {
20  return ds.num_positives(label_id_t{label});
21  }
22  auto num_negatives(const DatasetBase& ds, long label) {
23  return ds.num_negatives(label_id_t{label});
24  }
25  auto get_labels(const DatasetBase& ds, long id) {
26  return *ds.get_labels(label_id_t{id});
27  }
28  auto get_features(const DatasetBase& ds) {
29  return ds.get_features()->unpack_variant();
30  }
32  (*ds.edit_features()) = GenericFeatureMatrix(std::move(features));
33  }
35  (*ds.edit_features()) = GenericFeatureMatrix(std::move(features));
36  }
37 
38  PyDataSet load_xmc(const std::filesystem::path& source_file, bool one_based_indexing) {
39  if(one_based_indexing) {
41  } else {
43  }
44  }
45 
46  void save_xmc(const std::filesystem::path& target_file, const DatasetBase& ds, int precision) {
47  io::save_xmc_dataset(target_file, dynamic_cast<const MultiLabelData&>(ds), precision);
48  }
49 
50  PyDataSet load_slice(const std::filesystem::path& features_file, const std::filesystem::path& labels_file) {
51  return wrap_shared(io::read_slice_dataset(features_file, labels_file));
52  }
53 }
54 
55 void register_dataset(pybind11::module_& m) {
56  // data set
57  py::class_<DatasetBase, PyDataSet>(m, "DataSet")
58  // we need to distinguish these two overloads by kwarg name I think, because otherwise we get implicit conversions between dense and sparse matrices
59  // by having these two signatures, we seem to prevent automatic conversions e.g. from double ndarray to float, and instead get an error message.
60  // same later on for set_features
61  .def(py::init([](SparseFeatures features, std::vector<std::vector<long>> positives) -> PyDataSet
62  { return std::make_shared<MultiLabelData>(std::move(features), std::move(positives)); }),
63  py::kw_only(), py::arg("sparse_features"), py::arg("positives")
64  )
65  .def(py::init([](DenseFeatures features, std::vector<std::vector<long>> positives) -> PyDataSet
66  { return std::make_shared<MultiLabelData>(std::move(features), std::move(positives)); }),
67  py::kw_only(), py::arg("dense_features"), py::arg("positives")
68  )
69  .PY_PROPERTY(num_features)
70  .PY_PROPERTY(num_examples)
71  .PY_PROPERTY(num_labels)
72  .def("num_positives", num_positives, py::arg("label_id"))
73  .def("num_negatives", num_negatives, py::arg("label_id"))
74  .def("get_labels", get_labels, py::arg("label_id"))
75  .def("get_features", get_features)
76  .def("set_features", set_features_sparse,
77  py::kw_only(),
78  py::arg("sparse_features"))
79  .def("set_features", set_features_dense,
80  py::kw_only(),
81  py::arg("dense_features"));
82 
83  // dataset io functions
84  m.def("load_xmc", load_xmc,
85  py::arg("source_file"), py::kw_only(),
86  py::arg("one_based_index") = false,
87  py::call_guard<py::gil_scoped_release>());
88 
89  m.def("save_xmc", save_xmc,
90  py::arg("file_name"), py::arg("dataset"),
91  py::kw_only(), py::arg("precision") = 4,
92  py::call_guard<py::gil_scoped_release>()
93  );
94 
95  m.def("load_slice", load_slice,
96  py::kw_only(),py::arg("features"), py::arg("labels"),
97  py::call_guard<py::gil_scoped_release>()
98  );
99 }
std::shared_ptr< T > wrap_shared(T &&source)
Definition: binding.h:75
std::shared_ptr< dismec::DatasetBase > PyDataSet
Definition: binding.h:81
virtual long num_negatives(label_id_t id) const
Definition: data.cpp:17
std::shared_ptr< const BinaryLabelVector > get_labels(label_id_t id) const
Definition: data.cpp:21
virtual long num_positives(label_id_t id) const
Definition: data.cpp:13
std::shared_ptr< const GenericFeatureMatrix > get_features() const
get a shared pointer to the (immutable) feature data
Definition: data.cpp:39
std::shared_ptr< GenericFeatureMatrix > edit_features()
get a shared pointer to mutable feature data. Use with care.
Definition: data.cpp:43
Strong typedef for an int to signify a label id.
Definition: types.h:20
auto num_positives(const DatasetBase &ds, long label)
Definition: py_data.cpp:19
auto get_features(const DatasetBase &ds)
Definition: py_data.cpp:28
auto get_labels(const DatasetBase &ds, long id)
Definition: py_data.cpp:25
auto set_features_sparse(DatasetBase &ds, SparseFeatures features)
Definition: py_data.cpp:31
void save_xmc(const std::filesystem::path &target_file, const DatasetBase &ds, int precision)
Definition: py_data.cpp:46
PyDataSet load_xmc(const std::filesystem::path &source_file, bool one_based_indexing)
Definition: py_data.cpp:38
auto num_negatives(const DatasetBase &ds, long label)
Definition: py_data.cpp:22
PyDataSet load_slice(const std::filesystem::path &features_file, const std::filesystem::path &labels_file)
Definition: py_data.cpp:50
auto set_features_dense(DatasetBase &ds, DenseFeatures features)
Definition: py_data.cpp:34
constexpr double precision(const ConfusionMatrixBase< T > &matrix)
constexpr T positives(const ConfusionMatrixBase< T > &matrix)
MultiLabelData read_xmc_dataset(const std::filesystem::path &source, IndexMode mode=IndexMode::ZERO_BASED)
Reads a dataset given in the extreme multilabel classification format.
Definition: xmc.cpp:216
MultiLabelData read_slice_dataset(std::istream &features, std::istream &labels)
reads a dataset given in slice format.
Definition: slice.cpp:36
void save_xmc_dataset(std::ostream &target, const MultiLabelData &data)
Saves the given dataset in XMC format.
Definition: xmc.cpp:294
@ ONE_BASED
labels and feature indices are 1, 2, ..., num
@ ZERO_BASED
labels and feature indices are 0, 1, ..., num - 1
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50
types::GenericMatrix< DenseFeatures, SparseFeatures > GenericFeatureMatrix
Definition: matrix_types.h:60
#define PY_PROPERTY(function)
Definition: py_data.cpp:15
void register_dataset(pybind11::module_ &m)
Definition: py_data.cpp:55