DiSMEC++
|
#include <prediction.h>
Public Member Functions | |
TopKPredictionTaskGenerator (const DatasetBase *data, std::shared_ptr< const Model > model, long K) | |
void | update_model (std::shared_ptr< const Model > model) |
void | run_tasks (long begin, long end, thread_id_t thread_id) override |
long | num_tasks () const override |
void | prepare (long num_threads, long chunk_size) override |
Called to notify the TaskGenerator about the number of threads. More... | |
void | finalize () override |
Called after all threads have finished their tasks. More... | |
const PredictionMatrix & | get_top_k_values () const |
const IndexMatrix & | get_top_k_indices () const |
const std::array< std::int64_t, 4 > & | get_confusion_matrix () const |
Public Member Functions inherited from dismec::prediction::PredictionBase | |
PredictionBase (const DatasetBase *data, std::shared_ptr< const Model > model) | |
Constructor, checks that data and model are compatible. More... | |
Public Member Functions inherited from dismec::parallel::TaskGenerator | |
virtual | ~TaskGenerator ()=default |
Static Public Attributes | |
static constexpr const int | TRUE_POSITIVES = 0 |
static constexpr const int | FALSE_POSITIVES = 1 |
static constexpr const int | TRUE_NEGATIVES = 2 |
static constexpr const int | FALSE_NEGATIVES = 3 |
Private Attributes | |
long | m_K |
PredictionMatrix | m_TopKValues |
IndexMatrix | m_TopKIndices |
std::vector< PredictionMatrix > | m_ThreadLocalPredictionCache |
std::vector< PredictionMatrix > | m_ThreadLocalTopKValues |
std::vector< IndexMatrix > | m_ThreadLocalTopKIndices |
std::vector< std::array< std::int64_t, 4 > > | m_ThreadLocalConfusionMatrix |
std::vector< std::vector< long > > | m_GroundTruth |
std::array< std::int64_t, 4 > | m_ConfusionMatrix |
Additional Inherited Members | |
Public Types inherited from dismec::parallel::TaskGenerator | |
using | thread_id_t = dismec::parallel::thread_id_t |
Protected Member Functions inherited from dismec::prediction::PredictionBase | |
void | make_thread_local_features (long num_threads) |
void | init_thread (thread_id_t thread_id) final |
Called once a thread has spun up, but before it runs its first task. More... | |
void | do_prediction (long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target) |
Predicts the scores for a subset of the instances given by the half-open interval [begin, end) . More... | |
Protected Attributes inherited from dismec::prediction::PredictionBase | |
const DatasetBase * | m_Data |
Data on which the prediction is run. More... | |
std::shared_ptr< const Model > | m_Model |
Model (possibly partial) for which prediction is run. More... | |
Definition at line 84 of file prediction.h.
TopKPredictionTaskGenerator::TopKPredictionTaskGenerator | ( | const DatasetBase * | data, |
std::shared_ptr< const Model > | model, | ||
long | K | ||
) |
Definition at line 80 of file prediction.cpp.
References m_ConfusionMatrix, m_GroundTruth, m_K, m_TopKIndices, m_TopKValues, dismec::DatasetBase::num_examples(), dismec::DatasetBase::num_labels(), and dismec::opaque_int_type< Tag, T >::to_index().
|
overridevirtual |
Called after all threads have finished their tasks.
This function is called from the main thread after all worker threads have finished their work. It can be used to perform single threaded reductions or clean up per-thread buffers.
Reimplemented from dismec::parallel::TaskGenerator.
Definition at line 124 of file prediction.cpp.
References m_ConfusionMatrix, m_ThreadLocalConfusionMatrix, and m_ThreadLocalPredictionCache.
|
inline |
Definition at line 99 of file prediction.h.
References m_ConfusionMatrix.
|
inline |
Definition at line 97 of file prediction.h.
References m_TopKIndices.
|
inline |
Definition at line 96 of file prediction.h.
References m_TopKValues.
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 99 of file prediction.cpp.
References dismec::prediction::PredictionBase::m_Data, and dismec::DatasetBase::num_examples().
|
overridevirtual |
Called to notify the TaskGenerator
about the number of threads.
This function is called from the main thread, before distributed work is started. It gives the TaskGenerator
a chance to allocate working memory for each thread, so these allocations don't need to be done and repeated in run_task()
.
init_thread()
should be used. This will be called from inside the thread that will do the actual computations, so that when using this on a NUMA
system, first-touch policy has a chance to place the allocation in the correct RAM. In that case, this function should only allocate an array of pointers that will be filled in by init_thread()
. num_threads | Number of threads that will be used. |
chunk_size | A hint for the size of chunks used when running this task. Note that if the total number of tasks is not a multiple of the chunk_size , there may be some calls to run_tasks() with less than chunk_size tasks. |
Reimplemented from dismec::parallel::TaskGenerator.
Definition at line 103 of file prediction.cpp.
References m_K, dismec::prediction::PredictionBase::m_Model, m_ThreadLocalConfusionMatrix, m_ThreadLocalPredictionCache, m_ThreadLocalTopKIndices, m_ThreadLocalTopKValues, and dismec::prediction::PredictionBase::make_thread_local_features().
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 133 of file prediction.cpp.
References dismec::prediction::PredictionBase::do_prediction(), FALSE_NEGATIVES, FALSE_POSITIVES, m_GroundTruth, m_K, dismec::prediction::PredictionBase::m_Model, m_ThreadLocalConfusionMatrix, m_ThreadLocalPredictionCache, m_ThreadLocalTopKIndices, m_ThreadLocalTopKValues, m_TopKIndices, m_TopKValues, dismec::opaque_int_type< Tag, T >::to_index(), TRUE_NEGATIVES, and TRUE_POSITIVES.
void TopKPredictionTaskGenerator::update_model | ( | std::shared_ptr< const Model > | model | ) |
Definition at line 210 of file prediction.cpp.
References dismec::prediction::PredictionBase::m_Model.
|
staticconstexpr |
Definition at line 104 of file prediction.h.
Referenced by main(), and run_tasks().
|
staticconstexpr |
Definition at line 102 of file prediction.h.
Referenced by main(), and run_tasks().
|
private |
Definition at line 117 of file prediction.h.
Referenced by finalize(), get_confusion_matrix(), and TopKPredictionTaskGenerator().
|
private |
Definition at line 116 of file prediction.h.
Referenced by run_tasks(), and TopKPredictionTaskGenerator().
|
private |
Definition at line 106 of file prediction.h.
Referenced by prepare(), run_tasks(), and TopKPredictionTaskGenerator().
|
private |
Definition at line 114 of file prediction.h.
Referenced by finalize(), prepare(), and run_tasks().
|
private |
Definition at line 111 of file prediction.h.
Referenced by finalize(), prepare(), and run_tasks().
|
private |
Definition at line 113 of file prediction.h.
Referenced by prepare(), and run_tasks().
|
private |
Definition at line 112 of file prediction.h.
Referenced by prepare(), and run_tasks().
|
private |
Definition at line 109 of file prediction.h.
Referenced by get_top_k_indices(), run_tasks(), and TopKPredictionTaskGenerator().
|
private |
Definition at line 108 of file prediction.h.
Referenced by get_top_k_values(), run_tasks(), and TopKPredictionTaskGenerator().
|
staticconstexpr |
Definition at line 103 of file prediction.h.
Referenced by main(), and run_tasks().
|
staticconstexpr |
Definition at line 101 of file prediction.h.
Referenced by main(), and run_tasks().