【pytorch】给训练踩踩油门-- Pytorch 加速数据读取_训练模型 prefetch-程序员宅基地

技术标签: pytorch  

本文转载于 https://zhuanlan.zhihu.com/p/80695364

训练模型的时候有时候会发现显卡的占用一直跑不满,会很浪费,往往是因为IO瓶颈导致的训练速度降低。
本文可以从以下几个方面进行对模型加速:

一, prefetch_generator

使用 prefetch_generator 库在后台加载下一 batch 的数据。

安装:

pip install prefetch_generator

使用:

# 新建DataLoaderX类
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator

class DataLoaderX(DataLoader):

    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

然后用 DataLoaderX 替换原本的 DataLoader

提速原因

原本 PyTorch 默认的 DataLoader 会创建一些 worker 线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。
使用 prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。

二, data_prefetcher

使用 data_prefetcher 新开 cuda stream 来拷贝 tensor 到 gpu。

class DataPrefetcher():
    def __init__(self, loader, opt):
        self.loader = iter(loader)
        self.opt = opt
        self.stream = torch.cuda.Stream()
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)
        except StopIteration:
            self.batch = None
            return
        with torch.cuda.stream(self.stream):
            for k in self.batch:
                if k != 'meta':
                    self.batch[k] = self.batch[k].to(device=self.opt.device, non_blocking=True)

            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            #     self.next_input = self.next_input.float()

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        self.preload()
        return batch
# ----改造前----
for iter_id, batch in enumerate(data_loader):
    if iter_id >= num_iters:
        break
    for k in batch:
        if k != 'meta':
            batch[k] = batch[k].to(device=opt.device, non_blocking=True)
    run_step()
    
# ----改造后----
prefetcher = DataPrefetcher(data_loader, opt)
batch = prefetcher.next()
iter_id = 0
while batch is not None:
    iter_id += 1
    if iter_id >= num_iters:
        break
    run_step()
    batch = prefetcher.next()

提速原因
默认情况下,PyTorch 将所有涉及到 GPU 的操作(比如内核操作,cpu->gpu,gpu->cpu)都排入同一个 stream(default stream)中,并对同一个流的操作序列化,它们永远不会并行。要想并行,两个操作必须位于不同的 stream 中。
而前向传播位于 default stream 中,因此,要想将下一个 batch 数据的预读取(涉及 cpu->gpu)与当前 batch 的前向传播并行处理,就必须:
(1) cpu 上的数据 batch 必须 pinned;
(2)预读取操作必须在另一个 stream 上进行
上面的 data_prefetcher 类满足这两个要求。注意 dataloader 必须设置 pin_memory=True 来满足第一个条件

三, 把内存当硬盘

把数据放内存里,降低 io 延迟。

sudo mount tmpfs /path/to/your/data -t tmpfs -o size=30G

然后把数据放挂载的目录下,即可。

  • size 指定的是 tmpfs 动态大小的上限,实际大小根据实际使用情况而定;
  • 数据不一定放在物理内存中,系统根据情况,有可能放在 swap 的页面,swap 一般是在系统盘;
  • 重启或者断电后数据全部清空。

如果想系统启动时自动挂载,可以编辑 /etc/fstab,在最后添加如下内容:

mount tmpfs in /tmp/
tmpfs /tmp tmpfs size=30G 0 0

四, 设置num_worker

DataLoader 的 num_worker 如果设置太小,则不能充分利用多线程提速,如果设置太大,会造成线程阻塞,或者撑爆内存,反而导致训练变慢甚至程序崩溃。

他的大小和具体的硬件和软件都有关系,所以没有一个统一的标准,可以通过一些简单的实验来确定。

我的经验是设置成 cpu 的核心数或者 gpu 的数量比较合适。

五、 优化数据预处理

主要有两个方面:

  • 尽量简化预处理的操作,使用 numpy、opencv 等优化过的库,多多利用向量化代码,提升代码运行效率;
  • 尽量缩减数据大小,不要传输无用信息。
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/shwan_ma/article/details/103331166

智能推荐

PostgreSql 主从同步搭建_delphi查询postgresql数据构建主从关系-程序员宅基地

文章浏览阅读899次。环境操作系统:CentOS Linux release 7.6.1810 (Core)数据库版本:PostgreSQL 12.4IP:192.168.100.170 主库192.168.100.202 从库实施步骤主库创建账号同步数据postgres=# CREATE ROLE replica login replication encrypted password 'replica';CREATE ROLE主库 pg_hba.conf 文件增加备库访问控制host r_delphi查询postgresql数据构建主从关系

最近iOS开发中遇到的问题-程序员宅基地

文章浏览阅读310次。背景:最近在做一款通讯软件1.通话后并dismiss通话的controller后,视图会导致下移的情况出现,比如我的uitabbar就出现下移,通过[selfperformSelector:@selector(dismissSelfController)withObject:nilafterDelay:1.0f];使界面过段时间才被dismiss,可以解决上述问题。_ios开发中遇到的问题

SHA256withRSA密钥加签_sha256withrsa签名算法-程序员宅基地

文章浏览阅读5.6k次。银联对接的一个加密方式,sha256withrsa其实就是rsa2加密方式,完整加签代码记录以作备忘/** * 根据原文生成签名内容 * * @param string $data 原文内容 * * @return string */ private function sign($data) { $f..._sha256withrsa签名算法

展示 用户画像_从0搭建用户画像系统(二)之数据看板-程序员宅基地

文章浏览阅读309次。作者介绍酒仙桥@道明学长自如数据PM一只告别野路子,带你探索数据新世界上期我们了解了《从0搭建用户画像系统(一)之系统五大常规模块介绍》,本期将和大家分享用户画像系统中数据看板模块的一些思考。笔者之前经历多个企业级画像系统搭建,总结起来,搭建数据看板的目的不外乎两类:秀”肌肉”和“方便看数”。分享一下笔者经历过两家公司考虑增加数据看板的案例。第一家三方大数据公司,核心盈利模式是通过沉淀的海量用户数..._立体用户画像 怎么可视化呈现

常见数据库类型-程序员宅基地

文章浏览阅读9k次。数据库(DataBase,简称DB),是指可以长期存放在计算机内部的、可以进行数据管理的仓库(可以直接理解为储存数据的仓库)。数据库是依据数据结构来构建的,所以我们看到的数据是比较”条理化“的(数据库分为库、表和一条条记录)查找的速度较快数据共享不过数据库有哪些类型,估计很多人都分不清楚,目前数据库类型大致分为三种数据库类型有哪些?数据库共有3种类型,为关系数据库、非关系型数据库和键值数据库。_数据库类型

matlab程序实例光频梳,一种基于光频梳和频谱整形的任意波形发生装置及方法与流程...-程序员宅基地

文章浏览阅读2.2k次。本申请属于微波光子信号生成领域,具体涉及一种基于光频梳和频谱整形的任意波形发生装置及方法。背景技术:目前半导体激光器、光纤光学和微波天线、微波单片集成电路等光学与微波相结合的技术正在高速发展,产生了微波光子学这种结合了微波和光纤通信的交叉学科,而微波光子方法中的任意波形信号发生方法,正是基于这一学科研究,生成具有更高频率和个更高带宽的波形信号,能有效避免其在电域内合成时受限于电子器件的采样率而产生..._oeo震荡的matlab 程序

随便推点

SQL注入(二)-程序员宅基地

文章浏览阅读108次。接上一篇SQL注入(一)的内容说说防范SQL第5点 5.限制输入长度   如果在Web页面上使用文本框收集用户输入的数据,使用文本框的MaxLength属性来限制用户输入过长的字符也是一个很好的方法,因为用户的输入不够长,也就减少了贴入大量脚本的可能性。程序员可以针对需要收集的数据类型作出一个相应的限制策略。 6.URL重写技术 我们利用URL重写技术过滤一些SQ

Linux之grep命令详解_linux grep-程序员宅基地

文章浏览阅读2.6w次,点赞19次,收藏140次。注: 部分概念介绍来源于网络一、简介grep (global search regular expression(RE) and print out the line,全面搜索正则表达式并把行打印出来)是一种强大的文本搜索工具,它能使用正则表达式搜索文本,并把匹配的行打印出来。二、grep常用用法1、grep [-acinv] [--color=auto] '搜寻字符串' filename.txt选项与参数:-w :被匹配的文本只能是单词,而不能是单词中的某一部分,如文本中有liker,而我_linux grep

android开发6.0权限适配-程序员宅基地

文章浏览阅读542次。写博客只是为了方便记忆,希望自己能够坚持下去。在android开发中,如果项目的targetSdkVersion >= 23,在手机版本为android6.0+的手机上运行时,某些危险的权限需要用户授权,如果用户不同意,而直接运行某些代码,会造成程序的崩溃。 Runtime Permissions This release introduces a new permissions model,

php新特性--持续更新-程序员宅基地

文章浏览阅读54次。命名空间  在其他语言中不算新鲜事,但php是5.3.0中引入,具体定义就不复述了,其主要作用是 封装和组织相关php类 。命名空间被引入之前php主要是通过Zend方式组织代码,这种方式带来的问题是类名依赖于目录(虽然命名空间之后规范要求也要和目录一样)导致类名特别特别长,如:Zend_Cloud_DocumentService_Adapter_WindowsAzure_Query。  ...

python: numpy-- 函数 argsort 用法-程序员宅基地

文章浏览阅读6.5k次。argsort() 函数将数组的值从小到大排序后,并按照其相对应的索引值输出举例说明:一维数组>>> a = array([3,1,2])>>> argsort(a)array([1, 2, 0])二维数组>>> b = array([[1,2],[2,3]])>>> argsort(b,axis=1) #按行排序array([[0, 1], [0, 1

AUTO CAD系统变量的设置方法_cad2014标题栏怎么加入变量-程序员宅基地

文章浏览阅读1.6k次。一般情况下,我们无需对AutoCAD的系统变量值作修改和设置,取其缺省值就能正常工作。但在有特殊要求时,就必须修改相关的系统变量。如果我们能熟练地掌握一些常用系统变量的使用方法和功能,就能使我们的工作更为便利、顺畅,大大地提高我们的绘图水平和工作效率。现将一些常用的系统变量的用法和功能展示出来,供各位同仁参考。 系统变量的设置方法     在命令行Command:_中输入命令set_cad2014标题栏怎么加入变量

推荐文章

热门文章

相关标签