DiSMEC++
|
#include <cascade.h>
Public Member Functions | |
CascadeTraining (std::shared_ptr< const DatasetBase > tfidf_data, std::shared_ptr< const GenericFeatureMatrix > dense_data, HyperParameters hyper_params, std::shared_ptr< init::WeightInitializationStrategy > dense_init, real_t dense_reg, std::shared_ptr< init::WeightInitializationStrategy > sparse_init, real_t sparse_reg, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, std::shared_ptr< const std::vector< std::vector< long >>> shortlist=nullptr) | |
long | num_features () const override |
std::shared_ptr< objective::Objective > | make_objective () const override |
Makes an Objective object suitable for the dataset. More... | |
std::unique_ptr< solvers::Minimizer > | make_minimizer () const override |
Makes a Minimizer object suitable for the dataset. More... | |
std::unique_ptr< init::WeightsInitializer > | make_initializer () const override |
Makes a WeightsInitializer object. More... | |
std::shared_ptr< model::Model > | make_model (long num_features, model::PartialModelSpec spec) const override |
Creates the model that will be used to store the results. More... | |
void | update_minimizer (solvers::Minimizer &base_minimizer, label_id_t label_id) const override |
Updates the setting of the Minimizer for handling label label_id . More... | |
void | update_objective (objective::Objective &base_objective, label_id_t label_id) const override |
Updates the setting of the Objective for handling label label_id . More... | |
std::unique_ptr< postproc::PostProcessor > | make_post_processor (const std::shared_ptr< objective::Objective > &objective) const override |
Makes a PostProcessor object. More... | |
TrainingStatsGatherer & | get_statistics_gatherer () override |
![]() | |
TrainingSpec (std::shared_ptr< const DatasetBase > data) | |
virtual | ~TrainingSpec ()=default |
const DatasetBase & | get_data () const |
const std::shared_ptr< spdlog::logger > & | get_logger () const |
void | set_logger (std::shared_ptr< spdlog::logger > l) |
Private Attributes | |
HyperParameters | m_NewtonSettings |
parallel::NUMAReplicator< const GenericFeatureMatrix > | m_SparseReplicator |
parallel::NUMAReplicator< const GenericFeatureMatrix > | m_DenseReplicator |
std::shared_ptr< const std::vector< std::vector< long > > > | m_Shortlist |
std::shared_ptr< postproc::PostProcessFactory > | m_PostProcessor |
std::shared_ptr< init::WeightInitializationStrategy > | m_DenseInitStrategy |
std::shared_ptr< init::WeightInitializationStrategy > | m_SparseInitStrategy |
std::shared_ptr< TrainingStatsGatherer > | m_StatsGather |
long | m_NumFeatures |
double | m_BaseEpsilon |
real_t | m_DenseReg |
real_t | m_SparseReg |
CascadeTraining::CascadeTraining | ( | std::shared_ptr< const DatasetBase > | tfidf_data, |
std::shared_ptr< const GenericFeatureMatrix > | dense_data, | ||
HyperParameters | hyper_params, | ||
std::shared_ptr< init::WeightInitializationStrategy > | dense_init, | ||
real_t | dense_reg, | ||
std::shared_ptr< init::WeightInitializationStrategy > | sparse_init, | ||
real_t | sparse_reg, | ||
std::shared_ptr< postproc::PostProcessFactory > | post_proc, | ||
std::shared_ptr< TrainingStatsGatherer > | gatherer, | ||
std::shared_ptr< const std::vector< std::vector< long >>> | shortlist = nullptr |
||
) |
Definition at line 132 of file cascade.cpp.
References dismec::HyperParameters::get(), m_BaseEpsilon, and m_NewtonSettings.
|
overridevirtual |
Implements dismec::TrainingSpec.
Definition at line 128 of file cascade.cpp.
References m_StatsGather.
|
overridevirtual |
Makes a WeightsInitializer object.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned initializer will be used. so that the default NUMA-local strategy should be reasonable.
Implements dismec::TrainingSpec.
Definition at line 109 of file cascade.cpp.
References m_DenseInitStrategy, m_DenseReplicator, m_SparseInitStrategy, and m_SparseReplicator.
|
overridevirtual |
Makes a Minimizer object suitable for the dataset.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread will run the minimizer, so that the default NUMA-local strategy should be reasonable.
Implements dismec::TrainingSpec.
Definition at line 46 of file cascade.cpp.
References dismec::HyperParameters::apply(), m_NewtonSettings, and m_NumFeatures.
|
overridevirtual |
Creates the model that will be used to store the results.
This extension point gives the TrainingSpec
a way to decide whether the model storage used shall be a sparse or a dense model, or maybe some wrapper pointing to external memory. TODO why is this a shared_ptr and not a unique_ptr ?
num_features | Number of input features for the model. |
spec | Partial model specification for the created model. |
Implements dismec::TrainingSpec.
Definition at line 119 of file cascade.cpp.
References num_features().
|
overridevirtual |
Makes an Objective object suitable for the dataset.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers.
Implements dismec::TrainingSpec.
Definition at line 38 of file cascade.cpp.
References m_DenseReg, m_DenseReplicator, m_SparseReg, m_SparseReplicator, and dismec::objective::make_sp_dense_squared_hinge().
|
overridevirtual |
Makes a PostProcessor object.
This is called before the actual work of the training threads starts, so that we can pre-allocate all the necessary buffers. Is is called in the thread were the returned post processor will be used, so that the default NUMA-local strategy should be reasonable. The PostProcessor
can be adapted to the thread_local objective
that is supplied here.
Implements dismec::TrainingSpec.
Definition at line 124 of file cascade.cpp.
References m_PostProcessor.
|
inlineoverridevirtual |
Reimplemented from dismec::TrainingSpec.
Definition at line 26 of file cascade.h.
References m_NumFeatures.
Referenced by make_model().
|
overridevirtual |
Updates the setting of the Minimizer for handling label label_id
.
This is needed e.g. to set a stopping criterion that depends on the number of positive labels. This function will be called concurrently from different threads, but each thread will make calls with different minimizer
parameter.
minimizer | A Minimizer. This is assumed to have been created using make_minimizer(), so in particular it should by dynamic_cast -able to the actual Minimizer type used by this TrainingSpec. |
label_id | The id of the label inside the dataset for which we update the minimizer. |
Implements dismec::TrainingSpec.
Definition at line 53 of file cascade.cpp.
References dismec::TrainingSpec::get_data(), dismec::DatasetBase::get_labels(), m_BaseEpsilon, m_Shortlist, dismec::DatasetBase::num_examples(), dismec::DatasetBase::num_positives(), and dismec::opaque_int_type< Tag, T >::to_index().
|
overridevirtual |
Updates the setting of the Objective for handling label label_id
.
This will e.g. extract the corresponding label vector from the dataset and supply it to the objective. This function will be called concurrently from different threads, but each thread will call with a different objective
parameter.
objective | An Objective. This is assumed to have been created using make_objective(), so in particular it should by dynamic_cast -able to the actual Objective type used by this TrainingSpec. |
label_id | The id of the label inside the dataset for which we update the objective. |
Implements dismec::TrainingSpec.
Definition at line 80 of file cascade.cpp.
References dismec::TrainingSpec::get_data(), dismec::DatasetBase::get_labels(), m_DenseReplicator, m_Shortlist, m_SparseReplicator, dismec::shortlist_features(), dismec::ssize(), and dismec::opaque_int_type< Tag, T >::to_index().
|
private |
Definition at line 64 of file cascade.h.
Referenced by CascadeTraining(), and update_minimizer().
|
private |
Definition at line 58 of file cascade.h.
Referenced by make_initializer().
|
private |
Definition at line 66 of file cascade.h.
Referenced by make_objective().
|
private |
Definition at line 50 of file cascade.h.
Referenced by make_initializer(), make_objective(), and update_objective().
|
private |
Definition at line 47 of file cascade.h.
Referenced by CascadeTraining(), and make_minimizer().
|
private |
Definition at line 63 of file cascade.h.
Referenced by make_minimizer(), and num_features().
|
private |
Definition at line 55 of file cascade.h.
Referenced by make_post_processor().
|
private |
Definition at line 52 of file cascade.h.
Referenced by update_minimizer(), and update_objective().
|
private |
Definition at line 59 of file cascade.h.
Referenced by make_initializer().
|
private |
Definition at line 67 of file cascade.h.
Referenced by make_objective().
|
private |
Definition at line 49 of file cascade.h.
Referenced by make_initializer(), make_objective(), and update_objective().
|
private |
Definition at line 61 of file cascade.h.
Referenced by get_statistics_gatherer().