DiSMEC++
metrics.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_POSTPROCESSING_H
7 #define DISMEC_POSTPROCESSING_H
8 
9 #include <vector>
10 #include <atomic>
11 #include "matrix_types.h"
12 #include "data/types.h"
13 #include "evaluate.h"
14 #include "utils/sum.h"
15 #include "utils/confusion_matrix.h"
16 
17 namespace dismec::prediction {
19 
29  public:
30  using gt_info_vec = std::vector<sTrueLabelInfo>;
31  using pd_info_vec = std::vector<sPredLabelInfo>;
32 
36  virtual ~MetricCollectionInterface() = default;
38  [[nodiscard]] long num_labels() const { return m_NumLabels; }
39 
40  virtual void update(const pd_info_vec& prediction, const gt_info_vec& labels) = 0;
41  virtual void reduce(const MetricCollectionInterface& other) = 0;
42  [[nodiscard]] virtual std::unique_ptr<MetricCollectionInterface> clone() const = 0;
43  private:
45  };
46 
48  public:
49  ConfusionMatrixRecorder(long num_labels, long k);
50  void update(const pd_info_vec& prediction, const gt_info_vec& labels) override;
51  void reduce(const MetricCollectionInterface& other) override;
52  [[nodiscard]] std::unique_ptr<MetricCollectionInterface> clone() const override;
53 
54  [[nodiscard]] long get_k() const { return m_K; }
55  [[nodiscard]] ConfusionMatrix get_confusion_matrix(label_id_t label) const;
56  private:
57  long m_K;
58  long m_InstanceCount = 0;
59  std::vector<ConfusionMatrix> m_Confusion;
60  };
61 
63  public:
64  explicit InstanceAveragedMetric(long num_labels);
65  void reduce(const MetricCollectionInterface& other) override;
66 
67  [[nodiscard]] double value() const {
68  if(m_NumSamples == 0) return 0.0;
69  return m_Accumulator.value() / static_cast<double>(m_NumSamples);
70  }
71  protected:
72  void accumulate(double value);
73  private:
75  long m_NumSamples = 0;
76  };
77 
78 
80  public:
81  InstanceRankedPositives(long num_labels, long k, bool normalize=false);
82  InstanceRankedPositives(long num_labels, long k, bool normalize, std::vector<double> weights);
83  void update(const pd_info_vec& prediction, const gt_info_vec& labels) override;
84  [[nodiscard]] std::unique_ptr<MetricCollectionInterface> clone() const override;
85  private:
86  long m_K;
88  std::vector<double> m_Weights;
89  std::vector<double> m_Cumulative;
90  };
91 
93  public:
94  explicit AbandonmentAtK(long num_labels, long k);
95  void update(const pd_info_vec& prediction, const gt_info_vec& labels) override;
96  [[nodiscard]] std::unique_ptr<MetricCollectionInterface> clone() const override;
97  private:
98  long m_K;
99  };
100 
102  public:
103  virtual ~MetricReportInterface() = default;
104 
105  using metric_t = std::pair<std::string, double>;
106  [[nodiscard]] virtual std::vector<metric_t> get_values() const = 0;
107  };
108 
110  public:
111  InstanceWiseMetricReporter(std::string name, const InstanceAveragedMetric* metric);
112  [[nodiscard]] std::vector<metric_t> get_values() const override;
113  private:
114  std::string m_Name;
116  };
117 
119  public:
120  explicit MacroMetricReporter(const ConfusionMatrixRecorder* confusion);
121  [[nodiscard]] std::vector<metric_t> get_values() const override;
122 
125  };
126 
127  void add_coverage(double threshold, std::string name={});
128  void add_precision(ReductionType reduction = MACRO, std::string name={});
129  void add_accuracy(ReductionType reduction = MACRO, std::string name={});
130  void add_specificity(ReductionType reduction = MACRO, std::string name={});
131  void add_balanced_accuracy(ReductionType reduction = MACRO, std::string name={});
132  void add_informedness(ReductionType reduction = MACRO, std::string name={});
133  void add_markedness(ReductionType reduction = MACRO, std::string name={});
134  void add_recall(ReductionType reduction = MACRO, std::string name={});
135  void add_fowlkes_mallows(ReductionType reduction = MACRO, std::string name={});
136  void add_negative_predictive_value(ReductionType reduction = MACRO, std::string name={});
137  void add_matthews(ReductionType reduction = MACRO, std::string name={});
138  void add_positive_likelihood_ratio(ReductionType reduction = MACRO, std::string name={});
139  void add_negative_likelihood_ratio(ReductionType reduction = MACRO, std::string name={});
140  void add_diagnostic_odds_ratio(ReductionType reduction = MACRO, std::string name={});
141  void add_f_measure(ReductionType reduction = MACRO, double beta = 1.0, std::string name={});
142  void add_confusion_matrix();
143 
144  void add_reduction(std::string name, ReductionType type, std::function<double(const ConfusionMatrix&)>);
145  private:
146  void add_reduction_helper(std::string name, const char* pattern, ReductionType type,
147  std::function<double(const ConfusionMatrix&)> fn);
148  using reduction_fn = std::function<double(const ConfusionMatrix&)>;
149  std::vector<std::pair<std::string, reduction_fn>> m_MacroReductions;
150  std::vector<std::pair<std::string, reduction_fn>> m_MicroReductions;
152  };
153 }
154 
155 #endif //DISMEC_POSTPROCESSING_H
Float value() const
Definition: sum.h:22
Strong typedef for an int to signify a label id.
Definition: types.h:20
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
Definition: metrics.cpp:150
std::unique_ptr< MetricCollectionInterface > clone() const override
Definition: metrics.cpp:162
AbandonmentAtK(long num_labels, long k)
Definition: metrics.cpp:147
void reduce(const MetricCollectionInterface &other) override
Definition: metrics.cpp:46
ConfusionMatrixRecorder(long num_labels, long k)
Definition: metrics.cpp:25
std::unique_ptr< MetricCollectionInterface > clone() const override
Definition: metrics.cpp:65
std::vector< ConfusionMatrix > m_Confusion
Definition: metrics.h:59
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
Definition: metrics.cpp:29
ConfusionMatrix get_confusion_matrix(label_id_t label) const
Definition: metrics.cpp:58
KahanAccumulator< double > m_Accumulator
Definition: metrics.h:74
void reduce(const MetricCollectionInterface &other) override
Definition: metrics.cpp:79
std::unique_ptr< MetricCollectionInterface > clone() const override
Definition: metrics.cpp:139
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
Definition: metrics.cpp:119
std::vector< double > m_Cumulative
Definition: metrics.h:89
InstanceRankedPositives(long num_labels, long k, bool normalize=false)
Definition: metrics.cpp:100
std::vector< metric_t > get_values() const override
Definition: metrics.cpp:175
InstanceWiseMetricReporter(std::string name, const InstanceAveragedMetric *metric)
Definition: metrics.cpp:170
const InstanceAveragedMetric * m_Metric
Definition: metrics.h:115
std::function< double(const ConfusionMatrix &)> reduction_fn
Definition: metrics.h:148
void add_coverage(double threshold, std::string name={})
Definition: metrics.cpp:179
void add_negative_predictive_value(ReductionType reduction=MACRO, std::string name={})
MacroMetricReporter(const ConfusionMatrixRecorder *confusion)
Definition: metrics.cpp:264
void add_fowlkes_mallows(ReductionType reduction=MACRO, std::string name={})
void add_specificity(ReductionType reduction=MACRO, std::string name={})
void add_f_measure(ReductionType reduction=MACRO, double beta=1.0, std::string name={})
Definition: metrics.cpp:225
void add_informedness(ReductionType reduction=MACRO, std::string name={})
void add_accuracy(ReductionType reduction=MACRO, std::string name={})
void add_negative_likelihood_ratio(ReductionType reduction=MACRO, std::string name={})
void add_recall(ReductionType reduction=MACRO, std::string name={})
std::vector< metric_t > get_values() const override
Definition: metrics.cpp:270
void add_diagnostic_odds_ratio(ReductionType reduction=MACRO, std::string name={})
void add_markedness(ReductionType reduction=MACRO, std::string name={})
void add_positive_likelihood_ratio(ReductionType reduction=MACRO, std::string name={})
const ConfusionMatrixRecorder * m_ConfusionMatrix
Definition: metrics.h:151
void add_matthews(ReductionType reduction=MACRO, std::string name={})
void add_reduction(std::string name, ReductionType type, std::function< double(const ConfusionMatrix &)>)
Definition: metrics.cpp:256
std::vector< std::pair< std::string, reduction_fn > > m_MacroReductions
Definition: metrics.h:149
std::vector< std::pair< std::string, reduction_fn > > m_MicroReductions
Definition: metrics.h:150
void add_precision(ReductionType reduction=MACRO, std::string name={})
void add_reduction_helper(std::string name, const char *pattern, ReductionType type, std::function< double(const ConfusionMatrix &)> fn)
Definition: metrics.cpp:248
void add_balanced_accuracy(ReductionType reduction=MACRO, std::string name={})
Base class for all metrics that can be calculated during the evaluation phase.
Definition: metrics.h:28
long num_labels() const
Gets the number of labels.
Definition: metrics.h:38
virtual std::unique_ptr< MetricCollectionInterface > clone() const =0
virtual void update(const pd_info_vec &prediction, const gt_info_vec &labels)=0
std::vector< sPredLabelInfo > pd_info_vec
Definition: metrics.h:31
std::vector< sTrueLabelInfo > gt_info_vec
Definition: metrics.h:30
virtual void reduce(const MetricCollectionInterface &other)=0
virtual std::vector< metric_t > get_values() const =0
std::pair< std::string, double > metric_t
Definition: metrics.h:105
ConfusionMatrixBase< long > ConfusionMatrix
Definition: metrics.h:18