6 #ifndef DISMEC_DISMEC_H
7 #define DISMEC_DISMEC_H
41 std::shared_ptr<WeightingScheme> weighting,
42 std::shared_ptr<init::WeightInitializationStrategy> init,
43 std::shared_ptr<postproc::PostProcessFactory> post_proc,
44 std::shared_ptr<TrainingStatsGatherer> gatherer,
48 [[nodiscard]] std::shared_ptr<objective::Objective>
make_objective()
const override;
49 [[nodiscard]] std::unique_ptr<solvers::Minimizer>
make_minimizer()
const override;
50 [[nodiscard]] std::unique_ptr<init::WeightsInitializer>
make_initializer()
const override;
56 [[nodiscard]] std::unique_ptr<postproc::PostProcessor>
make_post_processor(
const std::shared_ptr<objective::Objective>&
objective)
const override;
An implementation of TrainingSpec that models the DiSMEC algorithm.
parallel::NUMAReplicator< const GenericFeatureMatrix > m_FeatureReplicator
std::shared_ptr< objective::Objective > make_objective() const override
Makes an Objective object suitable for the dataset.
std::shared_ptr< WeightingScheme > m_Weighting
HyperParameters m_NewtonSettings
void update_objective(objective::Objective &base_objective, label_id_t label_id) const override
Updates the setting of the Objective for handling label label_id.
TrainingStatsGatherer & get_statistics_gatherer() override
std::shared_ptr< postproc::PostProcessFactory > m_PostProcessor
std::shared_ptr< init::WeightInitializationStrategy > m_InitStrategy
DiSMECTraining(std::shared_ptr< const DatasetBase > data, HyperParameters hyper_params, std::shared_ptr< WeightingScheme > weighting, std::shared_ptr< init::WeightInitializationStrategy > init, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, bool use_sparse, RegularizerSpec regularizer, LossType loss)
Creates a DiSMECTraining instance.
void update_minimizer(solvers::Minimizer &base_minimizer, label_id_t label_id) const override
Updates the setting of the Minimizer for handling label label_id.
std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const override
Creates the model that will be used to store the results.
std::unique_ptr< solvers::Minimizer > make_minimizer() const override
Makes a Minimizer object suitable for the dataset.
std::shared_ptr< TrainingStatsGatherer > m_StatsGather
RegularizerSpec m_Regularizer
std::unique_ptr< init::WeightsInitializer > make_initializer() const override
Makes a WeightsInitializer object.
std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const override
Makes a PostProcessor object.
This class represents a set of hyper-parameters.
This class gathers the setting-specific parts of the training process.
virtual long num_features() const
Strong typedef for an int to signify a label id.
Class that models an optimization objective.
Helper class to ensure that each NUMA node has its own copy of some immutable data.
Main namespace in which all types, classes, and functions are defined.
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
Specifies how to interpret a weight matrix for a partial model.