使用nvidia dali tensorflow plugin为Tensorflow训练加速
使用Tensorflow进行计算机视觉研究时,常常会遇到磁盘读写瓶颈,具体表现为CPU和磁盘使用率极高,而GPU使用率很低。对于这种IO瓶颈,可以使用NVIDIA开源的DALI库进行数据读写加速,以下是安装和使用教程。
安装(需要TensorFlow 版本1.7或更高,CUDA9.0或更高)
CUDA9.0:
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/9.0 nvidia-dali-tf-plugin
CUDA10.0:
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/10.0 nvidia-dali-tf-plugin
例子:
class SimplePipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, img_dir):
super(SimplePipeline, self).__init__(batch_size, num_threads, device_id, seed = 12)
self.input = ops.FileReader(file_root = img_dir)
self.decode = ops.ImageDecoder(device = 'mixed', output_type = types.RGB)
def define_graph(self):
pngs, labels = self.input()
images = self.decode(pngs)
return images
dark_pipe = SimplePipeline(batch_size, 1, 0, dark_img_dir)
gt_pipe = SimplePipeline(batch_size, 1, 0, gt_img_dir)
daliop = dali_tf.DALIIterator()
in_image= daliop(pipeline = dark_pipe, shapes=[[batch_size,400,600,3]],dtypes=[tf.uint8])
in_image=tf.to_float(in_image[0])/255.0
gt_image= daliop(pipeline = gt_pipe, shapes=[[batch_size,400,600,3]],dtypes=[tf.uint8])
gt_image=tf.to_float(gt_image[0])/255.0
with tf.Session() as sess:
in_img,gt_img = sess.run([in_image, gt_image])
以上对应文件路径如下
root
-dark_img_dir
-images
-labels
-gt_img_dir
-images
-labels
评论已关闭