本教程详述了如何利用NVIDIA FLARE构建先进的联邦学习实验,并在非独立同分布(non-IID)的CIFAR-10数据集上比较FedAvg和FedProx两种主流算法。为模拟真实环境中客户端数据标签的不均衡,实验采用狄利克雷(Dirichlet)分布来划分数据。整个流程通过NVFlare的Job API定义并启动任务,Client API则处理本地训练、模型交换及客户端与服务器的通信。
首先,需安装必要的库,如nvflare和torch等。接着,设定关键实验参数:包括客户端数量(NUM_SITES=3)、通信轮次(NUM_ROUNDS=5)、本地训练周期(LOCAL_EPOCHS=1)、控制非IID程度的Dirichlet alpha值(ALPHA=0.3)、批处理大小(BATCH_SIZE=64)和学习率(LR=0.01)。CIFAR-10数据集被下载一次,供所有模拟客户端安全复用。
核心客户端脚本定义了一个小型卷积神经网络Net,专为CIFAR-10图像分类设计,特别移除了批量归一化层,以简化FedAvg的模型状态字典。脚本中还包含关键的dirichlet_partition函数,该函数使用狄利克雷分布实现确定性的非IID标签偏斜数据划分。通过在所有客户端进程中使用相同的随机种子,确保了对全局数据划分的一致性,从而有效模拟了各客户端数据标签分布的显著差异。这一设置使得能够直接可视化FedAvg和FedProx在相同分区数据上,其全局模型精度在不同通信轮次中的演变。