DiSMEC++
dismec::DiSMECTraining Class Reference

An implementation of TrainingSpec that models the DiSMEC algorithm. More...

#include <dismec.h>

Inheritance diagram for dismec::DiSMECTraining:
dismec::TrainingSpec

Public Member Functions

 DiSMECTraining (std::shared_ptr< const DatasetBase > data, HyperParameters hyper_params, std::shared_ptr< WeightingScheme > weighting, std::shared_ptr< init::WeightInitializationStrategy > init, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, bool use_sparse, RegularizerSpec regularizer, LossType loss)
 Creates a DiSMECTraining instance. More...
 
std::shared_ptr< objective::Objectivemake_objective () const override
 Makes an Objective object suitable for the dataset. More...
 
std::unique_ptr< solvers::Minimizermake_minimizer () const override
 Makes a Minimizer object suitable for the dataset. More...
 
std::unique_ptr< init::WeightsInitializermake_initializer () const override
 Makes a WeightsInitializer object. More...
 
std::shared_ptr< model::Modelmake_model (long num_features, model::PartialModelSpec spec) const override
 Creates the model that will be used to store the results. More...
 
void update_minimizer (solvers::Minimizer &base_minimizer, label_id_t label_id) const override
 Updates the setting of the Minimizer for handling label label_id. More...
 
void update_objective (objective::Objective &base_objective, label_id_t label_id) const override
 Updates the setting of the Objective for handling label label_id. More...
 
std::unique_ptr< postproc::PostProcessormake_post_processor (const std::shared_ptr< objective::Objective > &objective) const override
 Makes a PostProcessor object. More...
 
TrainingStatsGathererget_statistics_gatherer () override
 
- Public Member Functions inherited from dismec::TrainingSpec
 TrainingSpec (std::shared_ptr< const DatasetBase > data)
 
virtual ~TrainingSpec ()=default
 
const DatasetBaseget_data () const
 
virtual long num_features () const
 
const std::shared_ptr< spdlog::logger > & get_logger () const
 
void set_logger (std::shared_ptr< spdlog::logger > l)
 

Private Attributes

HyperParameters m_NewtonSettings
 
std::shared_ptr< WeightingSchemem_Weighting
 
bool m_UseSparseModel = false
 
std::shared_ptr< init::WeightInitializationStrategym_InitStrategy
 
std::shared_ptr< postproc::PostProcessFactorym_PostProcessor
 
parallel::NUMAReplicator< const GenericFeatureMatrixm_FeatureReplicator
 
std::shared_ptr< TrainingStatsGathererm_StatsGather
 
double m_BaseEpsilon
 
RegularizerSpec m_Regularizer
 
LossType m_Loss
 

Detailed Description

An implementation of TrainingSpec that models the DiSMEC algorithm.

The algorithm runs the NewtonWithLineSearch optimizer on a Regularized_SquaredHingeSVC objective. The minimization can be influenced by providing a HyperParameters object that sets e.g. the stopping criterion and number of steps. The squared hinge loss can be influenced by giving a custom WeightingScheme to e.g. have constant weighting or propensity based weighting.

The stopping criterion epsilon of the NewtonWithLineSearch optimizer is adjusted for the number of positive/negative label instances from the given base value. If eps is the value given in hyper_params and for a given label id there are p and n positive and negative instances, then the epsilon used will be \(\text{epsilon} = \text{eps} \cdot \text{min}(p, n, 1) / (p+n) \).

Todo:
Figure out why we do this and put a reference/explanation here.

Definition at line 32 of file dismec.h.

Constructor & Destructor Documentation

◆ DiSMECTraining()

DiSMECTraining::DiSMECTraining ( std::shared_ptr< const DatasetBase data,
HyperParameters  hyper_params,
std::shared_ptr< WeightingScheme weighting,
std::shared_ptr< init::WeightInitializationStrategy init,
std::shared_ptr< postproc::PostProcessFactory post_proc,
std::shared_ptr< TrainingStatsGatherer gatherer,
bool  use_sparse,
RegularizerSpec  regularizer,
LossType  loss 
)

Creates a DiSMECTraining instance.

Parameters
dataThe dataset on which to train.
hyper_paramsHyper parameters that will be applied to the NewtonWithLineSearch optimizer.
weightingPositive/Negative label weighting that will be used for the Regularized_SquaredHingeSVC objective.

Definition at line 90 of file dismec.cpp.

References dismec::HyperParameters::get(), m_BaseEpsilon, m_InitStrategy, m_NewtonSettings, and m_PostProcessor.

Member Function Documentation

◆ get_statistics_gatherer()

TrainingStatsGatherer & DiSMECTraining::get_statistics_gatherer ( )
overridevirtual

Implements dismec::TrainingSpec.

Definition at line 152 of file dismec.cpp.

References m_StatsGather.

◆ make_initializer()

std::unique_ptr< init::WeightsInitializer > DiSMECTraining::make_initializer ( ) const
overridevirtual

Makes a WeightsInitializer object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned initializer will be used. so that the default NUMA-local strategy should be reasonable.

Implements dismec::TrainingSpec.

Definition at line 136 of file dismec.cpp.

References m_FeatureReplicator, and m_InitStrategy.

◆ make_minimizer()

std::unique_ptr< solvers::Minimizer > DiSMECTraining::make_minimizer ( ) const
overridevirtual

Makes a Minimizer object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread will run the minimizer, so that the default NUMA-local strategy should be reasonable.

Implements dismec::TrainingSpec.

Definition at line 71 of file dismec.cpp.

References dismec::HyperParameters::apply(), m_NewtonSettings, and dismec::TrainingSpec::num_features().

◆ make_model()

std::shared_ptr< model::Model > DiSMECTraining::make_model ( long  num_features,
model::PartialModelSpec  spec 
) const
overridevirtual

Creates the model that will be used to store the results.

This extension point gives the TrainingSpec a way to decide whether the model storage used shall be a sparse or a dense model, or maybe some wrapper pointing to external memory. TODO why is this a shared_ptr and not a unique_ptr ?

Parameters
num_featuresNumber of input features for the model.
specPartial model specification for the created model.

Implements dismec::TrainingSpec.

Definition at line 140 of file dismec.cpp.

References m_UseSparseModel, and dismec::TrainingSpec::num_features().

◆ make_objective()

std::shared_ptr< objective::Objective > DiSMECTraining::make_objective ( ) const
overridevirtual

Makes an Objective object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers.

Implements dismec::TrainingSpec.

Definition at line 64 of file dismec.cpp.

References m_FeatureReplicator, m_Loss, m_Regularizer, dismec::make_loss(), dismec::objective::make_regularizer(), and dismec::types::visit().

◆ make_post_processor()

std::unique_ptr< postproc::PostProcessor > DiSMECTraining::make_post_processor ( const std::shared_ptr< objective::Objective > &  objective) const
overridevirtual

Makes a PostProcessor object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned post processor will be used, so that the default NUMA-local strategy should be reasonable. The PostProcessor can be adapted to the thread_local objective that is supplied here.

Implements dismec::TrainingSpec.

Definition at line 148 of file dismec.cpp.

References m_PostProcessor.

◆ update_minimizer()

void DiSMECTraining::update_minimizer ( solvers::Minimizer minimizer,
label_id_t  label_id 
) const
overridevirtual

Updates the setting of the Minimizer for handling label label_id.

This is needed e.g. to set a stopping criterion that depends on the number of positive labels. This function will be called concurrently from different threads, but each thread will make calls with different minimizer parameter.

Parameters
minimizerA Minimizer. This is assumed to have been created using make_minimizer(), so in particular it should by dynamic_cast-able to the actual Minimizer type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the minimizer.

Implements dismec::TrainingSpec.

Definition at line 77 of file dismec.cpp.

References dismec::TrainingSpec::get_data(), m_BaseEpsilon, dismec::DatasetBase::num_examples(), and dismec::DatasetBase::num_positives().

◆ update_objective()

void DiSMECTraining::update_objective ( objective::Objective objective,
label_id_t  label_id 
) const
overridevirtual

Updates the setting of the Objective for handling label label_id.

This will e.g. extract the corresponding label vector from the dataset and supply it to the objective. This function will be called concurrently from different threads, but each thread will call with a different objective parameter.

Parameters
objectiveAn Objective. This is assumed to have been created using make_objective(), so in particular it should by dynamic_cast-able to the actual Objective type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the objective.

Implements dismec::TrainingSpec.

Definition at line 122 of file dismec.cpp.

References dismec::TrainingSpec::get_data(), dismec::DatasetBase::get_labels(), and m_Weighting.

Member Data Documentation

◆ m_BaseEpsilon

double dismec::DiSMECTraining::m_BaseEpsilon
private

Definition at line 74 of file dismec.h.

Referenced by DiSMECTraining(), and update_minimizer().

◆ m_FeatureReplicator

parallel::NUMAReplicator<const GenericFeatureMatrix> dismec::DiSMECTraining::m_FeatureReplicator
private

Definition at line 70 of file dismec.h.

Referenced by make_initializer(), and make_objective().

◆ m_InitStrategy

std::shared_ptr<init::WeightInitializationStrategy> dismec::DiSMECTraining::m_InitStrategy
private

Definition at line 65 of file dismec.h.

Referenced by DiSMECTraining(), and make_initializer().

◆ m_Loss

LossType dismec::DiSMECTraining::m_Loss
private

Definition at line 76 of file dismec.h.

Referenced by make_objective().

◆ m_NewtonSettings

HyperParameters dismec::DiSMECTraining::m_NewtonSettings
private

Definition at line 60 of file dismec.h.

Referenced by DiSMECTraining(), and make_minimizer().

◆ m_PostProcessor

std::shared_ptr<postproc::PostProcessFactory> dismec::DiSMECTraining::m_PostProcessor
private

Definition at line 68 of file dismec.h.

Referenced by DiSMECTraining(), and make_post_processor().

◆ m_Regularizer

RegularizerSpec dismec::DiSMECTraining::m_Regularizer
private

Definition at line 75 of file dismec.h.

Referenced by make_objective().

◆ m_StatsGather

std::shared_ptr<TrainingStatsGatherer> dismec::DiSMECTraining::m_StatsGather
private

Definition at line 72 of file dismec.h.

Referenced by get_statistics_gatherer().

◆ m_UseSparseModel

bool dismec::DiSMECTraining::m_UseSparseModel = false
private

Definition at line 62 of file dismec.h.

Referenced by make_model().

◆ m_Weighting

std::shared_ptr<WeightingScheme> dismec::DiSMECTraining::m_Weighting
private

Definition at line 61 of file dismec.h.

Referenced by update_objective().


The documentation for this class was generated from the following files: