DiSMEC++
dismec::objective::LinearClassifierImpBase< Derived > Class Template Reference

Implementation helper for linear classifier derived classes. More...

#include <linear.h>

Inheritance diagram for dismec::objective::LinearClassifierImpBase< Derived >:
dismec::objective::LinearClassifierBase dismec::objective::Objective dismec::stats::Tracked

Public Member Functions

 LinearClassifierImpBase (std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
 
- Public Member Functions inherited from dismec::objective::LinearClassifierBase
 LinearClassifierBase (std::shared_ptr< const GenericFeatureMatrix > X)
 
long num_instances () const noexcept
 
long num_variables () const noexcept override
 
BinaryLabelVectorget_label_ref ()
 
void update_costs (real_t positive, real_t negative)
 
- Public Member Functions inherited from dismec::objective::Objective
 Objective ()
 
virtual ~Objective () noexcept=default
 
real_t value (const HashVector &location)
 Evaluate the objective at the given location. More...
 
void diag_preconditioner (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Get precondition to be used in CG optimization. More...
 
void project_to_line (const HashVector &location, const DenseRealVector &direction)
 creates a function g such that g(a) = objective(location + a * direction) Use lookup_on_line() to evaluate g. More...
 
void gradient_at_zero (Eigen::Ref< DenseRealVector > target)
 Gets the gradient for location zero. More...
 
void gradient (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Evaluate the gradient at location. More...
 
void hessian_times_direction (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target)
 Calculates the product of the Hessian matrix at location with direction. More...
 
void gradient_and_pre_conditioner (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre)
 Combines the calculation of gradient and pre-conditioner, which may be more efficient in some cases. More...
 
- Public Member Functions inherited from dismec::stats::Tracked
 Tracked ()
 Default constructor, creates the internal stats::StatisticsCollection. More...
 
void register_stat (const std::string &name, std::unique_ptr< Statistics > stat)
 Registers a tracker for the statistics name. More...
 
std::shared_ptr< StatisticsCollectionget_stats () const
 Gets an ownership-sharing reference to the StatisticsCollection. More...
 

Protected Member Functions

const Derived & derived () const
 
Derived & derived ()
 
real_t value_unchecked (const HashVector &location) override
 
real_t lookup_on_line (real_t position) override
 Looks up the value of the objective on the line defined by the last call to project_to_line(). More...
 
void project_to_line_unchecked (const HashVector &location, const DenseRealVector &direction) override
 
void gradient_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override
 
void gradient_at_zero_unchecked (Eigen::Ref< DenseRealVector > target) override
 
void hessian_times_direction_unchecked (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override
 
void diag_preconditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override
 
void gradient_and_pre_conditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) override
 
- Protected Member Functions inherited from dismec::objective::LinearClassifierBase
const DenseRealVectorx_times_w (const HashVector &w)
 Calculates the vector of feature matrix times weights w More...
 
template<class Derived >
void update_xtw_cache (const HashVector &new_weight, const Eigen::MatrixBase< Derived > &new_result)
 Updates the cached value for x_times_w. More...
 
void project_linear_to_line (const HashVector &location, const DenseRealVector &direction)
 Prepares the cache variables for line projection. More...
 
auto line_interpolation (real_t t) const
 
void declare_vector_on_last_line (const HashVector &location, real_t t) override
 State that the given vector corresponds to a certain position on the line of the last line search. More...
 
const GenericFeatureMatrixgeneric_features () const
 
const DenseFeaturesdense_features () const
 
const SparseFeaturessparse_features () const
 
const DenseRealVectorcosts () const
 
const BinaryLabelVectorlabels () const
 
- Protected Member Functions inherited from dismec::stats::Tracked
 ~Tracked ()
 Non-virtual destructor. Declared protected, so we don't accidentally try to do a polymorphic delete. More...
 
template<class T >
void record (stat_id_t stat, T &&value)
 Record statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_stat (stat_id_t index, StatisticMetaData meta)
 Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_tag (tag_id_t index, std::string name)
 Declares a new tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
void set_tag (tag_id_t tag, long value)
 Set value of tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
auto make_timer (stat_id_t id, Args... args)
 Creates a new ScopeTimer using stats::record_scope_time. More...
 

Private Attributes

std::unique_ptr< Objectivem_Regularizer
 Pointer to the regularizer. More...
 

Detailed Description

template<class Derived>
class dismec::objective::LinearClassifierImpBase< Derived >

Implementation helper for linear classifier derived classes.

Template Parameters
DerivedThe derived class, which is expected to provide the following functions:
template<typename Derived>
real_t value_from_xTw(const DenseRealVector& cost, const BinaryLabelVector& labels, const Eigen::DenseBase<Derived>& xTw);
void gradient_and_diag()
const BinaryLabelVector & labels() const
Definition: linear.cpp:89
real_t value_from_xTw(const DenseRealVector &cost, const BinaryLabelVector &labels, const Eigen::DenseBase< Derived > &xTw)
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
Definition: regularizers.h:11

Definition at line 124 of file linear.h.

Constructor & Destructor Documentation

◆ LinearClassifierImpBase()

template<class Derived >
dismec::objective::LinearClassifierImpBase< Derived >::LinearClassifierImpBase ( std::shared_ptr< const GenericFeatureMatrix X,
std::unique_ptr< Objective regularizer 
)
inline

Definition at line 126 of file linear.h.

Member Function Documentation

◆ derived() [1/2]

template<class Derived >
Derived& dismec::objective::LinearClassifierImpBase< Derived >::derived ( )
inlineprotected

Definition at line 133 of file linear.h.

◆ derived() [2/2]

◆ diag_preconditioner_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::diag_preconditioner_unchecked ( const HashVector location,
Eigen::Ref< DenseRealVector target 
)
inlineoverrideprotectedvirtual

The function that does the actual computation. This is called in diag_preconditioner() after the arguments have been validated. The default implementation returns ones.

Reimplemented from dismec::objective::Objective.

Definition at line 170 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.

◆ gradient_and_pre_conditioner_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::gradient_and_pre_conditioner_unchecked ( const HashVector location,
Eigen::Ref< DenseRealVector gradient,
Eigen::Ref< DenseRealVector pre 
)
inlineoverrideprotectedvirtual

The function that does the actual computation. This is called in gradient_and_pre_conditioner() after the arguments have been validated. The default implementation sucessively calls gradient() and diag_preconditioner().

Reimplemented from dismec::objective::Objective.

Definition at line 175 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::derived(), dismec::objective::Objective::gradient(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.

◆ gradient_at_zero_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::gradient_at_zero_unchecked ( Eigen::Ref< DenseRealVector target)
inlineoverrideprotectedvirtual

The function that does the actual computation. This is called in gradient_at_zero() after the argument has been validated. The default implementation is rather inefficient and creates a new temporary zero vector.

Reimplemented from dismec::objective::Objective.

Definition at line 158 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.

◆ gradient_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::gradient_unchecked ( const HashVector location,
Eigen::Ref< DenseRealVector target 
)
inlineoverrideprotectedvirtual

The function that does the actual gradient computation. This is called in gradient() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 152 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.

◆ hessian_times_direction_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::hessian_times_direction_unchecked ( const HashVector location,
const DenseRealVector direction,
Eigen::Ref< DenseRealVector target 
)
inlineoverrideprotectedvirtual

The function that does the actual computation. This is called in hessian_times_direction() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 163 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.

◆ lookup_on_line()

template<class Derived >
real_t dismec::objective::LinearClassifierImpBase< Derived >::lookup_on_line ( real_t  position)
inlineoverrideprotectedvirtual

Looks up the value of the objective on the line defined by the last call to project_to_line().

Parameters
positionThe location where the objective is calculated.
Returns
The value of objective(location + position * direction), where location and direction are the vectors passed to the last call of project_to_line().
Attention
This function may use results cached project_to_line(), so it has to be called after a call to that function. A new call to project_to_line() will change the line which is evaluated.

Implements dismec::objective::Objective.

Definition at line 142 of file linear.h.

References dismec::objective::LinearClassifierBase::costs(), dismec::objective::LinearClassifierBase::labels(), dismec::objective::LinearClassifierBase::line_interpolation(), dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer, and dismec::l2_reg_sq_hinge_detail::value_from_xTw().

◆ project_to_line_unchecked()

template<class Derived >
void dismec::objective::LinearClassifierImpBase< Derived >::project_to_line_unchecked ( const HashVector location,
const DenseRealVector direction 
)
inlineoverrideprotectedvirtual

The function that does the actual computation. This is called in project_to_line() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 147 of file linear.h.

References dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer, and dismec::objective::LinearClassifierBase::project_linear_to_line().

◆ value_unchecked()

template<class Derived >
real_t dismec::objective::LinearClassifierImpBase< Derived >::value_unchecked ( const HashVector location)
inlineoverrideprotectedvirtual

Member Data Documentation

◆ m_Regularizer


The documentation for this class was generated from the following file: