6 #ifndef DISMEC_SRC_PREDICTION_EVALUATE_H
7 #define DISMEC_SRC_PREDICTION_EVALUATE_H
36 using LabelList = std::vector<std::vector<label_id_t>>;
41 void prepare(
long num_threads,
long chunk_size)
override;
45 void add_dcg_at_k(
long k,
bool normalize, std::string name = {});
49 [[nodiscard]] std::vector<std::pair<std::string, double>>
get_metrics()
const;
56 [[nodiscard]]
long num_tasks()
const override;
58 using prediction_t = Eigen::Ref<const Eigen::Matrix<long, 1, Eigen::Dynamic>>;
60 std::vector<sTrueLabelInfo>& proc_labels, std::vector<sPredLabelInfo>& proc_pred);
67 std::vector<std::vector<std::unique_ptr<MetricCollectionInterface>>>
m_Collectors;
68 std::vector<std::unique_ptr<MetricReportInterface>>
m_Metrics;
Strong typedef for an int to signify a label id.
Base class for all parallelized operations.
dismec::parallel::thread_id_t thread_id_t
Strong typedef for an int to signify a thread id.
This TaskGenerator enables the calculation of evaluation metrics on top-k style sparse predictions.
const IndexMatrix * m_Predictions
void add_dcg_at_k(long k, bool normalize, std::string name={})
MacroMetricReporter * add_macro_at_k(long k)
std::vector< std::vector< sTrueLabelInfo > > m_ThreadLocalTrueLabels
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
~EvaluateMetrics() override
void add_precision_at_k(long k, std::string name={})
EvaluateMetrics(const LabelList *sparse_labels, const IndexMatrix *sparse_predictions, long num_labels)
void finalize() override
Called after all threads have finished their tasks.
long num_tasks() const override
std::vector< std::vector< std::unique_ptr< MetricCollectionInterface > > > m_Collectors
Eigen::Ref< const Eigen::Matrix< long, 1, Eigen::Dynamic > > prediction_t
void run_tasks(long begin, long end, thread_id_t thread_id) override
std::vector< std::unique_ptr< MetricReportInterface > > m_Metrics
std::vector< std::pair< std::string, double > > get_metrics() const
const LabelList * m_Labels
void add_abandonment_at_k(long k, std::string name={})
void init_thread(thread_id_t thread_id) override
Called once a thread has spun up, but before it runs its first task.
static void process_prediction(const std::vector< label_id_t > &raw_labels, const prediction_t &raw_prediction, std::vector< sTrueLabelInfo > &proc_labels, std::vector< sPredLabelInfo > &proc_pred)
std::vector< std::vector< sPredLabelInfo > > m_ThreadLocalPredictedLabels
std::vector< std::vector< label_id_t > > LabelList
void run_task(long task_id, thread_id_t thread_id)
Base class for all metrics that can be calculated during the evaluation phase.
types::DenseRowMajor< long > IndexMatrix
Matrix used for indices in sparse predictions.