DiSMEC++
|
This class implements a squared norm (L2) regularizer. Thus f(x) = 0.5 |x|^2
.
More...
#include <regularizers_imp.h>
Public Member Functions | |
SquaredNormRegularizer (real_t scale=1, bool ignore_bias=false) | |
real_t | value_unchecked (const HashVector &location) override |
void | project_to_line_unchecked (const HashVector &location, const DenseRealVector &direction) override |
real_t | lookup_on_line (real_t a) override |
Looks up the value of the objective on the line defined by the last call to project_to_line() . More... | |
Public Member Functions inherited from dismec::objective::PointWiseRegularizer< SquaredNormRegularizer > | |
PointWiseRegularizer (real_t scale=1, bool ignore_bias=false) | |
long | num_variables () const noexcept final |
The pointwise regularizer can act on arbitrarily sized vectors, so num_variables() == -1 . More... | |
real_t | value_unchecked (const HashVector &location) override |
void | hessian_times_direction_unchecked (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override |
void | gradient_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override |
void | gradient_at_zero_unchecked (Eigen::Ref< DenseRealVector > target) override |
void | diag_preconditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override |
void | project_to_line_unchecked (const HashVector &location, const DenseRealVector &direction) override |
real_t | lookup_on_line (real_t a) override |
Looks up the value of the objective on the line defined by the last call to project_to_line() . More... | |
bool | dont_regularize_bias () const |
real_t | scale () const |
Returns the common scale factor for the entire regularizer. More... | |
Public Member Functions inherited from dismec::objective::Objective | |
Objective () | |
virtual | ~Objective () noexcept=default |
real_t | value (const HashVector &location) |
Evaluate the objective at the given location . More... | |
void | diag_preconditioner (const HashVector &location, Eigen::Ref< DenseRealVector > target) |
Get precondition to be used in CG optimization. More... | |
void | project_to_line (const HashVector &location, const DenseRealVector &direction) |
creates a function g such that g(a) = objective(location + a * direction) Use lookup_on_line() to evaluate g . More... | |
virtual void | declare_vector_on_last_line (const HashVector &location, real_t t) |
State that the given vector corresponds to a certain position on the line of the last line search. More... | |
void | gradient_at_zero (Eigen::Ref< DenseRealVector > target) |
Gets the gradient for location zero. More... | |
void | gradient (const HashVector &location, Eigen::Ref< DenseRealVector > target) |
Evaluate the gradient at location . More... | |
void | hessian_times_direction (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) |
Calculates the product of the Hessian matrix at location with direction . More... | |
void | gradient_and_pre_conditioner (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) |
Combines the calculation of gradient and pre-conditioner, which may be more efficient in some cases. More... | |
Public Member Functions inherited from dismec::stats::Tracked | |
Tracked () | |
Default constructor, creates the internal stats::StatisticsCollection . More... | |
void | register_stat (const std::string &name, std::unique_ptr< Statistics > stat) |
Registers a tracker for the statistics name . More... | |
std::shared_ptr< StatisticsCollection > | get_stats () const |
Gets an ownership-sharing reference to the StatisticsCollection . More... | |
Static Public Member Functions | |
static real_t | point_wise_value (real_t x) |
static real_t | point_wise_grad (real_t x) |
static real_t | point_wise_quad (real_t x) |
Private Attributes | |
real_t | m_LsCache_w02 |
real_t | m_LsCache_d2 |
real_t | m_LsCache_dTw |
Additional Inherited Members | |
Protected Member Functions inherited from dismec::stats::Tracked | |
~Tracked () | |
Non-virtual destructor. Declared protected, so we don't accidentally try to do a polymorphic delete. More... | |
template<class T > | |
void | record (stat_id_t stat, T &&value) |
Record statistics. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
void | declare_stat (stat_id_t index, StatisticMetaData meta) |
Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
void | declare_tag (tag_id_t index, std::string name) |
Declares a new tag. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
template<class... Args> | |
void | set_tag (tag_id_t tag, long value) |
Set value of tag. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
template<class... Args> | |
auto | make_timer (stat_id_t id, Args... args) |
Creates a new ScopeTimer using stats::record_scope_time . More... | |
This class implements a squared norm (L2) regularizer. Thus f(x) = 0.5 |x|^2
.
Since this is a quadratic function with diagonal Hessian, the implementation is mostly trivial. The only interesting code is that of project_to_line()
and lookup_on_line()
, because these functions can do some smart so that lookup_on_line()
can be implemented in O(1)
. The regularizer admits an additional scale parameter by which its value (and thus gradient, hessian etc) will be scaled.
Definition at line 30 of file regularizers_imp.h.
|
explicit |
Definition at line 16 of file regularizers_imp.cpp.
Looks up the value of the objective on the line defined by the last call to project_to_line()
.
position | The location where the objective is calculated. |
objective(location + position * direction)
, where location
and direction
are the vectors passed to the last call of project_to_line()
. project_to_line()
, so it has to be called after a call to that function. A new call to project_to_line()
will change the line which is evaluated. Implements dismec::objective::Objective.
Definition at line 39 of file regularizers_imp.cpp.
References m_LsCache_d2, m_LsCache_dTw, m_LsCache_w02, and dismec::objective::PointWiseRegularizer< SquaredNormRegularizer >::scale().
Definition at line 52 of file regularizers_imp.cpp.
Definition at line 56 of file regularizers_imp.cpp.
Definition at line 48 of file regularizers_imp.cpp.
|
overridevirtual |
This function calculates helper values to facilitate a fast implementation of lookup_on_line()
. It is based on the decomposition
\[ \|x + t d\|^2 = \|x\|^2 + 2 t \langle x, d \rangle + t^2 \|d\|^2. \]
Therefore, we calculate \( \|x\|^2\), \(\langle x, d \rangle\) and \(\|d\|^2 \) here.
Implements dismec::objective::Objective.
Definition at line 21 of file regularizers_imp.cpp.
References dismec::objective::PointWiseRegularizer< SquaredNormRegularizer >::dont_regularize_bias(), m_LsCache_d2, m_LsCache_dTw, and m_LsCache_w02.
|
overridevirtual |
The function that does the actual value computation. This is called in value()
after the argument has been validated.
Implements dismec::objective::Objective.
Definition at line 44 of file regularizers_imp.cpp.
References dismec::objective::PointWiseRegularizer< CRTP >::value_unchecked().
|
private |
Definition at line 54 of file regularizers_imp.h.
Referenced by lookup_on_line(), and project_to_line_unchecked().
|
private |
Definition at line 55 of file regularizers_imp.h.
Referenced by lookup_on_line(), and project_to_line_unchecked().
|
private |
Definition at line 53 of file regularizers_imp.h.
Referenced by lookup_on_line(), and project_to_line_unchecked().