使用Tensorflow进行计算机视觉研究时,常常会遇到磁盘读写瓶颈,具体表现为CPU和磁盘使用率极高,而GPU使用率很低。对于这种IO瓶颈,可以使用NVIDIA开源的DALI库进行数据读写加速,以下是安装和使用教程。

安装(需要TensorFlow 版本1.7或更高,CUDA9.0或更高)

CUDA9.0:

pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/9.0 nvidia-dali-tf-plugin

CUDA10.0:

pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/10.0 nvidia-dali-tf-plugin

例子:

class SimplePipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, img_dir):
        super(SimplePipeline, self).__init__(batch_size, num_threads, device_id, seed = 12)
        self.input = ops.FileReader(file_root = img_dir)
        self.decode = ops.ImageDecoder(device = 'mixed', output_type = types.RGB)
    def define_graph(self):
        pngs, labels = self.input()
        images = self.decode(pngs)
        return images

dark_pipe = SimplePipeline(batch_size, 1, 0, dark_img_dir)
gt_pipe = SimplePipeline(batch_size, 1, 0, gt_img_dir)
daliop = dali_tf.DALIIterator()
in_image= daliop(pipeline = dark_pipe, shapes=[[batch_size,400,600,3]],dtypes=[tf.uint8])
in_image=tf.to_float(in_image[0])/255.0

gt_image= daliop(pipeline = gt_pipe, shapes=[[batch_size,400,600,3]],dtypes=[tf.uint8])
gt_image=tf.to_float(gt_image[0])/255.0

with tf.Session() as sess:
    in_img,gt_img = sess.run([in_image, gt_image])

以上对应文件路径如下

root
    -dark_img_dir
        -images
        -labels
    -gt_img_dir
        -images
        -labels

标签: none

评论已关闭