返回顶部
首页 > 资讯 > 精选 >pytorch实践线性模型3d源码分析
  • 515
分享到

pytorch实践线性模型3d源码分析

2023-07-06 01:07:29 515人浏览 独家记忆
摘要

这篇文章主要介绍“PyTorch实践线性模型3D源码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“pytorch实践线性模型3d源码分析”文章能帮助大家解决问题。y = wx +b通过meshg

这篇文章主要介绍“PyTorch实践线性模型3D源码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“pytorch实践线性模型3d源码分析”文章能帮助大家解决问题。

y = wx +b
通过meshgrid 得到两个二维矩阵
关键理解:
plot_surface需要的xyz是二维np数组
这里提前准备meshgrid来生产x和y需要的参数
下图的W和I即plot_surface需要xy

pytorch实践线性模型3d源码分析

Z即我们需要的权重损失
计算方式要和W,I. I的每行中内容是一样的就是y=wx+b的b是一样的

    fig = plt.figure()    ax = fig.add_axes(Axes3D(fig))    ax.plot_surface(W, I, Z=MSE_data)

总的实验代码

import matplotlib.pyplot as pltimport numpy as npfrom mpl_toolkits.mplot3d import Axes3Dclass LinearModel:    @staticmethod    def forward(w, x):        return w * x    @staticmethod    def forward_with_intercept(w, x, b):        return w * x + b    @staticmethod    def get_loss(w, x, y_origin, exp=2, b=None):        if b:            y = LinearModel.forward_with_intercept(w, x, b)        else:            y = LinearModel.forward(w, x)        return pow(y_origin - y, exp)def test_2d():    x_data = [1.0, 2.0, 3.0]    y_data = [2.0, 4.0, 6.0]    weight_data = []    MSE_data = []    # 设定实验的权重范围    for w in np.arange(0.0, 4.1, 0.1):        weight_data.append(w)        loss_total = 0        # 计算每个权重在数据集上的MSE平均平方方差        for x_val, y_val in zip(x_data, y_data):            loss_total += LinearModel.get_loss(w, x_val, y_val)        MSE_data.append(loss_total / len(x_data))    # 绘图    plt.xlabel("weight")    plt.ylabel("MSE")    plt.plot(weight_data, MSE_data)    plt.show()def test_3d():    x_data = [1.0, 2.0, 3.0]    y_data = [5.0, 8.0, 11.0]    weight_data = np.arange(0.0, 4.1, 0.1)    intercept_data = np.arange(0.0, 4.1, 0.1)    W, I = np.meshgrid(weight_data, intercept_data)    MSE_data = []    # 设定实验的权重范围 循环要先写截距的 meshgrid 的返回第二个是相当于41*41 同一行值相同 ,要在第二层循环去遍历权重    for intercept in intercept_data:        MSE_data_tmp = []        for w in weight_data:            loss_total = 0            # 计算每个权重在数据集上的MSE平均平方方差            for x_val, y_val in zip(x_data, y_data):                loss_total += LinearModel.get_loss(w, x_val, y_val, b=intercept)            MSE_data_tmp.append(loss_total / len(x_data))        MSE_data.append(MSE_data_tmp)    MSE_data = np.array(MSE_data)    fig = plt.figure()    ax = fig.add_axes(Axes3D(fig))    ax.plot_surface(W, I, Z=MSE_data)    plt.xlabel("weight")    plt.ylabel("intercept")    plt.show()if __name__ == '__main__':    test_2d()    test_3d()

关于“pytorch实践线性模型3d源码分析”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注编程网精选频道,小编每天都会为大家更新不同的知识点。

--结束END--

本文标题: pytorch实践线性模型3d源码分析

本文链接: https://lsjlt.com/news/357291.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作