DiSMEC++
sparse.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_SPARSE_H
7 #define DISMEC_SPARSE_H
8 
9 #include "model.h"
10 
11 namespace dismec::model {
12  class SparseModel : public Model {
13  public:
16 
17  [[nodiscard]] long num_features() const override;
18 
19  [[nodiscard]] bool has_sparse_weights() const override { return true; }
20 
21  void predict_scores_unchecked(const FeatureMatrixIn& instances, PredictionMatrixOut target) const override;
22 
23  void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref<DenseRealVector> target) const override;
24 
25  void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn& weights) override;
26 
27  private:
28  std::vector<SparseRealVector> m_Weights;
30  };
31 }
32 
33 #endif //DISMEC_SPARSE_H
Strong typedef for an int to signify a label id.
Definition: types.h:20
A model combines a set of weight with some meta-information about these weights.
Definition: model.h:63
GenericInMatrix FeatureMatrixIn
Definition: model.h:66
long num_labels() const noexcept
How many labels are in the underlying dataset.
Definition: model.h:78
Eigen::Ref< PredictionMatrix > PredictionMatrixOut
Definition: model.h:65
GenericInVector WeightVectorIn
Definition: model.h:67
SparseModel(long num_features, long num_labels)
Definition: sparse.cpp:24
long num_features() const override
How many weights are in each weight vector, i.e. how many features should the input have.
Definition: sparse.cpp:39
void predict_scores_unchecked(const FeatureMatrixIn &instances, PredictionMatrixOut target) const override
Unchecked version of predict_scores().
Definition: sparse.cpp:79
void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn &weights) override
Unchecked version of set_weights_for_label().
Definition: sparse.cpp:113
bool has_sparse_weights() const override
whether this model stores the weights in a sparse format, or a dense format.
Definition: sparse.h:19
std::vector< SparseRealVector > m_Weights
Definition: sparse.h:28
void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref< DenseRealVector > target) const override
Unchecked version of get_weights_for_label().
Definition: sparse.cpp:84
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22