DiSMEC++
dense.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_DENSE_H
7 #define DISMEC_DENSE_H
8 
9 #include "model/model.h"
10 
11 namespace dismec::model {
12 
16  class DenseModel : public Model
17  {
18  public:
20  using weight_matrix_ptr = std::shared_ptr<WeightMatrix>;
21 
29  explicit DenseModel(const weight_matrix_ptr& weights);
30 
40 
42 
44 
46  [[nodiscard]] bool has_sparse_weights() const final { return false; }
47 
48  [[nodiscard]] long num_features() const override;
49 
51  [[nodiscard]] const WeightMatrix& get_raw_weights() const { return *m_Weights; }
52 
53  private:
54  void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref<DenseRealVector> target) const override;
55 
56  void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn& weights) override;
57 
58  void predict_scores_unchecked(const FeatureMatrixIn& instances,
59  PredictionMatrixOut target) const override;
60 
66  };
67 
68 }
69 
70 #endif //DISMEC_DENSE_H
Strong typedef for an int to signify a label id.
Definition: types.h:20
Implementation of the Model class that stores the weights as a single, dense matrix.
Definition: dense.h:17
bool has_sparse_weights() const final
A dense model doesn't have sparse weights.
Definition: dense.h:46
types::DenseColMajor< real_t > WeightMatrix
Definition: dense.h:19
weight_matrix_ptr m_Weights
The matrix of weights.
Definition: dense.h:65
const WeightMatrix & get_raw_weights() const
provides read-only access to the raw weight matrix.
Definition: dense.h:51
void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn &weights) override
Unchecked version of set_weights_for_label().
Definition: dense.cpp:69
long num_features() const override
How many weights are in each weight vector, i.e. how many features should the input have.
Definition: dense.cpp:59
void predict_scores_unchecked(const FeatureMatrixIn &instances, PredictionMatrixOut target) const override
Unchecked version of predict_scores().
Definition: dense.cpp:76
void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref< DenseRealVector > target) const override
Unchecked version of get_weights_for_label().
Definition: dense.cpp:64
std::shared_ptr< WeightMatrix > weight_matrix_ptr
Definition: dense.h:20
DenseModel(const weight_matrix_ptr &weights)
Creates a (complete) dense model with the given weight matrix.
Definition: dense.cpp:31
A model combines a set of weight with some meta-information about these weights.
Definition: model.h:63
GenericInMatrix FeatureMatrixIn
Definition: model.h:66
long num_labels() const noexcept
How many labels are in the underlying dataset.
Definition: model.h:78
Eigen::Ref< PredictionMatrix > PredictionMatrixOut
Definition: model.h:65
GenericInVector WeightVectorIn
Definition: model.h:67
outer_const< T, dense_col_major_h > DenseColMajor
Definition: type_helpers.h:46
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22