DiSMEC++
initializer.h File Reference
#include "matrix_types.h"
#include "fwd.h"
#include "parallel/numa.h"
#include "stats/tracked.h"
#include "spec.h"
#include <memory>
#include <filesystem>
#include <optional>

Go to the source code of this file.

Classes

class  dismec::init::WeightsInitializer
 Base class for all weight initializers. More...
 
class  dismec::init::WeightInitializationStrategy
 Base class for all weight init strategies. More...
 

Namespaces

 dismec
 Main namespace in which all types, classes, and functions are defined.
 
 dismec::init
 

Functions

std::shared_ptr< WeightInitializationStrategydismec::init::create_zero_initializer ()
 Creates an initialization strategy that initializes all weight vectors to zero. More...
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_constant_initializer (DenseRealVector vec)
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_pretrained_initializer (std::shared_ptr< model::Model > model)
 Creates an initialization strategy that uses an already trained model to set the initial weights. More...
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_numpy_initializer (const std::filesystem::path &weights, std::optional< std::filesystem::path > biases)
 Creates an initialization strategy that uses weights loaded from a npy file. More...
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_feature_mean_initializer (std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
 Creates an initialization strategy based on the mean of positive and negative features. More...
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_multi_pos_mean_strategy (std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2)
 Creates an initialization strategy based on the mean of positive and negative features. More...
 
std::shared_ptr< WeightInitializationStrategydismec::init::create_ova_primal_initializer (const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)