DiSMEC++
subset.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_SUBSET_H
7 #define DISMEC_SUBSET_H
8 
9 #include "training/initializer.h"
10 
11 namespace dismec::init {
13  public:
14  SubsetFeatureMeanInitializer(std::shared_ptr<const DatasetBase> data,
15  const DenseRealVector& mean_of_all,
16  std::shared_ptr<const GenericFeatureMatrix> local_features, real_t pos, real_t neg);
17 
18  protected:
19  std::shared_ptr<const DatasetBase> m_DataSet;
20  std::shared_ptr<const GenericFeatureMatrix> m_LocalFeatures;
21 
25 
28 
29  static constexpr stats::stat_id_t STAT_DURATION{0};
30 
31  std::pair<real_t, real_t> calculate_factors(
32  label_id_t label_id,
33  const Eigen::Ref<DenseRealVector>& mean_of_positives);
34  };
35 
37  public:
38  SubsetFeatureMeanStrategy(std::shared_ptr<const DatasetBase> data, real_t negative_target,
39  real_t positive_target);
40 
41  protected:
42  std::shared_ptr<const DatasetBase> m_DataSet;
46  };
47 }
48 
49 #endif //DISMEC_SUBSET_H
std::shared_ptr< const DatasetBase > m_DataSet
Definition: subset.h:19
std::shared_ptr< const GenericFeatureMatrix > m_LocalFeatures
Definition: subset.h:20
std::pair< real_t, real_t > calculate_factors(label_id_t label_id, const Eigen::Ref< DenseRealVector > &mean_of_positives)
Definition: subset.cpp:43
static constexpr stats::stat_id_t STAT_DURATION
Definition: subset.h:29
SubsetFeatureMeanInitializer(std::shared_ptr< const DatasetBase > data, const DenseRealVector &mean_of_all, std::shared_ptr< const GenericFeatureMatrix > local_features, real_t pos, real_t neg)
Definition: subset.cpp:18
SubsetFeatureMeanStrategy(std::shared_ptr< const DatasetBase > data, real_t negative_target, real_t positive_target)
Definition: subset.cpp:76
std::shared_ptr< const DatasetBase > m_DataSet
Definition: subset.h:42
DenseRealVector m_MeanOfAllInstances
Definition: subset.h:43
Base class for all weight init strategies.
Definition: initializer.h:53
Base class for all weight initializers.
Definition: initializer.h:30
Strong typedef for an int to signify a label id.
Definition: types.h:20
types::DenseVector< std::int8_t > BinaryLabelVector
Dense vector for storing binary labels.
Definition: matrix_types.h:68
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17