DiSMEC++
|
#include "matrix_types.h"
#include "fwd.h"
#include "parallel/numa.h"
#include "stats/tracked.h"
#include "spec.h"
#include <memory>
#include <filesystem>
#include <optional>
Go to the source code of this file.
Classes | |
class | dismec::init::WeightsInitializer |
Base class for all weight initializers. More... | |
class | dismec::init::WeightInitializationStrategy |
Base class for all weight init strategies. More... | |
Namespaces | |
dismec | |
Main namespace in which all types, classes, and functions are defined. | |
dismec::init | |
Functions | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_zero_initializer () |
Creates an initialization strategy that initializes all weight vectors to zero. More... | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_constant_initializer (DenseRealVector vec) |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_pretrained_initializer (std::shared_ptr< model::Model > model) |
Creates an initialization strategy that uses an already trained model to set the initial weights. More... | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_numpy_initializer (const std::filesystem::path &weights, std::optional< std::filesystem::path > biases) |
Creates an initialization strategy that uses weights loaded from a npy file. More... | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_feature_mean_initializer (std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2) |
Creates an initialization strategy based on the mean of positive and negative features. More... | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_multi_pos_mean_strategy (std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2) |
Creates an initialization strategy based on the mean of positive and negative features. More... | |
std::shared_ptr< WeightInitializationStrategy > | dismec::init::create_ova_primal_initializer (const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss) |