DiSMEC++
dismec::CascadeTraining Class Reference

#include <cascade.h>

Inheritance diagram for dismec::CascadeTraining:
dismec::TrainingSpec

Public Member Functions

 CascadeTraining (std::shared_ptr< const DatasetBase > tfidf_data, std::shared_ptr< const GenericFeatureMatrix > dense_data, HyperParameters hyper_params, std::shared_ptr< init::WeightInitializationStrategy > dense_init, real_t dense_reg, std::shared_ptr< init::WeightInitializationStrategy > sparse_init, real_t sparse_reg, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, std::shared_ptr< const std::vector< std::vector< long >>> shortlist=nullptr)
 
long num_features () const override
 
std::shared_ptr< objective::Objectivemake_objective () const override
 Makes an Objective object suitable for the dataset. More...
 
std::unique_ptr< solvers::Minimizermake_minimizer () const override
 Makes a Minimizer object suitable for the dataset. More...
 
std::unique_ptr< init::WeightsInitializermake_initializer () const override
 Makes a WeightsInitializer object. More...
 
std::shared_ptr< model::Modelmake_model (long num_features, model::PartialModelSpec spec) const override
 Creates the model that will be used to store the results. More...
 
void update_minimizer (solvers::Minimizer &base_minimizer, label_id_t label_id) const override
 Updates the setting of the Minimizer for handling label label_id. More...
 
void update_objective (objective::Objective &base_objective, label_id_t label_id) const override
 Updates the setting of the Objective for handling label label_id. More...
 
std::unique_ptr< postproc::PostProcessormake_post_processor (const std::shared_ptr< objective::Objective > &objective) const override
 Makes a PostProcessor object. More...
 
TrainingStatsGathererget_statistics_gatherer () override
 
- Public Member Functions inherited from dismec::TrainingSpec
 TrainingSpec (std::shared_ptr< const DatasetBase > data)
 
virtual ~TrainingSpec ()=default
 
const DatasetBaseget_data () const
 
const std::shared_ptr< spdlog::logger > & get_logger () const
 
void set_logger (std::shared_ptr< spdlog::logger > l)
 

Private Attributes

HyperParameters m_NewtonSettings
 
parallel::NUMAReplicator< const GenericFeatureMatrixm_SparseReplicator
 
parallel::NUMAReplicator< const GenericFeatureMatrixm_DenseReplicator
 
std::shared_ptr< const std::vector< std::vector< long > > > m_Shortlist
 
std::shared_ptr< postproc::PostProcessFactorym_PostProcessor
 
std::shared_ptr< init::WeightInitializationStrategym_DenseInitStrategy
 
std::shared_ptr< init::WeightInitializationStrategym_SparseInitStrategy
 
std::shared_ptr< TrainingStatsGathererm_StatsGather
 
long m_NumFeatures
 
double m_BaseEpsilon
 
real_t m_DenseReg
 
real_t m_SparseReg
 

Detailed Description

Definition at line 13 of file cascade.h.

Constructor & Destructor Documentation

◆ CascadeTraining()

CascadeTraining::CascadeTraining ( std::shared_ptr< const DatasetBase tfidf_data,
std::shared_ptr< const GenericFeatureMatrix dense_data,
HyperParameters  hyper_params,
std::shared_ptr< init::WeightInitializationStrategy dense_init,
real_t  dense_reg,
std::shared_ptr< init::WeightInitializationStrategy sparse_init,
real_t  sparse_reg,
std::shared_ptr< postproc::PostProcessFactory post_proc,
std::shared_ptr< TrainingStatsGatherer gatherer,
std::shared_ptr< const std::vector< std::vector< long >>>  shortlist = nullptr 
)

Definition at line 132 of file cascade.cpp.

References dismec::HyperParameters::get(), m_BaseEpsilon, and m_NewtonSettings.

Member Function Documentation

◆ get_statistics_gatherer()

TrainingStatsGatherer & CascadeTraining::get_statistics_gatherer ( )
overridevirtual

Implements dismec::TrainingSpec.

Definition at line 128 of file cascade.cpp.

References m_StatsGather.

◆ make_initializer()

std::unique_ptr< init::WeightsInitializer > CascadeTraining::make_initializer ( ) const
overridevirtual

Makes a WeightsInitializer object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned initializer will be used. so that the default NUMA-local strategy should be reasonable.

Implements dismec::TrainingSpec.

Definition at line 109 of file cascade.cpp.

References m_DenseInitStrategy, m_DenseReplicator, m_SparseInitStrategy, and m_SparseReplicator.

◆ make_minimizer()

std::unique_ptr< solvers::Minimizer > CascadeTraining::make_minimizer ( ) const
overridevirtual

Makes a Minimizer object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread will run the minimizer, so that the default NUMA-local strategy should be reasonable.

Implements dismec::TrainingSpec.

Definition at line 46 of file cascade.cpp.

References dismec::HyperParameters::apply(), m_NewtonSettings, and m_NumFeatures.

◆ make_model()

std::shared_ptr< model::Model > CascadeTraining::make_model ( long  num_features,
model::PartialModelSpec  spec 
) const
overridevirtual

Creates the model that will be used to store the results.

This extension point gives the TrainingSpec a way to decide whether the model storage used shall be a sparse or a dense model, or maybe some wrapper pointing to external memory. TODO why is this a shared_ptr and not a unique_ptr ?

Parameters
num_featuresNumber of input features for the model.
specPartial model specification for the created model.

Implements dismec::TrainingSpec.

Definition at line 119 of file cascade.cpp.

References num_features().

◆ make_objective()

std::shared_ptr< objective::Objective > CascadeTraining::make_objective ( ) const
overridevirtual

Makes an Objective object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers.

Implements dismec::TrainingSpec.

Definition at line 38 of file cascade.cpp.

References m_DenseReg, m_DenseReplicator, m_SparseReg, m_SparseReplicator, and dismec::objective::make_sp_dense_squared_hinge().

◆ make_post_processor()

std::unique_ptr< postproc::PostProcessor > CascadeTraining::make_post_processor ( const std::shared_ptr< objective::Objective > &  objective) const
overridevirtual

Makes a PostProcessor object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned post processor will be used, so that the default NUMA-local strategy should be reasonable. The PostProcessor can be adapted to the thread_local objective that is supplied here.

Implements dismec::TrainingSpec.

Definition at line 124 of file cascade.cpp.

References m_PostProcessor.

◆ num_features()

long dismec::CascadeTraining::num_features ( ) const
inlineoverridevirtual

Reimplemented from dismec::TrainingSpec.

Definition at line 26 of file cascade.h.

References m_NumFeatures.

Referenced by make_model().

◆ update_minimizer()

void CascadeTraining::update_minimizer ( solvers::Minimizer minimizer,
label_id_t  label_id 
) const
overridevirtual

Updates the setting of the Minimizer for handling label label_id.

This is needed e.g. to set a stopping criterion that depends on the number of positive labels. This function will be called concurrently from different threads, but each thread will make calls with different minimizer parameter.

Parameters
minimizerA Minimizer. This is assumed to have been created using make_minimizer(), so in particular it should by dynamic_cast-able to the actual Minimizer type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the minimizer.

Implements dismec::TrainingSpec.

Definition at line 53 of file cascade.cpp.

References dismec::TrainingSpec::get_data(), dismec::DatasetBase::get_labels(), m_BaseEpsilon, m_Shortlist, dismec::DatasetBase::num_examples(), dismec::DatasetBase::num_positives(), and dismec::opaque_int_type< Tag, T >::to_index().

◆ update_objective()

void CascadeTraining::update_objective ( objective::Objective objective,
label_id_t  label_id 
) const
overridevirtual

Updates the setting of the Objective for handling label label_id.

This will e.g. extract the corresponding label vector from the dataset and supply it to the objective. This function will be called concurrently from different threads, but each thread will call with a different objective parameter.

Parameters
objectiveAn Objective. This is assumed to have been created using make_objective(), so in particular it should by dynamic_cast-able to the actual Objective type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the objective.

Implements dismec::TrainingSpec.

Definition at line 80 of file cascade.cpp.

References dismec::TrainingSpec::get_data(), dismec::DatasetBase::get_labels(), m_DenseReplicator, m_Shortlist, m_SparseReplicator, dismec::shortlist_features(), dismec::ssize(), and dismec::opaque_int_type< Tag, T >::to_index().

Member Data Documentation

◆ m_BaseEpsilon

double dismec::CascadeTraining::m_BaseEpsilon
private

Definition at line 64 of file cascade.h.

Referenced by CascadeTraining(), and update_minimizer().

◆ m_DenseInitStrategy

std::shared_ptr<init::WeightInitializationStrategy> dismec::CascadeTraining::m_DenseInitStrategy
private

Definition at line 58 of file cascade.h.

Referenced by make_initializer().

◆ m_DenseReg

real_t dismec::CascadeTraining::m_DenseReg
private

Definition at line 66 of file cascade.h.

Referenced by make_objective().

◆ m_DenseReplicator

parallel::NUMAReplicator<const GenericFeatureMatrix> dismec::CascadeTraining::m_DenseReplicator
private

Definition at line 50 of file cascade.h.

Referenced by make_initializer(), make_objective(), and update_objective().

◆ m_NewtonSettings

HyperParameters dismec::CascadeTraining::m_NewtonSettings
private

Definition at line 47 of file cascade.h.

Referenced by CascadeTraining(), and make_minimizer().

◆ m_NumFeatures

long dismec::CascadeTraining::m_NumFeatures
private

Definition at line 63 of file cascade.h.

Referenced by make_minimizer(), and num_features().

◆ m_PostProcessor

std::shared_ptr<postproc::PostProcessFactory> dismec::CascadeTraining::m_PostProcessor
private

Definition at line 55 of file cascade.h.

Referenced by make_post_processor().

◆ m_Shortlist

std::shared_ptr<const std::vector<std::vector<long> > > dismec::CascadeTraining::m_Shortlist
private

Definition at line 52 of file cascade.h.

Referenced by update_minimizer(), and update_objective().

◆ m_SparseInitStrategy

std::shared_ptr<init::WeightInitializationStrategy> dismec::CascadeTraining::m_SparseInitStrategy
private

Definition at line 59 of file cascade.h.

Referenced by make_initializer().

◆ m_SparseReg

real_t dismec::CascadeTraining::m_SparseReg
private

Definition at line 67 of file cascade.h.

Referenced by make_objective().

◆ m_SparseReplicator

parallel::NUMAReplicator<const GenericFeatureMatrix> dismec::CascadeTraining::m_SparseReplicator
private

Definition at line 49 of file cascade.h.

Referenced by make_initializer(), make_objective(), and update_objective().

◆ m_StatsGather

std::shared_ptr<TrainingStatsGatherer> dismec::CascadeTraining::m_StatsGather
private

Definition at line 61 of file cascade.h.

Referenced by get_statistics_gatherer().


The documentation for this class was generated from the following files: