43 std::shared_ptr<const GenericFeatureMatrix> X,
44 std::unique_ptr<objective::Objective> reg) {
48 return std::make_shared<objective::Regularized_SquaredHingeSVC>(X, std::move(reg));
72 auto minimizer = std::make_unique<solvers::NewtonWithLineSearch>(
num_features());
81 throw std::logic_error(
"Could not cast minimizer to <NewtonWithLineSearch>");
85 double small_count =
static_cast<double>(std::min(num_pos,
get_data().num_examples() - num_pos));
92 std::shared_ptr<WeightingScheme> weighting,
93 std::shared_ptr<init::WeightInitializationStrategy> init,
94 std::shared_ptr<postproc::PostProcessFactory> post_proc,
95 std::shared_ptr<TrainingStatsGatherer> gatherer,
100 m_NewtonSettings( std::move(hyper_params) ),
101 m_Weighting( std::move(weighting) ),
102 m_UseSparseModel( use_sparse ),
103 m_InitStrategy( std::move(init) ),
104 m_PostProcessor( std::move(post_proc) ),
106 m_StatsGather( std::move(gatherer) ),
107 m_Regularizer( regularizer ),
111 throw std::invalid_argument(
"Missing weight initialization strategy");
115 throw std::invalid_argument(
"Missing weight post processor");
125 throw std::logic_error(
"Could not cast objective to <LinearClassifierBase>");
142 return std::make_shared<model::SparseModel>(
num_features, spec);
144 return std::make_shared<model::DenseModel>(
num_features, spec);
164 return std::make_shared<DiSMECTraining>(std::move(data), std::move(params), std::move(config.
Weighting),
165 std::move(config.
Init),
long num_examples() const noexcept
Get the total number of instances, i.e. the number of rows in the feature matrix.
std::shared_ptr< const BinaryLabelVector > get_labels(label_id_t id) const
virtual long num_positives(label_id_t id) const
long num_features() const noexcept
Get the total number of features, i.e. the number of columns in the feature matrix.
parallel::NUMAReplicator< const GenericFeatureMatrix > m_FeatureReplicator
std::shared_ptr< objective::Objective > make_objective() const override
Makes an Objective object suitable for the dataset.
std::shared_ptr< WeightingScheme > m_Weighting
HyperParameters m_NewtonSettings
void update_objective(objective::Objective &base_objective, label_id_t label_id) const override
Updates the setting of the Objective for handling label label_id.
TrainingStatsGatherer & get_statistics_gatherer() override
std::shared_ptr< postproc::PostProcessFactory > m_PostProcessor
std::shared_ptr< init::WeightInitializationStrategy > m_InitStrategy
DiSMECTraining(std::shared_ptr< const DatasetBase > data, HyperParameters hyper_params, std::shared_ptr< WeightingScheme > weighting, std::shared_ptr< init::WeightInitializationStrategy > init, std::shared_ptr< postproc::PostProcessFactory > post_proc, std::shared_ptr< TrainingStatsGatherer > gatherer, bool use_sparse, RegularizerSpec regularizer, LossType loss)
Creates a DiSMECTraining instance.
void update_minimizer(solvers::Minimizer &base_minimizer, label_id_t label_id) const override
Updates the setting of the Minimizer for handling label label_id.
std::shared_ptr< model::Model > make_model(long num_features, model::PartialModelSpec spec) const override
Creates the model that will be used to store the results.
std::unique_ptr< solvers::Minimizer > make_minimizer() const override
Makes a Minimizer object suitable for the dataset.
std::shared_ptr< TrainingStatsGatherer > m_StatsGather
RegularizerSpec m_Regularizer
std::unique_ptr< init::WeightsInitializer > make_initializer() const override
Makes a WeightsInitializer object.
std::unique_ptr< postproc::PostProcessor > make_post_processor(const std::shared_ptr< objective::Objective > &objective) const override
Makes a PostProcessor object.
This class represents a set of hyper-parameters.
hyper_param_t get(const std::string &name) const
Gets the hyper-parameter with the given name, or throws if it does not exist.
void apply(HyperParameterBase &target) const
This class gathers the setting-specific parts of the training process.
const DatasetBase & get_data() const
virtual long num_features() const
Strong typedef for an int to signify a label id.
Base class for objectives that use a linear classifier.
Class that models an optimization objective.
auto get_features(const DatasetBase &ds)
std::shared_ptr< WeightInitializationStrategy > create_zero_initializer()
Creates an initialization strategy that initializes all weight vectors to zero.
std::unique_ptr< Objective > make_regularizer(const SquaredNormConfig &config)
std::unique_ptr< GenericLinearClassifier > make_huber_hinge(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer, real_t epsilon)
std::unique_ptr< GenericLinearClassifier > make_logistic_loss(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
std::unique_ptr< GenericLinearClassifier > make_squared_hinge(std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< Objective > regularizer)
FactoryPtr create_identity()
auto visit(F &&f, Variants &&... variants)
Main namespace in which all types, classes, and functions are defined.
std::shared_ptr< objective::Objective > make_loss(LossType type, std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< objective::Objective > regularizer)
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
std::shared_ptr< postproc::PostProcessFactory > PostProcessing
RegularizerSpec Regularizer
std::shared_ptr< init::WeightInitializationStrategy > Init
std::shared_ptr< WeightingScheme > Weighting
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Specifies how to interpret a weight matrix for a partial model.
#define THROW_EXCEPTION(exception_type,...)