DiSMEC++
dismec::objective::LinearClassifierBase Class Referenceabstract

Base class for objectives that use a linear classifier. More...

#include <linear.h>

Inheritance diagram for dismec::objective::LinearClassifierBase:
dismec::objective::Objective dismec::stats::Tracked dismec::objective::LinearClassifierImpBase< Regularized_SquaredHingeSVC > dismec::objective::GenericLinearClassifier dismec::objective::LinearClassifierImpBase< Derived > dismec::objective::Regularized_SquaredHingeSVC dismec::objective::GenericMarginClassifier< MarginFunction >

Public Member Functions

 LinearClassifierBase (std::shared_ptr< const GenericFeatureMatrix > X)
 
long num_instances () const noexcept
 
long num_variables () const noexcept override
 
BinaryLabelVectorget_label_ref ()
 
void update_costs (real_t positive, real_t negative)
 
- Public Member Functions inherited from dismec::objective::Objective
 Objective ()
 
virtual ~Objective () noexcept=default
 
real_t value (const HashVector &location)
 Evaluate the objective at the given location. More...
 
void diag_preconditioner (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Get precondition to be used in CG optimization. More...
 
void project_to_line (const HashVector &location, const DenseRealVector &direction)
 creates a function g such that g(a) = objective(location + a * direction) Use lookup_on_line() to evaluate g. More...
 
virtual real_t lookup_on_line (real_t position)=0
 Looks up the value of the objective on the line defined by the last call to project_to_line(). More...
 
void gradient_at_zero (Eigen::Ref< DenseRealVector > target)
 Gets the gradient for location zero. More...
 
void gradient (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Evaluate the gradient at location. More...
 
void hessian_times_direction (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target)
 Calculates the product of the Hessian matrix at location with direction. More...
 
void gradient_and_pre_conditioner (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre)
 Combines the calculation of gradient and pre-conditioner, which may be more efficient in some cases. More...
 
- Public Member Functions inherited from dismec::stats::Tracked
 Tracked ()
 Default constructor, creates the internal stats::StatisticsCollection. More...
 
void register_stat (const std::string &name, std::unique_ptr< Statistics > stat)
 Registers a tracker for the statistics name. More...
 
std::shared_ptr< StatisticsCollectionget_stats () const
 Gets an ownership-sharing reference to the StatisticsCollection. More...
 

Protected Member Functions

const DenseRealVectorx_times_w (const HashVector &w)
 Calculates the vector of feature matrix times weights w More...
 
template<class Derived >
void update_xtw_cache (const HashVector &new_weight, const Eigen::MatrixBase< Derived > &new_result)
 Updates the cached value for x_times_w. More...
 
void project_linear_to_line (const HashVector &location, const DenseRealVector &direction)
 Prepares the cache variables for line projection. More...
 
auto line_interpolation (real_t t) const
 
void declare_vector_on_last_line (const HashVector &location, real_t t) override
 State that the given vector corresponds to a certain position on the line of the last line search. More...
 
const GenericFeatureMatrixgeneric_features () const
 
const DenseFeaturesdense_features () const
 
const SparseFeaturessparse_features () const
 
const DenseRealVectorcosts () const
 
const BinaryLabelVectorlabels () const
 
- Protected Member Functions inherited from dismec::stats::Tracked
 ~Tracked ()
 Non-virtual destructor. Declared protected, so we don't accidentally try to do a polymorphic delete. More...
 
template<class T >
void record (stat_id_t stat, T &&value)
 Record statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_stat (stat_id_t index, StatisticMetaData meta)
 Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_tag (tag_id_t index, std::string name)
 Declares a new tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
void set_tag (tag_id_t tag, long value)
 Set value of tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
auto make_timer (stat_id_t id, Args... args)
 Creates a new ScopeTimer using stats::record_scope_time. More...
 

Private Member Functions

virtual void invalidate_labels ()=0
 

Private Attributes

std::shared_ptr< const GenericFeatureMatrixm_FeatureMatrix
 
VectorHash m_Last_W {}
 cache for the last argument to x_times_w(). More...
 
DenseRealVector m_X_times_w
 cache for the last result of x_times_w() corresponding to m_Last_W. More...
 
DenseRealVector m_LsCache_xTd
 cache for line search implementation: feature times direction More...
 
DenseRealVector m_LsCache_xTw
 cache for line search implementation: feature times weights More...
 
DenseRealVector m_Costs
 Label-Dependent costs. More...
 
BinaryLabelVector m_Y
 Label vector – use a vector of ints here. We encode label present == 1, absent == -1. More...
 

Detailed Description

Base class for objectives that use a linear classifier.

This base provides functions for caching the matrix multiplication \( x^T w \) that can be used by derived classes. This class does not assume a specific type for the feature matrix, but stores it as a GenericFeatureMatrix. It also provides caching for the line search computations.

This base class also stores the per-instance const vector and the ground-truth label vector.

Attention
Derived classes may want to cache certain results based on the predicted scores and the labels. For the input, location parameter, we can easily check if we can reuse a cached result, but this will break if the labels change. Therefore, such classes should implement the virtual function invalidate_labels() so that cached results will be invalidated. This function will be called by the base class whenever the label vector is modified.

Definition at line 27 of file linear.h.

Constructor & Destructor Documentation

◆ LinearClassifierBase()

LinearClassifierBase::LinearClassifierBase ( std::shared_ptr< const GenericFeatureMatrix X)

Member Function Documentation

◆ costs()

◆ declare_vector_on_last_line()

void dismec::objective::LinearClassifierBase::declare_vector_on_last_line ( const HashVector location,
real_t  t 
)
inlineoverrideprotectedvirtual

State that the given vector corresponds to a certain position on the line of the last line search.

This function is a pure optimization hint. It is used in the following scenario: If several computations need the product of weight vector w and feature matrix X, then we can compute this product only once and use a cached value for all later invocations. This can be done by comparing the vector hashes. However, as soon as a vector is modified, these hashes are invalidated. To do an efficient line search over ‘w’ = w + t d, we also cache the value ofX d, so thatX w' = X w + t X d. This function then declares that the vector given inlocationcorresponds tow + t d, wherewandd are the arguments passed to the last call ofproject_to_line()`.

Todo:
improve this interface, together with project_to_line, to be less error prone!

Reimplemented from dismec::objective::Objective.

Definition at line 76 of file linear.h.

References m_LsCache_xTd, m_LsCache_xTw, and update_xtw_cache().

◆ dense_features()

const DenseFeatures & LinearClassifierBase::dense_features ( ) const
protected

Definition at line 39 of file linear.cpp.

References m_FeatureMatrix.

◆ generic_features()

const GenericFeatureMatrix & LinearClassifierBase::generic_features ( ) const
protected

Definition at line 47 of file linear.cpp.

References m_FeatureMatrix.

◆ get_label_ref()

BinaryLabelVector & LinearClassifierBase::get_label_ref ( )

Definition at line 70 of file linear.cpp.

References invalidate_labels(), and m_Y.

Referenced by dismec::init::create_ova_primal_initializer().

◆ invalidate_labels()

virtual void dismec::objective::LinearClassifierBase::invalidate_labels ( )
privatepure virtual

This function will be called whenever m_Y changes so that derived classes can invalidate their caches.

Implemented in dismec::objective::Regularized_SquaredHingeSVC, and dismec::objective::GenericLinearClassifier.

Referenced by get_label_ref().

◆ labels()

◆ line_interpolation()

auto dismec::objective::LinearClassifierBase::line_interpolation ( real_t  t) const
inlineprotected

◆ num_instances()

long LinearClassifierBase::num_instances ( ) const
noexcept

Definition at line 31 of file linear.cpp.

References m_FeatureMatrix.

◆ num_variables()

long LinearClassifierBase::num_variables ( ) const
overridevirtualnoexcept

Gets the number of variables this objective expects. May return -1 if the objective is agnostic to the number of variables, e.g. for regularizers.

Implements dismec::objective::Objective.

Definition at line 35 of file linear.cpp.

References m_FeatureMatrix.

◆ project_linear_to_line()

void LinearClassifierBase::project_linear_to_line ( const HashVector location,
const DenseRealVector direction 
)
protected

Prepares the cache variables for line projection.

This function precomputes \( x^T d\) and \(x^T w\), so that for a line search parameter t we can get \( x^T (w + td) \) by simple linear combination, i.e. we can skip the matrix multiplication. This can then be used in the implementation of the line search lookup by calling the line_interpolation() function.

Parameters
locationThe origin point on the line.
directionThe direction of the line.

Definition at line 63 of file linear.cpp.

References m_FeatureMatrix, m_LsCache_xTd, m_LsCache_xTw, dismec::types::visit(), and x_times_w().

Referenced by dismec::objective::GenericLinearClassifier::project_to_line_unchecked(), and dismec::objective::LinearClassifierImpBase< Derived >::project_to_line_unchecked().

◆ sparse_features()

const SparseFeatures & LinearClassifierBase::sparse_features ( ) const
protected

Definition at line 43 of file linear.cpp.

References m_FeatureMatrix.

Referenced by dismec::objective::Regularized_SquaredHingeSVC::features().

◆ update_costs()

void LinearClassifierBase::update_costs ( real_t  positive,
real_t  negative 
)

Definition at line 75 of file linear.cpp.

References m_Costs, and m_Y.

◆ update_xtw_cache()

template<class Derived >
void dismec::objective::LinearClassifierBase::update_xtw_cache ( const HashVector new_weight,
const Eigen::MatrixBase< Derived > &  new_result 
)
inlineprotected

Updates the cached value for x_times_w.

Parameters
new_weightThe new value of w.
new_resultThe value of x^T w.

Definition at line 53 of file linear.h.

References dismec::HashVector::hash(), m_Last_W, and m_X_times_w.

Referenced by declare_vector_on_last_line().

◆ x_times_w()

const DenseRealVector & LinearClassifierBase::x_times_w ( const HashVector w)
protected

Calculates the vector of feature matrix times weights w

Parameters
wThe weight vector with which to multiply.

Consecutive calls to this function with the same argument w will return a reference to a cached result. However, calling with another value in between will invalidate the cache.

Returns
A reference to the cached result vector.

Definition at line 51 of file linear.cpp.

References dismec::HashVector::hash(), m_FeatureMatrix, m_Last_W, m_X_times_w, dismec::stats::Tracked::make_timer(), anonymous_namespace{linear.cpp}::STAT_PERF_MATMUL, and dismec::types::visit().

Referenced by dismec::objective::Regularized_SquaredHingeSVC::margin_error(), project_linear_to_line(), and dismec::objective::LinearClassifierImpBase< Derived >::value_unchecked().

Member Data Documentation

◆ m_Costs

DenseRealVector dismec::objective::LinearClassifierBase::m_Costs
private

Label-Dependent costs.

Definition at line 103 of file linear.h.

Referenced by costs(), LinearClassifierBase(), and update_costs().

◆ m_FeatureMatrix

std::shared_ptr<const GenericFeatureMatrix> dismec::objective::LinearClassifierBase::m_FeatureMatrix
private

we keep a refcounted pointer to the training features. this is to support shared memory parallelization of multilabel training. Derived classes may form a pointer to the concrete type of m_FeatureMatrix

Definition at line 90 of file linear.h.

Referenced by dense_features(), generic_features(), num_instances(), num_variables(), project_linear_to_line(), sparse_features(), and x_times_w().

◆ m_Last_W

VectorHash dismec::objective::LinearClassifierBase::m_Last_W {}
private

cache for the last argument to x_times_w().

Definition at line 93 of file linear.h.

Referenced by update_xtw_cache(), and x_times_w().

◆ m_LsCache_xTd

DenseRealVector dismec::objective::LinearClassifierBase::m_LsCache_xTd
private

cache for line search implementation: feature times direction

Definition at line 98 of file linear.h.

Referenced by declare_vector_on_last_line(), line_interpolation(), and project_linear_to_line().

◆ m_LsCache_xTw

DenseRealVector dismec::objective::LinearClassifierBase::m_LsCache_xTw
private

cache for line search implementation: feature times weights

Definition at line 100 of file linear.h.

Referenced by declare_vector_on_last_line(), line_interpolation(), and project_linear_to_line().

◆ m_X_times_w

DenseRealVector dismec::objective::LinearClassifierBase::m_X_times_w
private

cache for the last result of x_times_w() corresponding to m_Last_W.

Definition at line 95 of file linear.h.

Referenced by update_xtw_cache(), and x_times_w().

◆ m_Y

BinaryLabelVector dismec::objective::LinearClassifierBase::m_Y
private

Label vector – use a vector of ints here. We encode label present == 1, absent == -1.

Definition at line 106 of file linear.h.

Referenced by get_label_ref(), labels(), and update_costs().


The documentation for this class was generated from the following files: