Torch 模型 感受野可视化

前言:感受野是卷积神经网络 (CNN) 中一个重要的概念,它表示 CNN 每一层输出的特征图上的像素点在输入图像上映射的区域。感受野的大小和形状直接影响到网络对输入图像的感知范围和精度,进而调整网络结构、卷积核大小和步长等参数,以改善网络的性能。

效果:本文的实验在 torchvision.models 中的 resnet18 上进行,分别绘制了理论感受野、训练前感受野、训练后感受野

5db41ff89046413db29b9d2546c6e5b9.png

开发环境:PyTorch 1.9.0

适用模型:最大池化层使用 nn.MaxPool 而不是 torch.nn.functional.max_pool 的模型

声明:本文所使用代码不开源,觉得本文的思路可行的话,请加 QQ - 1398173074 购买 (¥40,注明来意)

商品仅包含一份 120+ 行的代码。本文所使用的代码基于 torch、matplotlib 以及其它标准库。其中包含一个名为 ReceptiveField 的类,用于绘制 CNN 的感受野

代码实现

ReceptiveField 提供了以下函数:

  • _replace:将 MaxPool (这种求最大值的操作会影响感受野的正确性) 替换为 AvgPool
  • __init__:注册前向传播的“挂钩”,用于提取目标层的特征图用于反向传播
  • _backward:前向推导图像,利用“挂钩”获取特征图,从特征图中心点反向传播梯度,进行一系列处理后将梯度图转换为感受野图
  • theoretical:结合 _backward 函数求解理论感受野,其结果经过 sum、sqrt 之后即为理论感受野的尺寸
  • effective:默认情况下结合 _backward 函数求解训练前感受野 (即随机权重的模型);给定 state_dict 时将加载权重,求解训练后的感受野
  • compare:使用 matplotlib 绘制理论感受野、训练前感受野、训练后感受野
class ReceptiveField:
    """ :param model: 需要进行可视化的模型
        :param tar_layer: 感兴趣的层, 其所输出特征图需有 4 个维度 [B, C, H, W]
        :param img_size: 测试时使用的图像尺寸
        :cvar n_sample: 生成的随机图像的数量, 详见 effective 方法"""
    n_sample = 8

    def __init__(self,
                 model: nn.Module,
                 tar_layer: Union[int, nn.Module],
                 img_size: Union[int, Tuple[int, int]],
                 in_channels: int = 3,
                 use_cuda: bool = False,
                 use_copy: bool = False): ...

    def compare(self, theoretical=True, original=True, state_dict=None, **imshow_kw):
        """ :param theoretical: 是否绘制理论感受野
            :param original: 是否绘制训练前的感受野
            :param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""

    def effective(self, state_dict=None):
        """ :param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""

    def theoretical(self, light=1.):
        """ :param light: 理论感受野的亮度 [0, 1]"""

    def _replace(self, model): ...

    def _backward(self, x): ...

在本文的示例中,对 resnet18 的 layer3 进行了可视化,并计算出理论感受野的尺寸为 211×211

if __name__ == "__main__":
    from torchvision.models import resnet18

    # Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"
    m = resnet18()

    # Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载
    state_dict = resnet18(pretrained=True).state_dict()

    # Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)
    with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:
        r.compare(state_dict=state_dict)
        # 理论感受野的尺寸
        s = round(r.theoretical().sum() ** 0.5)
        print(f"Theoretical RF: {s}×{s}")
    plt.show()

    # Step 4: 加载模型的参数
    m.load_state_dict(state_dict)

如果将 resnet18 中的某一个卷积改成空洞卷积,感受野将进一步增大到 243×243

if __name__ == "__main__":
    from torchvision.models import resnet18

    # Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"
    m = resnet18()
    print(m)
    m.layer3[1].conv1.dilation = 2
    m.layer3[1].conv1.padding = 2

    # Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载
    state_dict = resnet18(pretrained=True).state_dict()

    # Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)
    with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:
        r.compare(state_dict=state_dict)
        # 理论感受野的尺寸
        s = round(r.theoretical().sum() ** 0.5)
        print(f"Theoretical RF: {s}×{s}")
    plt.show()

    # Step 4: 加载模型的参数
    m.load_state_dict(state_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/556029.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

后端-MySQL-week11 多表查询

tips: distinct————紧跟“select”之后&#xff0c;用于去重 多表查询 概述 一对多&#xff08;多对一&#xff09; 多对多 一对一 多表查询概述 分类 连接查询 内连接 外连接 自连接 必须起别名&#xff01; 联合查询-union&#xff0c;union all 子查询 概念 分类 …

OpenMesh 极小曲面(局部迭代法)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 我们的目标是想得到一个曲率处处为0的曲面,具体操作如下所述: 二、实现代码 #define _USE_MATH_DEFINES #include

量子时代加密安全与区块链应用的未来

量子时代加密安全与区块链应用的未来 现代密码学仍然是一门相对年轻的学科&#xff0c;但其历史却显示了一种重要的模式。大多数的发展都是基于几年甚至几十年前的研究。而这种缓慢的发展速度也是有原因的&#xff0c;就像药物和疫苗在进入市场之前需要经过多年的严格测试一样&…

【Web】2022DASCTF X SU 三月春季挑战赛 题解(全)

目录 ezpop calc upgdstore ezpop 瞪眼看链子 fin#__destruct -> what#__toString -> fin.run() -> crow#__invoke -> fin#__call -> mix.get_flag() exp <?php class crow {public $v1;public $v2;}class fin {public $f1; }class what {public $a; }…

MATLAB中gurobi 运行报错与调试

问题背景如下&#xff1a;刚拿到一份MATLAB的代码&#xff0c;但是电脑第一次安装gurobi&#xff0c;在运行过程中发生了报错&#xff0c;使用断点进行调试和步进调试方法&#xff0c;最终发现&#xff0c;这个问题出在了哪一步&#xff0c;然后向了人工智能和CSDN、百度寻求答…

VScode远程连接虚拟机提示: 无法建立连接:XHR failed.问题解决方案

一问题描述 在vscode下载插件Remote-SSH远程连接虚拟机时提示无法建立连接 二.最大嫌疑原因&#xff1a; 我也是在网上找了许久&#xff0c;发现就是网络原因&#xff0c;具体不知&#xff0c;明明访问别的网页没问题&#xff0c;就是连不上&#xff0c;然后发现下载vscode的…

数据赋能(61)——要求:数据管理部门职责

“要求&#xff1a;数据管理部门职责”是作为标准的参考内容编写的。 数据管理部门职责在于以数据资源为核心&#xff0c;将原始数据转化为可被业务部门与数据服务部门有效利用的数据资源&#xff0c;以支持业务赋能的实现。 数据管理要确保数据的完整性、准确性与一致性&…

Debian12 中重新安装MSSQL 并指定服务器、数据库、数据表字段的字符排序规则和默认语言等参数

在 Linux 上配置 SQL Server 设置 - SQL Server | Microsoft Learn 零、查看sql server 服务器支持的字符排序规则 SELECT Name from sys.fn_helpcollations() where name Like Chinese% go------ Chinese_PRC_CI_AI Chinese_PRC_CI_AI_WS Chinese_PRC_CI_AI_KS Chinese_PRC_…

【工具使用】CSDN中如何给文章添加目录跳转

这里写需要添加的目录名称 一级标题二级标题三级标题 一级标题 二级标题 三级标题 文章添加标题示例&#xff1a;

YoloV8改进策略:注意力改进、Neck层改进|自研全新的Mamba注意力|即插即用,简单易懂|附结构图|检测、分割、关键点均适用(独家原创,全世界首发)

摘要 无Mamba不狂欢,本文打造基于Mamba的注意力机制。全世界首发基于Mamba的注意力啊!对Mamba感兴趣的朋友一定不要错过啊! 基于Mamba的高效注意力代码和结构图 import torch import torch.nn as nn # 导入自定义的Mamba模块 from mamba_ssm import Mamba class Eff…

MySql安装(Linux)

一、清除原来的mysql环境 在前期建议使用root用户来进行操作&#xff0c;使用 su -来切换成root用户&#xff0c;但是如果老是提示认证失败&#xff0c;那么有可能我们的root密码并没有被设置&#xff0c; 我们可以先设置root的密码 sudo passwd root 然后就可以切换了。 …

爬虫 | 基于 Python 实现有道翻译工具

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本项目旨在利用 Python 语言实现一个简单的有道翻译工具。有道翻译是一款常用的在线翻译服务&#xff0c;能够实现多种语言的互译&#xff0c;提供高质量的翻译结果。 目录 一、项目功能 二、注意事项 三、代码解析 1. 导入…

【Linux】socket编程3

欢迎来到Cefler的博客&#x1f601; &#x1f54c;博客主页&#xff1a;折纸花满衣 &#x1f3e0;个人专栏&#xff1a;题目解析 &#x1f30e;推荐文章&#xff1a;【Linux】socket套接字 前言 下面的编程代码中&#xff0c;一些socket接口需要参考【Linux】socket套接字 目录…

【C语言】冒泡排序算法详解

目录 一、算法原理二、算法分析时间复杂度空间复杂度稳定性 三、C语言实现四、Python实现 冒泡排序&#xff08;Bubble Sort&#xff09;是一种基础的排序算法。它重复地遍历要排序的数列&#xff0c;一次比较两个元素&#xff0c;如果他们的顺序错误就把他们交换过来。遍历数列…

IDEA 使用备忘录(不断更新)

IDEA 项目结构&#xff08;注意层级结构&#xff0c;新建相应结构时&#xff0c;按照以下顺序新建&#xff09;&#xff1a; project&#xff08;项目&#xff09; module&#xff08;模块&#xff09; package&#xff08;包&#xff09; class&#xff08;类&#xff09; 项…

Matlab|【免费】【sci】考虑不同充电需求的电动汽车有序充电调度方法

目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序复现sci文献《A coordinated charging scheduling method for electric vehicles considering different charging demands》&#xff0c;主要实现电动汽车协调充电调度方法&#xff0c;该方法主要有以…

【JAVA进阶篇教学】第三篇:JDK8中Stream API使用

博主打算从0-1讲解下java进阶篇教学&#xff0c;今天教学第三篇&#xff1a;JDK8中Stream API使用。 Java 8 中的 Stream API 提供了一种便捷、高效的方式来处理集合数据&#xff0c;它支持函数式编程风格的操作&#xff0c;包括过滤、映射、归约等。Stream API 可以大大简化集…

Ubuntu 22最新dockers部署redis哨兵模式,并整合spring boot的详细记录(含spring boot项目包)

dockers部署redis哨兵模式&#xff0c;并整合spring boot 环境说明相关学习博客一、在docker中安装redis1、下载dockers镜像包和redis配置文件&#xff08;主从一样&#xff09;2、编辑配置文件&#xff08;主从一样&#xff09;3、启动redis&#xff08;主从一样&#xff09;4…

4-Java方法详解

目录 Java方法详解 1、什么是方法 2、方法的定义及调用 3、方法重载 4、命令行传参 5、可变参数 6、递归 例题&#xff1a;代码实现一个计算机 Java方法详解 1、什么是方法 2、方法的定义及调用 形参&#xff1a;用来定义作用的 实参&#xff1a;实际调用传递给他的参数…

【Qt 学习笔记】Qt常用控件 | 显示类控件Progress Bar的使用及说明

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;Qt 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ Qt常用控件 | 显示类控件Progress Bar的使用及说明 文章编号&#xff…
最新文章