DiSMEC++
linear.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "linear.h"
7 #include "utils/eigen_generic.h"
8 #include "utils/throw_error.h"
9 #include "stats/timer.h"
10 
11 using namespace dismec;
12 using namespace dismec::objective;
13 
14 namespace {
16 }
17 
18 LinearClassifierBase::LinearClassifierBase(std::shared_ptr<const GenericFeatureMatrix> X) :
19  m_FeatureMatrix( std::move(X) ),
20  m_X_times_w( m_FeatureMatrix->rows() ),
21  m_LsCache_xTd( m_FeatureMatrix->rows() ),
22  m_LsCache_xTw( m_FeatureMatrix->rows() ),
23  m_Costs( m_FeatureMatrix->rows() ),
24  m_Y( m_FeatureMatrix->rows() )
25 {
26  m_Costs.fill(1);
27  declare_stat(STAT_PERF_MATMUL, {"perf_matmul", "µs"});
28 }
29 
30 
31 long LinearClassifierBase::num_instances() const noexcept {
32  return m_FeatureMatrix->rows();
33 }
34 
35 long LinearClassifierBase::num_variables() const noexcept {
36  return m_FeatureMatrix->cols();
37 }
38 
40  return m_FeatureMatrix->dense();
41 }
42 
44  return m_FeatureMatrix->sparse();
45 }
46 
48  return *m_FeatureMatrix;
49 }
50 
52  if(w.hash() == m_Last_W) {
53  return m_X_times_w;
54  }
55  auto timer = make_timer(STAT_PERF_MATMUL);
56  visit([&](auto&& features) {
57  m_X_times_w.noalias() = features * w;
58  }, *m_FeatureMatrix);
59  m_Last_W = w.hash();
60  return m_X_times_w;
61 }
62 
64  visit([&](auto&& features) {
65  m_LsCache_xTd.noalias() = features * direction;
66  }, *m_FeatureMatrix);
67  m_LsCache_xTw = x_times_w(location);
68 }
69 
72  return m_Y;
73 }
74 
76  for(int i = 0; i < m_Costs.size(); ++i) {
77  if(m_Y.coeff(i) == 1) {
78  m_Costs.coeffRef(i) = positive;
79  } else {
80  m_Costs.coeffRef(i) = negative;
81  }
82  }
83 }
84 
86  return m_Costs;
87 }
88 
90  return m_Y;
91 }
An Eigen vector with versioning information, to implement simple caching of results.
Definition: hash_vector.h:43
VectorHash hash() const
Gets the unique id of this vector.
Definition: hash_vector.cpp:45
const SparseFeatures & sparse_features() const
Definition: linear.cpp:43
std::shared_ptr< const GenericFeatureMatrix > m_FeatureMatrix
Definition: linear.h:90
DenseRealVector m_X_times_w
cache for the last result of x_times_w() corresponding to m_Last_W.
Definition: linear.h:95
long num_instances() const noexcept
Definition: linear.cpp:31
VectorHash m_Last_W
cache for the last argument to x_times_w().
Definition: linear.h:93
long num_variables() const noexcept override
Definition: linear.cpp:35
LinearClassifierBase(std::shared_ptr< const GenericFeatureMatrix > X)
Definition: linear.cpp:18
const DenseRealVector & costs() const
Definition: linear.cpp:85
const DenseFeatures & dense_features() const
Definition: linear.cpp:39
DenseRealVector m_Costs
Label-Dependent costs.
Definition: linear.h:103
DenseRealVector m_LsCache_xTw
cache for line search implementation: feature times weights
Definition: linear.h:100
BinaryLabelVector m_Y
Label vector – use a vector of ints here. We encode label present == 1, absent == -1.
Definition: linear.h:106
void update_costs(real_t positive, real_t negative)
Definition: linear.cpp:75
void project_linear_to_line(const HashVector &location, const DenseRealVector &direction)
Prepares the cache variables for line projection.
Definition: linear.cpp:63
DenseRealVector m_LsCache_xTd
cache for line search implementation: feature times direction
Definition: linear.h:98
BinaryLabelVector & get_label_ref()
Definition: linear.cpp:70
const BinaryLabelVector & labels() const
Definition: linear.cpp:89
const GenericFeatureMatrix & generic_features() const
Definition: linear.cpp:47
const DenseRealVector & x_times_w(const HashVector &w)
Calculates the vector of feature matrix times weights w
Definition: linear.cpp:51
auto make_timer(stat_id_t id, Args... args)
Creates a new ScopeTimer using stats::record_scope_time.
Definition: tracked.h:130
void declare_stat(stat_id_t index, StatisticMetaData meta)
Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCo...
Definition: tracked.cpp:16
constexpr const dismec::stats::stat_id_t STAT_PERF_MATMUL
Definition: linear.cpp:15
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50
float real_t
The default type for floating point values.
Definition: config.h:17