DiSMEC++
combine.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "training/postproc.h"
7 #include "data/types.h"
8 
9 namespace dismec::postproc {
11  public:
12  explicit CombinePostProcessor(std::vector<std::unique_ptr<PostProcessor>> children);
13  void process(label_id_t label_id, Eigen::Ref<DenseRealVector> weight_vector, solvers::MinimizationResult& result) override;
14  private:
15  std::vector<std::unique_ptr<PostProcessor>> m_Children;
16  };
17 
18  CombinePostProcessor::CombinePostProcessor(std::vector<std::unique_ptr<PostProcessor>> children) :
19  m_Children(std::move(children)) {
20 
21  }
22 
24  Eigen::Ref<DenseRealVector> weight_vector,
26  for(auto& child : m_Children) {
27  child->process(label_id, weight_vector, result);
28  }
29  }
30 
31 
33  public:
34  explicit CombinedFactory(std::vector<std::shared_ptr<PostProcessFactory>> children) :
35  m_Children(std::move(children)) {
36 
37  }
38 
39  [[nodiscard]] std::unique_ptr<PostProcessor>
40  make_processor(const std::shared_ptr<objective::Objective>& objective) const override {
41  std::vector<std::unique_ptr<PostProcessor>> children;
42  children.reserve(m_Children.size());
43  std::transform(begin(m_Children), end(m_Children), std::back_inserter(children),
44  [&](auto&& factory) {
45  return factory->make_processor(objective);
46  });
47  return std::make_unique<CombinePostProcessor>(std::move(children));
48  }
49 
50  std::vector<std::shared_ptr<PostProcessFactory>> m_Children;
51  };
52 }
53 
54 std::shared_ptr<dismec::postproc::PostProcessFactory> dismec::postproc::create_combined(std::vector<std::shared_ptr<PostProcessFactory>> children) {
55  return std::make_shared<CombinedFactory>( std::move(children) );
56 }
Strong typedef for an int to signify a label id.
Definition: types.h:20
CombinePostProcessor(std::vector< std::unique_ptr< PostProcessor >> children)
Definition: combine.cpp:18
std::vector< std::unique_ptr< PostProcessor > > m_Children
Definition: combine.cpp:15
void process(label_id_t label_id, Eigen::Ref< DenseRealVector > weight_vector, solvers::MinimizationResult &result) override
Apply post-processing for the weight_vector corresponding to the label label_id.
Definition: combine.cpp:23
std::unique_ptr< PostProcessor > make_processor(const std::shared_ptr< objective::Objective > &objective) const override
Definition: combine.cpp:40
CombinedFactory(std::vector< std::shared_ptr< PostProcessFactory >> children)
Definition: combine.cpp:34
std::vector< std::shared_ptr< PostProcessFactory > > m_Children
Definition: combine.cpp:50
FactoryPtr create_combined(std::vector< FactoryPtr > processor)