DiSMEC++
|
Implementation helper for linear classifier derived classes. More...
#include <linear.h>
Public Member Functions | |
LinearClassifierImpBase (std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer) | |
![]() | |
LinearClassifierBase (std::shared_ptr< const GenericFeatureMatrix > X) | |
long | num_instances () const noexcept |
long | num_variables () const noexcept override |
BinaryLabelVector & | get_label_ref () |
void | update_costs (real_t positive, real_t negative) |
![]() | |
Objective () | |
virtual | ~Objective () noexcept=default |
real_t | value (const HashVector &location) |
Evaluate the objective at the given location . More... | |
void | diag_preconditioner (const HashVector &location, Eigen::Ref< DenseRealVector > target) |
Get precondition to be used in CG optimization. More... | |
void | project_to_line (const HashVector &location, const DenseRealVector &direction) |
creates a function g such that g(a) = objective(location + a * direction) Use lookup_on_line() to evaluate g . More... | |
void | gradient_at_zero (Eigen::Ref< DenseRealVector > target) |
Gets the gradient for location zero. More... | |
void | gradient (const HashVector &location, Eigen::Ref< DenseRealVector > target) |
Evaluate the gradient at location . More... | |
void | hessian_times_direction (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) |
Calculates the product of the Hessian matrix at location with direction . More... | |
void | gradient_and_pre_conditioner (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) |
Combines the calculation of gradient and pre-conditioner, which may be more efficient in some cases. More... | |
![]() | |
Tracked () | |
Default constructor, creates the internal stats::StatisticsCollection . More... | |
void | register_stat (const std::string &name, std::unique_ptr< Statistics > stat) |
Registers a tracker for the statistics name . More... | |
std::shared_ptr< StatisticsCollection > | get_stats () const |
Gets an ownership-sharing reference to the StatisticsCollection . More... | |
Protected Member Functions | |
const Derived & | derived () const |
Derived & | derived () |
real_t | value_unchecked (const HashVector &location) override |
real_t | lookup_on_line (real_t position) override |
Looks up the value of the objective on the line defined by the last call to project_to_line() . More... | |
void | project_to_line_unchecked (const HashVector &location, const DenseRealVector &direction) override |
void | gradient_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override |
void | gradient_at_zero_unchecked (Eigen::Ref< DenseRealVector > target) override |
void | hessian_times_direction_unchecked (const HashVector &location, const DenseRealVector &direction, Eigen::Ref< DenseRealVector > target) override |
void | diag_preconditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > target) override |
void | gradient_and_pre_conditioner_unchecked (const HashVector &location, Eigen::Ref< DenseRealVector > gradient, Eigen::Ref< DenseRealVector > pre) override |
![]() | |
const DenseRealVector & | x_times_w (const HashVector &w) |
Calculates the vector of feature matrix times weights w More... | |
template<class Derived > | |
void | update_xtw_cache (const HashVector &new_weight, const Eigen::MatrixBase< Derived > &new_result) |
Updates the cached value for x_times_w. More... | |
void | project_linear_to_line (const HashVector &location, const DenseRealVector &direction) |
Prepares the cache variables for line projection. More... | |
auto | line_interpolation (real_t t) const |
void | declare_vector_on_last_line (const HashVector &location, real_t t) override |
State that the given vector corresponds to a certain position on the line of the last line search. More... | |
const GenericFeatureMatrix & | generic_features () const |
const DenseFeatures & | dense_features () const |
const SparseFeatures & | sparse_features () const |
const DenseRealVector & | costs () const |
const BinaryLabelVector & | labels () const |
![]() | |
~Tracked () | |
Non-virtual destructor. Declared protected, so we don't accidentally try to do a polymorphic delete. More... | |
template<class T > | |
void | record (stat_id_t stat, T &&value) |
Record statistics. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
void | declare_stat (stat_id_t index, StatisticMetaData meta) |
Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
void | declare_tag (tag_id_t index, std::string name) |
Declares a new tag. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
template<class... Args> | |
void | set_tag (tag_id_t tag, long value) |
Set value of tag. This function just forwards all its arguments to the internal StatisticsCollection . More... | |
template<class... Args> | |
auto | make_timer (stat_id_t id, Args... args) |
Creates a new ScopeTimer using stats::record_scope_time . More... | |
Private Attributes | |
std::unique_ptr< Objective > | m_Regularizer |
Pointer to the regularizer. More... | |
Implementation helper for linear classifier derived classes.
Derived | The derived class, which is expected to provide the following functions: template<typename Derived>
real_t value_from_xTw(const DenseRealVector& cost, const BinaryLabelVector& labels, const Eigen::DenseBase<Derived>& xTw);
void gradient_and_diag()
const BinaryLabelVector & labels() const Definition: linear.cpp:89 real_t value_from_xTw(const DenseRealVector &cost, const BinaryLabelVector &labels, const Eigen::DenseBase< Derived > &xTw) Definition: reg_sq_hinge_detail.h:83 types::DenseVector< std::int8_t > BinaryLabelVector Dense vector for storing binary labels. Definition: matrix_types.h:68 types::DenseVector< real_t > DenseRealVector Any dense, real values vector. Definition: matrix_types.h:40 |
|
inline |
|
inlineprotected |
|
inlineprotected |
Definition at line 129 of file linear.h.
Referenced by dismec::objective::LinearClassifierImpBase< Derived >::diag_preconditioner_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_and_pre_conditioner_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_at_zero_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::hessian_times_direction_unchecked(), and dismec::objective::LinearClassifierImpBase< Derived >::value_unchecked().
|
inlineoverrideprotectedvirtual |
The function that does the actual computation. This is called in diag_preconditioner()
after the arguments have been validated. The default implementation returns ones.
Reimplemented from dismec::objective::Objective.
Definition at line 170 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.
|
inlineoverrideprotectedvirtual |
The function that does the actual computation. This is called in gradient_and_pre_conditioner()
after the arguments have been validated. The default implementation sucessively calls gradient()
and diag_preconditioner()
.
Reimplemented from dismec::objective::Objective.
Definition at line 175 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::derived(), dismec::objective::Objective::gradient(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.
|
inlineoverrideprotectedvirtual |
The function that does the actual computation. This is called in gradient_at_zero()
after the argument has been validated. The default implementation is rather inefficient and creates a new temporary zero vector.
Reimplemented from dismec::objective::Objective.
Definition at line 158 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.
|
inlineoverrideprotectedvirtual |
The function that does the actual gradient computation. This is called in gradient()
after the arguments have been validated.
Implements dismec::objective::Objective.
Definition at line 152 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.
|
inlineoverrideprotectedvirtual |
The function that does the actual computation. This is called in hessian_times_direction()
after the arguments have been validated.
Implements dismec::objective::Objective.
Definition at line 163 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::derived(), and dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer.
|
inlineoverrideprotectedvirtual |
Looks up the value of the objective on the line defined by the last call to project_to_line()
.
position | The location where the objective is calculated. |
objective(location + position * direction)
, where location
and direction
are the vectors passed to the last call of project_to_line()
. project_to_line()
, so it has to be called after a call to that function. A new call to project_to_line()
will change the line which is evaluated. Implements dismec::objective::Objective.
Definition at line 142 of file linear.h.
References dismec::objective::LinearClassifierBase::costs(), dismec::objective::LinearClassifierBase::labels(), dismec::objective::LinearClassifierBase::line_interpolation(), dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer, and dismec::l2_reg_sq_hinge_detail::value_from_xTw().
|
inlineoverrideprotectedvirtual |
The function that does the actual computation. This is called in project_to_line()
after the arguments have been validated.
Implements dismec::objective::Objective.
Definition at line 147 of file linear.h.
References dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer, and dismec::objective::LinearClassifierBase::project_linear_to_line().
|
inlineoverrideprotectedvirtual |
The function that does the actual value computation. This is called in value()
after the argument has been validated.
Implements dismec::objective::Objective.
Definition at line 137 of file linear.h.
References dismec::objective::LinearClassifierBase::costs(), dismec::objective::LinearClassifierImpBase< Derived >::derived(), dismec::objective::LinearClassifierBase::labels(), dismec::objective::LinearClassifierImpBase< Derived >::m_Regularizer, and dismec::objective::LinearClassifierBase::x_times_w().
|
private |
Pointer to the regularizer.
Definition at line 182 of file linear.h.
Referenced by dismec::objective::LinearClassifierImpBase< Derived >::diag_preconditioner_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_and_pre_conditioner_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_at_zero_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::gradient_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::hessian_times_direction_unchecked(), dismec::objective::LinearClassifierImpBase< Derived >::lookup_on_line(), dismec::objective::LinearClassifierImpBase< Derived >::project_to_line_unchecked(), and dismec::objective::LinearClassifierImpBase< Derived >::value_unchecked().