DiSMEC++
ova-primal.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "training/initializer.h"
7 #include "solver/newton.h"
8 #include "objective/linear.h"
9 #include "data/data.h"
10 
11 #include <spdlog/spdlog.h>
12 
13 using namespace dismec::init;
14 
15 std::shared_ptr<WeightInitializationStrategy> dismec::init::create_ova_primal_initializer(
16  const std::shared_ptr<DatasetBase>& data, RegularizerSpec regularizer, LossType loss) {
17  auto minimizer = std::make_unique<solvers::NewtonWithLineSearch>(data->num_features());
18  auto reg = std::visit([](auto&& config){ return make_regularizer(config); }, regularizer);
19  auto loss_fn = make_loss(loss, data->get_features(), std::move(reg));
20  dynamic_cast<objective::LinearClassifierBase&>(*loss_fn).get_label_ref().fill(-1);
21  //minimizer->set_epsilon(0.01 / data->num_examples());
22 
23  DenseRealVector target(data->num_features());
24  target.setZero();
25  spdlog::info("Starting to calculate OVA-Primal init vector");
26  auto result = minimizer->minimize(*loss_fn, target);
27 
28  spdlog::info("OVA-Primal init vector has been calculated in {} ms. Loss {} -> {}",
29  result.Duration.count(), result.InitialValue, result.FinalValue);
30 
31 
32  return create_constant_initializer(std::move(target));
33 }
Base class for objectives that use a linear classifier.
Definition: linear.h:27
BinaryLabelVector & get_label_ref()
Definition: linear.cpp:70
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
Definition: constant.cpp:56
std::shared_ptr< WeightInitializationStrategy > create_ova_primal_initializer(const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)
Definition: ova-primal.cpp:15
std::unique_ptr< Objective > make_regularizer(const SquaredNormConfig &config)
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
std::shared_ptr< objective::Objective > make_loss(LossType type, std::shared_ptr< const GenericFeatureMatrix > X, std::unique_ptr< objective::Objective > regularizer)
Definition: dismec.cpp:41
std::variant< objective::SquaredNormConfig, objective::HuberConfig, objective::ElasticConfig > RegularizerSpec
Definition: spec.h:143
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
LossType
Definition: spec.h:129