DiSMEC++
dismec::prediction::TopKPredictionTaskGenerator Class Reference

#include <prediction.h>

Inheritance diagram for dismec::prediction::TopKPredictionTaskGenerator:
dismec::prediction::PredictionBase dismec::parallel::TaskGenerator

Public Member Functions

 TopKPredictionTaskGenerator (const DatasetBase *data, std::shared_ptr< const Model > model, long K)
 
void update_model (std::shared_ptr< const Model > model)
 
void run_tasks (long begin, long end, thread_id_t thread_id) override
 
long num_tasks () const override
 
void prepare (long num_threads, long chunk_size) override
 Called to notify the TaskGenerator about the number of threads. More...
 
void finalize () override
 Called after all threads have finished their tasks. More...
 
const PredictionMatrixget_top_k_values () const
 
const IndexMatrixget_top_k_indices () const
 
const std::array< std::int64_t, 4 > & get_confusion_matrix () const
 
- Public Member Functions inherited from dismec::prediction::PredictionBase
 PredictionBase (const DatasetBase *data, std::shared_ptr< const Model > model)
 Constructor, checks that data and model are compatible. More...
 
- Public Member Functions inherited from dismec::parallel::TaskGenerator
virtual ~TaskGenerator ()=default
 

Static Public Attributes

static constexpr const int TRUE_POSITIVES = 0
 
static constexpr const int FALSE_POSITIVES = 1
 
static constexpr const int TRUE_NEGATIVES = 2
 
static constexpr const int FALSE_NEGATIVES = 3
 

Private Attributes

long m_K
 
PredictionMatrix m_TopKValues
 
IndexMatrix m_TopKIndices
 
std::vector< PredictionMatrixm_ThreadLocalPredictionCache
 
std::vector< PredictionMatrixm_ThreadLocalTopKValues
 
std::vector< IndexMatrixm_ThreadLocalTopKIndices
 
std::vector< std::array< std::int64_t, 4 > > m_ThreadLocalConfusionMatrix
 
std::vector< std::vector< long > > m_GroundTruth
 
std::array< std::int64_t, 4 > m_ConfusionMatrix
 

Additional Inherited Members

- Public Types inherited from dismec::parallel::TaskGenerator
using thread_id_t = dismec::parallel::thread_id_t
 
- Protected Member Functions inherited from dismec::prediction::PredictionBase
void make_thread_local_features (long num_threads)
 
void init_thread (thread_id_t thread_id) final
 Called once a thread has spun up, but before it runs its first task. More...
 
void do_prediction (long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target)
 Predicts the scores for a subset of the instances given by the half-open interval [begin, end). More...
 
- Protected Attributes inherited from dismec::prediction::PredictionBase
const DatasetBasem_Data
 Data on which the prediction is run. More...
 
std::shared_ptr< const Modelm_Model
 Model (possibly partial) for which prediction is run. More...
 

Detailed Description

Definition at line 84 of file prediction.h.

Constructor & Destructor Documentation

◆ TopKPredictionTaskGenerator()

TopKPredictionTaskGenerator::TopKPredictionTaskGenerator ( const DatasetBase data,
std::shared_ptr< const Model model,
long  K 
)

Member Function Documentation

◆ finalize()

void TopKPredictionTaskGenerator::finalize ( )
overridevirtual

Called after all threads have finished their tasks.

This function is called from the main thread after all worker threads have finished their work. It can be used to perform single threaded reductions or clean up per-thread buffers.

Reimplemented from dismec::parallel::TaskGenerator.

Definition at line 124 of file prediction.cpp.

References m_ConfusionMatrix, m_ThreadLocalConfusionMatrix, and m_ThreadLocalPredictionCache.

◆ get_confusion_matrix()

const std::array<std::int64_t, 4>& dismec::prediction::TopKPredictionTaskGenerator::get_confusion_matrix ( ) const
inline

Definition at line 99 of file prediction.h.

References m_ConfusionMatrix.

◆ get_top_k_indices()

const IndexMatrix& dismec::prediction::TopKPredictionTaskGenerator::get_top_k_indices ( ) const
inline

Definition at line 97 of file prediction.h.

References m_TopKIndices.

◆ get_top_k_values()

const PredictionMatrix& dismec::prediction::TopKPredictionTaskGenerator::get_top_k_values ( ) const
inline

Definition at line 96 of file prediction.h.

References m_TopKValues.

◆ num_tasks()

long TopKPredictionTaskGenerator::num_tasks ( ) const
overridevirtual

◆ prepare()

void TopKPredictionTaskGenerator::prepare ( long  num_threads,
long  chunk_size 
)
overridevirtual

Called to notify the TaskGenerator about the number of threads.

This function is called from the main thread, before distributed work is started. It gives the TaskGenerator a chance to allocate working memory for each thread, so these allocations don't need to be done and repeated in run_task().

Note
For memory that is used inside the computations done by each thread, the init_thread() should be used. This will be called from inside the thread that will do the actual computations, so that when using this on a NUMA system, first-touch policy has a chance to place the allocation in the correct RAM. In that case, this function should only allocate an array of pointers that will be filled in by init_thread().
Parameters
num_threadsNumber of threads that will be used.
chunk_sizeA hint for the size of chunks used when running this task. Note that if the total number of tasks is not a multiple of the chunk_size, there may be some calls to run_tasks() with less than chunk_size tasks.

Reimplemented from dismec::parallel::TaskGenerator.

Definition at line 103 of file prediction.cpp.

References m_K, dismec::prediction::PredictionBase::m_Model, m_ThreadLocalConfusionMatrix, m_ThreadLocalPredictionCache, m_ThreadLocalTopKIndices, m_ThreadLocalTopKValues, and dismec::prediction::PredictionBase::make_thread_local_features().

◆ run_tasks()

◆ update_model()

void TopKPredictionTaskGenerator::update_model ( std::shared_ptr< const Model model)

Definition at line 210 of file prediction.cpp.

References dismec::prediction::PredictionBase::m_Model.

Member Data Documentation

◆ FALSE_NEGATIVES

constexpr const int dismec::prediction::TopKPredictionTaskGenerator::FALSE_NEGATIVES = 3
staticconstexpr

Definition at line 104 of file prediction.h.

Referenced by main(), and run_tasks().

◆ FALSE_POSITIVES

constexpr const int dismec::prediction::TopKPredictionTaskGenerator::FALSE_POSITIVES = 1
staticconstexpr

Definition at line 102 of file prediction.h.

Referenced by main(), and run_tasks().

◆ m_ConfusionMatrix

std::array<std::int64_t, 4> dismec::prediction::TopKPredictionTaskGenerator::m_ConfusionMatrix
private

Definition at line 117 of file prediction.h.

Referenced by finalize(), get_confusion_matrix(), and TopKPredictionTaskGenerator().

◆ m_GroundTruth

std::vector<std::vector<long> > dismec::prediction::TopKPredictionTaskGenerator::m_GroundTruth
private

Definition at line 116 of file prediction.h.

Referenced by run_tasks(), and TopKPredictionTaskGenerator().

◆ m_K

long dismec::prediction::TopKPredictionTaskGenerator::m_K
private

Definition at line 106 of file prediction.h.

Referenced by prepare(), run_tasks(), and TopKPredictionTaskGenerator().

◆ m_ThreadLocalConfusionMatrix

std::vector<std::array<std::int64_t, 4> > dismec::prediction::TopKPredictionTaskGenerator::m_ThreadLocalConfusionMatrix
private

Definition at line 114 of file prediction.h.

Referenced by finalize(), prepare(), and run_tasks().

◆ m_ThreadLocalPredictionCache

std::vector<PredictionMatrix> dismec::prediction::TopKPredictionTaskGenerator::m_ThreadLocalPredictionCache
private

Definition at line 111 of file prediction.h.

Referenced by finalize(), prepare(), and run_tasks().

◆ m_ThreadLocalTopKIndices

std::vector<IndexMatrix> dismec::prediction::TopKPredictionTaskGenerator::m_ThreadLocalTopKIndices
private

Definition at line 113 of file prediction.h.

Referenced by prepare(), and run_tasks().

◆ m_ThreadLocalTopKValues

std::vector<PredictionMatrix> dismec::prediction::TopKPredictionTaskGenerator::m_ThreadLocalTopKValues
private

Definition at line 112 of file prediction.h.

Referenced by prepare(), and run_tasks().

◆ m_TopKIndices

IndexMatrix dismec::prediction::TopKPredictionTaskGenerator::m_TopKIndices
private

Definition at line 109 of file prediction.h.

Referenced by get_top_k_indices(), run_tasks(), and TopKPredictionTaskGenerator().

◆ m_TopKValues

PredictionMatrix dismec::prediction::TopKPredictionTaskGenerator::m_TopKValues
private

Definition at line 108 of file prediction.h.

Referenced by get_top_k_values(), run_tasks(), and TopKPredictionTaskGenerator().

◆ TRUE_NEGATIVES

constexpr const int dismec::prediction::TopKPredictionTaskGenerator::TRUE_NEGATIVES = 2
staticconstexpr

Definition at line 103 of file prediction.h.

Referenced by main(), and run_tasks().

◆ TRUE_POSITIVES

constexpr const int dismec::prediction::TopKPredictionTaskGenerator::TRUE_POSITIVES = 0
staticconstexpr

Definition at line 101 of file prediction.h.

Referenced by main(), and run_tasks().


The documentation for this class was generated from the following files: