DiSMEC++
TrainingProgram Class Reference

Public Member Functions

 TrainingProgram ()
 
int run (int argc, const char **argv)
 
 TrainingProgram ()
 
int run (int argc, const char **argv)
 

Private Member Functions

void setup_source_cmdline ()
 
void setup_save_cmdline ()
 
void setup_label_range ()
 
void parse_label_range ()
 
void setup_hyper_params ()
 
CascadeTrainingConfig make_config (const std::shared_ptr< MultiLabelData > &data, std::shared_ptr< const GenericFeatureMatrix > dense)
 
void setup_source_cmdline ()
 
void setup_save_cmdline ()
 
void setup_label_range ()
 
void parse_label_range ()
 
void setup_hyper_params ()
 
void setup_regularization ()
 
DismecTrainingConfig make_config (const std::shared_ptr< MultiLabelData > &data)
 

Private Attributes

CLI::App app {"DiSMEC-Cascade"}
 
std::string TfIdfFile
 
std::string DenseFile
 
std::string ShortlistFile
 
std::filesystem::path ModelFile
 
io::model::SaveOption SaveOptions
 
int FirstLabel = 0
 
int NumLabels = -1
 
bool ContinueRun = false
 
label_id_t LabelsBegin {0}
 
label_id_t LabelsEnd {-1}
 
CLI::Option * FirstLabelOpt
 
CLI::Option * NumLabelsOpt
 
HyperParameters hps
 
std::string StatsOutFile = "stats.json"
 
std::string StatsLevelFile = {}
 
bool NormalizeSparse = false
 
bool NormalizeDense = false
 
DatasetTransform TransformSparse = DatasetTransform::IDENTITY
 
std::filesystem::path DenseWeightsFile
 
std::filesystem::path DenseBiasesFile
 
bool InitSparseMSI = false
 
bool InitDenseMSI = false
 
real_t RegScaleSparse = 1.0
 
real_t RegScaleDense = 1.0
 
bool AugmentDenseWithBias = false
 
bool AugmentSparseWithBias = false
 
long NumThreads = -1
 
long Timeout = -1
 
long BatchSize = -1
 
std::filesystem::path ExportProcessedData
 
int Verbose = 0
 
bool ReorderFeatures = false
 
DataProcessing DataProc
 
std::string WeightingMode
 
double PropA = 0.55
 
double PropB = 1.5
 
std::string WeightingPosFile
 
std::string WeightingNegFile
 
std::filesystem::path SourceModel
 
CLI::Option * PreTrainedOpt
 
std::string InitMode
 
std::optional< real_tBiasInitValue
 
real_t MSI_PFac = 1
 
real_t MSI_NFac = -2
 
int InitMaxPos = 1
 
RegularizerType Regularizer = RegularizerType::REG_L2
 
real_t RegScale = 1.0
 
bool RegBias = false
 
real_t Sparsify = -1
 
LossType Loss = LossType::SQUARED_HINGE
 

Detailed Description

Definition at line 26 of file cascade.cpp.

Constructor & Destructor Documentation

◆ TrainingProgram() [1/2]

TrainingProgram::TrainingProgram ( )

◆ TrainingProgram() [2/2]

TrainingProgram::TrainingProgram ( )

Member Function Documentation

◆ make_config() [1/2]

◆ make_config() [2/2]

◆ parse_label_range() [1/2]

◆ parse_label_range() [2/2]

void TrainingProgram::parse_label_range ( )
private

◆ run() [1/2]

◆ run() [2/2]

int TrainingProgram::run ( int  argc,
const char **  argv 
)

◆ setup_hyper_params() [1/2]

void TrainingProgram::setup_hyper_params ( )
private

Definition at line 144 of file cascade.cpp.

◆ setup_hyper_params() [2/2]

void TrainingProgram::setup_hyper_params ( )
private

◆ setup_label_range() [1/2]

void TrainingProgram::setup_label_range ( )
private

Definition at line 131 of file cascade.cpp.

◆ setup_label_range() [2/2]

void TrainingProgram::setup_label_range ( )
private

◆ setup_regularization()

void TrainingProgram::setup_regularization ( )
private

◆ setup_save_cmdline() [1/2]

void TrainingProgram::setup_save_cmdline ( )
private

Definition at line 104 of file cascade.cpp.

References dismec::io::model::SPARSE_TXT.

◆ setup_save_cmdline() [2/2]

void TrainingProgram::setup_save_cmdline ( )
private

◆ setup_source_cmdline() [1/2]

void TrainingProgram::setup_source_cmdline ( )
private

Definition at line 121 of file cascade.cpp.

◆ setup_source_cmdline() [2/2]

void TrainingProgram::setup_source_cmdline ( )
private

Member Data Documentation

◆ app

CLI::App TrainingProgram::app {"DiSMEC-Cascade"}
private

Definition at line 31 of file cascade.cpp.

◆ AugmentDenseWithBias

bool TrainingProgram::AugmentDenseWithBias = false
private

Definition at line 82 of file cascade.cpp.

◆ AugmentSparseWithBias

bool TrainingProgram::AugmentSparseWithBias = false
private

Definition at line 83 of file cascade.cpp.

◆ BatchSize

long TrainingProgram::BatchSize = -1
private

Definition at line 89 of file cascade.cpp.

◆ BiasInitValue

std::optional<real_t> TrainingProgram::BiasInitValue
private

Definition at line 77 of file train.cpp.

◆ ContinueRun

bool TrainingProgram::ContinueRun = false
private

Definition at line 49 of file cascade.cpp.

◆ DataProc

DataProcessing TrainingProgram::DataProc
private

Definition at line 40 of file train.cpp.

◆ DenseBiasesFile

std::filesystem::path TrainingProgram::DenseBiasesFile
private

Definition at line 73 of file cascade.cpp.

◆ DenseFile

std::string TrainingProgram::DenseFile
private

Definition at line 37 of file cascade.cpp.

◆ DenseWeightsFile

std::filesystem::path TrainingProgram::DenseWeightsFile
private

Definition at line 72 of file cascade.cpp.

◆ ExportProcessedData

std::filesystem::path TrainingProgram::ExportProcessedData
private

Definition at line 90 of file cascade.cpp.

◆ FirstLabel

int TrainingProgram::FirstLabel = 0
private

Definition at line 47 of file cascade.cpp.

◆ FirstLabelOpt

CLI::Option * TrainingProgram::FirstLabelOpt
private

Definition at line 55 of file cascade.cpp.

◆ hps

HyperParameters TrainingProgram::hps
private

Definition at line 60 of file cascade.cpp.

◆ InitDenseMSI

bool TrainingProgram::InitDenseMSI = false
private

Definition at line 75 of file cascade.cpp.

◆ InitMaxPos

int TrainingProgram::InitMaxPos = 1
private

Definition at line 80 of file train.cpp.

◆ InitMode

std::string TrainingProgram::InitMode
private

Definition at line 76 of file train.cpp.

◆ InitSparseMSI

bool TrainingProgram::InitSparseMSI = false
private

Definition at line 74 of file cascade.cpp.

◆ LabelsBegin

label_id_t TrainingProgram::LabelsBegin {0}
private

Definition at line 52 of file cascade.cpp.

◆ LabelsEnd

label_id_t TrainingProgram::LabelsEnd {-1}
private

Definition at line 53 of file cascade.cpp.

◆ Loss

LossType TrainingProgram::Loss = LossType::SQUARED_HINGE
private

Definition at line 88 of file train.cpp.

◆ ModelFile

std::filesystem::path TrainingProgram::ModelFile
private

Definition at line 42 of file cascade.cpp.

◆ MSI_NFac

real_t TrainingProgram::MSI_NFac = -2
private

Definition at line 79 of file train.cpp.

◆ MSI_PFac

real_t TrainingProgram::MSI_PFac = 1
private

Definition at line 78 of file train.cpp.

◆ NormalizeDense

bool TrainingProgram::NormalizeDense = false
private

Definition at line 68 of file cascade.cpp.

◆ NormalizeSparse

bool TrainingProgram::NormalizeSparse = false
private

Definition at line 67 of file cascade.cpp.

◆ NumLabels

int TrainingProgram::NumLabels = -1
private

Definition at line 48 of file cascade.cpp.

◆ NumLabelsOpt

CLI::Option * TrainingProgram::NumLabelsOpt
private

Definition at line 56 of file cascade.cpp.

◆ NumThreads

long TrainingProgram::NumThreads = -1
private

Definition at line 87 of file cascade.cpp.

◆ PreTrainedOpt

CLI::Option* TrainingProgram::PreTrainedOpt
private

Definition at line 74 of file train.cpp.

◆ PropA

double TrainingProgram::PropA = 0.55
private

Definition at line 64 of file train.cpp.

◆ PropB

double TrainingProgram::PropB = 1.5
private

Definition at line 65 of file train.cpp.

◆ RegBias

bool TrainingProgram::RegBias = false
private

Definition at line 84 of file train.cpp.

◆ RegScale

real_t TrainingProgram::RegScale = 1.0
private

Definition at line 83 of file train.cpp.

◆ RegScaleDense

real_t TrainingProgram::RegScaleDense = 1.0
private

Definition at line 79 of file cascade.cpp.

◆ RegScaleSparse

real_t TrainingProgram::RegScaleSparse = 1.0
private

Definition at line 78 of file cascade.cpp.

◆ Regularizer

RegularizerType TrainingProgram::Regularizer = RegularizerType::REG_L2
private

Definition at line 82 of file train.cpp.

◆ ReorderFeatures

bool TrainingProgram::ReorderFeatures = false
private

Definition at line 38 of file train.cpp.

◆ SaveOptions

io::model::SaveOption TrainingProgram::SaveOptions
private

Definition at line 43 of file cascade.cpp.

◆ ShortlistFile

std::string TrainingProgram::ShortlistFile
private

Definition at line 38 of file cascade.cpp.

◆ SourceModel

std::filesystem::path TrainingProgram::SourceModel
private

Definition at line 73 of file train.cpp.

◆ Sparsify

real_t TrainingProgram::Sparsify = -1
private

Definition at line 86 of file train.cpp.

◆ StatsLevelFile

std::string TrainingProgram::StatsLevelFile = {}
private

Definition at line 64 of file cascade.cpp.

◆ StatsOutFile

std::string TrainingProgram::StatsOutFile = "stats.json"
private

Definition at line 63 of file cascade.cpp.

◆ TfIdfFile

std::string TrainingProgram::TfIdfFile
private

Definition at line 36 of file cascade.cpp.

◆ Timeout

long TrainingProgram::Timeout = -1
private

Definition at line 88 of file cascade.cpp.

◆ TransformSparse

DatasetTransform TrainingProgram::TransformSparse = DatasetTransform::IDENTITY
private

Definition at line 69 of file cascade.cpp.

◆ Verbose

int TrainingProgram::Verbose = 0
private

Definition at line 92 of file cascade.cpp.

◆ WeightingMode

std::string TrainingProgram::WeightingMode
private

Definition at line 63 of file train.cpp.

◆ WeightingNegFile

std::string TrainingProgram::WeightingNegFile
private

Definition at line 67 of file train.cpp.

◆ WeightingPosFile

std::string TrainingProgram::WeightingPosFile
private

Definition at line 66 of file train.cpp.


The documentation for this class was generated from the following files: