6 #ifndef DISMEC_TRAINING_SPEC_H
7 #define DISMEC_TRAINING_SPEC_H
12 #include "spdlog/fwd.h"
40 [[nodiscard]]
virtual std::shared_ptr<objective::Objective>
make_objective()
const = 0;
48 [[nodiscard]]
virtual std::unique_ptr<solvers::Minimizer>
make_minimizer()
const = 0;
56 [[nodiscard]]
virtual std::unique_ptr<init::WeightsInitializer>
make_initializer()
const = 0;
65 [[nodiscard]]
virtual std::unique_ptr<postproc::PostProcessor>
make_post_processor(
const std::shared_ptr<objective::Objective>&
objective)
const = 0;
105 [[nodiscard]]
const std::shared_ptr<spdlog::logger>&
get_logger()
const {
114 std::shared_ptr<const DatasetBase>
m_Data;
138 std::shared_ptr<objective::Objective>
make_loss(
140 std::shared_ptr<const GenericFeatureMatrix> X,
141 std::unique_ptr<objective::Objective> regularizer);
143 using RegularizerSpec = std::variant<objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig>;
147 std::shared_ptr<init::WeightInitializationStrategy>
Init;
156 std::shared_ptr<init::WeightInitializationStrategy>
DenseInit;
157 std::shared_ptr<init::WeightInitializationStrategy>
SparseInit;
170 std::shared_ptr<const GenericFeatureMatrix> dense,
171 std::shared_ptr<
const std::vector<std::vector<long>>> shortlist,
This class represents a set of hyper-parameters.
This class gathers the setting-specific parts of the training process.
virtual std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const =0
Creates the model that will be used to store the results.
virtual std::unique_ptr< solvers::Minimizer > make_minimizer() const =0
Makes a Minimizer object suitable for the dataset.
void set_logger(std::shared_ptr< spdlog::logger > l)
virtual ~TrainingSpec()=default
std::shared_ptr< const DatasetBase > m_Data
virtual std::unique_ptr< init::WeightsInitializer > make_initializer() const =0
Makes a WeightsInitializer object.
const std::shared_ptr< spdlog::logger > & get_logger() const
virtual TrainingStatsGatherer & get_statistics_gatherer()=0
virtual void update_objective(objective::Objective &objective, label_id_t label_id) const =0
Updates the setting of the Objective for handling label label_id.
const DatasetBase & get_data() const
TrainingSpec(std::shared_ptr< const DatasetBase > data)
std::shared_ptr< spdlog::logger > m_Logger
logger to be used for info logging
virtual std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const =0
Makes a PostProcessor object.
virtual long num_features() const
virtual void update_minimizer(solvers::Minimizer &minimizer, label_id_t label_id) const =0
Updates the setting of the Minimizer for handling label label_id.
virtual std::shared_ptr< objective::Objective > make_objective() const =0
Makes an Objective object suitable for the dataset.
Strong typedef for an int to signify a label id.
Class that models an optimization objective.
Main namespace in which all types, classes, and functions are defined.
std::shared_ptr< objective::Objective > make_loss(LossType type, std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< objective::Objective > regularizer)
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
std::shared_ptr< TrainingSpec > create_cascade_training(std::shared_ptr< const DatasetBase > data, std::shared_ptr< const GenericFeatureMatrix > dense, std::shared_ptr< const std::vector< std::vector< long >>> shortlist, HyperParameters params, CascadeTrainingConfig config)
float real_t
The default type for floating point values.
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
std::shared_ptr< init::WeightInitializationStrategy > DenseInit
std::shared_ptr< init::WeightInitializationStrategy > SparseInit
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
RegularizerSpec Regularizer
std::shared_ptr< init::WeightInitializationStrategy > Init
std::shared_ptr< WeightingScheme > Weighting
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Specifies how to interpret a weight matrix for a partial model.