import os
import tensorflow as tf
import numpy as np

def mkdir(result_dir):
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    return


def load_ckpt_initialize(checkpoint_dir, sess):
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        print('loaded ' + ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)  
    return saver
 

def sobel_loss(img1,img2):
# sobel边缘检测后计算L1 loss,本来是为了防止模糊化,但似乎有一些问题,网络会去学到不存在的边缘
    edge_1 = tf.math.abs(tf.image.sobel_edges(img1))
    edge_2 = tf.math.abs(tf.image.sobel_edges(img2))
    m_1 = tf.reduce_mean(edge_1)
    m_2 = tf.reduce_mean(edge_2)
    edge_bin_1 = tf.cast(edge_1>m_1, tf.float32)
    edge_bin_2 = tf.cast(edge_2>m_2, tf.float32)
    return tf.reduce_mean(tf.math.abs(edge_bin_1-edge_bin_2))

def load_image_to_memory(image_dir):
# 将image_dir下的所有图片加载到内存,存到一个list里
    if image_dir[-1]!="/" and image_dir[-2:]!="\\":
        print("invalid dir")
        return -1
    image_pack=[]
    for i in os.listdir(image_dir):
        image_tmp = plt.imread(image_dir + i)
        image_pack.append(image_tmp)
    return image_pack

def generate_batch(train_pics,batch_size,dark_pack,gt_pack):
""" 
   :argument 
             train_pics: 图片个数 
             batch_size: int 
             dark_pack: list, appended [h,w,c] 图片(nparray), train_pics 个元素
             gt_pack: list, appended [h,w,c] 图片(nparray), train_pics个元素
   :returns
            nparray [batch_size, h,w,c] 随机截取的一个batch
"""
    input_patches = []
    gt_patches = []
    for ind in np.random.permutation(train_pics)[:batch_size]:
        imgrgb=dark_pack[ind]
        imggt = gt_pack[ind]
        
        W = imgrgb.shape[0]
        H = imgrgb.shape[1]
        ps = max(min(W//4,H//4),8)

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)

        img_feed_in = imgrgb[np.newaxis,:,:,:] 
        img_feed_gt = imggt[np.newaxis,:,:,:] 
        # random crop flip to generate patches
        input_patch = img_feed_in[:, xx:xx + ps, yy:yy + ps, :]
        gt_patch = img_feed_gt[:, xx * 2:xx * 2 + ps * 2, yy * 2:yy * 2 + ps * 2, :]
        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
        input_patch = np.minimum(input_patch, 1.0)
        input_patches.append(input_patch)
        gt_patches.append(gt_patch)
    return np.concatenate(input_patches,0),np.concatenate(gt_patches,0)

def cov(x,y):
# NHWC格式图像的协方差covariance
    mshape = x.shape
    #n,h,w,c
    x_bar = tf.reduce_mean(x, axis=[1,2,3])
    y_bar = tf.reduce_mean(y, axis=[1,2,3])
    x_bar = tf.einsum("i,jkl->ijkl",x_bar,tf.ones_like(x[0,:,:,:]))
    y_bar = tf.einsum("i,jkl->ijkl",y_bar,tf.ones_like(x[0,:,:,:]))
    return tf.reduce_mean((x-x_bar)*(y-y_bar), [1,2,3])

def cumsum(xs):
# tensorflow version "np.cumsum"
    values = tf.unstack(xs)
    out = []
    prev = tf.zeros_like(values[0])
    for val in values:
        s = prev + val
        out.append(s)
        prev = s
    result = tf.stack(out)
    return result

def piecewise_linear_fn(x, x1,x2,y1,y2):
# 分段线性插值
    return tf.where(tf.logical_or(tf.less(x,x1), tf.greater(x,x2)),
                    tf.constant(0.0,shape=[1]), 
                    y1 + (y2-y1)/(x2-x1)*(x-x1))

def count(matrix, minval, maxval):
# 计数 count(minval< matrix < maxval)
    return tf.reduce_sum(tf.where( tf.logical_and(tf.greater(matrix, minval), tf.less(matrix, maxval)), 
                         tf.ones_like(matrix, dtype=tf.float32), 
                         tf.zeros_like(matrix, dtype=tf.float32)))

def generate_opt(loss):
    lr = tf.placeholder(tf.float32)
    opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
    return opt, lr

def generate_weights(shape):
    weights = tf.Variable(tf.random_uniform(shape=shape,minval=0.0,maxval=1.0))
    return weights

标签: none

评论已关闭