6 #ifndef DISMEC_TRAINING_H
7 #define DISMEC_TRAINING_H
33 void prepare(
long num_threads,
long chunk_size)
override;
36 [[nodiscard]]
long num_tasks()
const override;
38 [[nodiscard]]
const std::shared_ptr<model::Model>&
get_model()
const {
return m_Model; }
74 std::shared_ptr<model::Model>
Model;
Generates tasks for training weights for the i'th label.
long num_tasks() const override
std::vector< DenseRealVector > m_ThreadLocalWorkingVector
label_id_t m_LabelRangeEnd
std::vector< std::unique_ptr< postproc::PostProcessor > > m_ThreadLocalPostProc
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
void run_task(long task_id, thread_id_t thread_id)
label_id_t m_LabelRangeBegin
const std::shared_ptr< model::Model > & get_model() const
std::shared_ptr< TrainingSpec > m_TaskSpec
const std::vector< solvers::MinimizationResult > & get_results() const
void finalize() override
Called after all threads have finished their tasks.
std::vector< std::unique_ptr< ResultStatsGatherer > > m_ResultGatherers
~TrainingTaskGenerator() override
std::vector< std::unique_ptr< init::WeightsInitializer > > m_ThreadLocalWeightInit
std::vector< std::unique_ptr< solvers::Minimizer > > m_ThreadLocalMinimizer
TrainingTaskGenerator(std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
solvers::MinimizationResult train_label(label_id_t label_id, thread_id_t thread_id)
Runs the training of a single label.
std::shared_ptr< model::Model > m_Model
void run_tasks(long begin, long end, thread_id_t thread_id) override
std::vector< solvers::MinimizationResult > m_Results
std::vector< std::shared_ptr< objective::Objective > > m_ThreadLocalObjective
void init_thread(thread_id_t thread_id) override
Called once a thread has spun up, but before it runs its first task.
Strong typedef for an int to signify a label id.
Base class for all parallelized operations.
dismec::parallel::thread_id_t thread_id_t
Strong typedef for an int to signify a thread id.
Main namespace in which all types, classes, and functions are defined.
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
float real_t
The default type for floating point values.
std::shared_ptr< model::Model > Model