DiSMEC++
|
#include <prediction.h>
Public Member Functions | |
FullPredictionTaskGenerator (const DatasetBase *data, std::shared_ptr< const Model > model) | |
void | prepare (long num_threads, long chunk_size) override |
Called to notify the TaskGenerator about the number of threads. More... | |
void | run_tasks (long begin, long end, thread_id_t thread_id) override |
long | num_tasks () const override |
const PredictionMatrix & | get_predictions () const |
![]() | |
PredictionBase (const DatasetBase *data, std::shared_ptr< const Model > model) | |
Constructor, checks that data and model are compatible. More... | |
![]() | |
virtual | ~TaskGenerator ()=default |
virtual void | finalize () |
Called after all threads have finished their tasks. More... | |
Private Attributes | |
PredictionMatrix | m_Predictions |
Additional Inherited Members | |
![]() | |
using | thread_id_t = dismec::parallel::thread_id_t |
![]() | |
void | make_thread_local_features (long num_threads) |
void | init_thread (thread_id_t thread_id) final |
Called once a thread has spun up, but before it runs its first task. More... | |
void | do_prediction (long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target) |
Predicts the scores for a subset of the instances given by the half-open interval [begin, end) . More... | |
![]() | |
const DatasetBase * | m_Data |
Data on which the prediction is run. More... | |
std::shared_ptr< const Model > | m_Model |
Model (possibly partial) for which prediction is run. More... | |
Definition at line 70 of file prediction.h.
FullPredictionTaskGenerator::FullPredictionTaskGenerator | ( | const DatasetBase * | data, |
std::shared_ptr< const Model > | model | ||
) |
Definition at line 59 of file prediction.cpp.
References m_Predictions, dismec::DatasetBase::num_examples(), and dismec::DatasetBase::num_labels().
|
inline |
Definition at line 79 of file prediction.h.
References m_Predictions.
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 65 of file prediction.cpp.
References dismec::prediction::PredictionBase::m_Data, and dismec::DatasetBase::num_examples().
|
overridevirtual |
Called to notify the TaskGenerator
about the number of threads.
This function is called from the main thread, before distributed work is started. It gives the TaskGenerator
a chance to allocate working memory for each thread, so these allocations don't need to be done and repeated in run_task()
.
init_thread()
should be used. This will be called from inside the thread that will do the actual computations, so that when using this on a NUMA
system, first-touch policy has a chance to place the allocation in the correct RAM. In that case, this function should only allocate an array of pointers that will be filled in by init_thread()
. num_threads | Number of threads that will be used. |
chunk_size | A hint for the size of chunks used when running this task. Note that if the total number of tasks is not a multiple of the chunk_size , there may be some calls to run_tasks() with less than chunk_size tasks. |
Reimplemented from dismec::parallel::TaskGenerator.
Definition at line 75 of file prediction.cpp.
References dismec::prediction::PredictionBase::make_thread_local_features().
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 70 of file prediction.cpp.
References dismec::prediction::PredictionBase::do_prediction(), and m_Predictions.
|
private |
Definition at line 81 of file prediction.h.
Referenced by FullPredictionTaskGenerator(), get_predictions(), and run_tasks().