windows swin transformer训练自己的目标检测数据集_# do not use mmdet version fp16 fp16 = none-程序员宅基地

技术标签: 算法  python  深度学习  模型训练  pytorch  

主要是有几个地方的文件要修改一下

config/swin下的配置文件,我用的是mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py

img_scale主要是内存不够要改小

_base_ = [
    '../_base_/models/mask_rcnn_swin_fpn.py',
    # '../_base_/datasets/coco_instance.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

model = dict(
    backbone=dict(
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        ape=False,
        drop_path_rate=0.1,
        patch_norm=True,
        use_checkpoint=False
    ),
    neck=dict(in_channels=[96, 192, 384, 768]))

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=False),
    # dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='AutoAugment',
         policies=[
             [
                 dict(type='Resize',
                      # img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
                      #            (608, 1333), (640, 1333), (672, 1333), (704, 1333),
                      #            (736, 1333), (768, 1333), (800, 1333)],
                      img_scale=[(224,224)],
                      multiscale_mode='value',
                      keep_ratio=True)
             ],
             [
                 dict(type='Resize',
                      # img_scale=[(400, 1333), (500, 1333), (600, 1333)],
                      img_scale=[(224, 224)],
                      multiscale_mode='value',
                      keep_ratio=True),
                 dict(type='RandomCrop',
                      crop_type='absolute_range',
                      crop_size=(384, 600),
                      allow_negative_crop=True),
                 dict(type='Resize',
                      # img_scale=[(480, 1333), (512, 1333), (544, 1333),
                      #            (576, 1333), (608, 1333), (640, 1333),
                      #            (672, 1333), (704, 1333), (736, 1333),
                      #            (768, 1333), (800, 1333)],
                      img_scale=[(224, 224)],
                      multiscale_mode='value',
                      override=True,
                      keep_ratio=True)
             ]
         ]),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
    # dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))

optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
# runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
# optimizer_config = dict(
#     type="DistOptimizerHook",
#     update_interval=1,
#     grad_clip=None,
#     coalesce=True,
#     bucket_size_mb=-1,
#     use_fp16=True,
# )

configs\_base_\default_runtime.py (如果下载了官方预训练模型)

checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
# 预训练模型的路径按需修改,不用预训练模型训练的结果会很差
#load_from = 'D:/Users/Downloads/Swin-Transformer-Object-Detection/model/mask_rcnn_swin_tiny_patch4_window7_1x.pth'
resume_from = None
workflow = [('train', 1)]

configs\_base_\datasets\coco_detection.py 数据集路径,格式和coco保持一致

dataset_type = 'CocoDataset'
# data_root = 'data/coco/'
data_root = 'F:/myproject/dataset/COCO_/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    # dict(type='LoadAnnotations', with_bbox=True),
    dict(type='LoadAnnotations', with_bbox=True,with_mask=False ,with_seg=False,poly2mask=False),
    dict(type='Resize', img_scale=(224, 224), keep_ratio=True),#img_scale=(1333, 800),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(224, 224), #img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

configs\_base_\models\mask_rcnn_swin_fpn.py 注意修改类别和mask

# model settings
model = dict(
    type='MaskRCNN',
    pretrained=None,
    backbone=dict(
        type='SwinTransformer',
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        ape=False,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        use_checkpoint=False),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            # 修改为自己的数据类数
            # num_classes=80,
            num_classes=20,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        # mask_roi_extractor=dict(
        #     type='SingleRoIExtractor',
        #     roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
        #     out_channels=256,
        #     featmap_strides=[4, 8, 16, 32]),
        # mask_head=dict(
        #     type='FCNMaskHead',
        #     num_convs=4,
        #     in_channels=256,
        #     conv_out_channels=256,
        # 修改为自己的数据类数
        #     num_classes=80,
        #     loss_mask=dict(
        #         type='CrossEntropyLoss', use_mask=True, loss_weight=1.0
        #     )
        # )
    ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            # mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            # mask_thr_binary=0.5
        )))

mmdet\datasets\coco.py 替换自定义数据集的类别

class CocoDataset(CustomDataset):

    # CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #            'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #            'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #            'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #            'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #            'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #            'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #            'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #            'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #            'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #            'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #            'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #            'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #            'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')

    CLASSES = (
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    )

mmdet\core\evaluation\class_names.py中的67行也要替换自定义数据集类别信息

# def coco_classes():
#     return [
#         'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
#         'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
#         'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
#         'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
#         'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
#         'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
#         'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
#         'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
#         'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
#         'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
#         'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
#         'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
#         'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
#     ]
def coco_classes():
    return [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/athrunsunny/article/details/124470524

智能推荐

GitHub上fork项目后与作者源代码保持一致的方法_如何保证checkout的代码是一致的-程序员宅基地

文章浏览阅读5.3k次。今天在找一些javaweb的项目练练手的时候,因为有很多的项目在GitHub上,在拿的时候,看到很多的大佬都是说尽量点击star不要fork,刚开始以为是只是为了给自己涨人气的,后来看到一个人的评论才知道是为什么。 因为fork过来的项目源代码只是目前题主上传到GitHub的源代码,以后如果题主对源代码进行了修改,那么更改的源代码不会和fork过来的同时更改,也就是自己的源代码是过期的项目_如何保证checkout的代码是一致的

DNS support edns-client-subnet_edns-client-subnet机制-程序员宅基地

文章浏览阅读7.6k次。转自:http://noops.me/?p=653&utm_source=tuicool&utm_medium=referral作者: wilbur | 7,104 浏览 | 2013/06/26 4:07 下午看了2天RFC,终于让DNS支持edns-client-subnet协议,通过google dns resolver的请求,可以获取用户_edns-client-subnet机制

AD19(Altium Designer)如何显示和隐藏网络_ad显示网络连线-程序员宅基地

文章浏览阅读2.1w次,点赞6次,收藏27次。AD19(Altium Designer)1、快捷键N+H+(点击要隐藏的网络)n’h点击即可2、选择工具栏view(视图) >>> Connect(连接)>>>Show network(显示网络)显示网络和隐藏网络同理_ad显示网络连线

android phone 模块分析_startnetstatpoll-程序员宅基地

文章浏览阅读1.7k次。2010-08-30 09:323585人阅读 评论(0)收藏 举报http://hi.baidu.com/anly%5Fjun/blog/index/0Andriod Phone模块相关(总览)2010-01-30 13:501、从java端发送at命令的处理流程。2、unsolicited 消息从modem上报到java的流程。3、猫相关_startnetstatpoll

软件工程(2018)第五次团队作业-程序员宅基地

文章浏览阅读29次。1 前言经过一学期的努力,我们终于完成了全部的教学工作,大家的团队作业也应该结束了吧,接下来请各队展示一下你们的成果吧!2 题目要求各团队将自己的项目介绍写一篇博文,4~5班同学现场介绍,1~3班同学将介绍过程录制为视频,要求现场或视频介绍必须有PPT,组长负责主要介绍,但各成员都要介绍自己在项目中承担的工作,1~3班各组将视频传到视频网站中(网站各组自选),并将视频链接发布到团队博客中1..._世强的软件类的题目

MMIO 与 Port I/O-程序员宅基地

文章浏览阅读3.1k次。I/O作为CPU和外设交流的一个渠道,主要分为两种,一种是Port I/O,一种是MMIO(Memory mapping I/O)。_port i/o

随便推点

Linux下如何查看tomcat是否启动、查看tomcat启动日志_centos7 查看指定tomcat 是否启动,没有则再次启动-程序员宅基地

文章浏览阅读5.1w次,点赞6次,收藏40次。在Linux系统下,重启Tomcat使用命令的操作!1.首先,进入Tomcat下的bin目录cd /usr/local/tomcat/bin使用Tomcat关闭命令./shutdown.sh查看Tomcat是否以关闭ps -ef|grep java如果显示以下相似信息,说明Tomcat还没有关闭root 7010 1 0 Apr19 _centos7 查看指定tomcat 是否启动,没有则再次启动

视频数据帧YUV格式转RGB格式--ffmpeg转换Qt利用Image显示_ffmpeg qt转image-程序员宅基地

文章浏览阅读2.3k次,点赞2次,收藏15次。接之前帖子,自己编写算法视频格式转换CPU占用率高,且转换效率低,故在此使用ffmpeg库进行转化;将一帧数据转换rgb后使用qt中QImage显示一帧数据;#ifndef VOIDSHOWIMAGE_H#define VOIDSHOWIMAGE_H#include <QWidget>#include <QLabel>#include <QImage&..._ffmpeg qt转image

POJ 2184 Cow Exhibition(变形01背包)_y - cow exhibition-程序员宅基地

文章浏览阅读367次。题目链接: POJ 2184 Cow Exhibition 题意: 给n头牛,每头牛有两个属性:smart和fun,选出若干头牛使得这些牛的smart和fun之和最大,并且smart和与fun和均不为负。 每头牛的smart和fun可以为负。 分析: 01背包和滚动数组。 用dp[j]表示得到smart和为j时的fun和最大值。但是因为j可能为负,一开始我是用map,但是一直TLE。。_y - cow exhibition

二项分布比例的置信区间计算_二项分布计算率的置信区间-程序员宅基地

文章浏览阅读2.4w次。之前一网友遇到类似问题,特查了相关文献,归结一下方法有二。1.根据各方法的计算公式进行编辑公式,data 步就可以搞定。相关文献有:《A Comparison of Binomial Proportion Interval Estimation Methods 》、《Confidence Interval Calculation for Binomial Proportions》,这两篇文章_二项分布计算率的置信区间

idea 创建vue项目提示:TypeError: this.CliEngine is not a constructor_idea报错typeerror:this cliengine is not a constructo-程序员宅基地

文章浏览阅读1.9k次,点赞2次,收藏3次。1、选择Details2、调试窗口下,找到eslint-plugin.js,并编辑3、找到如下方法function ESLintPlugin(state) {this.filterSource = state.filterSource;this.additionalRulesDirectory = state.additionalRootDirectory;this.calcBasic..._idea报错typeerror:this cliengine is not a constructor

写出命令实现1分钟后定时关机 linux,linux下实现定时关机-程序员宅基地

文章浏览阅读572次。ZDNetChina服务器站 操作系统技巧用crontab命令就可以了,下面看一下它的详细用法。名称 : crontab使用权限 : 所有使用者使用方式 :crontab [ -u user ] filecrontab [ -u user ] { -l | -r | -e }说明 :crontab 是用来让使用者在固定时间或固定间隔执行程序之用,换句话说,也就是类似使用者的时程表。-u user ..._linux案例练习 l.计算机将在 1分钟后关机,并且会显 “this server will be sh

推荐文章

热门文章

相关标签