6 #ifndef DISMEC_SRC_TRAINING_CASCADE_H
7 #define DISMEC_SRC_TRAINING_CASCADE_H
16 std::shared_ptr<const GenericFeatureMatrix> dense_data,
18 std::shared_ptr<init::WeightInitializationStrategy> dense_init,
20 std::shared_ptr<init::WeightInitializationStrategy> sparse_init,
22 std::shared_ptr<postproc::PostProcessFactory> post_proc,
23 std::shared_ptr<TrainingStatsGatherer> gatherer,
24 std::shared_ptr<
const std::vector<std::vector<long>>> shortlist =
nullptr);
28 [[nodiscard]] std::shared_ptr<objective::Objective>
make_objective()
const override;
30 [[nodiscard]] std::unique_ptr<solvers::Minimizer>
make_minimizer()
const override;
32 [[nodiscard]] std::unique_ptr<init::WeightsInitializer>
make_initializer()
const override;
34 [[nodiscard]] std::shared_ptr<model::Model>
41 [[nodiscard]] std::unique_ptr<postproc::PostProcessor>
52 std::shared_ptr<const std::vector<std::vector<long>>>
m_Shortlist;
long num_features() const override
std::shared_ptr< objective::Objective > make_objective() const override
Makes an Objective object suitable for the dataset.
HyperParameters m_NewtonSettings
std::unique_ptr< solvers::Minimizer > make_minimizer() const override
Makes a Minimizer object suitable for the dataset.
std::shared_ptr< init::WeightInitializationStrategy > m_DenseInitStrategy
std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const override
Creates the model that will be used to store the results.
std::unique_ptr< init::WeightsInitializer > make_initializer() const override
Makes a WeightsInitializer object.
std::shared_ptr< const std::vector< std::vector< long > > > m_Shortlist
CascadeTraining(std::shared_ptr< const DatasetBase > tfidf_data, std::shared_ptr< const GenericFeatureMatrix > dense_data, HyperParameters hyper_params, std::shared_ptr< init::WeightInitializationStrategy > dense_init, real_t dense_reg, std::shared_ptr< init::WeightInitializationStrategy > sparse_init, real_t sparse_reg, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, std::shared_ptr< const std::vector< std::vector< long >>> shortlist=nullptr)
TrainingStatsGatherer & get_statistics_gatherer() override
parallel::NUMAReplicator< const GenericFeatureMatrix > m_DenseReplicator
std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const override
Makes a PostProcessor object.
std::shared_ptr< postproc::PostProcessFactory > m_PostProcessor
std::shared_ptr< TrainingStatsGatherer > m_StatsGather
void update_minimizer(solvers::Minimizer &base_minimizer, label_id_t label_id) const override
Updates the setting of the Minimizer for handling label label_id.
void update_objective(objective::Objective &base_objective, label_id_t label_id) const override
Updates the setting of the Objective for handling label label_id.
parallel::NUMAReplicator< const GenericFeatureMatrix > m_SparseReplicator
std::shared_ptr< init::WeightInitializationStrategy > m_SparseInitStrategy
This class represents a set of hyper-parameters.
This class gathers the setting-specific parts of the training process.
Strong typedef for an int to signify a label id.
Class that models an optimization objective.
Helper class to ensure that each NUMA node has its own copy of some immutable data.
Main namespace in which all types, classes, and functions are defined.
float real_t
The default type for floating point values.
Specifies how to interpret a weight matrix for a partial model.