DiSMEC++
data.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_DATA_H
7 #define DISMEC_DATA_H
8 
9 #include <memory>
10 #include "matrix_types.h"
11 #include "data/types.h"
12 #include "utils/eigen_generic.h"
13 
14 namespace dismec {
15  class DatasetBase {
16  public:
17  virtual ~DatasetBase() = default;
18  DatasetBase(const DatasetBase&) = default;
19  DatasetBase(DatasetBase&&) = default;
21  DatasetBase& operator=(const DatasetBase&) = default;
22 
24  [[nodiscard]] std::shared_ptr<const GenericFeatureMatrix> get_features() const;
25 
27  [[nodiscard]] std::shared_ptr<GenericFeatureMatrix> edit_features();
28 
30  [[nodiscard]] long num_features() const noexcept;
31 
33  [[nodiscard]] long num_examples() const noexcept;
34 
37  [[nodiscard]] virtual long num_labels() const noexcept = 0;
38 
41  [[nodiscard]] virtual long num_positives(label_id_t id) const;
42 
45  [[nodiscard]] virtual long num_negatives(label_id_t id) const;
46 
49  [[nodiscard]] std::shared_ptr<const BinaryLabelVector> get_labels(label_id_t id) const;
50 
54  virtual void get_labels(label_id_t id, Eigen::Ref<BinaryLabelVector> target) const = 0;
55  protected:
56  explicit DatasetBase(SparseFeatures x);
57  explicit DatasetBase(DenseFeatures x);
58 
59  // features
60  std::shared_ptr<GenericFeatureMatrix> m_Features;
61  };
62 
69  class BinaryData : public DatasetBase {
70  public:
71  BinaryData(SparseFeatures x, std::shared_ptr<BinaryLabelVector> y) :
72  DatasetBase(std::move(x)), m_Labels(std::move(y))
73  {
74 
75  }
76 
77  [[nodiscard]] long num_labels() const noexcept override;
78  void get_labels(label_id_t i, Eigen::Ref<BinaryLabelVector> target) const override;
79  private:
80 
81  // targets
82  std::shared_ptr<BinaryLabelVector> m_Labels;
83  };
84 
85 
86  class MultiLabelData : public DatasetBase {
87  public:
88  MultiLabelData(SparseFeatures x, std::vector<std::vector<long>> y) :
89  DatasetBase(x.markAsRValue()), m_Labels(std::move(y)) {
90 
91  }
92 
93  MultiLabelData(DenseFeatures x, std::vector<std::vector<long>> y) :
94  DatasetBase(std::move(x)), m_Labels(std::move(y)) {
95  }
96 
97  [[nodiscard]] long num_labels() const noexcept override;
98  void get_labels(label_id_t label, Eigen::Ref<BinaryLabelVector> target) const override;
99 
100  // these are faster than the default implementation
101  [[nodiscard]] long num_positives(label_id_t id) const override;
102  [[nodiscard]] long num_negatives(label_id_t id) const override;
103 
104  [[nodiscard]] const std::vector<long>& get_label_instances(label_id_t label) const;
105 
106  void select_labels(label_id_t start, label_id_t end);
107 
108  [[nodiscard]] const std::vector<std::vector<long>>& all_labels() const { return m_Labels; }
109  private:
110  // targets: vector of vectors of example ids: if label i is present in example j, `j \in then m_Labels[i]`
111  std::vector<std::vector<long>> m_Labels;
112  };
113 }
114 
115 #endif //DISMEC_DATA_H
Collects the data related to a single optimization problem.
Definition: data.h:69
BinaryData(SparseFeatures x, std::shared_ptr< BinaryLabelVector > y)
Definition: data.h:71
virtual long num_negatives(label_id_t id) const
Definition: data.cpp:17
long num_examples() const noexcept
Get the total number of instances, i.e. the number of rows in the feature matrix.
Definition: data.cpp:52
DatasetBase & operator=(const DatasetBase &)=default
DatasetBase(DatasetBase &&)=default
DatasetBase(const DatasetBase &)=default
std::shared_ptr< const BinaryLabelVector > get_labels(label_id_t id) const
Definition: data.cpp:21
virtual long num_positives(label_id_t id) const
Definition: data.cpp:13
std::shared_ptr< const GenericFeatureMatrix > get_features() const
get a shared pointer to the (immutable) feature data
Definition: data.cpp:39
std::shared_ptr< GenericFeatureMatrix > m_Features
Definition: data.h:60
DatasetBase & operator=(DatasetBase &&)=default
virtual long num_labels() const noexcept=0
long num_features() const noexcept
Get the total number of features, i.e. the number of columns in the feature matrix.
Definition: data.cpp:48
virtual ~DatasetBase()=default
std::shared_ptr< GenericFeatureMatrix > edit_features()
get a shared pointer to mutable feature data. Use with care.
Definition: data.cpp:43
std::vector< std::vector< long > > m_Labels
Definition: data.h:111
MultiLabelData(SparseFeatures x, std::vector< std::vector< long >> y)
Definition: data.h:88
MultiLabelData(DenseFeatures x, std::vector< std::vector< long >> y)
Definition: data.h:93
Strong typedef for an int to signify a label id.
Definition: types.h:20
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50