6 #ifndef DISMEC_TRAINING_SPEC_H 
    7 #define DISMEC_TRAINING_SPEC_H 
   12 #include "spdlog/fwd.h" 
   40         [[nodiscard]] 
virtual std::shared_ptr<objective::Objective> 
make_objective() 
const = 0;
 
   48         [[nodiscard]] 
virtual std::unique_ptr<solvers::Minimizer> 
make_minimizer() 
const = 0;
 
   56         [[nodiscard]] 
virtual std::unique_ptr<init::WeightsInitializer> 
make_initializer() 
const = 0;
 
   65         [[nodiscard]] 
virtual std::unique_ptr<postproc::PostProcessor> 
make_post_processor(
const std::shared_ptr<objective::Objective>& 
objective) 
const = 0;
 
  105         [[nodiscard]] 
const std::shared_ptr<spdlog::logger>& 
get_logger()
 const {
 
  114         std::shared_ptr<const DatasetBase> 
m_Data;
 
  138     std::shared_ptr<objective::Objective> 
make_loss(
 
  140             std::shared_ptr<const GenericFeatureMatrix> X,
 
  141             std::unique_ptr<objective::Objective> regularizer);
 
  143     using RegularizerSpec = std::variant<objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig>;
 
  147         std::shared_ptr<init::WeightInitializationStrategy> 
Init;
 
  156         std::shared_ptr<init::WeightInitializationStrategy> 
DenseInit;
 
  157         std::shared_ptr<init::WeightInitializationStrategy> 
SparseInit;
 
  170                                                           std::shared_ptr<const GenericFeatureMatrix> dense,
 
  171                                                           std::shared_ptr<
const std::vector<std::vector<long>>> shortlist,
 
This class represents a set of hyper-parameters.
This class gathers the setting-specific parts of the training process.
virtual std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const =0
Creates the model that will be used to store the results.
virtual std::unique_ptr< solvers::Minimizer > make_minimizer() const =0
Makes a Minimizer object suitable for the dataset.
void set_logger(std::shared_ptr< spdlog::logger > l)
virtual ~TrainingSpec()=default
std::shared_ptr< const DatasetBase > m_Data
virtual std::unique_ptr< init::WeightsInitializer > make_initializer() const =0
Makes a WeightsInitializer object.
const std::shared_ptr< spdlog::logger > & get_logger() const
virtual TrainingStatsGatherer & get_statistics_gatherer()=0
virtual void update_objective(objective::Objective &objective, label_id_t label_id) const =0
Updates the setting of the Objective for handling label label_id.
const DatasetBase & get_data() const
TrainingSpec(std::shared_ptr< const DatasetBase > data)
std::shared_ptr< spdlog::logger > m_Logger
logger to be used for info logging
virtual std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const =0
Makes a PostProcessor object.
virtual long num_features() const
virtual void update_minimizer(solvers::Minimizer &minimizer, label_id_t label_id) const =0
Updates the setting of the Minimizer for handling label label_id.
virtual std::shared_ptr< objective::Objective > make_objective() const =0
Makes an Objective object suitable for the dataset.
Strong typedef for an int to signify a label id.
Class that models an optimization objective.
Main namespace in which all types, classes, and functions are defined.
std::shared_ptr< objective::Objective > make_loss(LossType type, std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< objective::Objective > regularizer)
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
std::shared_ptr< TrainingSpec > create_cascade_training(std::shared_ptr< const DatasetBase > data, std::shared_ptr< const GenericFeatureMatrix > dense, std::shared_ptr< const std::vector< std::vector< long >>> shortlist, HyperParameters params, CascadeTrainingConfig config)
float real_t
The default type for floating point values.
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
std::shared_ptr< init::WeightInitializationStrategy > DenseInit
std::shared_ptr< init::WeightInitializationStrategy > SparseInit
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
RegularizerSpec Regularizer
std::shared_ptr< init::WeightInitializationStrategy > Init
std::shared_ptr< WeightingScheme > Weighting
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Specifies how to interpret a weight matrix for a partial model.