直方图规定化+暗通道去雾 python
import numpy as np
import cv2
def zmMinFilterGray(src, r=7):
'''最小值滤波,r是滤波器半径'''
return cv2.erode(src, np.ones((2*r+1, 2*r+1)))
def guidedfilter(I, p, r, eps):
'''引导滤波,直接参考网上的matlab代码'''
height, width = I.shape
m_I = cv2.boxFilter(I, -1, (r,r))
m_p = cv2.boxFilter(p, -1, (r,r))
m_Ip = cv2.boxFilter(I*p, -1, (r,r))
cov_Ip = m_Ip-m_I*m_p
m_II = cv2.boxFilter(I*I, -1, (r,r))
var_I = m_II-m_I*m_I
a = cov_Ip/(var_I+eps)
b = m_p-a*m_I
m_a = cv2.boxFilter(a, -1, (r,r))
m_b = cv2.boxFilter(b, -1, (r,r))
return m_a*I+m_b
def getV1(m, r, eps, w, maxV1): #输入rgb图像,值范围[0,1]
'''计算大气遮罩(暗通道)图像V1和光照值A, V1 = 1-t/A'''
V1 = np.min(m,2) #得到暗通道图像
#V1 = inverse_color_01(V1)
# cv2.imshow("dark channel", V1)
# cv2.waitKey(0)
V1 =guidedfilter(V1, zmMinFilterGray(V1,7), r, eps) #使用引导滤波优化
bins = 2000
ht = np.histogram(V1, bins) #计算大气光照A
d = np.cumsum(ht[0])/float(V1.size)
for lmax in range(bins-1, 0, -1):
if d[lmax]<=0.999:
break
A = np.mean(m,2)[V1>=ht[1][lmax]].max()
V1 = np.minimum(V1*w, maxV1) #对值范围进行限制
return V1,A
def deHaze(m, r=81, eps=0.001, w=0.95, maxV1=0.80):
Y = np.zeros(m.shape)
V1,A = getV1(m, r, eps, w, maxV1) #得到遮罩图像和大气光照
for k in range(3):
Y[:,:,k] = (m[:,:,k]-V1)/(1-V1/A) #颜色校正
Y = np.clip(Y, 0, 1)
return Y
def interpo(x, middle_ratio, max_x, max_num):
# 自定义的直方图
#| /\
#| / \
#| / \
#|/____ratio___\___
#
if not 0<=x<=max_x:
return 0
elif x<middle_ratio*max_x:
return int(max_num/(middle_ratio*max_x)*x)+1
else:
return int(max_num - max_num/(255-middle_ratio*max_x)*(x-middle_ratio*max_x))+1
def func(ratio):
# 使用自定义的直方图曲线func()生成直方图实例化图像
ret = []
for i in range(256):
ret = ret + [i]*interpo(i, ratio, 255, 1000)
return np.array(ret)
def find_nearest_above(my_array, target):
diff = my_array - target
mask = np.ma.less_equal(diff, -1)
# We need to mask the negative differences
# since we are looking for values above
if np.all(mask):
c = np.abs(diff).argmin()
return c # returns min index of the nearest if target is greater than any value
masked_diff = np.ma.masked_array(diff, mask)
return masked_diff.argmin()
def hist_match(original, specified):
oldshape = original.shape
original = original.ravel()
specified = specified.ravel()
# get the set of unique pixel values and their corresponding indices and counts
s_values, bin_idx, s_counts = np.unique(original, return_inverse=True,return_counts=True)
t_values, t_counts = np.unique(specified, return_counts=True)
# Calculate s_k for original image
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles /= s_quantiles[-1]
# Calculate s_k for specified image
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles /= t_quantiles[-1]
# Round the values
sour = np.around(s_quantiles*255)
temp = np.around(t_quantiles*255)
# Map the rounded values
b=[]
for data in sour[:]:
b.append(find_nearest_above(temp,data))
b= np.array(b,dtype='uint8')
return b[bin_idx].reshape(oldshape)
def hist_match_all(img):
# 对图像的每个通道都进行直方图规定化,使用func()中定义的直方图进行规定
result = np.zeros_like(img)
for i in range(3):
specified = func(0.2)
result[:,:,i] = hist_match(img[:,:,i], specified)
return result
def image_hist(image): #画三通道图像的直方图
color = ("blue", "green", "red")#画笔颜色的值可以为大写或小写或只写首字母或大小写混合
for i, color in enumerate(color):
hist = cv2.calcHist([image], [i], None, [256], [0, 256])
plt.plot(hist, color=color)
plt.xlim([0, 256])
plt.show()
def normalize(img):
# 将像素最大值小于255的图像归一化到255
return (img/np.max(img)*255).astype(np.uint8)
def hist_match_dark_prior(img, dark=False):
# 基于暗/亮通道的直方图规定化
result = img.copy()
for i in range(3):
result[:,:,i] = cv2.equalizeHist(result[:,:,i])
if dark==False:
dark_prior = np.min(result, axis=2)
else:
dark_prior = np.max(result, axis=2)
#image_hist(np.stack([dark_prior]*3, axis=2))
for j in range(3):
for i in range(3):
result[:,:,i] = hist_match(result[:,:,i], dark_prior)
result = normalize(result)
#image_hist(result)
return result
if __name__=="__main__":
img = cv2.imread("filename")
img = hist_match_all(img)
img = dehaze255(img)
cv2.imwrite("filename_write", img)
再放一个pytorch版的:
import torch
def torch_equalize(image):
"""Implements Equalize function from PIL using PyTorch ops based on:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352"""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = im[:, :, c]
# Compute the histogram of the image channel.
histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
# For the purposes of computing the step, filter out the nonzeros.
nonzero_histo = torch.reshape(histo[histo != 0], [-1])
step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (torch.cumsum(histo, 0) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = torch.cat([torch.zeros(1), lut[:-1]])
# Clip the counts to be in range. This is done
# in the C code for image.point.
return torch.clamp(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
if step == 0:
result = im
else:
# can't index using 2d index. Have to flatten and then reshape
result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
result = result.reshape_as(im)
return result.type(torch.uint8)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
image = image.type(torch.float)
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = torch.stack([s1, s2, s3], 2)
return image
def find_nearest_above(my_array, target):
diff = my_array - target
mask = diff <= -1
# We need to mask the negative differences
# since we are looking for values above
if torch.all(mask):
c = torch.abs(diff).argmin()
return c # returns min index of the nearest if target is greater than any value
masked_diff = diff.clone()
masked_diff[mask] = 9999
return masked_diff.argmin()
def hist_match(source, template):
s = source.view(-1)
t = template.view(-1)
s_values, bin_idx, s_counts = torch.unique(s, return_inverse=True, return_counts=True)
t_values, t_counts = torch.unique(t, return_counts=True)
s_quantities = torch.cumsum(s_counts,0).type(torch.float)
t_quantities = torch.cumsum(t_counts,0).type(torch.float)
s_quantities = s_quantities/s_quantities[s_quantities.shape[0]-1]
t_quantities = t_quantities/t_quantities[t_quantities.shape[0]-1]
sour = (s_quantities * 255).type(torch.long)
temp = (t_quantities * 255).type(torch.long)
b = torch.zeros(sour.shape)
for i in range(sour.shape[0]):
b[i] = find_nearest_above(temp, sour[i])
s=b[bin_idx]
return s.view(source.shape)
def hist_match_dark_prior(img):
# input: img[h, w, c]
# output:res[h, w, c]
result = img.clone()
result = torch_equalize(result)
dark_prior,_ = torch.min(result, axis=2)
for i in range(3):
result[:,:,i] = hist_match(result[:,:,i], dark_prior)
return result
if __name__=='__main__':
from PIL import Image
import numpy as np
im=Image.open("Night-schene-03.jpg")
img=torch.from_numpy(np.array(im))
img1=hist_match_dark_prior(img).numpy()
im1=Image.fromarray(img1)
im1.save('out.png')
评论已关闭