DiSMEC++
dismec::objective::PointWiseRegularizer< CRTP > Class Template Reference

Base class for pointwise regularization functions. More...

#include <pointwise.h>

Inheritance diagram for dismec::objective::PointWiseRegularizer< CRTP >:
dismec::objective::Objective dismec::stats::Tracked

Public Member Functions

 PointWiseRegularizer (real_t scale=1, bool ignore_bias=false)
 
long num_variables () const noexcept final
 The pointwise regularizer can act on arbitrarily sized vectors, so num_variables() == -1. More...
 
real_t value_unchecked (const HashVector &location) override
 
void hessian_times_direction_unchecked (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override
 
void gradient_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override
 
void gradient_at_zero_unchecked (Eigen::Ref< DenseRealVector > target) override
 
void diag_preconditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override
 
void project_to_line_unchecked (const HashVector &location, const DenseRealVector &direction) override
 
real_t lookup_on_line (real_t a) override
 Looks up the value of the objective on the line defined by the last call to project_to_line(). More...
 
bool dont_regularize_bias () const
 
real_t scale () const
 Returns the common scale factor for the entire regularizer. More...
 
- Public Member Functions inherited from dismec::objective::Objective
 Objective ()
 
virtual ~Objective () noexcept=default
 
real_t value (const HashVector &location)
 Evaluate the objective at the given location. More...
 
void diag_preconditioner (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Get precondition to be used in CG optimization. More...
 
void project_to_line (const HashVector &location, const DenseRealVector &direction)
 creates a function g such that g(a) = objective(location + a * direction) Use lookup_on_line() to evaluate g. More...
 
virtual void declare_vector_on_last_line (const HashVector &location, real_t t)
 State that the given vector corresponds to a certain position on the line of the last line search. More...
 
void gradient_at_zero (Eigen::Ref< DenseRealVector > target)
 Gets the gradient for location zero. More...
 
void gradient (const HashVector &location, Eigen::Ref< DenseRealVector > target)
 Evaluate the gradient at location. More...
 
void hessian_times_direction (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target)
 Calculates the product of the Hessian matrix at location with direction. More...
 
void gradient_and_pre_conditioner (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre)
 Combines the calculation of gradient and pre-conditioner, which may be more efficient in some cases. More...
 
- Public Member Functions inherited from dismec::stats::Tracked
 Tracked ()
 Default constructor, creates the internal stats::StatisticsCollection. More...
 
void register_stat (const std::string &name, std::unique_ptr< Statistics > stat)
 Registers a tracker for the statistics name. More...
 
std::shared_ptr< StatisticsCollectionget_stats () const
 Gets an ownership-sharing reference to the StatisticsCollection. More...
 

Private Member Functions

real_t point_wise_value_ (real_t x) const
 calls point_wise_value() of the implementing class More...
 
real_t point_wise_grad_ (real_t x) const
 calls point_wise_grad() of the implementing class More...
 
real_t point_wise_quad_ (real_t x) const
 calls point_wise_quad() of the implementing class More...
 
long get_loop_bound (const HashVector &location) const
 

Private Attributes

bool m_LastWeightIsBias = false
 
real_t m_Scale = 1.0
 
DenseRealVector m_LineStart
 
DenseRealVector m_LineDirection
 

Additional Inherited Members

- Protected Member Functions inherited from dismec::stats::Tracked
 ~Tracked ()
 Non-virtual destructor. Declared protected, so we don't accidentally try to do a polymorphic delete. More...
 
template<class T >
void record (stat_id_t stat, T &&value)
 Record statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_stat (stat_id_t index, StatisticMetaData meta)
 Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
void declare_tag (tag_id_t index, std::string name)
 Declares a new tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
void set_tag (tag_id_t tag, long value)
 Set value of tag. This function just forwards all its arguments to the internal StatisticsCollection. More...
 
template<class... Args>
auto make_timer (stat_id_t id, Args... args)
 Creates a new ScopeTimer using stats::record_scope_time. More...
 

Detailed Description

template<class CRTP>
class dismec::objective::PointWiseRegularizer< CRTP >

Base class for pointwise regularization functions.

Template Parameters
CRTPThis class is implemented as a CRTP (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) template

This class provides default implementations for the objective, given that the CRTP derived class defined three scalar functions:

  • point_wise_value
  • point_wise_grad
  • point_wise_quad The regularizer is defined via \( R(w) = \sum_{i=1}^n f(w_i) \), and the three functions define the value, derivative and second derivative of f. More strictly speaking, f is expected to be a scalar function, the function g_x defined by point_wise_value, point_wise_grad and point_wise_quad is expected to define a lower-bounded upper bound on f. Let the return values of the three function be denoted as a, b, and c, then the function g_x(t) fulfills \( C \leq g_x(t) \leq f(t) \) with \( g_x(x) = f(x) \).

This base class also includes code to automatically remove the last weight from any calculations ("ignore_bias"), and for scaling the regularizer.

The line projection algorithm supplied by the default implementation caches the starting point and direction vector. In cases where more efficient algorithms are available, it is possible to override the project_to_line() and lookup_on_line() functions. In that case, m_LineStart and m_LineDirection will not be set and will not cause unneeded memory consumption.

Definition at line 39 of file pointwise.h.

Constructor & Destructor Documentation

◆ PointWiseRegularizer()

template<class T >
dismec::objective::PointWiseRegularizer< T >::PointWiseRegularizer ( real_t  scale = 1,
bool  ignore_bias = false 
)
explicit

Constructor for the regularizer. scale defines the prefactor by which the entire regularizer will be scaled. This value has to be larger than zero. ignore_bias declares whether the last entry in the weight vector should be considered a bias term, and be ignored in the calculations.

Definition at line 103 of file pointwise.h.

References dismec::objective::PointWiseRegularizer< CRTP >::m_Scale, and THROW_EXCEPTION.

Member Function Documentation

◆ diag_preconditioner_unchecked()

template<class T >
void dismec::objective::PointWiseRegularizer< T >::diag_preconditioner_unchecked ( const HashVector location,
Eigen::Ref< DenseRealVector target 
)
overridevirtual

The function that does the actual computation. This is called in diag_preconditioner() after the arguments have been validated. The default implementation returns ones.

Reimplemented from dismec::objective::Objective.

Definition at line 169 of file pointwise.h.

◆ dont_regularize_bias()

template<class CRTP >
bool dismec::objective::PointWiseRegularizer< CRTP >::dont_regularize_bias ( ) const
inline

Returns whether the last entry in the weight vector should be treated as a bias, and be ignored in the regularization.

Definition at line 66 of file pointwise.h.

References dismec::objective::PointWiseRegularizer< CRTP >::m_LastWeightIsBias.

Referenced by dismec::objective::PointWiseRegularizer< CRTP >::get_loop_bound().

◆ get_loop_bound()

template<class CRTP >
long dismec::objective::PointWiseRegularizer< CRTP >::get_loop_bound ( const HashVector location) const
inlineprivate

◆ gradient_at_zero_unchecked()

template<class T >
void dismec::objective::PointWiseRegularizer< T >::gradient_at_zero_unchecked ( Eigen::Ref< DenseRealVector target)
overridevirtual

The function that does the actual computation. This is called in gradient_at_zero() after the argument has been validated. The default implementation is rather inefficient and creates a new temporary zero vector.

Reimplemented from dismec::objective::Objective.

Definition at line 157 of file pointwise.h.

◆ gradient_unchecked()

template<class T >
void dismec::objective::PointWiseRegularizer< T >::gradient_unchecked ( const HashVector location,
Eigen::Ref< DenseRealVector target 
)
overridevirtual

The function that does the actual gradient computation. This is called in gradient() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 143 of file pointwise.h.

◆ hessian_times_direction_unchecked()

template<class T >
void dismec::objective::PointWiseRegularizer< T >::hessian_times_direction_unchecked ( const HashVector location,
const DenseRealVector direction,
Eigen::Ref< DenseRealVector target 
)
overridevirtual

The function that does the actual computation. This is called in hessian_times_direction() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 125 of file pointwise.h.

◆ lookup_on_line()

template<class T >
real_t dismec::objective::PointWiseRegularizer< T >::lookup_on_line ( real_t  position)
overridevirtual

Looks up the value of the objective on the line defined by the last call to project_to_line().

Parameters
positionThe location where the objective is calculated.
Returns
The value of objective(location + position * direction), where location and direction are the vectors passed to the last call of project_to_line().
Attention
This function may use results cached project_to_line(), so it has to be called after a call to that function. A new call to project_to_line() will change the line which is evaluated.

Implements dismec::objective::Objective.

Definition at line 190 of file pointwise.h.

◆ num_variables()

template<class CRTP >
long dismec::objective::PointWiseRegularizer< CRTP >::num_variables ( ) const
inlinefinalvirtualnoexcept

The pointwise regularizer can act on arbitrarily sized vectors, so num_variables() == -1.

Implements dismec::objective::Objective.

Definition at line 47 of file pointwise.h.

◆ point_wise_grad_()

template<class CRTP >
real_t dismec::objective::PointWiseRegularizer< CRTP >::point_wise_grad_ ( real_t  x) const
inlineprivate

calls point_wise_grad() of the implementing class

Definition at line 88 of file pointwise.h.

◆ point_wise_quad_()

template<class CRTP >
real_t dismec::objective::PointWiseRegularizer< CRTP >::point_wise_quad_ ( real_t  x) const
inlineprivate

calls point_wise_quad() of the implementing class

Definition at line 93 of file pointwise.h.

◆ point_wise_value_()

template<class CRTP >
real_t dismec::objective::PointWiseRegularizer< CRTP >::point_wise_value_ ( real_t  x) const
inlineprivate

calls point_wise_value() of the implementing class

Definition at line 83 of file pointwise.h.

◆ project_to_line_unchecked()

template<class T >
void dismec::objective::PointWiseRegularizer< T >::project_to_line_unchecked ( const HashVector location,
const DenseRealVector direction 
)
overridevirtual

The function that does the actual computation. This is called in project_to_line() after the arguments have been validated.

Implements dismec::objective::Objective.

Definition at line 183 of file pointwise.h.

References dismec::HashVector::get().

◆ scale()

template<class CRTP >
real_t dismec::objective::PointWiseRegularizer< CRTP >::scale ( ) const
inline

Returns the common scale factor for the entire regularizer.

Definition at line 69 of file pointwise.h.

References dismec::objective::PointWiseRegularizer< CRTP >::m_Scale.

◆ value_unchecked()

template<class T >
real_t dismec::objective::PointWiseRegularizer< T >::value_unchecked ( const HashVector location)
overridevirtual

The function that does the actual value computation. This is called in value() after the argument has been validated.

Implements dismec::objective::Objective.

Definition at line 112 of file pointwise.h.

Referenced by dismec::objective::SquaredNormRegularizer::value_unchecked().

Member Data Documentation

◆ m_LastWeightIsBias

template<class CRTP >
bool dismec::objective::PointWiseRegularizer< CRTP >::m_LastWeightIsBias = false
private

◆ m_LineDirection

template<class CRTP >
DenseRealVector dismec::objective::PointWiseRegularizer< CRTP >::m_LineDirection
private

This variable will cache the direction for tracking a line projection. It is set in project_to_line().

Definition at line 80 of file pointwise.h.

◆ m_LineStart

template<class CRTP >
DenseRealVector dismec::objective::PointWiseRegularizer< CRTP >::m_LineStart
private

This variable will cache the starting position for tracking a line projection. It is set in project_to_line().

Definition at line 77 of file pointwise.h.

◆ m_Scale


The documentation for this class was generated from the following file: