DiSMEC++
dismec::TrainingSpec Class Referenceabstract

This class gathers the setting-specific parts of the training process. More...

#include <spec.h>

Inheritance diagram for dismec::TrainingSpec:
dismec::CascadeTraining dismec::DiSMECTraining

Public Member Functions

 TrainingSpec (std::shared_ptr< const DatasetBase > data)
 
virtual ~TrainingSpec ()=default
 
const DatasetBaseget_data () const
 
virtual long num_features () const
 
virtual std::shared_ptr< objective::Objectivemake_objective () const =0
 Makes an Objective object suitable for the dataset. More...
 
virtual std::unique_ptr< solvers::Minimizermake_minimizer () const =0
 Makes a Minimizer object suitable for the dataset. More...
 
virtual std::unique_ptr< init::WeightsInitializermake_initializer () const =0
 Makes a WeightsInitializer object. More...
 
virtual std::unique_ptr< postproc::PostProcessormake_post_processor (const std::shared_ptr< objective::Objective > &objective) const =0
 Makes a PostProcessor object. More...
 
virtual std::shared_ptr< model::Modelmake_model (long num_features, model::PartialModelSpec spec) const =0
 Creates the model that will be used to store the results. More...
 
virtual void update_minimizer (solvers::Minimizer &minimizer, label_id_t label_id) const =0
 Updates the setting of the Minimizer for handling label label_id. More...
 
virtual void update_objective (objective::Objective &objective, label_id_t label_id) const =0
 Updates the setting of the Objective for handling label label_id. More...
 
virtual TrainingStatsGathererget_statistics_gatherer ()=0
 
const std::shared_ptr< spdlog::logger > & get_logger () const
 
void set_logger (std::shared_ptr< spdlog::logger > l)
 

Private Attributes

std::shared_ptr< const DatasetBasem_Data
 
std::shared_ptr< spdlog::logger > m_Logger
 logger to be used for info logging More...
 

Detailed Description

This class gathers the setting-specific parts of the training process.

The TrainingSpec class is responsible for generating and updating the Minimizer and Objective that will be used in the TrainingTaskGenerator.

Todo:
should we give the dataset to each operation, or maybe just set it in the beginning? I think maybe at some point convert all this dataset stuff to use shared_ptr everywhere.

Definition at line 24 of file spec.h.

Constructor & Destructor Documentation

◆ TrainingSpec()

dismec::TrainingSpec::TrainingSpec ( std::shared_ptr< const DatasetBase data)
inlineexplicit

Definition at line 26 of file spec.h.

◆ ~TrainingSpec()

virtual dismec::TrainingSpec::~TrainingSpec ( )
virtualdefault

Member Function Documentation

◆ get_data()

◆ get_logger()

const std::shared_ptr<spdlog::logger>& dismec::TrainingSpec::get_logger ( ) const
inline

Definition at line 105 of file spec.h.

References m_Logger.

◆ get_statistics_gatherer()

virtual TrainingStatsGatherer& dismec::TrainingSpec::get_statistics_gatherer ( )
pure virtual

◆ make_initializer()

virtual std::unique_ptr<init::WeightsInitializer> dismec::TrainingSpec::make_initializer ( ) const
pure virtual

Makes a WeightsInitializer object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned initializer will be used. so that the default NUMA-local strategy should be reasonable.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ make_minimizer()

virtual std::unique_ptr<solvers::Minimizer> dismec::TrainingSpec::make_minimizer ( ) const
pure virtual

Makes a Minimizer object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread will run the minimizer, so that the default NUMA-local strategy should be reasonable.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ make_model()

virtual std::shared_ptr<model::Model> dismec::TrainingSpec::make_model ( long  num_features,
model::PartialModelSpec  spec 
) const
pure virtual

Creates the model that will be used to store the results.

This extension point gives the TrainingSpec a way to decide whether the model storage used shall be a sparse or a dense model, or maybe some wrapper pointing to external memory. TODO why is this a shared_ptr and not a unique_ptr ?

Parameters
num_featuresNumber of input features for the model.
specPartial model specification for the created model.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ make_objective()

virtual std::shared_ptr<objective::Objective> dismec::TrainingSpec::make_objective ( ) const
pure virtual

Makes an Objective object suitable for the dataset.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ make_post_processor()

virtual std::unique_ptr<postproc::PostProcessor> dismec::TrainingSpec::make_post_processor ( const std::shared_ptr< objective::Objective > &  objective) const
pure virtual

Makes a PostProcessor object.

This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned post processor will be used, so that the default NUMA-local strategy should be reasonable. The PostProcessor can be adapted to the thread_local objective that is supplied here.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ num_features()

long TrainingSpec::num_features ( ) const
virtual

◆ set_logger()

void dismec::TrainingSpec::set_logger ( std::shared_ptr< spdlog::logger >  l)
inline

Definition at line 109 of file spec.h.

References m_Logger.

◆ update_minimizer()

virtual void dismec::TrainingSpec::update_minimizer ( solvers::Minimizer minimizer,
label_id_t  label_id 
) const
pure virtual

Updates the setting of the Minimizer for handling label label_id.

This is needed e.g. to set a stopping criterion that depends on the number of positive labels. This function will be called concurrently from different threads, but each thread will make calls with different minimizer parameter.

Parameters
minimizerA Minimizer. This is assumed to have been created using make_minimizer(), so in particular it should by dynamic_cast-able to the actual Minimizer type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the minimizer.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

◆ update_objective()

virtual void dismec::TrainingSpec::update_objective ( objective::Objective objective,
label_id_t  label_id 
) const
pure virtual

Updates the setting of the Objective for handling label label_id.

This will e.g. extract the corresponding label vector from the dataset and supply it to the objective. This function will be called concurrently from different threads, but each thread will call with a different objective parameter.

Parameters
objectiveAn Objective. This is assumed to have been created using make_objective(), so in particular it should by dynamic_cast-able to the actual Objective type used by this TrainingSpec.
label_idThe id of the label inside the dataset for which we update the objective.

Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.

Member Data Documentation

◆ m_Data

std::shared_ptr<const DatasetBase> dismec::TrainingSpec::m_Data
private

Definition at line 114 of file spec.h.

Referenced by get_data().

◆ m_Logger

std::shared_ptr<spdlog::logger> dismec::TrainingSpec::m_Logger
private

logger to be used for info logging

Definition at line 117 of file spec.h.

Referenced by get_logger(), and set_logger().


The documentation for this class was generated from the following files: