import numpy as np
import cv2

def zmMinFilterGray(src, r=7):
    '''最小值滤波,r是滤波器半径'''
    return cv2.erode(src, np.ones((2*r+1, 2*r+1)))                    
def guidedfilter(I, p, r, eps):
    '''引导滤波,直接参考网上的matlab代码'''
    height, width = I.shape
    m_I = cv2.boxFilter(I, -1, (r,r))
    m_p = cv2.boxFilter(p, -1, (r,r))
    m_Ip = cv2.boxFilter(I*p, -1, (r,r))
    cov_Ip = m_Ip-m_I*m_p
  
    m_II = cv2.boxFilter(I*I, -1, (r,r))
    var_I = m_II-m_I*m_I
  
    a = cov_Ip/(var_I+eps)
    b = m_p-a*m_I
  
    m_a = cv2.boxFilter(a, -1, (r,r))
    m_b = cv2.boxFilter(b, -1, (r,r))
    return m_a*I+m_b
  
def getV1(m, r, eps, w, maxV1):  #输入rgb图像,值范围[0,1]
    '''计算大气遮罩(暗通道)图像V1和光照值A, V1 = 1-t/A'''
    V1 = np.min(m,2)                                       #得到暗通道图像
    #V1 = inverse_color_01(V1)
    
    # cv2.imshow("dark channel", V1)
    # cv2.waitKey(0) 
    V1 =guidedfilter(V1, zmMinFilterGray(V1,7), r, eps)   #使用引导滤波优化
    bins = 2000
    ht = np.histogram(V1, bins)                           #计算大气光照A
    d = np.cumsum(ht[0])/float(V1.size)
    for lmax in range(bins-1, 0, -1):
        if d[lmax]<=0.999:
            break
    A  = np.mean(m,2)[V1>=ht[1][lmax]].max()
    V1 = np.minimum(V1*w, maxV1)                   #对值范围进行限制
    return V1,A

  
def deHaze(m, r=81, eps=0.001, w=0.95, maxV1=0.80):
    Y = np.zeros(m.shape)
    V1,A = getV1(m, r, eps, w, maxV1)               #得到遮罩图像和大气光照
    for k in range(3):
        Y[:,:,k] = (m[:,:,k]-V1)/(1-V1/A)           #颜色校正
    Y =  np.clip(Y, 0, 1)
    return Y

def interpo(x, middle_ratio, max_x, max_num):
# 自定义的直方图
#|      /\
#|    /    \
#|  /        \
#|/____ratio___\___
#
    if not 0<=x<=max_x:
        return 0
    elif x<middle_ratio*max_x:
        return int(max_num/(middle_ratio*max_x)*x)+1
    else:
        return int(max_num - max_num/(255-middle_ratio*max_x)*(x-middle_ratio*max_x))+1

def func(ratio):
# 使用自定义的直方图曲线func()生成直方图实例化图像
    ret = []
    for i in range(256):
        ret = ret + [i]*interpo(i, ratio, 255, 1000)
    return np.array(ret)

def find_nearest_above(my_array, target):
    diff = my_array - target
    mask = np.ma.less_equal(diff, -1)
    # We need to mask the negative differences
    # since we are looking for values above
    if np.all(mask):
        c = np.abs(diff).argmin()
        return c # returns min index of the nearest if target is greater than any value
    masked_diff = np.ma.masked_array(diff, mask)
    return masked_diff.argmin()

def hist_match(original, specified):

    oldshape = original.shape
    original = original.ravel()
    specified = specified.ravel()

    # get the set of unique pixel values and their corresponding indices and counts
    s_values, bin_idx, s_counts = np.unique(original, return_inverse=True,return_counts=True)
    t_values, t_counts = np.unique(specified, return_counts=True)
   
    # Calculate s_k for original image
    s_quantiles = np.cumsum(s_counts).astype(np.float64)
    s_quantiles /= s_quantiles[-1]
    
    # Calculate s_k for specified image
    t_quantiles = np.cumsum(t_counts).astype(np.float64)
    t_quantiles /= t_quantiles[-1]

    # Round the values
    sour = np.around(s_quantiles*255)
    temp = np.around(t_quantiles*255)
    
    # Map the rounded values
    b=[]
    for data in sour[:]:
        b.append(find_nearest_above(temp,data))
    b= np.array(b,dtype='uint8')

    return b[bin_idx].reshape(oldshape)

def hist_match_all(img):
# 对图像的每个通道都进行直方图规定化,使用func()中定义的直方图进行规定
    result = np.zeros_like(img)
    for i in range(3):
        specified = func(0.2)
        result[:,:,i] = hist_match(img[:,:,i], specified)
    return result

def image_hist(image): #画三通道图像的直方图
   color = ("blue", "green", "red")#画笔颜色的值可以为大写或小写或只写首字母或大小写混合
   for i, color in enumerate(color):
       hist = cv2.calcHist([image], [i], None, [256], [0, 256])
       plt.plot(hist, color=color)
       plt.xlim([0, 256])
   plt.show()

def normalize(img):
# 将像素最大值小于255的图像归一化到255
    return (img/np.max(img)*255).astype(np.uint8)

def hist_match_dark_prior(img, dark=False):
# 基于暗/亮通道的直方图规定化
    result = img.copy()
    for i in range(3):
        result[:,:,i] = cv2.equalizeHist(result[:,:,i])
    if dark==False:
        dark_prior = np.min(result, axis=2)
    else:
        dark_prior = np.max(result, axis=2)
    #image_hist(np.stack([dark_prior]*3, axis=2))
    for j in range(3):
        for i in range(3):
            result[:,:,i] = hist_match(result[:,:,i], dark_prior)
    result = normalize(result)

    #image_hist(result)
    return result

if __name__=="__main__":
    img = cv2.imread("filename")
    img = hist_match_all(img)
    img = dehaze255(img)
    cv2.imwrite("filename_write", img)

再放一个pytorch版的:

import torch


def torch_equalize(image):
    """Implements Equalize function from PIL using PyTorch ops based on:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352"""
    def scale_channel(im, c):
        """Scale the data in the channel to implement equalize."""
        im = im[:, :, c]
        # Compute the histogram of the image channel.
        histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
        # For the purposes of computing the step, filter out the nonzeros.
        nonzero_histo = torch.reshape(histo[histo != 0], [-1])
        step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
        def build_lut(histo, step):
            # Compute the cumulative sum, shifting by step // 2
            # and then normalization by step.
            lut = (torch.cumsum(histo, 0) + (step // 2)) // step
            # Shift lut, prepending with 0.
            lut = torch.cat([torch.zeros(1), lut[:-1]]) 
            # Clip the counts to be in range.  This is done
            # in the C code for image.point.
            return torch.clamp(lut, 0, 255)

        # If step is zero, return the original image.  Otherwise, build
        # lut from the full histogram and step and then index from it.
        if step == 0:
            result = im
        else:
            # can't index using 2d index. Have to flatten and then reshape
            result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
            result = result.reshape_as(im)
        
        return result.type(torch.uint8)

    # Assumes RGB for now.  Scales each channel independently
    # and then stacks the result.
    image = image.type(torch.float)
    s1 = scale_channel(image, 0)
    s2 = scale_channel(image, 1)
    s3 = scale_channel(image, 2)
    image = torch.stack([s1, s2, s3], 2)
    return image


def find_nearest_above(my_array, target):
    diff = my_array - target
    mask = diff <= -1
    # We need to mask the negative differences
    # since we are looking for values above
    if torch.all(mask):
        c = torch.abs(diff).argmin()
        return c # returns min index of the nearest if target is greater than any value
    masked_diff = diff.clone()
    masked_diff[mask] = 9999
    return masked_diff.argmin()


def hist_match(source, template):
    s = source.view(-1) 
    t = template.view(-1) 
    s_values, bin_idx, s_counts = torch.unique(s, return_inverse=True, return_counts=True) 
    t_values, t_counts = torch.unique(t, return_counts=True) 
    s_quantities = torch.cumsum(s_counts,0).type(torch.float)
    t_quantities = torch.cumsum(t_counts,0).type(torch.float)
    s_quantities = s_quantities/s_quantities[s_quantities.shape[0]-1]
    t_quantities = t_quantities/t_quantities[t_quantities.shape[0]-1]
    sour = (s_quantities * 255).type(torch.long) 
    temp = (t_quantities * 255).type(torch.long) 
    b = torch.zeros(sour.shape) 
    for i in range(sour.shape[0]):
        b[i] = find_nearest_above(temp, sour[i])

    s=b[bin_idx] 
    return s.view(source.shape)
    
def hist_match_dark_prior(img):
# input: img[h, w, c]
# output:res[h, w, c]
    result = img.clone()
    result = torch_equalize(result)
    dark_prior,_ = torch.min(result, axis=2)
    for i in range(3):
        result[:,:,i] = hist_match(result[:,:,i], dark_prior)
    return result

if __name__=='__main__':
    from PIL import Image
    import numpy as np

    im=Image.open("Night-schene-03.jpg")
    img=torch.from_numpy(np.array(im))
    img1=hist_match_dark_prior(img).numpy()
    im1=Image.fromarray(img1)
    im1.save('out.png')

标签: none

评论已关闭