Horizontal Bert
Introduction
Horizontal Bert model is a model obtained by building Bert proposed in the paper “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding” on the horizontal federation system and aims to solve sentiment analysis tasks in our scenario. “bert” is implemented based on TensorFlow framework. “bert_torch” is implemented based on PyTorch framework.
Parameter List
identity: str Federated identity of the party, should be one of “label_trainer” or “assist trainer”.
- model_info:
name:
strModel name, should be “horizontal_bert” or “horizontal_bert_torch”.- config:
from_pretrained:
boolWhether to use pretrained model. Only support True.num_labels:
intNumber of output labels.
- input:
- trainset:
type:
strTrain dataset file type, such as “tsv”.path:
strFolder path of train dataset.name:
strFile name of train dataset.
- valset:
type:
strValidation dataset file type, such as “tsv”.path:
strFolder path of Validation dataset.name:
strFile name of Validation dataset.
- output:
- model:
type:
strModel output format, support “file”.path:
strFolder path of output model.name:
strFile name of output model.
- metrics:
type:
strMetrics output file format.path:
strFolder path of output metrics.header:
boolWhether to include the column name.
- evaluation:
type:
strEvaluation output file format.path:
strFolder path of output Evaluation.header:
boolWhether to include the column name.
- train_info:
device:
strDevice on which the algorithm runs, support “cpu” and specified gpu device such as cuda:0.- interaction_params
save_frequency:
intNumber of epoches of model saving interval.save_probabilities:
boolWhether to save the probability of model output.save_probabilities_bins_number:
intNumber of bins of probability histogram.write_training_prediction:
boolWhether to save the prediction of training set.write_validation_prediction:
boolWhether to save the prediction of validation set.echo_training_metrics:
boolWhether to print the metrics of training set.
- params:
global_epoch:
intGlobal training epoch.local_epoch:
intLocal training epoch of involved parties.batch_size:
intBatch size of samples in local and global process.- aggregation_config:
type:
strAggregation method, support “fedavg”, “fedprox” and “scaffold”.- encryption:
method:
strEncryption method, recommend “otp”.key_bitlength:
intKey length of one time pad encryption, support 64 and 128. 128 is recommended for better security.data_type:
strInput data type, support “numpy.ndarray” for TensorFlow and “torch.Tensor” for PyTorch.- key_exchange:
key_bitlength:
intBit length of paillier key, recommend to be greater than or equal to 2048.optimized:
boolWhether to use optimized method.
- csprng:
name:
strPseudo-random number generation method.method:
strCorresponding hash method.
- optimizer_config: Support optimizers and their parameters defined in Tensorflow, PyTorch or registered by user. For example:
- Adam:
lr:
floatLearning rate.epsilon:
floatEpsilon.clipnorm:
floatClipnorm.
- lr_scheduler_config: Support lr_scheduler and their parameters defined in Tensorflow, PyTorch or registered by user. For example:
- CosinAnnealingLR:
T_max:
intMaximum iterations.
- lossfunc_config: Support lossfunc and their parameters defined in Tensorflow, PyTorch or registered by user. For example:
SparseCategoricalCrossentropy:
- metric_config: Support multiple metrics.
accuracy: Accuracy.
precision: Precision.
recall: Recall.
f1_score: F1 score.
auc: Area Under Curve.
ks: Kolmogorov-Smirnov (KS) Statistics.
- early_stopping:
key:
strIndicators of early stop strategy, such as “acc”.patience:
intTolerance number of early stop strategy.delta:
floatTolerance range of early stop strategy.