DiSMEC++
constant.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "training/initializer.h"
7 #include "data/types.h"
8 
9 using namespace dismec::init;
10 
11 namespace dismec::init {
13  public:
14  explicit ConstantInitializer(std::shared_ptr<const DenseRealVector> vec) : m_InitVector(std::move(vec)) {
15  if (!m_InitVector) {
16  throw std::logic_error("Initial vector is <null>");
17  }
18  }
19 
20  void get_initial_weight(label_id_t label_id, Eigen::Ref <DenseRealVector> target,
21  objective::Objective &objective) override {
22  target = *m_InitVector;
23  }
24 
25  private:
26  std::shared_ptr<const DenseRealVector> m_InitVector;
27  };
28 
35  public:
37 
38  [[nodiscard]] std::unique_ptr <WeightsInitializer>
39  make_initializer(const std::shared_ptr<const GenericFeatureMatrix>& features) const override;
40 
41  private:
43  };
44 
45 }
46 
48  m_InitVector(std::make_shared<DenseRealVector>(std::move(vec))) {
49 }
50 
51 std::unique_ptr<WeightsInitializer> ConstantInitializationStrategy::make_initializer(
52  const std::shared_ptr<const GenericFeatureMatrix>& features) const {
53  return std::make_unique<ConstantInitializer>(m_InitVector.get_local() );
54 }
55 
56 std::shared_ptr<WeightInitializationStrategy> dismec::init::create_constant_initializer(DenseRealVector vec) {
57  return std::make_shared<ConstantInitializationStrategy>(std::move(vec));
58 }
An initialization strategy that sets the weight vector to a given constant.
Definition: constant.cpp:34
ConstantInitializationStrategy(DenseRealVector vec)
Definition: constant.cpp:47
parallel::NUMAReplicator< DenseRealVector > m_InitVector
Definition: constant.cpp:42
std::unique_ptr< WeightsInitializer > make_initializer(const std::shared_ptr< const GenericFeatureMatrix > &features) const override
Creats a new, thread local WeightsInitializer.
Definition: constant.cpp:51
std::shared_ptr< const DenseRealVector > m_InitVector
Definition: constant.cpp:26
void get_initial_weight(label_id_t label_id, Eigen::Ref< DenseRealVector > target, objective::Objective &objective) override
Generate an initial vector for the given label. The result should be placed in target.
Definition: constant.cpp:20
ConstantInitializer(std::shared_ptr< const DenseRealVector > vec)
Definition: constant.cpp:14
Base class for all weight init strategies.
Definition: initializer.h:53
Base class for all weight initializers.
Definition: initializer.h:30
Strong typedef for an int to signify a label id.
Definition: types.h:20
Class that models an optimization objective.
Definition: objective.h:41
std::shared_ptr< const T > get_local() const
Definition: numa.h:82
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
Definition: constant.cpp:56
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40