DiSMEC++
dismec::prediction::FullPredictionTaskGenerator Class Reference

#include <prediction.h>

Inheritance diagram for dismec::prediction::FullPredictionTaskGenerator:
dismec::prediction::PredictionBase dismec::parallel::TaskGenerator

Public Member Functions

 FullPredictionTaskGenerator (const DatasetBase *data, std::shared_ptr< const Model > model)
 
void prepare (long num_threads, long chunk_size) override
 Called to notify the TaskGenerator about the number of threads. More...
 
void run_tasks (long begin, long end, thread_id_t thread_id) override
 
long num_tasks () const override
 
const PredictionMatrixget_predictions () const
 
- Public Member Functions inherited from dismec::prediction::PredictionBase
 PredictionBase (const DatasetBase *data, std::shared_ptr< const Model > model)
 Constructor, checks that data and model are compatible. More...
 
- Public Member Functions inherited from dismec::parallel::TaskGenerator
virtual ~TaskGenerator ()=default
 
virtual void finalize ()
 Called after all threads have finished their tasks. More...
 

Private Attributes

PredictionMatrix m_Predictions
 

Additional Inherited Members

- Public Types inherited from dismec::parallel::TaskGenerator
using thread_id_t = dismec::parallel::thread_id_t
 
- Protected Member Functions inherited from dismec::prediction::PredictionBase
void make_thread_local_features (long num_threads)
 
void init_thread (thread_id_t thread_id) final
 Called once a thread has spun up, but before it runs its first task. More...
 
void do_prediction (long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target)
 Predicts the scores for a subset of the instances given by the half-open interval [begin, end). More...
 
- Protected Attributes inherited from dismec::prediction::PredictionBase
const DatasetBasem_Data
 Data on which the prediction is run. More...
 
std::shared_ptr< const Modelm_Model
 Model (possibly partial) for which prediction is run. More...
 

Detailed Description

Definition at line 70 of file prediction.h.

Constructor & Destructor Documentation

◆ FullPredictionTaskGenerator()

FullPredictionTaskGenerator::FullPredictionTaskGenerator ( const DatasetBase data,
std::shared_ptr< const Model model 
)

Member Function Documentation

◆ get_predictions()

const PredictionMatrix& dismec::prediction::FullPredictionTaskGenerator::get_predictions ( ) const
inline

Definition at line 79 of file prediction.h.

References m_Predictions.

◆ num_tasks()

long FullPredictionTaskGenerator::num_tasks ( ) const
overridevirtual

◆ prepare()

void FullPredictionTaskGenerator::prepare ( long  num_threads,
long  chunk_size 
)
overridevirtual

Called to notify the TaskGenerator about the number of threads.

This function is called from the main thread, before distributed work is started. It gives the TaskGenerator a chance to allocate working memory for each thread, so these allocations don't need to be done and repeated in run_task().

Note
For memory that is used inside the computations done by each thread, the init_thread() should be used. This will be called from inside the thread that will do the actual computations, so that when using this on a NUMA system, first-touch policy has a chance to place the allocation in the correct RAM. In that case, this function should only allocate an array of pointers that will be filled in by init_thread().
Parameters
num_threadsNumber of threads that will be used.
chunk_sizeA hint for the size of chunks used when running this task. Note that if the total number of tasks is not a multiple of the chunk_size, there may be some calls to run_tasks() with less than chunk_size tasks.

Reimplemented from dismec::parallel::TaskGenerator.

Definition at line 75 of file prediction.cpp.

References dismec::prediction::PredictionBase::make_thread_local_features().

◆ run_tasks()

void FullPredictionTaskGenerator::run_tasks ( long  begin,
long  end,
thread_id_t  thread_id 
)
overridevirtual

Member Data Documentation

◆ m_Predictions

PredictionMatrix dismec::prediction::FullPredictionTaskGenerator::m_Predictions
private

Definition at line 81 of file prediction.h.

Referenced by FullPredictionTaskGenerator(), get_predictions(), and run_tasks().


The documentation for this class was generated from the following files: