DiSMEC++
cascade.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_SRC_TRAINING_CASCADE_H
7 #define DISMEC_SRC_TRAINING_CASCADE_H
8 
9 #include "training.h"
10 #include "parallel/numa.h"
11 
12 namespace dismec {
13  class CascadeTraining : public TrainingSpec {
14  public:
15  CascadeTraining(std::shared_ptr<const DatasetBase> tfidf_data,
16  std::shared_ptr<const GenericFeatureMatrix> dense_data,
17  HyperParameters hyper_params,
18  std::shared_ptr<init::WeightInitializationStrategy> dense_init,
19  real_t dense_reg,
20  std::shared_ptr<init::WeightInitializationStrategy> sparse_init,
21  real_t sparse_reg,
22  std::shared_ptr<postproc::PostProcessFactory> post_proc,
23  std::shared_ptr<TrainingStatsGatherer> gatherer,
24  std::shared_ptr<const std::vector<std::vector<long>>> shortlist = nullptr);
25 
26  long num_features() const override { return m_NumFeatures; }
27 
28  [[nodiscard]] std::shared_ptr<objective::Objective> make_objective() const override;
29 
30  [[nodiscard]] std::unique_ptr<solvers::Minimizer> make_minimizer() const override;
31 
32  [[nodiscard]] std::unique_ptr<init::WeightsInitializer> make_initializer() const override;
33 
34  [[nodiscard]] std::shared_ptr<model::Model>
35  make_model(long num_features, model::PartialModelSpec spec) const override;
36 
37  void update_minimizer(solvers::Minimizer& base_minimizer, label_id_t label_id) const override;
38 
39  void update_objective(objective::Objective& base_objective, label_id_t label_id) const override;
40 
41  [[nodiscard]] std::unique_ptr<postproc::PostProcessor>
42  make_post_processor(const std::shared_ptr<objective::Objective>& objective) const override;
43 
45 
46  private:
48 
51 
52  std::shared_ptr<const std::vector<std::vector<long>>> m_Shortlist;
53 
54  // post processing
55  std::shared_ptr<postproc::PostProcessFactory> m_PostProcessor;
56 
57  // initial conditions
58  std::shared_ptr<init::WeightInitializationStrategy> m_DenseInitStrategy;
59  std::shared_ptr<init::WeightInitializationStrategy> m_SparseInitStrategy;
60 
61  std::shared_ptr<TrainingStatsGatherer> m_StatsGather;
62 
64  double m_BaseEpsilon;
65 
68  };
69 }
70 
71 #endif //DISMEC_SRC_TRAINING_CASCADE_H
long num_features() const override
Definition: cascade.h:26
std::shared_ptr< objective::Objective > make_objective() const override
Makes an Objective object suitable for the dataset.
Definition: cascade.cpp:38
HyperParameters m_NewtonSettings
Definition: cascade.h:47
std::unique_ptr< solvers::Minimizer > make_minimizer() const override
Makes a Minimizer object suitable for the dataset.
Definition: cascade.cpp:46
std::shared_ptr< init::WeightInitializationStrategy > m_DenseInitStrategy
Definition: cascade.h:58
std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const override
Creates the model that will be used to store the results.
Definition: cascade.cpp:119
std::unique_ptr< init::WeightsInitializer > make_initializer() const override
Makes a WeightsInitializer object.
Definition: cascade.cpp:109
std::shared_ptr< const std::vector< std::vector< long > > > m_Shortlist
Definition: cascade.h:52
CascadeTraining(std::shared_ptr< const DatasetBase > tfidf_data, std::shared_ptr< const GenericFeatureMatrix > dense_data, HyperParameters hyper_params, std::shared_ptr< init::WeightInitializationStrategy > dense_init, real_t dense_reg, std::shared_ptr< init::WeightInitializationStrategy > sparse_init, real_t sparse_reg, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, std::shared_ptr< const std::vector< std::vector< long >>> shortlist=nullptr)
Definition: cascade.cpp:132
TrainingStatsGatherer & get_statistics_gatherer() override
Definition: cascade.cpp:128
parallel::NUMAReplicator< const GenericFeatureMatrix > m_DenseReplicator
Definition: cascade.h:50
std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const override
Makes a PostProcessor object.
Definition: cascade.cpp:124
std::shared_ptr< postproc::PostProcessFactory > m_PostProcessor
Definition: cascade.h:55
std::shared_ptr< TrainingStatsGatherer > m_StatsGather
Definition: cascade.h:61
void update_minimizer(solvers::Minimizer &base_minimizer, label_id_t label_id) const override
Updates the setting of the Minimizer for handling label label_id.
Definition: cascade.cpp:53
void update_objective(objective::Objective &base_objective, label_id_t label_id) const override
Updates the setting of the Objective for handling label label_id.
Definition: cascade.cpp:80
parallel::NUMAReplicator< const GenericFeatureMatrix > m_SparseReplicator
Definition: cascade.h:49
std::shared_ptr< init::WeightInitializationStrategy > m_SparseInitStrategy
Definition: cascade.h:59
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
This class gathers the setting-specific parts of the training process.
Definition: spec.h:24
Strong typedef for an int to signify a label id.
Definition: types.h:20
Class that models an optimization objective.
Definition: objective.h:41
Helper class to ensure that each NUMA node has its own copy of some immutable data.
Definition: numa.h:72
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
float real_t
The default type for floating point values.
Definition: config.h:17
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22