DiSMEC++
sparsify.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include <utility>
7 
8 #include "training/postproc.h"
9 #include "data/types.h"
10 #include "solver/minimizer.h"
11 #include "utils/hash_vector.h"
13 #include "stats/collection.h"
14 #include "stats/timer.h"
15 
16 namespace {
18  constexpr stat_id_t STAT_CUTOFF{0};
19  constexpr stat_id_t STAT_NNZ{1};
22  constexpr stat_id_t STAT_DURATION{4};
23 };
24 
25 
26 namespace dismec::postproc {
27  class Sparsify : public PostProcessor {
28  public:
29  Sparsify(std::shared_ptr<objective::Objective> objective, real_t tolerance) :
30  m_Objective(std::move(objective)),
31  m_Tolerance(tolerance),
32  m_WorkingVector(DenseRealVector(m_Objective->num_variables())) {
33 
34  declare_stat(STAT_CUTOFF, {"cutoff", {}});
35  declare_stat(STAT_NNZ, {"nnz", "%"});
36  declare_stat(STAT_BINARY_SEARCH_STEPS, {"binary_search_steps", {}});
37  declare_stat(STAT_INITIAL_STEPS, {"initial_steps", {}});
38  declare_stat(STAT_DURATION, {"duration", "µs"});
39  }
40  private:
41  void process(label_id_t label_id, Eigen::Ref<DenseRealVector> weight_vector, solvers::MinimizationResult& result) override;
42 
43  std::shared_ptr<objective::Objective> m_Objective;
46 
47  static int make_sparse(Eigen::Ref<DenseRealVector> target, const Eigen::Ref<const DenseRealVector>& source, real_t cutoff) {
48  int nnz = 0;
49  for(int i = 0; i < target.size(); ++i) {
50  auto w_i = source.coeff(i);
51  bool is_small = abs(w_i) < cutoff;
52  target.coeffRef(i) = is_small ? 0 : w_i;
53  if(!is_small) ++nnz;
54  }
55  return nnz;
56  }
57 
58  struct BoundData {
60  long NNZ;
62  };
63 
67  };
68 
69 
70 
71  UpperBoundResult find_initial_bounds(Eigen::Ref<DenseRealVector> weight_vector, real_t tolerance, real_t initial_lower);
72 
74  real_t m_SumLogVal = std::log(0.02);
75  real_t m_SumSqrLog = std::log(0.02) * std::log(0.02);
76  };
77 
78 
79 
80  void Sparsify::process(label_id_t label_id, Eigen::Ref<DenseRealVector> weight_vector, solvers::MinimizationResult& result) {
81  auto timer = make_timer(STAT_DURATION);
82  m_WorkingVector = weight_vector;
83  real_t tolerance = (1 + m_Tolerance) * result.FinalValue + real_t{1e-5};
84 
85  auto [lower, upper] = find_initial_bounds(weight_vector, tolerance, result.FinalValue);
86 
87  // now we can do a binary search
88  int count = 0;
89  while( (lower.NNZ - upper.NNZ) > upper.NNZ / 10 + 1 ) {
90  real_t middle = (upper.Cutoff + lower.Cutoff) / 2;
91  int nnz = make_sparse(m_WorkingVector.modify(), weight_vector, middle);
92  auto new_score = m_Objective->value(m_WorkingVector);
93  if(new_score > tolerance) {
94  upper.Cutoff = middle;
95  upper.NNZ = nnz;
96  upper.Loss = new_score;
97  } else {
98  lower.Cutoff = middle;
99  lower.NNZ = nnz;
100  lower.Loss = new_score;
101  }
102  ++count;
103  }
104  record(STAT_BINARY_SEARCH_STEPS, count);
105 
106  // finally, apply the culling to the actual weight vector
107  int nnz = make_sparse(weight_vector, weight_vector, lower.Cutoff);
108 
109  m_NumValues += 1;
110  real_t log_cutoff = std::log(lower.Cutoff);
111  m_SumLogVal += log_cutoff;
112  m_SumSqrLog += log_cutoff*log_cutoff;
113 
114  record(STAT_CUTOFF, lower.Cutoff);
115  record(STAT_NNZ, float(100 * nnz) / weight_vector.size());
116  }
117 
118  Sparsify::UpperBoundResult Sparsify::find_initial_bounds(Eigen::Ref<DenseRealVector> weight_vector, real_t tolerance, real_t initial_lower)
119  {
120  real_t mean_log = m_SumLogVal / m_NumValues;
121  real_t std_log = std::sqrt(m_SumSqrLog / m_NumValues - mean_log*mean_log + real_t{1e-5});
122 
123  int step_count = 0;
124 
125  auto check_bound = [&](real_t log_cutoff) {
126  real_t cutoff = std::exp(log_cutoff);
127  int nnz = make_sparse(m_WorkingVector.modify(), weight_vector, cutoff);
128  auto score = m_Objective->value(m_WorkingVector);
129  ++step_count;
130  return BoundData{cutoff, nnz, score};
131  };
132 
133  // we assume that [exp(mean_log - 2std_var), exp(mean_log + 2std_var)] is a good interval
134  auto at_mean = check_bound( mean_log );
135  if(at_mean.Loss > tolerance) {
136  // ok, mean is an upper bound
137  // let's try the lower bound then
138  BoundData minus_std = check_bound(mean_log - std_log);
139  if(minus_std.Loss > tolerance) {
140  record(STAT_INITIAL_STEPS, step_count);
141  return {{0, weight_vector.size(), initial_lower}, minus_std};
142  }
143  record(STAT_INITIAL_STEPS, step_count);
144  return {minus_std, at_mean};
145  }
146 
147  // ok, mean is a lower bound
148  BoundData plus_std = check_bound(mean_log + std_log);
149  if(plus_std.Loss > tolerance) {
150  record(STAT_INITIAL_STEPS, step_count);
151  return {at_mean, plus_std};
152  }
153 
154  // one more naive trial:
155  BoundData plus_3_std = check_bound(mean_log + 3 * std_log);
156  if(plus_3_std.Loss > tolerance) {
157  record(STAT_INITIAL_STEPS, step_count);
158  return {plus_std, plus_3_std};
159  }
160 
161  BoundData at_max = check_bound( std::log(weight_vector.maxCoeff()) );
162  record(STAT_INITIAL_STEPS, step_count);
163  return {plus_3_std, at_max};
164  }
165 }
166 
167 std::shared_ptr<dismec::postproc::PostProcessFactory> dismec::postproc::create_sparsify(real_t tolerance) {
168  return std::make_shared<GenericPostProcFactory<Sparsify, real_t>>(tolerance);
169 }
An Eigen vector with versioning information, to implement simple caching of results.
Definition: hash_vector.h:43
Strong typedef for an int to signify a label id.
Definition: types.h:20
void process(label_id_t label_id, Eigen::Ref< DenseRealVector > weight_vector, solvers::MinimizationResult &result) override
Apply post-processing for the weight_vector corresponding to the label label_id.
Definition: sparsify.cpp:80
UpperBoundResult find_initial_bounds(Eigen::Ref< DenseRealVector > weight_vector, real_t tolerance, real_t initial_lower)
Definition: sparsify.cpp:118
std::shared_ptr< objective::Objective > m_Objective
Definition: sparsify.cpp:43
static int make_sparse(Eigen::Ref< DenseRealVector > target, const Eigen::Ref< const DenseRealVector > &source, real_t cutoff)
Definition: sparsify.cpp:47
Sparsify(std::shared_ptr< objective::Objective > objective, real_t tolerance)
Definition: sparsify.cpp:29
auto make_timer(stat_id_t id, Args... args)
Creates a new ScopeTimer using stats::record_scope_time.
Definition: tracked.h:130
void declare_stat(stat_id_t index, StatisticMetaData meta)
Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCo...
Definition: tracked.cpp:16
constexpr stat_id_t STAT_DURATION
Definition: sparsify.cpp:22
constexpr stat_id_t STAT_CUTOFF
Definition: sparsify.cpp:18
constexpr stat_id_t STAT_BINARY_SEARCH_STEPS
Definition: sparsify.cpp:20
constexpr stat_id_t STAT_NNZ
Definition: sparsify.cpp:19
constexpr stat_id_t STAT_INITIAL_STEPS
Definition: sparsify.cpp:21
FactoryPtr create_sparsify(real_t tolerance)
Definition: sparsify.cpp:167
opaque_int_type< detail::stat_id_tag > stat_id_t
An opaque int-like type that is used to identify a statistic in a StatisticsCollection.
Definition: stat_id.h:24
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17
float real_t
Definition: regularizers.h:11