6 #ifndef DISMEC_TRAINING_H 
    7 #define DISMEC_TRAINING_H 
   33         void prepare(
long num_threads, 
long chunk_size) 
override;
 
   36         [[nodiscard]] 
long num_tasks() 
const override;
 
   38         [[nodiscard]] 
const std::shared_ptr<model::Model>& 
get_model()
 const { 
return m_Model; }
 
   74         std::shared_ptr<model::Model> 
Model;
 
Generates tasks for training weights for the i'th label.
 
long num_tasks() const override
 
std::vector< DenseRealVector > m_ThreadLocalWorkingVector
 
label_id_t m_LabelRangeEnd
 
std::vector< std::unique_ptr< postproc::PostProcessor > > m_ThreadLocalPostProc
 
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
 
void run_task(long task_id, thread_id_t thread_id)
 
label_id_t m_LabelRangeBegin
 
const std::shared_ptr< model::Model > & get_model() const
 
std::shared_ptr< TrainingSpec > m_TaskSpec
 
const std::vector< solvers::MinimizationResult > & get_results() const
 
void finalize() override
Called after all threads have finished their tasks.
 
std::vector< std::unique_ptr< ResultStatsGatherer > > m_ResultGatherers
 
~TrainingTaskGenerator() override
 
std::vector< std::unique_ptr< init::WeightsInitializer > > m_ThreadLocalWeightInit
 
std::vector< std::unique_ptr< solvers::Minimizer > > m_ThreadLocalMinimizer
 
TrainingTaskGenerator(std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
 
solvers::MinimizationResult train_label(label_id_t label_id, thread_id_t thread_id)
Runs the training of a single label.
 
std::shared_ptr< model::Model > m_Model
 
void run_tasks(long begin, long end, thread_id_t thread_id) override
 
std::vector< solvers::MinimizationResult > m_Results
 
std::vector< std::shared_ptr< objective::Objective > > m_ThreadLocalObjective
 
void init_thread(thread_id_t thread_id) override
Called once a thread has spun up, but before it runs its first task.
 
Strong typedef for an int to signify a label id.
 
Base class for all parallelized operations.
 
dismec::parallel::thread_id_t thread_id_t
 
Strong typedef for an int to signify a thread id.
 
Main namespace in which all types, classes, and functions are defined.
 
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
 
float real_t
The default type for floating point values.
 
std::shared_ptr< model::Model > Model