DiSMEC++
linear.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_LINEAR_H
7 #define DISMEC_LINEAR_H
8 
9 #include "objective.h"
10 #include "utils/hash_vector.h"
11 
12 namespace dismec::objective {
28  public:
29  LinearClassifierBase(std::shared_ptr<const GenericFeatureMatrix> X);
30 
31  [[nodiscard]] long num_instances() const noexcept;
32  [[nodiscard]] long num_variables() const noexcept override;
33 
34  [[nodiscard]] BinaryLabelVector& get_label_ref();
35  void update_costs(real_t positive, real_t negative);
36  protected:
45  const DenseRealVector& x_times_w(const HashVector& w);
46 
52  template<class Derived>
53  void update_xtw_cache(const HashVector& new_weight, const Eigen::MatrixBase<Derived>& new_result) {
54  // update the cached result to the new value
55  m_X_times_w.noalias() = new_result;
56  // and set the hash so that we can identify calls using the new weights
57  m_Last_W = new_weight.hash();
58  }
59 
70  void project_linear_to_line(const HashVector& location, const DenseRealVector& direction);
71 
72  [[nodiscard]] auto line_interpolation(real_t t) const {
73  return m_LsCache_xTw + t * m_LsCache_xTd;
74  }
75 
76  void declare_vector_on_last_line(const HashVector& location, real_t t) override {
78  }
79 
80  [[nodiscard]] const GenericFeatureMatrix& generic_features() const;
81  [[nodiscard]] const DenseFeatures& dense_features() const;
82  [[nodiscard]] const SparseFeatures& sparse_features() const;
83 
84  [[nodiscard]] const DenseRealVector& costs() const;
85  [[nodiscard]] const BinaryLabelVector& labels() const;
86  private:
90  std::shared_ptr<const GenericFeatureMatrix> m_FeatureMatrix;
91 
96 
101 
104 
107 
110  virtual void invalidate_labels() = 0;
111  };
112 
123  template<class Derived>
125  public:
126  LinearClassifierImpBase(std::shared_ptr<const GenericFeatureMatrix> X, std::unique_ptr<Objective> regularizer) :
127  LinearClassifierBase( std::move(X) ), m_Regularizer( std::move(regularizer) ) {};
128  protected:
129  const Derived& derived() const {
130  return static_cast<const Derived&>(*this);
131  }
132 
133  Derived& derived() {
134  return static_cast<Derived&>(*this);
135  }
136 
137  real_t value_unchecked(const HashVector& location) override {
138  const DenseRealVector& xTw = x_times_w(location);
139  return derived().value_from_xTw(costs(), labels(), xTw) + m_Regularizer->value(location);
140  }
141 
142  real_t lookup_on_line(real_t position) override {
144  return f + m_Regularizer->lookup_on_line(position);
145  }
146 
147  void project_to_line_unchecked(const HashVector& location, const DenseRealVector& direction) override {
148  project_linear_to_line(location, direction);
149  m_Regularizer->project_to_line(location, direction);
150  }
151 
152  void gradient_unchecked(const HashVector& location, Eigen::Ref<DenseRealVector> target) override {
153  m_Regularizer->gradient(location, target);
154  derived().gradient_imp(location, target);
155  }
156 
157 
158  void gradient_at_zero_unchecked(Eigen::Ref<DenseRealVector> target) override {
159  m_Regularizer->gradient_at_zero(target);
160  derived().gradient_at_zero_imp(target);
161  }
162 
164  const DenseRealVector& direction,
165  Eigen::Ref<DenseRealVector> target) override {
166  m_Regularizer->hessian_times_direction(location, direction, target);
167  derived().hessian_times_direction_imp(location, direction, target);
168  }
169 
170  void diag_preconditioner_unchecked(const HashVector& location, Eigen::Ref<DenseRealVector> target) override {
171  m_Regularizer->diag_preconditioner(location, target);
172  derived().diag_preconditioner_imp(location, target);
173  }
174 
175  void gradient_and_pre_conditioner_unchecked(const HashVector& location, Eigen::Ref<DenseRealVector> gradient,
176  Eigen::Ref<DenseRealVector> pre) override {
177  m_Regularizer->gradient_and_pre_conditioner(location, gradient, pre);
178  derived().gradient_and_pre_conditioner_imp(location, gradient, pre);
179  }
180  private:
182  std::unique_ptr<Objective> m_Regularizer;
183  };
184 }
185 
186 #endif //DISMEC_LINEAR_H
An Eigen vector with versioning information, to implement simple caching of results.
Definition: hash_vector.h:43
VectorHash hash() const
Gets the unique id of this vector.
Definition: hash_vector.cpp:45
A unique identifier for a HashVector.
Definition: hash_vector.h:118
Base class for objectives that use a linear classifier.
Definition: linear.h:27
const SparseFeatures & sparse_features() const
Definition: linear.cpp:43
std::shared_ptr< const GenericFeatureMatrix > m_FeatureMatrix
Definition: linear.h:90
DenseRealVector m_X_times_w
cache for the last result of x_times_w() corresponding to m_Last_W.
Definition: linear.h:95
long num_instances() const noexcept
Definition: linear.cpp:31
VectorHash m_Last_W
cache for the last argument to x_times_w().
Definition: linear.h:93
long num_variables() const noexcept override
Definition: linear.cpp:35
LinearClassifierBase(std::shared_ptr< const GenericFeatureMatrix > X)
Definition: linear.cpp:18
const DenseRealVector & costs() const
Definition: linear.cpp:85
const DenseFeatures & dense_features() const
Definition: linear.cpp:39
void update_xtw_cache(const HashVector &new_weight, const Eigen::MatrixBase< Derived > &new_result)
Updates the cached value for x_times_w.
Definition: linear.h:53
DenseRealVector m_Costs
Label-Dependent costs.
Definition: linear.h:103
void declare_vector_on_last_line(const HashVector &location, real_t t) override
State that the given vector corresponds to a certain position on the line of the last line search.
Definition: linear.h:76
DenseRealVector m_LsCache_xTw
cache for line search implementation: feature times weights
Definition: linear.h:100
BinaryLabelVector m_Y
Label vector – use a vector of ints here. We encode label present == 1, absent == -1.
Definition: linear.h:106
void update_costs(real_t positive, real_t negative)
Definition: linear.cpp:75
auto line_interpolation(real_t t) const
Definition: linear.h:72
void project_linear_to_line(const HashVector &location, const DenseRealVector &direction)
Prepares the cache variables for line projection.
Definition: linear.cpp:63
DenseRealVector m_LsCache_xTd
cache for line search implementation: feature times direction
Definition: linear.h:98
BinaryLabelVector & get_label_ref()
Definition: linear.cpp:70
const BinaryLabelVector & labels() const
Definition: linear.cpp:89
const GenericFeatureMatrix & generic_features() const
Definition: linear.cpp:47
const DenseRealVector & x_times_w(const HashVector &w)
Calculates the vector of feature matrix times weights w
Definition: linear.cpp:51
Implementation helper for linear classifier derived classes.
Definition: linear.h:124
void gradient_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > target) override
Definition: linear.h:152
void diag_preconditioner_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > target) override
Definition: linear.h:170
real_t value_unchecked(const HashVector &location) override
Definition: linear.h:137
void hessian_times_direction_unchecked(const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override
Definition: linear.h:163
std::unique_ptr< Objective > m_Regularizer
Pointer to the regularizer.
Definition: linear.h:182
LinearClassifierImpBase(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
Definition: linear.h:126
void gradient_and_pre_conditioner_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) override
Definition: linear.h:175
void project_to_line_unchecked(const HashVector &location, const DenseRealVector &direction) override
Definition: linear.h:147
void gradient_at_zero_unchecked(Eigen::Ref< DenseRealVector > target) override
Definition: linear.h:158
real_t lookup_on_line(real_t position) override
Looks up the value of the objective on the line defined by the last call to project_to_line().
Definition: linear.h:142
const Derived & derived() const
Definition: linear.h:129
Class that models an optimization objective.
Definition: objective.h:41
void gradient(const HashVector &location, Eigen::Ref< DenseRealVector > target)
Evaluate the gradient at location.
Definition: objective.cpp:96
real_t value_from_xTw(const DenseRealVector &cost, const BinaryLabelVector &labels, const Eigen::DenseBase< Derived > &xTw)
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50
float real_t
The default type for floating point values.
Definition: config.h:17