10 #include "spdlog/spdlog.h"
11 #include "spdlog/fmt/fmt.h"
18 std::shared_ptr<const Model> model) :
19 m_Data(data), m_Model(std::move(model)), m_FeatureReplicator(m_Data->
get_features())
22 throw std::invalid_argument(
23 fmt::format(
"Mismatched number of labels between model ({}) and data ({})",
28 throw std::invalid_argument(
29 fmt::format(
"Mismatched number of features between model ({}) and data ({})",
53 visit([&](
const auto& features){
85 m_TopKValues.setConstant(-std::numeric_limits<real_t>::infinity());
88 std::vector<std::vector<long>> examples_to_labels(data->
num_examples());
90 for(
auto example :
dynamic_cast<const MultiLabelData*
>(data)->get_label_instances(label)) {
91 examples_to_labels[example].push_back(label.to_index());
106 cache.resize(chunk_size,
m_Model->num_weights());
110 cache.resize(chunk_size,
m_K);
114 cache.resize(chunk_size,
m_K);
127 for(
int i = 0; i < 4; ++i) {
140 long index_offset =
m_Model->labels_begin().to_index();
141 long last_index =
m_Model->labels_end().to_index();
148 do_prediction(begin, end, thread_id, prediction_matrix.middleRows(0, end-begin));
151 std::int64_t true_positives = 0;
152 std::int64_t num_gt_positives = 0;
153 for(
long sample = begin; sample < end; ++sample) {
158 if(gt < index_offset)
continue;
159 if(gt >= last_index)
break;
162 if(prediction_matrix.coeff(sample - begin, gt - index_offset) > 0) {
169 std::int64_t positive_prediction = 0;
170 for(
long t = 0; t < end - begin; ++t) {
171 double threshold = topk_vals.coeff(t,
m_TopKValues.cols() - 1);
174 for(
long j = 0; j < prediction_matrix.cols(); ++j)
176 real_t value = prediction_matrix.coeff(t, j);
177 if(value > 0) ++positive_prediction;
178 if(value < threshold) {
182 long index = index_offset + j;
183 for(
long k = 0; k <
m_K; ++k) {
186 if(value > topk_vals.coeff(t, k)) {
187 value = std::exchange(topk_vals.coeffRef(t, k), value);
188 index = std::exchange(topk_idx.coeffRef(t, k), index);
193 threshold = topk_vals.coeff(t, topk_vals.cols() - 1);
197 std::int64_t total = (end - begin) * prediction_matrix.cols();
198 std::int64_t true_neg = total - positive_prediction - num_gt_positives + true_positives;
long num_examples() const noexcept
Get the total number of instances, i.e. the number of rows in the feature matrix.
virtual long num_labels() const noexcept=0
long num_features() const noexcept
Get the total number of features, i.e. the number of columns in the feature matrix.
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
Strong typedef for an int to signify a thread id.
long num_tasks() const override
void run_tasks(long begin, long end, thread_id_t thread_id) override
PredictionMatrix m_Predictions
FullPredictionTaskGenerator(const DatasetBase *data, std::shared_ptr< const Model > model)
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
Base class for handling predictions.
void do_prediction(long begin, long end, thread_id_t thread_id, Eigen::Ref< PredictionMatrix > target)
Predicts the scores for a subset of the instances given by the half-open interval [begin,...
void init_thread(thread_id_t thread_id) final
Called once a thread has spun up, but before it runs its first task.
PredictionBase(const DatasetBase *data, std::shared_ptr< const Model > model)
Constructor, checks that data and model are compatible.
std::vector< std::shared_ptr< const GenericFeatureMatrix > > m_ThreadLocalFeatures
std::shared_ptr< const Model > m_Model
Model (possibly partial) for which prediction is run.
const DatasetBase * m_Data
Data on which the prediction is run.
parallel::NUMAReplicator< const GenericFeatureMatrix > m_FeatureReplicator
The NUMAReplicator that generates NUMA-local copies for the feature matrices.
void make_thread_local_features(long num_threads)
static constexpr const int TRUE_POSITIVES
IndexMatrix m_TopKIndices
void finalize() override
Called after all threads have finished their tasks.
std::vector< PredictionMatrix > m_ThreadLocalTopKValues
std::vector< PredictionMatrix > m_ThreadLocalPredictionCache
void update_model(std::shared_ptr< const Model > model)
std::vector< std::vector< long > > m_GroundTruth
std::vector< IndexMatrix > m_ThreadLocalTopKIndices
static constexpr const int TRUE_NEGATIVES
TopKPredictionTaskGenerator(const DatasetBase *data, std::shared_ptr< const Model > model, long K)
static constexpr const int FALSE_POSITIVES
static constexpr const int FALSE_NEGATIVES
void run_tasks(long begin, long end, thread_id_t thread_id) override
long num_tasks() const override
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
std::array< std::int64_t, 4 > m_ConfusionMatrix
PredictionMatrix m_TopKValues
std::vector< std::array< std::int64_t, 4 > > m_ThreadLocalConfusionMatrix
Eigen::Ref< SparseRowMajor< T > > SparseRowMajorRef
Eigen::Ref< DenseRowMajor< T > > DenseRowMajorRef
Model::FeatureMatrixIn make_matrix(const SparseFeatures &features, long begin, long end)
auto get_features(const DatasetBase &ds)
auto visit(F &&f, Variants &&... variants)
Main namespace in which all types, classes, and functions are defined.
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
float real_t
The default type for floating point values.