返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch实现FedProx联邦学习算法
  • 610
分享到

PyTorch实现FedProx联邦学习算法

2024-04-02 19:04:59 610人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

目录I. 前言III. FedProx1. 模型定义2. 服务器端3. 客户端更新IV. 完整代码I. 前言 FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化

I. 前言

FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化实验总结

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

III. FedProx

算法伪代码:

1. 模型定义

客户端的模型为一个简单的四层神经网络模型:

# -*- coding:utf-8 -*-
"""
@Time: 2022/03/03 12:23
@Author: KI
@File: model.py
@Motto: Hungry And Humble
"""
from torch import nn
class ANN(nn.Module):
    def __init__(self, args, name):
        super(ANN, self).__init__()
        self.name = name
        self.len = 0
        self.loss = 0
        self.fc1 = nn.Linear(args.input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)
    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

2. 服务器端

服务器端和FedAvg一致,即重复进行客户端采样、参数传达、参数聚合三个步骤:

# -*- coding:utf-8 -*-
"""
@Time: 2022/03/03 12:50
@Author: KI
@File: server.py
@Motto: Hungry And Humble
"""
import copy
import random
import numpy as np
import torch
from model import ANN
from client import train, test
class FedProx:
    def __init__(self, args):
        self.args = args
        self.nn = ANN(args=self.args, name='server').to(args.device)
        self.nns = []
        for i in range(self.args.K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.args.clients[i]
            self.nns.append(temp)
    def server(self):
        for t in range(self.args.r):
            print('round', t + 1, ':')
            # sampling
            m = np.max([int(self.args.C * self.args.K), 1])
            index = random.sample(range(0, self.args.K), m)  # st
            # dispatch
            self.dispatch(index)
            # local updating
            self.client_update(index, t)
            # aggregation
            self.aggregation(index)
        return self.nn
    def aggregation(self, index):
        s = 0
        for j in index:
            # nORMal
            s += self.nns[j].len
        params = {}
        for k, v in self.nns[0].named_parameters():
            params[k] = torch.zeros_like(v.data)
        for j in index:
            for k, v in self.nns[j].named_parameters():
                params[k] += v.data * (self.nns[j].len / s)
        for k, v in self.nn.named_parameters():
            v.data = params[k].data.clone()
    def dispatch(self, index):
        for j in index:
            for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
                old_params.data = new_params.data.clone()
    def client_update(self, index, global_round):  # update nn
        for k in index:
            self.nns[k] = train(self.args, self.nns[k], self.nn, global_round)
    def global_test(self):
        model = self.nn
        model.eval()
        for client in self.args.clients:
            model.name = client
            test(self.args, model)

3. 客户端更新

FedProx中客户端需要优化的函数为:

作者在FedAvg损失函数的基础上,引入了一个proximal term,我们可以称之为近端项。引入近端项后,客户端在本地训练后得到的模型参数 w将不会与初始时的服务器参数wt偏离太多。

对应的代码为:

def train(args, model, server, global_round):
    model.train()
    Dtr, Dte = nn_seq_wind(model.name, args.B)
    model.len = len(Dtr)
    global_model = copy.deepcopy(server)
    if args.weight_decay != 0:
        lr = args.lr * pow(args.weight_decay, global_round)
    else:
        lr = args.lr
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                     weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9, weight_decay=args.weight_decay)
    print('training...')
    loss_function = nn.MSELoss().to(args.device)
    loss = 0
    for epoch in range(args.E):
        for (seq, label) in Dtr:
            seq = seq.to(args.device)
            label = label.to(args.device)
            y_pred = model(seq)
            optimizer.zero_grad()
            # compute proximal_term
            proximal_term = 0.0
            for w, w_t in zip(model.parameters(), global_model.parameters()):
                proximal_term += (w - w_t).norm(2)
            loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term
            loss.backward()
            optimizer.step()
        print('epoch', epoch, ':', loss.item())
    return model

我们在原有MSE损失函数的基础上加上了一个近端项:

for w, w_t in zip(model.parameters(), global_model.parameters()):
    proximal_term += (w - w_t).norm(2)

然后再反向传播求梯度,然后优化器step更新参数。

原始论文中还提出了一个不精确解的概念:

不过值得注意的是,我并没有在原始论文的实验部分找到如何选择 γ \gamma γ的说明。查了一下资料后发现是涉及到了近端梯度下降的知识,本文代码并没有考虑不精确解,后期可能会补上。

IV. 完整代码

链接:https://pan.baidu.com/s/1hj2EOcqIUmM-C6R1cyjE5Q 

提取码:fghp 

项目结构:

其中:

  • server.py为服务器端操作。
  • client.py为客户端操作。
  • data_process.py为数据处理部分。
  • model.py为模型定义文件。
  • args.py为参数定义文件。
  • main.py为主文件,如想要运行此项目可直接运行:
python main.py

以上就是PyTorch实现FedProx的联邦学习算法的详细内容,更多关于PyTorch实现FedProx算法的资料请关注编程网其它相关文章!

--结束END--

本文标题: PyTorch实现FedProx联邦学习算法

本文链接: https://lsjlt.com/news/117896.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • PyTorch实现FedProx联邦学习算法
    目录I. 前言III. FedProx1. 模型定义2. 服务器端3. 客户端更新IV. 完整代码I. 前言 FedProx的原理请见:FedAvg联邦学习FedProx异质网络优化...
    99+
    2024-04-02
  • PyTorch怎么实现FedProx联邦学习算法
    这篇文章主要介绍了PyTorch怎么实现FedProx联邦学习算法的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇PyTorch怎么实现FedProx联邦学习算法文章都会有所收获,下面我们一起来看看吧。I. 前言...
    99+
    2023-06-30
  • PyTorch实现联邦学习的基本算法FedAvg
    目录I. 前言II. 数据介绍特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端IV. 代码实现1. 初始化2. 服务器端3. 客户端4. 测试V. 实验及结果VI....
    99+
    2024-04-02
  • FedAvg联邦学习FedProx异质网络优化实验总结
    目录前言I. FedAvgII. FedProxIII. 实验IV. 总结前言 题目: Federated Optimization for Heterogeneous Netwo...
    99+
    2024-04-02
  • 联邦学习神经网络FedAvg算法实现
    目录I. 前言II. 数据介绍1. 特征构造III. 联邦学习1. 整体框架2. 服务器端3. 客户端4. 代码实现4.1 初始化4.2 服务器端4.3 客户端4.4 测试IV. 实...
    99+
    2024-04-02
  • 使用Pytorch实现强化学习——DQN算法
    目录 一、强化学习的主要构成 二、基于python的强化学习框架 三、gym 四、DQN算法 1.经验回放 2.目标网络 五、使用pytorch实现DQN算法 1.replay memory 2.神经网络部分 3.Agent 4.模型训练...
    99+
    2023-09-24
    python 开发语言
  • 联邦学习算法介绍-FedAvg详细案例-Python代码获取
    联邦学习算法介绍-FedAvg详细案例-Python代码获取 一、联邦学习系统框架二、联邦平均算法(FedAvg)三、联邦随梯度下降算法 (FedSGD)四、差分隐私随联邦梯度下降算法 (DP...
    99+
    2023-08-31
    python 算法 机器学习
  • javascript算法学习实现代码
    排序 var len = 100000; var i; var arr = []; for(i=0; i...
    99+
    2022-11-21
    javascript 算法学习
  • pyTorch深度学习softmax实现解析
    目录用PyTorch实现linear模型模拟数据集定义模型加载数据集optimizer模型训练softmax回归模型Fashion-MNISTcross_entropy模型的实现利用...
    99+
    2024-04-02
  • Pytorch实现LSTM案例总结学习
    目录前言模型构建部分主要工作1、构建网络层、前向传播forward()2、实例化网络,定义损失函数和优化器3、训练模型、反向传播backward()4、测试模型前言 关键步骤主要分为...
    99+
    2024-04-02
  • PyTorch中如何实现自监督学习
    自监督学习是一种无需人工标注数据的学习方法,通过模型自身生成标签或目标来进行训练。在PyTorch中,可以通过以下几种方式实现自监督...
    99+
    2024-03-05
    PyTorch
  • pyTorch深入学习梯度和Linear Regression实现
    目录梯度线性回归(linear regression)模拟数据集加载数据集定义loss_function梯度 PyTorch的数据结构是tensor,它有个属性叫做requires_...
    99+
    2024-04-02
  • 如何在PyTorch中实现半监督学习
    在PyTorch中实现半监督学习可以使用一些已有的半监督学习方法,比如自训练(self-training)、伪标签(pseudo-l...
    99+
    2024-03-05
    PyTorch
  • Python实现机器学习算法的分类
    Python算法的分类 对葡萄酒数据集进行测试,由于数据集是多分类且数据的样本分布不平衡,所以直接对数据测试,效果不理想。所以使用SMOTE过采样对数据进行处理,对数据去重,去空,处...
    99+
    2024-04-02
  • python如何实现感知器学习算法
    这篇文章主要介绍python如何实现感知器学习算法,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!我们将研究一种判别式分类方法,其中直接学习评估 g(x)所需的 w 参数。我们将使用感知器学习算法。感知器学习算法很容易...
    99+
    2023-06-29
  • PyTorch深度学习实战(5)——计算机视觉基础
    PyTorch深度学习实战(5)——计算机视觉基础 0. 前言 1. 图像表示 2. 将图像转换为结构化数组 2.1 灰度图像表示 ...
    99+
    2023-09-07
    深度学习 pytorch 计算机视觉 原力计划
  • 利用PyTorch实现爬山算法
    目录0. 前言1. 使用 PyTorch 实现爬山算法1.1 爬山算法简介1.2 使用爬山算法进行 CartPole 游戏2. 改进爬山算法0. 前言 在随机搜索策略中,每个回合都是...
    99+
    2024-04-02
  • pytorch机器学习softmax回归的简洁实现
    目录初始化模型参数重新审视softmax的实现优化算法通过深度学习框架的高级API也能更方便地实现分类模型。让我们继续使用Fashion-MNIST数据集,并保持批量大小为256。 ...
    99+
    2024-04-02
  • pyTorch深度学习多层感知机的实现
    目录激活函数多层感知机的PyTorch实现激活函数 前两节实现的传送门 pyTorch深度学习softmax实现解析 pyTorch深入学习梯度和Linear Regression实...
    99+
    2024-04-02
  • PyTorch中如何实现模型的集成学习
    在PyTorch中实现模型的集成学习,可以通过以下步骤进行: 定义多个模型:首先需要定义多个不同的模型,可以是同一种模型的不同实例...
    99+
    2024-03-06
    PyTorch
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作