DiSMEC++
dense.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 // We are deliberately doing things wrong here, and apparently that triggers a gcc warning.
7 #include "model/dense.h"
8 #include "spdlog/fmt/fmt.h"
9 #include "utils/eigen_generic.h"
10 
11 using namespace dismec;
12 using namespace dismec::model;
13 
14 namespace {
23  long check_positive(long v, const char* error_msg) {
24  if(v > 0) {
25  return v;
26  }
27  throw std::invalid_argument(error_msg);
28  }
29 }
30 
32  DenseModel(weights, {label_id_t{0}, weights->cols(), weights->cols()}) {
35 }
36 
38  Model(partial), m_Weights(std::move(weights))
39 {
40  if(m_Weights->cols() != partial.label_count) {
41  throw std::invalid_argument(fmt::format("Declared {} weights, but got matrix with {} columns",
42  partial.label_count, m_Weights->cols()));
43  }
44 }
45 
46 DenseModel::DenseModel(long num_features, long num_labels):
47  DenseModel(num_features, PartialModelSpec{label_id_t{0}, num_labels, num_labels})
48 {
49 }
50 
51 DenseModel::DenseModel(long num_features, PartialModelSpec partial) :
52  DenseModel(std::make_shared<WeightMatrix>(
53  check_positive(num_features, "Number of features must be positive!"),
54  check_positive(partial.label_count, "Number of weight must be positive!")),
55  partial)
56 {
57 }
58 
60  return m_Weights->rows();
61 }
62 
63 
64 void DenseModel::get_weights_for_label_unchecked(label_id_t label, Eigen::Ref<DenseRealVector> target) const
65 {
66  target = m_Weights->col(label.to_index());
67 }
68 
70 {
71  visit([this, label](auto&& v){
72  m_Weights->col(label.to_index()) = v;
73  }, weights);
74 }
75 
77  visit([&, this](const auto& features) {
78  target.noalias() = features * (*m_Weights);
79  }, instances);
80 }
Strong typedef for an int to signify a label id.
Definition: types.h:20
Implementation of the Model class that stores the weights as a single, dense matrix.
Definition: dense.h:17
types::DenseColMajor< real_t > WeightMatrix
Definition: dense.h:19
weight_matrix_ptr m_Weights
The matrix of weights.
Definition: dense.h:65
void set_weights_for_label_unchecked(label_id_t label, const WeightVectorIn &weights) override
Unchecked version of set_weights_for_label().
Definition: dense.cpp:69
long num_features() const override
How many weights are in each weight vector, i.e. how many features should the input have.
Definition: dense.cpp:59
void predict_scores_unchecked(const FeatureMatrixIn &instances, PredictionMatrixOut target) const override
Unchecked version of predict_scores().
Definition: dense.cpp:76
void get_weights_for_label_unchecked(label_id_t label, Eigen::Ref< DenseRealVector > target) const override
Unchecked version of get_weights_for_label().
Definition: dense.cpp:64
std::shared_ptr< WeightMatrix > weight_matrix_ptr
Definition: dense.h:20
DenseModel(const weight_matrix_ptr &weights)
Creates a (complete) dense model with the given weight matrix.
Definition: dense.cpp:31
A model combines a set of weight with some meta-information about these weights.
Definition: model.h:63
Eigen::Ref< PredictionMatrix > PredictionMatrixOut
Definition: model.h:65
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
long check_positive(long v, const char *error_msg)
Checks that v is positive and returns v.
Definition: dense.cpp:23
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22
long label_count
Number of labels in the partial model.
Definition: model.h:24