DiSMEC++
|
This class gathers the setting-specific parts of the training process. More...
#include <spec.h>
Public Member Functions | |
TrainingSpec (std::shared_ptr< const DatasetBase > data) | |
virtual | ~TrainingSpec ()=default |
const DatasetBase & | get_data () const |
virtual long | num_features () const |
virtual std::shared_ptr< objective::Objective > | make_objective () const =0 |
Makes an Objective object suitable for the dataset. More... | |
virtual std::unique_ptr< solvers::Minimizer > | make_minimizer () const =0 |
Makes a Minimizer object suitable for the dataset. More... | |
virtual std::unique_ptr< init::WeightsInitializer > | make_initializer () const =0 |
Makes a WeightsInitializer object. More... | |
virtual std::unique_ptr< postproc::PostProcessor > | make_post_processor (const std::shared_ptr< objective::Objective > &objective) const =0 |
Makes a PostProcessor object. More... | |
virtual std::shared_ptr< model::Model > | make_model (long num_features, model::PartialModelSpec spec) const =0 |
Creates the model that will be used to store the results. More... | |
virtual void | update_minimizer (solvers::Minimizer &minimizer, label_id_t label_id) const =0 |
Updates the setting of the Minimizer for handling label label_id . More... | |
virtual void | update_objective (objective::Objective &objective, label_id_t label_id) const =0 |
Updates the setting of the Objective for handling label label_id . More... | |
virtual TrainingStatsGatherer & | get_statistics_gatherer ()=0 |
const std::shared_ptr< spdlog::logger > & | get_logger () const |
void | set_logger (std::shared_ptr< spdlog::logger > l) |
Private Attributes | |
std::shared_ptr< const DatasetBase > | m_Data |
std::shared_ptr< spdlog::logger > | m_Logger |
logger to be used for info logging More... | |
This class gathers the setting-specific parts of the training process.
The TrainingSpec class is responsible for generating and updating the Minimizer and Objective that will be used in the TrainingTaskGenerator.
|
inlineexplicit |
|
virtualdefault |
|
inline |
Definition at line 31 of file spec.h.
References m_Data.
Referenced by num_features(), dismec::CascadeTraining::update_minimizer(), dismec::DiSMECTraining::update_minimizer(), dismec::CascadeTraining::update_objective(), and dismec::DiSMECTraining::update_objective().
|
inline |
|
pure virtual |
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Makes a WeightsInitializer object.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned initializer will be used. so that the default NUMA-local strategy should be reasonable.
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Makes a Minimizer object suitable for the dataset.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread will run the minimizer, so that the default NUMA-local strategy should be reasonable.
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Creates the model that will be used to store the results.
This extension point gives the TrainingSpec
a way to decide whether the model storage used shall be a sparse or a dense model, or maybe some wrapper pointing to external memory. TODO why is this a shared_ptr and not a unique_ptr ?
num_features | Number of input features for the model. |
spec | Partial model specification for the created model. |
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Makes an Objective object suitable for the dataset.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers.
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Makes a PostProcessor object.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned post processor will be used, so that the default NUMA-local strategy should be reasonable. The PostProcessor
can be adapted to the thread_local objective
that is supplied here.
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
virtual |
Reimplemented in dismec::CascadeTraining.
Definition at line 173 of file dismec.cpp.
References get_data(), and dismec::DatasetBase::num_features().
Referenced by dismec::DiSMECTraining::make_minimizer(), and dismec::DiSMECTraining::make_model().
|
inline |
|
pure virtual |
Updates the setting of the Minimizer for handling label label_id
.
This is needed e.g. to set a stopping criterion that depends on the number of positive labels. This function will be called concurrently from different threads, but each thread will make calls with different minimizer
parameter.
minimizer | A Minimizer. This is assumed to have been created using make_minimizer(), so in particular it should by dynamic_cast -able to the actual Minimizer type used by this TrainingSpec. |
label_id | The id of the label inside the dataset for which we update the minimizer. |
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
pure virtual |
Updates the setting of the Objective for handling label label_id
.
This will e.g. extract the corresponding label vector from the dataset and supply it to the objective. This function will be called concurrently from different threads, but each thread will call with a different objective
parameter.
objective | An Objective. This is assumed to have been created using make_objective(), so in particular it should by dynamic_cast -able to the actual Objective type used by this TrainingSpec. |
label_id | The id of the label inside the dataset for which we update the objective. |
Implemented in dismec::DiSMECTraining, and dismec::CascadeTraining.
|
private |
Definition at line 114 of file spec.h.
Referenced by get_data().
|
private |
logger to be used for info logging
Definition at line 117 of file spec.h.
Referenced by get_logger(), and set_logger().