DiSMEC++
postproc.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "postproc.h"
7 #include "data/types.h"
8 #include "objective/objective.h"
9 #include "utils/hash_vector.h"
10 #include "spdlog/spdlog.h"
11 #include "postproc/generic.h"
12 
13 namespace dismec::postproc {
14  struct IdentityPostProc : public PostProcessor {
15  explicit IdentityPostProc(const std::shared_ptr<objective::Objective>&) {}
16  void process(label_id_t label_id,
17  Eigen::Ref<DenseRealVector> weight_vector,
18  solvers::MinimizationResult& result) override {};
19  };
20 
22  public:
23  CullingPostProcessor(const std::shared_ptr<objective::Objective>& objective, real_t eps);
24  void process(label_id_t label_id, Eigen::Ref<DenseRealVector> weight_vector, solvers::MinimizationResult& result) override;
25  private:
27  };
28 
30  Eigen::Ref<DenseRealVector> weight_vector,
32  for(long i = 0; i < weight_vector.size(); ++i) {
33  real_t& w = weight_vector.coeffRef(i);
34  if(abs(w) <= m_Epsilon) {
35  w = real_t{0};
36  }
37  }
38  }
39 
40  CullingPostProcessor::CullingPostProcessor(const std::shared_ptr<objective::Objective>& objective,
41  real_t eps) : m_Epsilon(eps) {
42  if(eps < 0) {
43  throw std::invalid_argument("Epsilon has to be positive");
44  }
45  }
46 }
47 
49 
50 std::shared_ptr<PostProcessFactory> dismec::postproc::create_identity() {
51  return std::make_shared<GenericPostProcFactory<IdentityPostProc>>();
52 }
53 
54 std::shared_ptr<PostProcessFactory> dismec::postproc::create_culling(real_t eps) {
55  return std::make_shared<GenericPostProcFactory<CullingPostProcessor, real_t>>( eps );
56 }
57 
Strong typedef for an int to signify a label id.
Definition: types.h:20
void process(label_id_t label_id, Eigen::Ref< DenseRealVector > weight_vector, solvers::MinimizationResult &result) override
Apply post-processing for the weight_vector corresponding to the label label_id.
Definition: postproc.cpp:29
CullingPostProcessor(const std::shared_ptr< objective::Objective > &objective, real_t eps)
Definition: postproc.cpp:40
FactoryPtr create_identity()
Definition: postproc.cpp:50
FactoryPtr create_culling(real_t eps)
Definition: postproc.cpp:54
float real_t
The default type for floating point values.
Definition: config.h:17
IdentityPostProc(const std::shared_ptr< objective::Objective > &)
Definition: postproc.cpp:15
void process(label_id_t label_id, Eigen::Ref< DenseRealVector > weight_vector, solvers::MinimizationResult &result) override
Apply post-processing for the weight_vector corresponding to the label label_id.
Definition: postproc.cpp:16