DiSMEC++
dismec.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_DISMEC_H
7 #define DISMEC_DISMEC_H
8 
9 
10 #include "spec.h"
11 #include "parallel/numa.h"
12 #include "utils/hyperparams.h"
13 
14 namespace dismec
15 {
32  class DiSMECTraining : public TrainingSpec {
33  public:
40  DiSMECTraining(std::shared_ptr<const DatasetBase> data, HyperParameters hyper_params,
41  std::shared_ptr<WeightingScheme> weighting,
42  std::shared_ptr<init::WeightInitializationStrategy> init,
43  std::shared_ptr<postproc::PostProcessFactory> post_proc,
44  std::shared_ptr<TrainingStatsGatherer> gatherer,
45  bool use_sparse,
46  RegularizerSpec regularizer, LossType loss);
47 
48  [[nodiscard]] std::shared_ptr<objective::Objective> make_objective() const override;
49  [[nodiscard]] std::unique_ptr<solvers::Minimizer> make_minimizer() const override;
50  [[nodiscard]] std::unique_ptr<init::WeightsInitializer> make_initializer() const override;
51  [[nodiscard]] std::shared_ptr<model::Model> make_model(long num_features, model::PartialModelSpec spec) const override;
52 
53  void update_minimizer(solvers::Minimizer& base_minimizer, label_id_t label_id) const override;
54  void update_objective(objective::Objective& base_objective, label_id_t label_id) const override;
55 
56  [[nodiscard]] std::unique_ptr<postproc::PostProcessor> make_post_processor(const std::shared_ptr<objective::Objective>& objective) const override;
57 
59  private:
61  std::shared_ptr<WeightingScheme> m_Weighting;
62  bool m_UseSparseModel = false;
63 
64  // initial conditions
65  std::shared_ptr<init::WeightInitializationStrategy> m_InitStrategy;
66 
67  // post processing
68  std::shared_ptr<postproc::PostProcessFactory> m_PostProcessor;
69 
71 
72  std::shared_ptr<TrainingStatsGatherer> m_StatsGather;
73 
74  double m_BaseEpsilon;
77  };
78 }
79 
80 #endif //DISMEC_DISMEC_H
An implementation of TrainingSpec that models the DiSMEC algorithm.
Definition: dismec.h:32
parallel::NUMAReplicator< const GenericFeatureMatrix > m_FeatureReplicator
Definition: dismec.h:70
std::shared_ptr< objective::Objective > make_objective() const override
Makes an Objective object suitable for the dataset.
Definition: dismec.cpp:64
std::shared_ptr< WeightingScheme > m_Weighting
Definition: dismec.h:61
HyperParameters m_NewtonSettings
Definition: dismec.h:60
void update_objective(objective::Objective &base_objective, label_id_t label_id) const override
Updates the setting of the Objective for handling label label_id.
Definition: dismec.cpp:122
double m_BaseEpsilon
Definition: dismec.h:74
TrainingStatsGatherer & get_statistics_gatherer() override
Definition: dismec.cpp:152
std::shared_ptr< postproc::PostProcessFactory > m_PostProcessor
Definition: dismec.h:68
std::shared_ptr< init::WeightInitializationStrategy > m_InitStrategy
Definition: dismec.h:65
DiSMECTraining(std::shared_ptr< const DatasetBase > data, HyperParameters hyper_params, std::shared_ptr< WeightingScheme > weighting, std::shared_ptr< init::WeightInitializationStrategy > init, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, bool use_sparse, RegularizerSpec regularizer, LossType loss)
Creates a DiSMECTraining instance.
Definition: dismec.cpp:90
void update_minimizer(solvers::Minimizer &base_minimizer, label_id_t label_id) const override
Updates the setting of the Minimizer for handling label label_id.
Definition: dismec.cpp:77
std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const override
Creates the model that will be used to store the results.
Definition: dismec.cpp:140
std::unique_ptr< solvers::Minimizer > make_minimizer() const override
Makes a Minimizer object suitable for the dataset.
Definition: dismec.cpp:71
std::shared_ptr< TrainingStatsGatherer > m_StatsGather
Definition: dismec.h:72
RegularizerSpec m_Regularizer
Definition: dismec.h:75
std::unique_ptr< init::WeightsInitializer > make_initializer() const override
Makes a WeightsInitializer object.
Definition: dismec.cpp:136
std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const override
Makes a PostProcessor object.
Definition: dismec.cpp:148
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
This class gathers the setting-specific parts of the training process.
Definition: spec.h:24
virtual long num_features() const
Definition: dismec.cpp:173
Strong typedef for an int to signify a label id.
Definition: types.h:20
Class that models an optimization objective.
Definition: objective.h:41
Helper class to ensure that each NUMA node has its own copy of some immutable data.
Definition: numa.h:72
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
Definition: spec.h:143
LossType
Definition: spec.h:129
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22