6 #ifndef DISMEC_PREDICTION_H
7 #define DISMEC_PREDICTION_H
75 void prepare(
long num_threads,
long chunk_size)
override;
77 [[nodiscard]]
long num_tasks()
const override;
92 [[nodiscard]]
long num_tasks()
const override;
93 void prepare(
long num_threads,
long chunk_size)
override;
Helper class to ensure that each NUMA node has its own copy of some immutable data.
Base class for all parallelized operations.
Strong typedef for an int to signify a thread id.
long num_tasks() const override
void run_tasks(long begin, long end, thread_id_t thread_id) override
PredictionMatrix m_Predictions
const PredictionMatrix & get_predictions() const
FullPredictionTaskGenerator(const DatasetBase *data, std::shared_ptr< const Model > model)
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
Base class for handling predictions.
void do_prediction(long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target)
Predicts the scores for a subset of the instances given by the half-open interval [begin,...
void init_thread(thread_id_t thread_id) final
Called once a thread has spun up, but before it runs its first task.
PredictionBase(const DatasetBase *data, std::shared_ptr< const Model > model)
Constructor, checks that data and model are compatible.
std::vector< std::shared_ptr< const GenericFeatureMatrix > > m_ThreadLocalFeatures
std::shared_ptr< const Model > m_Model
Model (possibly partial) for which prediction is run.
const DatasetBase * m_Data
Data on which the prediction is run.
parallel::NUMAReplicator< const GenericFeatureMatrix > m_FeatureReplicator
The NUMAReplicator that generates NUMA-local copies for the feature matrices.
void make_thread_local_features(long num_threads)
static constexpr const int TRUE_POSITIVES
const IndexMatrix & get_top_k_indices() const
IndexMatrix m_TopKIndices
void finalize() override
Called after all threads have finished their tasks.
std::vector< PredictionMatrix > m_ThreadLocalTopKValues
std::vector< PredictionMatrix > m_ThreadLocalPredictionCache
void update_model(std::shared_ptr< const Model > model)
const PredictionMatrix & get_top_k_values() const
std::vector< std::vector< long > > m_GroundTruth
std::vector< IndexMatrix > m_ThreadLocalTopKIndices
static constexpr const int TRUE_NEGATIVES
TopKPredictionTaskGenerator(const DatasetBase *data, std::shared_ptr< const Model > model, long K)
static constexpr const int FALSE_POSITIVES
const std::array< std::int64_t, 4 > & get_confusion_matrix() const
static constexpr const int FALSE_NEGATIVES
void run_tasks(long begin, long end, thread_id_t thread_id) override
long num_tasks() const override
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
std::array< std::int64_t, 4 > m_ConfusionMatrix
PredictionMatrix m_TopKValues
std::vector< std::array< std::int64_t, 4 > > m_ThreadLocalConfusionMatrix
types::DenseRowMajor< long > IndexMatrix
Matrix used for indices in sparse predictions.
types::DenseRowMajor< real_t > PredictionMatrix
Dense matrix in Row Major format used for predictions.