17 #include <nlohmann/json.hpp>
22 m_TargetFile(std::move(target_file)) {
24 m_Config = std::make_unique<nlohmann::json>();
26 std::fstream source_stream(source, std::fstream::in);
27 m_Config = std::make_unique<nlohmann::json>(nlohmann::json::parse(source_stream));
50 for (
auto&& accu : entries) {
51 for (
auto&& meta : accu.second->get_statistics_meta()) {
52 if (!accu.second->is_enabled_by_name(meta.Name))
continue;
53 std::string qualified_name = accu.first +
'.'+ meta.Name;
54 if (
m_Merged.count(qualified_name) == 0) {
55 m_Merged[qualified_name] = {meta, accu.second->get_stat(meta.Name).clone()};
58 m_Merged.at(qualified_name).Stat->merge(accu.second->get_stat(meta.Name));
68 target << std::setw(2) << result <<
"\n";
75 auto raw = stat.second.Stat->to_json();
76 if(!stat.second.Meta.Unit.empty())
77 raw[
"Unit"] = stat.second.Meta.Unit;
78 result[stat.first] = std::move(raw);
124 long pos = m_Data->num_positives(label);
135 if(m_InitWeightsCache)
136 *m_InitWeightsCache = init_weights;
138 m_InitWeightsCache = std::make_unique<DenseRealVector>(init_weights);
152 return weights - *m_InitWeightsCache;
163 auto gather = std::make_unique<DefaultGatherer>(*spec);
164 add_accu(
"result", thread, gather->get_stats());
169 std::lock_guard<std::mutex> lck{
m_Lock};
176 accumulator->provide_tags(*entry.second);
177 entry.second->provide_tags(*accumulator);
181 THROW_EXCEPTION(std::runtime_error,
"Could not emplace key {} for statistics collection", key);
185 for (
auto& entry :
m_Config->at(key).items()) {
186 if(!accumulator->has_stat(entry.key())) {
187 spdlog::warn(
"Statistics {} has been defined in json, but has not been declared", entry.key());
DefaultGatherer(const TrainingSpec &spec)
const DatasetBase * m_Data
void start_training(const DenseRealVector &init_weights) override
void start_label(label_id_t label) override
void record_result(const DenseRealVector &weights, const solvers::MinimizationResult &result) override
std::unique_ptr< DenseRealVector > m_InitWeightsCache
virtual ~ResultStatsGatherer()
This class gathers the setting-specific parts of the training process.
void add_accu(const std::string &key, thread_id_t thread, const std::shared_ptr< stats::StatisticsCollection > &accumulator)
void setup_postproc(thread_id_t thread, stats::Tracked &objective)
std::unique_ptr< nlohmann::json > m_Config
void setup_minimizer(thread_id_t thread, stats::Tracked &minimizer)
NOTE: these functions will be called concurrently.
void setup_initializer(thread_id_t thread, stats::Tracked &initializer)
nlohmann::json to_json() const
std::vector< std::unordered_map< std::string, collection_ptr_t > > m_PerThreadCollections
void setup_objective(thread_id_t thread, stats::Tracked &objective)
std::unordered_map< std::string, StatData > m_Merged
TrainingStatsGatherer(std::string source, std::string target_file)
std::unique_ptr< ResultStatsGatherer > create_results_gatherer(thread_id_t thread, const std::shared_ptr< const TrainingSpec > &spec)
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
Strong typedef for an int to signify a thread id.
A base class to be used for all types that implement some for of statistics tracking.
std::shared_ptr< StatisticsCollection > get_stats() const
Gets an ownership-sharing reference to the StatisticsCollection.
constexpr const stats::tag_id_t TAG_LABEL_ID
constexpr const stats::stat_id_t STAT_FINAL_LOSS
constexpr const stats::stat_id_t STAT_LABEL_ID
constexpr const stats::stat_id_t STAT_INIT_LOSS
constexpr const stats::stat_id_t STAT_TRAINING_SHIFT
constexpr const stats::stat_id_t STAT_WEIGHT_VECTOR
constexpr const stats::stat_id_t STAT_INIT_VECTOR
constexpr const stats::stat_id_t STAT_LABEL_FREQ
constexpr const stats::stat_id_t STAT_FINAL_GRAD
constexpr const stats::stat_id_t STAT_DURATION
constexpr const stats::stat_id_t STAT_INIT_GRAD
constexpr const stats::stat_id_t STAT_NUM_ITERS
constexpr const stats::tag_id_t TAG_LABEL_FREQ
std::unique_ptr< stats::Statistics > make_stat_from_json(const nlohmann::json &source)
Generates a stats::Statistics object based on a json configuration.
Main namespace in which all types, classes, and functions are defined.
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
float real_t
The default type for floating point values.
std::chrono::milliseconds Duration
#define THROW_EXCEPTION(exception_type,...)