Horizontal DenseNet
Introduction
Horizontal DenseNet model is a model obtained by building the model DenseNet proposed in the paper “Densely Connected Convolutional Networks” on the horizontal federation system, and is implemented based on the deep learning framework.
Parameter List
identity: str Federated identity of the party, should be one of “label_trainer” or “assist trainer”.
- model_info:
name:
strModel name, should be “horizontal_densenet”.- config:
num_classes:
intNumber of output classes.layers:
intDenseNet layers, support 121, 169, 201 and 264.
- input:
- trainset:
type:
strTrain dataset file type, such as “npz”.path:
strFolder path of train dataset.name:
strFile name of train dataset.
- valset:
type:
strValidation dataset file type, such as “npz”.path:
strFolder path of Validation dataset.name:
strFile name of Validation dataset.
- output:
- model:
type:
strModel output file format.path:
strFolder path of output model.name:
strFile name of output model.
- metrics:
type:
strMetrics output file format.path:
strFolder path of output metrics.header:
boolWhether to include the column name.
- evaluation:
type:
strEvaluation output file format.path:
strFolder path of output Evaluation.header:
boolWhether to include the column name.
- train_info:
device:
strDevice on which the algorithm runs, support “cpu” and specified gpu device such as cuda:0.- interaction_params
save_frequency:
intNumber of epoches of model saving interval.save_probabilities:
boolWhether to save the probability of model output.save_probabilities_bins_number:
intNumber of bins of probability histogram.write_training_prediction:
boolWhether to save the prediction of training set.write_validation_prediction:
boolWhether to save the prediction of validation set.echo_training_metrics:
boolWhether to print the metrics of training set.
- params:
global_epoch:
intGlobal training epoch.local_epoch:
intLocal training epoch of involved parties.batch_size:
intBatch size of samples in local and global process.- aggregation_config:
type:
strAggregation method, support “fedavg”, “fedprox” and “scaffold”.- encryption:
method:
strEncryption method, recommend “otp”.key_bitlength:
intKey length of one time pad encryption, support 64 and 128. 128 is recommended for better security.data_type:
strInput data type, support “torch.Tensor” and “numpy.ndarray”, depending on model data type.- key_exchange:
key_bitlength:
intBit length of paillier key, recommend to be greater than or equal to 2048.optimized:
boolWhether to use optimized method.
- csprng:
name:
strPseudo-random number generation method.method:
strCorresponding hash method.
- optimizer_config: Support optimizers and their parameters defined in PyTorch or registered by user. For example:
- SGD:
lr:
floatLearning rate.momentum:
floatMomentum.weight_decay:
floatWeight decay rate.
- lr_scheduler_config: Support lr_scheduler and their parameters defined in PyTorch or registered by user. For example:
- CosinAnnealingLR:
T_max:
intMaximum iterations.
- lossfunc_config: Support lossfunc and their parameters defined in PyTorch or registered by user. For example:
CrossEntropyLoss:
- metric_config: Support multiple metrics.
accuracy: Accuracy.
precision: Precision.
recall: Recall.
f1_score: F1 score.
auc: Area Under Curve.
ks: Kolmogorov-Smirnov (KS) Statistics.
- early_stopping:
key:
strIndicators of early stop strategy, such as “acc”.patience:
intTolerance number of early stop strategy.delta:
floatTolerance range of early stop strategy.