DiSMEC++
model.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "model/model.h"
7 #include "spdlog/fmt/fmt.h"
8 #include "utils/eigen_generic.h"
9 
10 using namespace dismec;
11 using namespace dismec::model;
12 
14  m_LabelsBegin(spec.first_label), m_LabelsEnd(spec.first_label + spec.label_count), m_NumLabels(spec.total_labels)
15 {
16  if(m_NumLabels <= 0) {
17  throw std::invalid_argument( fmt::format("Total number of labels must be positive! Got {}.", m_NumLabels) );
18  }
19 
21  throw std::invalid_argument( fmt::format("Invalid label range [{}, {}) specified. Total number of labels"
22  "was declared as {}.",
24  }
25 }
26 
27 
29  if(label < labels_begin() || label >= labels_end()) {
30  throw std::out_of_range(
31  fmt::format("label index {} is invalid. Labels must be in [{}, {})",
32  label.to_index(), labels_begin().to_index(), labels_end().to_index()));
33  }
34  return label_id_t{label - labels_begin()};
35 }
36 
39 }
40 
41 void Model::get_weights_for_label(label_id_t label, Eigen::Ref<DenseRealVector> target) const {
42  if(target.size() != num_features()) {
43  throw std::invalid_argument(
44  fmt::format("target size {} does not match number of features {}.",
45  target.size(), num_features()));
46  }
47 
49 }
50 
52  if(weights.size() != num_features()) {
53  throw std::invalid_argument(
54  fmt::format("weight size {} does not match number of features {}.",
55  weights.size(), num_features()));
56  }
58 }
59 
60 void Model::predict_scores(const FeatureMatrixIn& instances, PredictionMatrixOut target) const {
61  // check number of instances
62  if(instances.rows() != target.rows()) {
63  throw std::logic_error(fmt::format("Mismatch in number of rows between instances ({}) and target ({})",
64  instances.rows(), target.rows()));
65  }
66 
67  // check number of labels
68  if(target.cols() != num_weights()) {
69  throw std::logic_error(
70  fmt::format("Wrong number of columns in target ({}). Expect one column for each of the {} labels.",
71  target.cols(), num_weights()));
72  }
73 
74  if(instances.cols() != num_features()) {
75  throw std::logic_error(
76  fmt::format("Wrong number of columns in instances ({}). Expect one column for each of the {} features.",
77  instances.cols(), num_features()));
78  }
79  predict_scores_unchecked(instances, target);
80 }
81 
Strong typedef for an int to signify a label id.
Definition: types.h:20
label_id_t labels_end() const noexcept
Definition: model.h:102
virtual long num_features() const =0
How many weights are in each weight vector, i.e. how many features should the input have.
Model(PartialModelSpec spec)
Definition: model.cpp:13
long num_labels() const noexcept
How many labels are in the underlying dataset.
Definition: model.h:78
virtual void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref< DenseRealVector > target) const =0
Unchecked version of get_weights_for_label().
Eigen::Ref< PredictionMatrix > PredictionMatrixOut
Definition: model.h:65
virtual void predict_scores_unchecked(const FeatureMatrixIn &instances, PredictionMatrixOut target) const =0
Unchecked version of predict_scores().
void set_weights_for_label(label_id_t label, const WeightVectorIn &weights)
Sets the weights for a label.
Definition: model.cpp:51
bool is_partial_model() const
returns true if this instance only stores part of the weights of an entire model
Definition: model.cpp:37
label_id_t m_LabelsEnd
Definition: model.h:177
long num_weights() const noexcept
How many weights vectors are in this model.
Definition: model.h:87
label_id_t m_LabelsBegin
Definition: model.h:176
virtual void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn &weights)=0
Unchecked version of set_weights_for_label().
long m_NumLabels
Total number of labels of the complete model.
Definition: model.h:185
void get_weights_for_label(label_id_t label, Eigen::Ref< DenseRealVector > target) const
Gets the weights for the given label as a dense vector.
Definition: model.cpp:41
label_id_t labels_begin() const noexcept
Definition: model.h:98
void predict_scores(const FeatureMatrixIn &instances, PredictionMatrixOut target) const
Calculates the scores for all examples and all labels in this model.
Definition: model.cpp:60
label_id_t adjust_label(label_id_t label) const
Definition: model.cpp:28
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22