DiSMEC++
generic_linear.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_GENERIC_LINEAR_H
7 #define DISMEC_GENERIC_LINEAR_H
8 
9 #include "linear.h"
10 
11 namespace dismec::objective {
26  public:
27  GenericLinearClassifier(std::shared_ptr<const GenericFeatureMatrix> X, std::unique_ptr<Objective> regularizer);
28  private:
29  // declaration of the "unchecked" methods that need to be implemented for an objective.
31  real_t value_unchecked(const HashVector& location) override;
32  real_t lookup_on_line(real_t position) override;
33  void gradient_unchecked(const HashVector& location, Eigen::Ref<DenseRealVector> target) override;
34  void gradient_at_zero_unchecked(Eigen::Ref<DenseRealVector> target) override;
36  const HashVector& location,
37  const DenseRealVector& direction,
38  Eigen::Ref<DenseRealVector> target) override;
40  const HashVector& location,
41  Eigen::Ref<DenseRealVector> target) override;
43  const HashVector& location,
44  Eigen::Ref<DenseRealVector> gradient,
45  Eigen::Ref<DenseRealVector> pre) override;
46  void project_to_line_unchecked(const HashVector& location, const DenseRealVector& direction) override;
48 
50 
51  // virtual methods to be implemented in derived classes
53 
59  virtual void calculate_loss(const DenseRealVector& scores,
61  DenseRealVector& out) const = 0;
68  virtual void calculate_derivative(const DenseRealVector& scores,
70  DenseRealVector& out) const = 0;
71 
78  virtual void calculate_2nd_derivative(const DenseRealVector& scores,
80  DenseRealVector& out) const = 0;
81 
88  const DenseRealVector& cached_derivative(const HashVector& location);
95  const DenseRealVector& cached_2nd_derivative(const HashVector& location);
97 
98  void invalidate_labels() override;
99 
104 
107 
109  std::unique_ptr<Objective> m_Regularizer;
110  };
111 
119  template<class MarginFunction>
121  GenericMarginClassifier(std::shared_ptr<const GenericFeatureMatrix> X,
122  std::unique_ptr<Objective> regularizer,
123  MarginFunction phi) : GenericLinearClassifier( std::move(X), std::move(regularizer) ),
124  Phi(std::move(phi)) {
125 
126  }
127 
128  void calculate_loss(const DenseRealVector& scores,
129  const BinaryLabelVector& labels,
130  DenseRealVector& out) const override {
131  assert(scores.size() == labels.size());
132  for(int i = 0; i < scores.size(); ++i) {
133  real_t margin = scores.coeff(i) * real_t(labels.coeff(i));
134  out.coeffRef(i) = Phi.value(margin);
135  }
136  }
137 
139  const BinaryLabelVector& labels,
140  DenseRealVector& out) const override {
141  assert(scores.size() == labels.size());
142  for(int i = 0; i < scores.size(); ++i) {
143  real_t label = labels.coeff(i);
144  real_t margin = scores.coeff(i) * label;
145  out.coeffRef(i) = Phi.grad(margin) * label;
146  }
147  }
148 
150  const BinaryLabelVector& labels,
151  DenseRealVector& out) const override {
152  assert(scores.size() == labels.size());
153  for(int i = 0; i < scores.size(); ++i) {
154  real_t margin = scores.coeff(i) * real_t(labels.coeff(i));
155  out.coeffRef(i) = Phi.quad(margin);
156  }
157  }
158 
159  MarginFunction Phi;
160  };
161 
162 
163  std::unique_ptr<GenericLinearClassifier> make_squared_hinge(std::shared_ptr<const GenericFeatureMatrix> X,
164  std::unique_ptr<Objective> regularizer);
165 
166  std::unique_ptr<GenericLinearClassifier> make_logistic_loss(std::shared_ptr<const GenericFeatureMatrix> X,
167  std::unique_ptr<Objective> regularizer);
168 
169  std::unique_ptr<GenericLinearClassifier> make_huber_hinge(std::shared_ptr<const GenericFeatureMatrix> X,
170  std::unique_ptr<Objective> regularizer, real_t epsilon);
171 
172 }
173 
174 #endif //DISMEC_GENERIC_LINEAR_H
An Eigen vector with versioning information, to implement simple caching of results.
Definition: hash_vector.h:43
This is a non-templated, runtime-polymorphic generic implementation of the linear classifier objectiv...
std::unique_ptr< Objective > m_Regularizer
Pointer to the regularizer.
void project_to_line_unchecked(const HashVector &location, const DenseRealVector &direction) override
void gradient_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > target) override
virtual void calculate_loss(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const =0
Calculates the loss for each instance.
void hessian_times_direction_unchecked(const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override
CacheHelper m_DerivativeBuffer
Cached value of the last calculation of the 2nd derivative. Needs to be invalidated when the labels c...
const DenseRealVector & cached_derivative(const HashVector &location)
Gets the derivative vector for the current location.
GenericLinearClassifier(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
real_t value_from_xTw(const DenseRealVector &xTw)
void diag_preconditioner_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > target) override
real_t lookup_on_line(real_t position) override
Looks up the value of the objective on the line defined by the last call to project_to_line().
virtual void calculate_2nd_derivative(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const =0
Calculates the 2nd derivative of the loss with respect to the scores for each instance.
real_t value_unchecked(const HashVector &location) override
void gradient_at_zero_unchecked(Eigen::Ref< DenseRealVector > target) override
virtual void calculate_derivative(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const =0
Calculates the derivative of the loss with respect to the scores for each instance.
CacheHelper m_SecondDerivativeBuffer
Cached value of the last calculation of the loss derivative. Needs to be invalidated when the labels ...
void gradient_and_pre_conditioner_unchecked(const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) override
const DenseRealVector & cached_2nd_derivative(const HashVector &location)
Gets the 2nd derivative vector for the current location.
Base class for objectives that use a linear classifier.
Definition: linear.h:27
const BinaryLabelVector & labels() const
Definition: linear.cpp:89
void gradient(const HashVector &location, Eigen::Ref< DenseRealVector > target)
Evaluate the gradient at location.
Definition: objective.cpp:96
std::unique_ptr< GenericLinearClassifier > make_huber_hinge(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer, real_t epsilon)
std::unique_ptr< GenericLinearClassifier > make_logistic_loss(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
std::unique_ptr< GenericLinearClassifier > make_squared_hinge(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17
A utility class template that, when instatiated with a MarginFunction, produces the corresponding lin...
void calculate_loss(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const override
Calculates the loss for each instance.
void calculate_2nd_derivative(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const override
Calculates the 2nd derivative of the loss with respect to the scores for each instance.
GenericMarginClassifier(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer, MarginFunction phi)
void calculate_derivative(const DenseRealVector &scores, const BinaryLabelVector &labels, DenseRealVector &out) const override
Calculates the derivative of the loss with respect to the scores for each instance.