|
DiSMEC++
|
#include <prediction.h>
Public Member Functions | |
| FullPredictionTaskGenerator (const DatasetBase *data, std::shared_ptr< const Model > model) | |
| void | prepare (long num_threads, long chunk_size) override |
Called to notify the TaskGenerator about the number of threads. More... | |
| void | run_tasks (long begin, long end, thread_id_t thread_id) override |
| long | num_tasks () const override |
| const PredictionMatrix & | get_predictions () const |
Public Member Functions inherited from dismec::prediction::PredictionBase | |
| PredictionBase (const DatasetBase *data, std::shared_ptr< const Model > model) | |
Constructor, checks that data and model are compatible. More... | |
Public Member Functions inherited from dismec::parallel::TaskGenerator | |
| virtual | ~TaskGenerator ()=default |
| virtual void | finalize () |
| Called after all threads have finished their tasks. More... | |
Private Attributes | |
| PredictionMatrix | m_Predictions |
Additional Inherited Members | |
Public Types inherited from dismec::parallel::TaskGenerator | |
| using | thread_id_t = dismec::parallel::thread_id_t |
Protected Member Functions inherited from dismec::prediction::PredictionBase | |
| void | make_thread_local_features (long num_threads) |
| void | init_thread (thread_id_t thread_id) final |
| Called once a thread has spun up, but before it runs its first task. More... | |
| void | do_prediction (long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target) |
Predicts the scores for a subset of the instances given by the half-open interval [begin, end). More... | |
Protected Attributes inherited from dismec::prediction::PredictionBase | |
| const DatasetBase * | m_Data |
| Data on which the prediction is run. More... | |
| std::shared_ptr< const Model > | m_Model |
| Model (possibly partial) for which prediction is run. More... | |
Definition at line 70 of file prediction.h.
| FullPredictionTaskGenerator::FullPredictionTaskGenerator | ( | const DatasetBase * | data, |
| std::shared_ptr< const Model > | model | ||
| ) |
Definition at line 59 of file prediction.cpp.
References m_Predictions, dismec::DatasetBase::num_examples(), and dismec::DatasetBase::num_labels().
|
inline |
Definition at line 79 of file prediction.h.
References m_Predictions.
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 65 of file prediction.cpp.
References dismec::prediction::PredictionBase::m_Data, and dismec::DatasetBase::num_examples().
|
overridevirtual |
Called to notify the TaskGenerator about the number of threads.
This function is called from the main thread, before distributed work is started. It gives the TaskGenerator a chance to allocate working memory for each thread, so these allocations don't need to be done and repeated in run_task().
init_thread() should be used. This will be called from inside the thread that will do the actual computations, so that when using this on a NUMA system, first-touch policy has a chance to place the allocation in the correct RAM. In that case, this function should only allocate an array of pointers that will be filled in by init_thread(). | num_threads | Number of threads that will be used. |
| chunk_size | A hint for the size of chunks used when running this task. Note that if the total number of tasks is not a multiple of the chunk_size, there may be some calls to run_tasks() with less than chunk_size tasks. |
Reimplemented from dismec::parallel::TaskGenerator.
Definition at line 75 of file prediction.cpp.
References dismec::prediction::PredictionBase::make_thread_local_features().
|
overridevirtual |
Implements dismec::parallel::TaskGenerator.
Definition at line 70 of file prediction.cpp.
References dismec::prediction::PredictionBase::do_prediction(), and m_Predictions.
|
private |
Definition at line 81 of file prediction.h.
Referenced by FullPredictionTaskGenerator(), get_predictions(), and run_tasks().