6 #ifndef DISMEC_POSTPROCESSING_H
7 #define DISMEC_POSTPROCESSING_H
42 [[nodiscard]]
virtual std::unique_ptr<MetricCollectionInterface>
clone()
const = 0;
52 [[nodiscard]] std::unique_ptr<MetricCollectionInterface>
clone()
const override;
54 [[nodiscard]]
long get_k()
const {
return m_K; }
67 [[nodiscard]]
double value()
const {
84 [[nodiscard]] std::unique_ptr<MetricCollectionInterface>
clone()
const override;
96 [[nodiscard]] std::unique_ptr<MetricCollectionInterface>
clone()
const override;
106 [[nodiscard]]
virtual std::vector<metric_t>
get_values()
const = 0;
112 [[nodiscard]] std::vector<metric_t>
get_values()
const override;
121 [[nodiscard]] std::vector<metric_t>
get_values()
const override;
127 void add_coverage(
double threshold, std::string name={});
Strong typedef for an int to signify a label id.
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
std::unique_ptr< MetricCollectionInterface > clone() const override
AbandonmentAtK(long num_labels, long k)
void reduce(const MetricCollectionInterface &other) override
ConfusionMatrixRecorder(long num_labels, long k)
std::unique_ptr< MetricCollectionInterface > clone() const override
std::vector< ConfusionMatrix > m_Confusion
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
ConfusionMatrix get_confusion_matrix(label_id_t label) const
InstanceAveragedMetric(long num_labels)
KahanAccumulator< double > m_Accumulator
void reduce(const MetricCollectionInterface &other) override
void accumulate(double value)
std::unique_ptr< MetricCollectionInterface > clone() const override
void update(const pd_info_vec &prediction, const gt_info_vec &labels) override
std::vector< double > m_Cumulative
InstanceRankedPositives(long num_labels, long k, bool normalize=false)
std::vector< double > m_Weights
std::vector< metric_t > get_values() const override
InstanceWiseMetricReporter(std::string name, const InstanceAveragedMetric *metric)
const InstanceAveragedMetric * m_Metric
std::function< double(const ConfusionMatrix &)> reduction_fn
void add_coverage(double threshold, std::string name={})
void add_negative_predictive_value(ReductionType reduction=MACRO, std::string name={})
MacroMetricReporter(const ConfusionMatrixRecorder *confusion)
void add_fowlkes_mallows(ReductionType reduction=MACRO, std::string name={})
void add_specificity(ReductionType reduction=MACRO, std::string name={})
void add_f_measure(ReductionType reduction=MACRO, double beta=1.0, std::string name={})
void add_informedness(ReductionType reduction=MACRO, std::string name={})
void add_accuracy(ReductionType reduction=MACRO, std::string name={})
void add_negative_likelihood_ratio(ReductionType reduction=MACRO, std::string name={})
void add_confusion_matrix()
void add_recall(ReductionType reduction=MACRO, std::string name={})
std::vector< metric_t > get_values() const override
void add_diagnostic_odds_ratio(ReductionType reduction=MACRO, std::string name={})
void add_markedness(ReductionType reduction=MACRO, std::string name={})
void add_positive_likelihood_ratio(ReductionType reduction=MACRO, std::string name={})
const ConfusionMatrixRecorder * m_ConfusionMatrix
void add_matthews(ReductionType reduction=MACRO, std::string name={})
void add_reduction(std::string name, ReductionType type, std::function< double(const ConfusionMatrix &)>)
std::vector< std::pair< std::string, reduction_fn > > m_MacroReductions
std::vector< std::pair< std::string, reduction_fn > > m_MicroReductions
void add_precision(ReductionType reduction=MACRO, std::string name={})
void add_reduction_helper(std::string name, const char *pattern, ReductionType type, std::function< double(const ConfusionMatrix &)> fn)
void add_balanced_accuracy(ReductionType reduction=MACRO, std::string name={})
Base class for all metrics that can be calculated during the evaluation phase.
virtual ~MetricCollectionInterface()=default
long num_labels() const
Gets the number of labels.
virtual std::unique_ptr< MetricCollectionInterface > clone() const =0
virtual void update(const pd_info_vec &prediction, const gt_info_vec &labels)=0
std::vector< sPredLabelInfo > pd_info_vec
std::vector< sTrueLabelInfo > gt_info_vec
virtual void reduce(const MetricCollectionInterface &other)=0
MetricCollectionInterface(long num_labels)
virtual std::vector< metric_t > get_values() const =0
virtual ~MetricReportInterface()=default
std::pair< std::string, double > metric_t
ConfusionMatrixBase< long > ConfusionMatrix