DiSMEC++
pretrained.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "training/initializer.h"
7 #include "data/types.h"
8 #include "model/model.h"
9 
10 using namespace dismec::init;
11 
12 namespace dismec::init {
14  public:
15  explicit PreTrainedInitializer(std::shared_ptr<const model::Model> pre_trained) :
16  m_PreTrainedWeights(std::move(pre_trained))
17  {
18  if(!m_PreTrainedWeights) {
19  throw std::logic_error("pre trained model is <null>");
20  }
21  }
23  Eigen::Ref<DenseRealVector> target,
24  objective::Objective& objective) override {
25  m_PreTrainedWeights->get_weights_for_label(label_id, target);
26  }
27  private:
28  std::shared_ptr<const model::Model> m_PreTrainedWeights;
29  };
30 
32  public:
33  PreTrainedInitializationStrategy(std::shared_ptr<const model::Model> pre_trained);
34  [[nodiscard]] std::unique_ptr<WeightsInitializer> make_initializer(const std::shared_ptr<const GenericFeatureMatrix>& features) const override;
35  private:
36  std::shared_ptr<const model::Model> m_PreTrained;
37  };
38 }
39 
40 PreTrainedInitializationStrategy::PreTrainedInitializationStrategy(std::shared_ptr<const model::Model> pre_trained) :
41  m_PreTrained(std::move(pre_trained)) {
42 
43 }
44 std::unique_ptr<WeightsInitializer> PreTrainedInitializationStrategy::make_initializer(const std::shared_ptr<const GenericFeatureMatrix>& features) const {
45  return std::make_unique<PreTrainedInitializer>(m_PreTrained);
46 }
47 
48 std::shared_ptr<WeightInitializationStrategy> dismec::init::create_pretrained_initializer(std::shared_ptr<model::Model> model) {
49  return std::make_shared<PreTrainedInitializationStrategy>(std::move(model));
50 }
PreTrainedInitializationStrategy(std::shared_ptr< const model::Model > pre_trained)
Definition: pretrained.cpp:40
std::shared_ptr< const model::Model > m_PreTrained
Definition: pretrained.cpp:36
std::unique_ptr< WeightsInitializer > make_initializer(const std::shared_ptr< const GenericFeatureMatrix > &features) const override
Creats a new, thread local WeightsInitializer.
Definition: pretrained.cpp:44
PreTrainedInitializer(std::shared_ptr< const model::Model > pre_trained)
Definition: pretrained.cpp:15
std::shared_ptr< const model::Model > m_PreTrainedWeights
Definition: pretrained.cpp:28
void get_initial_weight(label_id_t label_id, Eigen::Ref< DenseRealVector > target, objective::Objective &objective) override
Generate an initial vector for the given label. The result should be placed in target.
Definition: pretrained.cpp:22
Base class for all weight init strategies.
Definition: initializer.h:53
Base class for all weight initializers.
Definition: initializer.h:30
Strong typedef for an int to signify a label id.
Definition: types.h:20
Class that models an optimization objective.
Definition: objective.h:41
std::shared_ptr< WeightInitializationStrategy > create_pretrained_initializer(std::shared_ptr< model::Model > model)
Creates an initialization strategy that uses an already trained model to set the initial weights.
Definition: pretrained.cpp:48