返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch实现联邦学习的基本算法FedAvg
  • 180
分享到

PyTorch实现联邦学习的基本算法FedAvg

2024-04-02 19:04:59 180人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录I. 前言II. 数据介绍特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端IV. 代码实现1. 初始化2. 服务器端3. 客户端4. 测试V. 实验及结果VI.

I. 前言

在之前的一篇博客联邦学习基本算法FedAvg的代码实现中利用numpy手搭神经网络实现了FedAvg,手搭的神经网络效果已经很好了,不过这还是属于自己造轮子,建议优先使用PyTorch来实现。

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。

我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

除了电力负荷数据以外,还有一个备选数据集:风功率数据集。两个数据集通过参数type指定:type == 'load’表示负荷数据,'wind’表示风功率数据。

特征构造

用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。

对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。

各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。

III. 联邦学习

1. 整体框架

原始论文中提出的FedAvg的框架为:

在这里插入图片描述

客户端模型采用PyTorch搭建:

class ANN(nn.Module):
    def __init__(self, input_dim, name, B, E, type, lr):
        super(ANN, self).__init__()
        self.name = name
        self.B = B
        self.E = E
        self.len = 0
        self.type = type
        self.lr = lr
        self.loss = 0
        self.fc1 = nn.Linear(input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)
    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

2. 服务器端

服务器端执行以下步骤:

简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总客户端更新后的参数形成最新的全局参数。下一轮通信时,服务器端将最新的参数分发给被选中的客户端,进行下一轮更新。

3. 客户端

客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。

IV. 代码实现

1. 初始化

class FedAvg:
    def __init__(self, options):
        self.C = options['C']
        self.E = options['E']
        self.B = options['B']
        self.K = options['K']
        self.r = options['r']
        self.input_dim = options['input_dim']
        self.type = options['type']
        self.lr = options['lr']
        self.clients = options['clients']
        self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device)
        self.nns = []
        for i in range(K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.clients[i]
            self.nns.append(temp)

参数:

  • K,客户端数量,本文为10个,也就是10个地区。
  • C:选择率,每一轮通信时都只是选择C * K个客户端。
  • E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
  • B:客户端更新本地模型的参数时,本地数据集batch大小为B
  • r:服务器端和客户端一共进行r轮通信。
  • clients:客户端集合
  • type:指定数据类型,负荷预测or风功率预测。
  • lr:学习率。
  • input_dim:数据输入维度。
  • nn:全局模型。
  • nns: 客户端模型集合。

2. 服务器端

服务器端代码如下:

def server(self):
     for t in range(self.r):
          print('第', t + 1, '轮通信:')
          m = np.max([int(self.C * self.K), 1])
          # sampling
          index = random.sample(range(0, self.K), m)
          # dispatch
          self.dispatch(index)
          # local updating
          self.client_update(index)
          # aggregation
          self.aggregation(index)
     # return global model
     return self.nn

其中client_update(index):

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

aggregation(index):

def aggregation(self, index):
     s = 0
     for j in index:
          # nORMal
          s += self.nns[j].len
     params = {}
     with torch.no_grad():
          for k, v in self.nns[0].named_parameters():
               params[k] = copy.deepcopy(v)
               params[k].zero_()
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    params[k] += v * (self.nns[j].len / s)
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               v.copy_(params[k])

dispatch(index):

def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

下面对重要代码进行分析:

客户端的选择

m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)

index中存储中m个0~10间的整数,表示被选中客户端的序号。

客户端的更新

for k in index:
    self.client_update(self.nns[k])

服务器端汇总客户端模型的参数

关于模型汇总方式,可以参考一下我的另一篇文章:对FedAvg中模型聚合过程的理解。

当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。

论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:

normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。

LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。

LS:根据损失与样本数量的乘积所占的比重来决定。 将更新后的参数分发给被选中的客户端

def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

3. 客户端

客户端只需要利用本地数据来进行更新就行了:

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

其中train():

def train(ann):
    ann.train()
    # print(p)
    if ann.type == 'load':
        Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
    else:
        Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
    ann.len = len(Dtr)
    # print(len(Dtr))
    loss_function = nn.MSELoss().to(device)
    loss = 0
    optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
    for epoch in range(ann.E):
        cnt = 0
        for (seq, label) in Dtr:
            cnt += 1
            seq = seq.to(device)
            label = label.to(device)
            y_pred = ann(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch', epoch, ':', loss.item())
    return ann

4. 测试

def global_test(self):
     model = self.nn
     model.eval()
     c = clients if self.type == 'load' else clients_wind
     for client in c:
          model.name = client
          test(model)

V. 实验及结果

本次实验的参数选择为:

KCEBr
100.550505
if __name__ == '__main__':
    K, C, E, B, r = 10, 0.5, 50, 50, 5
    type = 'load'
    input_dim = 30 if type == 'load' else 28
    _client = clients if type == 'load' else clients_wind
    lr = 0.08
    options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
               'input_dim': input_dim, 'lr': lr}
    fedavg = FedAvg(options)
    fedavg.server()
    fedavg.global_test()

各个客户端单独训练(训练50轮,batch大小为50)后在本地的测试集上的表现为:

客户端编号12345678910
MAPE / %5.334.113.034.203.022.702.942.992.304.10

可以看到,由于各个客户端的数据都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。

服务器与客户端通信5轮后,服务器上的全局模型在10个客户端测试集上的表现如下所示:

客户端编号12345678910
MAPE / %6.844.543.565.113.754.474.303.903.154.58

可以看到,经过联邦学习框架得到全局模型在各个客户端上表现同样很好ÿ0c;这是因为十个地区上的数据分布类似。

给出numpy和PyTorch的对比:

客户端编号12345678910
本地5.334.113.034.203.022.702.942.992.304.10
numpy6.584.193.175.133.584.694.713.752.944.77
PyTorch6.844.543.565.113.754.474.303.903.154.58

同样本地模型的效果是最好的,PyTorch搭建的网络和numpy搭建的网络效果差不多,但推荐使用PyTorch,不要造轮子。

VI. 源码及数据

我把数据和代码放在了GitHub上:源码及数据,原创不易,下载时请随手给个follow和star,感谢!

以上就是PyTorch实现联邦学习的基本算法FedAvg的详细内容,更多关于PyTorch实现FedAvg算法的资料请关注编程网其它相关文章!

--结束END--

本文标题: PyTorch实现联邦学习的基本算法FedAvg

本文链接: https://lsjlt.com/news/117904.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • PyTorch实现联邦学习的基本算法FedAvg
    目录I. 前言II. 数据介绍特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端IV. 代码实现1. 初始化2. 服务器端3. 客户端4. 测试V. 实验及结果VI....
    99+
    2024-04-02
  • 联邦学习神经网络FedAvg算法实现
    目录I. 前言II. 数据介绍1. 特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端4. 代码实现4.1 初始化4.2 服务器端4.3 客户端4.4 测试IV. 实...
    99+
    2024-04-02
  • PyTorch实现FedProx联邦学习算法
    目录I. 前言III. FedProx1. 模型定义2. 服务器端3. 客户端更新IV. 完整代码I. 前言 FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化...
    99+
    2024-04-02
  • PyTorch怎么实现FedProx联邦学习算法
    这篇文章主要介绍了PyTorch怎么实现FedProx联邦学习算法的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇PyTorch怎么实现FedProx联邦学习算法文章都会有所收获,下面我们一起来看看吧。I. 前言...
    99+
    2023-06-30
  • PyTorch怎么实现基本算法FedAvg
    本文小编为大家详细介绍“PyTorch怎么实现基本算法FedAvg”,内容详细,步骤清晰,细节处理妥当,希望这篇“PyTorch怎么实现基本算法FedAvg”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。数据介绍联...
    99+
    2023-06-30
  • 联邦学习算法介绍-FedAvg详细案例-Python代码获取
    联邦学习算法介绍-FedAvg详细案例-Python代码获取 一、联邦学习系统框架二、联邦平均算法(FedAvg)三、联邦随梯度下降算法 (FedSGD)四、差分隐私随联邦梯度下降算法 (DP...
    99+
    2023-08-31
    python 算法 机器学习
  • FedAvg联邦学习FedProx异质网络优化实验总结
    目录前言I. FedAvgII. FedProxIII. 实验IV. 总结前言 题目: Federated Optimization for Heterogeneous Netwo...
    99+
    2024-04-02
  • 联邦学习FedAvg中模型聚合过程的理解分析
    目录问题聚合1. 聚合所有客户端2. 仅聚合被选中的客户端3. 选择问题 联邦学习原始论文中给出的FedAvg的算法框架为: 参数介绍: K 表示客户端的个数, B表示每一次本地更...
    99+
    2024-04-02
  • 使用Pytorch实现强化学习——DQN算法
    目录 一、强化学习的主要构成 二、基于python的强化学习框架 三、gym 四、DQN算法 1.经验回放 2.目标网络 五、使用pytorch实现DQN算法 1.replay memory 2.神经网络部分 3.Agent 4.模型训练...
    99+
    2023-09-24
    python 开发语言
  • PyTorch深度学习实战(5)——计算机视觉基础
    PyTorch深度学习实战(5)——计算机视觉基础 0. 前言 1. 图像表示 2. 将图像转换为结构化数组 2.1 灰度图像表示 ...
    99+
    2023-09-07
    深度学习 pytorch 计算机视觉 原力计划
  • javascript算法学习实现代码
    排序 var len = 100000; var i; var arr = []; for(i=0; i...
    99+
    2022-11-21
    javascript 算法学习
  • Python机器学习之基于Pytorch实现猫狗分类
    目录一、环境配置二、数据集的准备三、猫狗分类的实例四、实现分类预测测试五、参考资料一、环境配置 安装Anaconda 具体安装过程,请点击本文 配置Pytorch pip install -i https://...
    99+
    2022-06-02
    Pytorch实现猫狗分类 Python Pytorch
  • avaScript基础学习-基本的语法规则
    目录一、运算符二、分支语句三、循环语句四、异常的捕获与处理五、js中的this关键字六、let与const定义变量使用规则七、js中的void链接八、异步编程setTimeout九、...
    99+
    2024-04-02
  • 入门学习Go的基本语法
    目录1. 变量与常量Golang 中的标识符与关键字Golang 中的变量Golang 中的常量Golang 中的iota常量计数器2. 基本数据类型Golang 中的整型Golan...
    99+
    2024-04-02
  • 学习CSS的基本框架构建原理与实现方法
    随着互联网的快速发展,网页的设计越来越受到重视。而CSS作为网页设计的重要部分之一,其制作网页基本框架的原理和实现方法也就备受关注了。本文将通过具体代码示例讲解CSS制作网页基本框架的原理与实现方法。 一、HTML和CSS基本语...
    99+
    2024-01-16
    CSS 网页 基本框架
  • Python实现机器学习算法的分类
    Python算法的分类 对葡萄酒数据集进行测试,由于数据集是多分类且数据的样本分布不平衡,所以直接对数据测试,效果不理想。所以使用SMOTE过采样对数据进行处理,对数据去重,去空,处...
    99+
    2024-04-02
  • Golang简介与基本语法的学习
    目录一、什么是Golang?二、安装Golang三、编写Hello World程序四、基本语法4.1 变量4.2 数组和切片4.3 控制流五、并发编程一、什么是Golang? Gol...
    99+
    2023-05-16
    Golang简介 Golang基本语法
  • Python机器学习之如何基于Pytorch实现猫狗分类
    这篇文章给大家分享的是有关Python机器学习之如何基于Pytorch实现猫狗分类的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。一、环境配置安装Anaconda具体安装过程,请点击本文配置Pytorchpip&n...
    99+
    2023-06-15
  • pytorch机器学习softmax回归的简洁实现
    目录初始化模型参数重新审视softmax的实现优化算法通过深度学习框架的高级API也能更方便地实现分类模型。让我们继续使用Fashion-MNIST数据集,并保持批量大小为256。 ...
    99+
    2024-04-02
  • pyTorch深度学习多层感知机的实现
    目录激活函数多层感知机的PyTorch实现激活函数 前两节实现的传送门 pyTorch深度学习softmax实现解析 pyTorch深入学习梯度和Linear Regression实...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作