注意力机制 YOLOv8添加注意力机制_yolov8引入注意力机制-程序员宅基地

技术标签: YOLO  python  yolov8  人工智能  开发语言  

一、注意力机制介绍:

注意力机制(Attention Mechanism)是深度学习中一种重要的技术,它可以帮助模型更好地关注输入数据中的关键信息,从而提高模型的性能。注意力机制最早在自然语言处理领域的序列到序列(seq2seq)模型中得到广泛应用,后来逐渐扩展到了计算机视觉、语音识别等多个领域。 

注意力机制的基本思想是为输入数据的每个部分分配一个权重,这个权重表示该部分对于当前任务的重要程度。在自然语言处理任务中,这通常意味着对输入句子中的每个单词分配一个权重,而在计算机视觉任务中,这可能意味着为输入图像的每个像素或区域分配一个权重。

添加方法

总结:1.在conv.py加入注意力代码

           2.在__init.oy__和tasks.py引用GAM

           3.修改yaml文件

1.在conv.py代码中加入注意力代码

conv.py的路径:ultralytics-main\ultralytics\nn\modules\conv.py 

如图下所示:

在conv.py的最下面添加注意力代码:

代码如下:

#-----------注意力机制代码-----------------
import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels,c2, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

效果如图下所示:

 

 2.注册及引用GAM_Attention

__init__.py文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\modules\__init__.py

如图下:

在__init__.py文件中,在导包里面找到from .conv import和__all__,最后面添加GAM_Attention。

如图下所示:

tasks.py 文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\tasks.py

如图下:

在tasks.py文件中,在导包里面找到from ultralytics.nn.modules最后面添加GAM_Attention

如图下所示:

 在tasks.py里写入调用方式

打开tasks.py,Ctrl键+F查找n = 1(有空格)就可以找到添加的位置,如效果图:

        # """**************add Attention***************"""
        elif m in {GAM_Attention}:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

效果如图下所示:

 3.修改自己的yolov8.yaml文件:

路径如下:ultralytics-main\ultralytics\cfg\models\v8\my_yolov8.yaml

如图下所示:

 修改后的代码如下(可以直接复制到自己的yaml里面):

# Ultralytics YOLO , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8-SPPCSPC.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAM_Attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

自己修改可以根据下图(修改后的图)红色箭头是需要修改的:

 

 完成以上就可以训练了

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/2301_79152843/article/details/132555870

智能推荐

1005 继续(3n+1)猜想(python)_phython 3n+1问题-程序员宅基地

文章浏览阅读122次。卡拉兹(Callatz)猜想已经在1001中给出了描述。在这个题目里,情况稍微有些复杂。当我们验证卡拉兹猜想的时候,为了避免重复计算,可以记录下递推过程中遇到的每一个数。例如对n=3进行验证的时候,我们需要计算 3、5、8、4、2、1,则当我们对n=5、8、4、2 进行验证的时候,就可以直接判定卡拉兹猜想的真伪,而不需要重复计算,因为这 4 个数已经在验证3的时候遇到过了,我们称 5、8、4、2 是被 3“覆盖”的数。我们称一个数列中的某个数n为“关键数”,如果n不能被数列中的其他数字所覆..._phython 3n+1问题

jupyter下的python基本使用和信号处理编程_jupyterlab fft-程序员宅基地

文章浏览阅读1.5k次。jupyter下的python基本使用和信号处理编程简介:jupyter notebook是一种 Web 应用,能让用户将说明文本、数学方程、代码和可视化内容全部组合到一个易于共享的文档中。它可以直接在代码旁写出叙述性文档,而不是另外编写单独的文档。也就是它可以能将代码、文档等这一切集中到一处,让用户一目了然。实验环境:腾讯云服务器centos7一、安装jupyter notebook..._jupyterlab fft

Notepad++ 安装XML Tools插件格式化XML文件-程序员宅基地

文章浏览阅读7.8k次,点赞3次,收藏5次。Notepad++ 安装XML Tools插件格式化XML文件Ritchie_Li2022.02.06 20:37:12字数 183阅读 0编辑文章1. 打开Notepad++ 软件2. 选择插件,选择“插件管理”3. 搜索 XML Tools,找到该插件后,勾选该文件,点击“安装”在Notepad++ 中安装,如果没有成功,可以在多尝试2次,我是第3次成功的,具体原因不知,但有的电脑一次就能安装成功的。4. 安装的进入如下:5.成功之后,插件栏显示6. 格式化XML文件, 单击 "_xml tools

Linux驱动开发———imx6ull的pinctrl子系统源码分析_0x4001b8b0-程序员宅基地

文章浏览阅读1.2k次,点赞3次,收藏21次。目录前言前言 最近在配置pinctrl时,配置了引脚复用寄存器的SION位,配置如下图中的所示,0x4001b8b0中的第30位表示SION位 按照个人理解,imx6ull在设备树中配置的pinctrl节点,后面所带的值应该为配置寄存器的值,而SION位是复用寄存器的第三十位..._0x4001b8b0

TCP三次握手四次挥手及各状态解释_计算机网络中seq是什么意思-程序员宅基地

文章浏览阅读2.1k次,点赞2次,收藏5次。常说的三次握手和四次挥手的意思就是TCP建立连接和断开连接的过程下图为TCP三次握手和四次挥手的过程图状态或符号解释seq(sequence number),序列号,用来标记数据段的顺序,TCP把连接中发送的数据字节都编上一个序号,第一个字节的编号由本地随机产生ack(acknowlege number),确认号,指的是期望接收到下一个字节的编号,因此当前报文段最后一个字节的编号+1即为确认号ACK(acknowledgement),确认,当ACK=1确认号字段才有效,ACK=0确认号无效S_计算机网络中seq是什么意思

基于SpringBoot+Vue+uniapp的企业人事管理系统的详细设计和实现(源码+lw+部署文档+讲解等)-程序员宅基地

文章浏览阅读822次,点赞22次,收藏16次。博主介绍:全网粉丝15W+,CSDN特邀作者、211毕业、高级全栈开发程序员、大厂多年工作经验、码云/掘金/华为云/阿里云/InfoQ/StackOverflow/github等平台优质作者、专注于Java、小程序技术领域和毕业项目实战,以及程序定制化开发、全栈讲解、就业辅导精彩专栏 推荐订阅2023-2024年最值得选的微信小程序毕业设计选题大全:100个热门选题推荐2023-2024年最值得选的Java毕业设计选题大全:500个热门选题推荐Java精品实战案例《500套》

随便推点

Linux那些事儿之我是U盘(37)彼岸花的传说(五)_unsigned soft : 1;-程序员宅基地

文章浏览阅读4.1k次。 燕子去了,有再来的时候;杨柳枯了,有再青的时候;桃花谢了,有再开的时候;老婆离了,有再找的时候,孩子跑了,有回来的时候;煮熟的鸭子飞了,有飞回来的时候.一个函数没讲完就跳走了,有再回来的时候.其实,那些人,那些事,终究不曾远离.于是,她再一次进入我们的视野. 她就是usb_stor_control_thread().唤醒她的是来自queuecommand的up(&(us->sema)_unsigned soft : 1;

usb-serial controller d感叹号_usb serial converter驱动感叹号-程序员宅基地

文章浏览阅读4k次,点赞6次,收藏12次。2. 安装正确的驱动程序:USB-Serial设备通常需要安装驱动程序才能正常工作。这些驱动程序通常可从设备制造商的官方网站下载。请确保下载并安装与您的操作系统兼容的最新驱动程序。解决:1. 确认设备已正确连接:检查USB-Serial设备是否正确插入计算机的USB接口,并确保插头没有松动或损坏。感叹号可能是对于USB-Serial设备发生的问题或错误的表达。这可能是指设备无法被识别、驱动安装问题、通信错误等。_usb serial converter驱动感叹号

Laravel定时任务_laravel 停止schedule:run-程序员宅基地

文章浏览阅读576次。Laravel 定时任务首先:Laravel 制定定时任务很简单的!在app/console 文件夹下面,执行 php artisan make:console TestSchedule,他会生成TestSchedule.php这个文件TestSchedule.php,这个文件写你要定时执行的代码逻辑;class TestSchedule extends Command { //..._laravel 停止schedule:run

LeetCode刷题—树的遍历(前中后序、层次)_leetcod遍历一棵树-程序员宅基地

文章浏览阅读295次。此篇用于梳理二叉树的遍历方式:深度优先遍历(前、中、后序遍历)和广度优先遍历,不仅能快速领会思想和总结规律,还可以顺便刷下这些题:144,二叉树的前序遍历,medium145,二叉树的后序遍历,medium94,二叉树的中序遍历,medium102,二叉树的层序遍历,easy230,二叉搜索树中第k小的元素,medium501,二叉搜索树中的众数,easy530,二叉树搜索树的最小绝对差,easy一、二叉树的遍历有四种方式:1. 前序遍历:根-左-右2. 中序遍历:左-根-右3. 后序_leetcod遍历一棵树

查询冗余数据-程序员宅基地

文章浏览阅读402次。[code="sql"]-- 冗余数据SELECT l.* FROM t_lifeservice_orders l, (SELECT t.* FROM t_lifeservice_orders t WHERE t.orderStatus = 2 GROUP BY t.orderNum, t.orderStatu..._数据冗余查询比联合查询快多少

ACM模式输入输出攻略 | C++篇-程序员宅基地

文章浏览阅读1.4w次,点赞97次,收藏328次。本文内容干货非常非常多,从笔试面试环境的要点,到C++输入输出的具体函数,再到几乎覆盖全部情况的ACM模式写法,最后也给出了链表和二叉树的定义和输入输出。_acm模式

推荐文章

热门文章

相关标签