DiSMEC++
submodel.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_SUBMODEL_H
7 #define DISMEC_SUBMODEL_H
8 
9 #include "model.h"
10 
11 namespace dismec::model {
12  template<class T>
13  class SubModelWrapper : public Model {
14  using model_t = std::remove_reference_t<decltype(*std::declval<T>())>;
15  static_assert(std::is_convertible_v<std::remove_cv_t<model_t>&, Model&>, "T should be like a pointer to Model");
16  public:
17  SubModelWrapper(T original, label_id_t begin, label_id_t end) :
18  Model(PartialModelSpec{begin, end - begin, original->num_labels()}), m_Original(original)
19  {
20  }
21 
22  [[nodiscard]] long num_features() const override { return m_Original->num_features(); }
23  [[nodiscard]] bool has_sparse_weights() const override { return m_Original->has_sparse_weights(); }
24 
25  void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref<DenseRealVector> target) const override {
26  // we cannot directly call the _unchecked method, so we have to undo the label correction.
27  return m_Original->get_weights_for_label(labels_begin() + label.to_index(), target);
28  }
29  void set_weights_for_label_unchecked(label_id_t label, const GenericInVector& weights) override {
30  // we cannot directly call the _unchecked method, so we have to undo the label correction.
31  if constexpr (std::is_const_v<model_t>) {
32  throw std::logic_error("Cannot set weights for constant sub-model");
33  } else {
34  m_Original->set_weights_for_label(labels_begin() + label.to_index(), weights);
35  }
36  }
37 
38  void predict_scores_unchecked(const GenericInMatrix& instances, PredictionMatrixOut target) const override {
39  throw std::logic_error("Cannot predict from model view");
40  }
41  private:
43  };
44 
47 }
48 
49 #endif //DISMEC_SUBMODEL_H
Strong typedef for an int to signify a label id.
Definition: types.h:20
A model combines a set of weight with some meta-information about these weights.
Definition: model.h:63
long num_labels() const noexcept
How many labels are in the underlying dataset.
Definition: model.h:78
Eigen::Ref< PredictionMatrix > PredictionMatrixOut
Definition: model.h:65
label_id_t labels_begin() const noexcept
Definition: model.h:98
void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref< DenseRealVector > target) const override
Unchecked version of get_weights_for_label().
Definition: submodel.h:25
bool has_sparse_weights() const override
whether this model stores the weights in a sparse format, or a dense format.
Definition: submodel.h:23
void predict_scores_unchecked(const GenericInMatrix &instances, PredictionMatrixOut target) const override
Unchecked version of predict_scores().
Definition: submodel.h:38
SubModelWrapper(T original, label_id_t begin, label_id_t end)
Definition: submodel.h:17
std::remove_reference_t< decltype(*std::declval< T >())> model_t
Definition: submodel.h:14
void set_weights_for_label_unchecked(label_id_t label, const GenericInVector &weights) override
Unchecked version of set_weights_for_label().
Definition: submodel.h:29
long num_features() const override
How many weights are in each weight vector, i.e. how many features should the input have.
Definition: submodel.h:22
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22