DiSMEC++
spec.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_TRAINING_SPEC_H
7 #define DISMEC_TRAINING_SPEC_H
8 
9 #include <memory>
10 #include "fwd.h"
11 #include "matrix_types.h"
12 #include "spdlog/fwd.h"
13 #include "objective/regularizers.h"
14 
15 namespace dismec
16 {
24  class TrainingSpec {
25  public:
26  explicit TrainingSpec(std::shared_ptr<const DatasetBase> data) :
27  m_Data(std::move(data)) {
28  }
29  virtual ~TrainingSpec() = default;
30 
31  [[nodiscard]] const DatasetBase& get_data() const { return *m_Data; }
32 
33  [[nodiscard]] virtual long num_features() const;
34 
40  [[nodiscard]] virtual std::shared_ptr<objective::Objective> make_objective() const = 0;
41 
48  [[nodiscard]] virtual std::unique_ptr<solvers::Minimizer> make_minimizer() const = 0;
49 
56  [[nodiscard]] virtual std::unique_ptr<init::WeightsInitializer> make_initializer() const = 0;
57 
65  [[nodiscard]] virtual std::unique_ptr<postproc::PostProcessor> make_post_processor(const std::shared_ptr<objective::Objective>& objective) const = 0;
66 
76  [[nodiscard]] virtual std::shared_ptr<model::Model> make_model(long num_features, model::PartialModelSpec spec) const = 0;
77 
88  virtual void update_minimizer(solvers::Minimizer& minimizer, label_id_t label_id) const = 0;
89 
100  virtual void update_objective(objective::Objective& objective, label_id_t label_id) const = 0;
101 
102  [[nodiscard]] virtual TrainingStatsGatherer& get_statistics_gatherer() = 0;
103 
104  // logger
105  [[nodiscard]] const std::shared_ptr<spdlog::logger>& get_logger() const {
106  return m_Logger;
107  }
108 
109  void set_logger(std::shared_ptr<spdlog::logger> l) {
110  m_Logger = std::move(l);
111  }
112 
113  private:
114  std::shared_ptr<const DatasetBase> m_Data;
115 
117  std::shared_ptr<spdlog::logger> m_Logger;
118  };
119 
120  enum class RegularizerType {
121  REG_L2,
122  REG_L1,
124  REG_HUBER,
127  };
128 
129  enum class LossType {
131  LOGISTIC,
132  HUBER_HINGE,
133  HINGE
134  };
135 
136  using real_t = float;
137 
138  std::shared_ptr<objective::Objective> make_loss(
139  LossType type,
140  std::shared_ptr<const GenericFeatureMatrix> X,
141  std::unique_ptr<objective::Objective> regularizer);
142 
143  using RegularizerSpec = std::variant<objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig>;
144 
146  std::shared_ptr<WeightingScheme> Weighting;
147  std::shared_ptr<init::WeightInitializationStrategy> Init;
148  std::shared_ptr<postproc::PostProcessFactory> PostProcessing;
149  std::shared_ptr<TrainingStatsGatherer> StatsGatherer;
150  bool Sparse;
153  };
154 
156  std::shared_ptr<init::WeightInitializationStrategy> DenseInit;
157  std::shared_ptr<init::WeightInitializationStrategy> SparseInit;
158  std::shared_ptr<postproc::PostProcessFactory> PostProcessing;
159  std::shared_ptr<TrainingStatsGatherer> StatsGatherer;
160 
163  };
164 
165  std::shared_ptr<TrainingSpec> create_dismec_training(std::shared_ptr<const DatasetBase> data,
166  HyperParameters params,
167  DismecTrainingConfig config);
168 
169  std::shared_ptr<TrainingSpec> create_cascade_training(std::shared_ptr<const DatasetBase> data,
170  std::shared_ptr<const GenericFeatureMatrix> dense,
171  std::shared_ptr<const std::vector<std::vector<long>>> shortlist,
172  HyperParameters params,
173  CascadeTrainingConfig config);
174 }
175 
176 #endif //DISMEC_TRAINING_SPEC_H
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
This class gathers the setting-specific parts of the training process.
Definition: spec.h:24
virtual std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const =0
Creates the model that will be used to store the results.
virtual std::unique_ptr< solvers::Minimizer > make_minimizer() const =0
Makes a Minimizer object suitable for the dataset.
void set_logger(std::shared_ptr< spdlog::logger > l)
Definition: spec.h:109
virtual ~TrainingSpec()=default
std::shared_ptr< const DatasetBase > m_Data
Definition: spec.h:114
virtual std::unique_ptr< init::WeightsInitializer > make_initializer() const =0
Makes a WeightsInitializer object.
const std::shared_ptr< spdlog::logger > & get_logger() const
Definition: spec.h:105
virtual TrainingStatsGatherer & get_statistics_gatherer()=0
virtual void update_objective(objective::Objective &objective, label_id_t label_id) const =0
Updates the setting of the Objective for handling label label_id.
const DatasetBase & get_data() const
Definition: spec.h:31
TrainingSpec(std::shared_ptr< const DatasetBase > data)
Definition: spec.h:26
std::shared_ptr< spdlog::logger > m_Logger
logger to be used for info logging
Definition: spec.h:117
virtual std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const =0
Makes a PostProcessor object.
virtual long num_features() const
Definition: dismec.cpp:173
virtual void update_minimizer(solvers::Minimizer &minimizer, label_id_t label_id) const =0
Updates the setting of the Minimizer for handling label label_id.
virtual std::shared_ptr< objective::Objective > make_objective() const =0
Makes an Objective object suitable for the dataset.
Strong typedef for an int to signify a label id.
Definition: types.h:20
Class that models an optimization objective.
Definition: objective.h:41
Forward-declares types.
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
std::shared_ptr< objective::Objective > make_loss(LossType type, std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< objective::Objective > regularizer)
Definition: dismec.cpp:41
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
Definition: spec.h:143
RegularizerType
Definition: spec.h:120
LossType
Definition: spec.h:129
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
Definition: dismec.cpp:157
std::shared_ptr< TrainingSpec > create_cascade_training(std::shared_ptr< const DatasetBase > data, std::shared_ptr< const GenericFeatureMatrix > dense, std::shared_ptr< const std::vector< std::vector< long >>> shortlist, HyperParameters params, CascadeTrainingConfig config)
Definition: cascade.cpp:161
float real_t
The default type for floating point values.
Definition: config.h:17
float real_t
Definition: regularizers.h:11
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Definition: spec.h:159
std::shared_ptr< init::WeightInitializationStrategy > DenseInit
Definition: spec.h:156
std::shared_ptr< init::WeightInitializationStrategy > SparseInit
Definition: spec.h:157
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
Definition: spec.h:158
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
Definition: spec.h:148
RegularizerSpec Regularizer
Definition: spec.h:151
std::shared_ptr< init::WeightInitializationStrategy > Init
Definition: spec.h:147
std::shared_ptr< WeightingScheme > Weighting
Definition: spec.h:146
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Definition: spec.h:149
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22