DiSMEC++
subset.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "subset.h"
7 #include "stats/collection.h"
8 #include "stats/timer.h"
9 #include "data/types.h"
10 #include "data/data.h"
11 #include "data/transform.h"
12 #include "objective/objective.h"
13 #include "utils/hash_vector.h"
14 #include <limits>
15 
16 using namespace dismec::init;
17 
19  std::shared_ptr<const DatasetBase> data,
20  const DenseRealVector& mean_of_all,
21  std::shared_ptr<const GenericFeatureMatrix> local_features,
22  real_t pos, real_t neg) :
23  m_DataSet(std::move(data)), m_LocalFeatures(std::move(local_features)),
24  m_MeanOfAll(DenseRealVector::Zero(1)), m_PosTarget(pos), m_NegTarget(neg)
25 {
26  if(!m_DataSet) {
27  throw std::logic_error("dataset is <null>");
28  }
29  if(!m_LocalFeatures) {
30  throw std::logic_error("local features are <null>");
31  }
32 
33  m_LabelBuffer.resize(m_DataSet->num_examples());
34  m_MeanOfAll = mean_of_all;
35 
36  m_MeanAllNormSquared = m_MeanOfAll.squaredNorm();
37 
38  declare_stat(STAT_DURATION, {"duration", "µs"});
39 }
40 
41 
42 
44  label_id_t label_id,
45  const Eigen::Ref<DenseRealVector>& mean_of_positives)
46 {
47 
48  real_t num_pos = m_DataSet->num_positives(label_id);
49  real_t PP = mean_of_positives.squaredNorm();
50  real_t PA = mean_of_positives.dot(m_MeanOfAll);
51  real_t p = num_pos / m_DataSet->num_examples();
52 
53  real_t divide = PA*PA - PP * m_MeanAllNormSquared;
54  // TODO spend some more time thinking about numerical stability here
55  if(std::abs(PA) < std::numeric_limits<real_t>::epsilon() ) {
56  if(std::abs(PA) < std::numeric_limits<real_t>::epsilon() ) {
57  return {real_t{0}, real_t{-1.f}};
58  }
59  return {m_PosTarget / PP, 0};
60  }
61 
62  // not sure under which situations this may happen, so we're just going with a simple heuristic here
63  if(std::abs(divide) < std::numeric_limits<real_t>::epsilon()) {
64  spdlog::warn("Cannot use initialization procedure, mean vectors are not linearly independent.");
65  return {real_t{0}, real_t{-1.f}};
66  }
67 
68  // otherwise, do a real calculation
70  real_t u = (f * PA - m_PosTarget * m_MeanAllNormSquared) / divide;
71  real_t v = (m_PosTarget - u * PP) / PA;
72 
73  return {u, v};
74 }
75 
76 SubsetFeatureMeanStrategy::SubsetFeatureMeanStrategy(std::shared_ptr<const DatasetBase> data, real_t positive_target,
77  real_t negative_target) :
78  m_DataSet(std::move(data)),
79  m_NegativeTarget(negative_target),
80  m_PositiveTarget(positive_target) {
81  if(!m_DataSet) {
82  throw std::logic_error("dataset is <null>");
83  }
85 }
std::shared_ptr< const DatasetBase > m_DataSet
Definition: subset.h:19
std::shared_ptr< const GenericFeatureMatrix > m_LocalFeatures
Definition: subset.h:20
std::pair< real_t, real_t > calculate_factors(label_id_t label_id, const Eigen::Ref< DenseRealVector > &mean_of_positives)
Definition: subset.cpp:43
static constexpr stats::stat_id_t STAT_DURATION
Definition: subset.h:29
SubsetFeatureMeanInitializer(std::shared_ptr< const DatasetBase > data, const DenseRealVector &mean_of_all, std::shared_ptr< const GenericFeatureMatrix > local_features, real_t pos, real_t neg)
Definition: subset.cpp:18
SubsetFeatureMeanStrategy(std::shared_ptr< const DatasetBase > data, real_t negative_target, real_t positive_target)
Definition: subset.cpp:76
std::shared_ptr< const DatasetBase > m_DataSet
Definition: subset.h:42
DenseRealVector m_MeanOfAllInstances
Definition: subset.h:43
Strong typedef for an int to signify a label id.
Definition: types.h:20
void declare_stat(stat_id_t index, StatisticMetaData meta)
Declares a new statistics. This function just forwards all its arguments to the internal StatisticsCo...
Definition: tracked.cpp:16
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
DenseRealVector get_mean_feature(const GenericFeatureMatrix &features)
Definition: transform.cpp:52
float real_t
The default type for floating point values.
Definition: config.h:17