***(F)ISTA algorithmfor image deblurring using wavelets***

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pywt
from pywt._doc_utils import wavedec2_keys, draw_2d_wp_basis

In [None]:
## to embed figures in the notebook
%matplotlib inline

In [None]:
# image creation
# input image
u = plt.imread('simpson512.png')
u = u[:,:,1] # extract green channel for black and white version

# start with small images for your experimentations
n = 128 # up to 512    
i = 80 # 0 for top left corner
u = u[i:i+n,i:i+n]
nr,nc = u.shape

# add noise
sig = 0.05  # noise standard deviation
noise = np.random.rand(nr,nc)
ub = u + sig*noise


In [None]:
# display
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
axes[0].imshow(u,cmap='gray')
axes[0].set_title('original image')
axes[1].imshow(ub,cmap='gray')
axes[1].set_title('noisy image')
fig.tight_layout()



In [None]:
# Wavelet transform: Documentation:

x= ub
shape = x.shape

max_lev = 3       # how many levels of decomposition to draw
label_levels = 3  # how many levels to explicitly label on the plots

fig, axes = plt.subplots(2, 4, figsize=[14, 8])
for level in range(0, max_lev + 1):
    if level == 0:
        # show the original image before decomposition
        axes[0, 0].set_axis_off()
        axes[1, 0].imshow(x, cmap=plt.cm.gray)
        axes[1, 0].set_title('Image')
        axes[1, 0].set_axis_off()
        continue

    # plot subband boundaries of a standard DWT basis
    draw_2d_wp_basis(shape, wavedec2_keys(level), ax=axes[0, level],
                     label_levels=label_levels)
    axes[0, level].set_title('{} level\ndecomposition'.format(level))

    # compute the 2D DWT
    c = pywt.wavedec2(x, 'haar', mode='periodization', level=level)
    # normalize each coefficient array independently for better visibility
    c[0] /= np.abs(c[0]).max()
    for detail_level in range(level):
        c[detail_level + 1] = [np.abs(d)/np.abs(d).max() for d in c[detail_level + 1]]
    # show the normalized coefficients
    arr, slices = pywt.coeffs_to_array(c)
    axes[1, level].imshow(arr, cmap=plt.cm.gray)
    axes[1, level].set_title('Coefficients\n({} level)'.format(level))
    axes[1, level].set_axis_off()

plt.tight_layout()



**Question:**
- Code a soft_thresholding(x,lmbd) function that applies the soft-thresholding operator to an array:
$$
(\operatorname{prox}_{\lambda \|\cdot\|_1}(x))_i = 
\begin{cases}
0 & \text{si $x_i\in[-\lambda,\lambda]$},\\
x_i+\lambda & \text{si $x_i<\lambda$},\\
x_i-\lambda & \text{si $x_i>\lambda$}.
\end{cases}
$$
- Plot the soft-thresholding function in 1D or in 2D

In [None]:
#TODO


**Question:**
- Let $W\in\mathcal{O}_n(\mathbb{R})$ be an orthogonal matrix and let $g(x) = f(Wx)$. Show that
$$
\operatorname{prox}_{g}(x) = W^T\operatorname{prox}_{f}(Wx)
$$
- What is computed by the function prox_l1_wavelet below?


In [None]:
level = 3
def prox_l1_wavelet(x,lmbd):
    # Wavelet decomposition
    c = pywt.wavedec2(x, 'haar', mode='periodization', level=level)
    # Apply soft_thresholding to all wavelet coefficients
    c[0] = soft_thresholding(c[0],lmbd)
    for detail_level in range(level):
        c[detail_level + 1] = [soft_thresholding(d,lmbd) for d in c[detail_level + 1]]
    # Wavelet reconstruction (inverse of wavelet decompositon)
    y = pywt.waverec2(c, 'haar', mode='periodization')
    #print(y.shape)
    return(y)
    

**Question:**
- Observe that the soft-thresholding has a tendency to remove the noise. Explain why.
- Vary the value of lmbd to visualize extreme cases


In [None]:
v = prox_l1_wavelet(ub,0.02)
# display
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
axes[0].imshow(ub,cmap='gray')
axes[0].set_title('noisy image')
axes[1].imshow(v,cmap='gray')
axes[1].set_title('soft thresholded')
fig.tight_layout()


**Question:**
In what follows we want to apply the ISTA algorithm do a deblurring problem: $\min_{x\in\mathbb{R}^n} F(x)$ with
$$
F(x) = f(x) + g(x) = \frac{1}{2} \| A x - u_b\|^2 + \lambda \|W x\|_1 
$$
where $x\mapsto Ax$ is a blurring operator (obtained by a convolution with a positive kernel), and $x\mapsto W x$ is the wavelet transform. $g$ has been studied above, we now turn to the term $f(x) = \frac{1}{2} \| A x - u_b\|^2$.

- Read and test the code below.
- What is the $\ell_1$-norm of the image "gausskernel" that we denote $u_g$?
- Recall the $L^p$-$L^1$ convolution theorem and justify that $\|A u \| = \|u \ast u_g\| \leq \|u\|$. Deduce that $\|A\|=1$.
- Give the expression of $x\mapsto \nabla f(x)$ and justify that the Lipshitz constant of $x\mapsto \nabla f(x)$ is $L=1$.
- Check that $L=1$ experimentally by implementing the power method with 50 iterations to the proper linear operator.

In [None]:
# Blurring

convol  = lambda a,b:  np.real(np.fft.ifft2(np.fft.fft2(a)*np.fft.fft2(b)))
s = 5
sigblur = 4
kernel = np.zeros(2*s+1)
for t in np.arange(-s,s+1):
    kernel[t] = np.exp(-t**2/(2*sigblur**2))
kernel /= sum(kernel)

gausskernel = np.zeros(ub.shape)
for t1 in np.arange(-s,s+1):
    for t2 in np.arange(-s,s+1):
        gausskernel[t1,t2] = kernel[t1]*kernel[t2]
        
def gaussianblur(x):
    return(convol(x,gausskernel))

v = gaussianblur(u)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
axes[0].imshow(u,cmap='gray')
axes[0].set_title('original image')
axes[1].imshow(v,cmap='gray')
axes[1].set_title('burred image')
fig.tight_layout()


In [None]:
# Power method to check the Lipschitz constant
#TODO
    

In [None]:
# Deblurring + denoising: Simulation of data
# data: blur + noise of original image
sig = 0.001  # noise standard deviation
ub = gaussianblur(u) + sig*noise

# display
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
axes[0].imshow(u,cmap='gray')
axes[0].set_title('original image')
axes[1].imshow(ub,cmap='gray')
axes[1].set_title('noisy image')
fig.tight_layout()


**Question:**
- Implement ISTA algorithm below to minimize F(x) = f(x) + g(x) with the iterations:
$$
x^{(k+1)} = \operatorname{prox}_{\frac{1}{L}g}\left( x^{(k)} - \frac{1}{L}\nabla f(x^{(k)}) \right)
$$
It is assumed that proxg(x) returns $\operatorname{prox}_{\frac{1}{L}g}(x)$.

In [None]:
def ISTA(x0, f, gradf, g, proxg, L, n_iter=50):
    x = x0.copy()
    Fx = f(x)+g(x)
    Fhist = []
    Fhist.append(Fx)
    k=0
    print([0, Fx])
    while(k<n_iter):
        
        #TODO
    
    return(x, Fhist)

**Question**
- Implement all necessary input functions to run ISTA for the deblurring problem
$\min_{x\in\mathbb{R}^n} F(x)$ with
$$
F(x) = f(x) + g(x) = \frac{1}{2} \| A x - u_b\|^2 + \lambda \|W x\|_1
$$

In [None]:
# Setting operators for ISTA with Deblurring + denoising
lmbd = 1e-5
L = 1
def f(x):
    #TODO
    

def gradf(x):
    #TODO

def g(x):
    c = pywt.wavedec2(x, 'haar', mode='periodization', level=level)
    # sum the abs value of all wavelet coefficients
    s = np.sum(np.abs(c[0]))
    for detail_level in range(level):
        for d in c[detail_level + 1]:
            s += np.sum(np.abs(d))
    return(lmbd*s)

def proxg(x):
    #TODO




In [None]:
# Run ISTA
usol, Fhist = ISTA(ub, f, gradf, g, proxg, L, n_iter=400)

In [None]:
# display
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 20))
axes[0,0].imshow(u,cmap='gray')
axes[0,0].set_title('original image')
axes[0,1].imshow(ub,cmap='gray')
axes[0,1].set_title('noisy image')
axes[1,0].imshow(u,cmap='gray')
axes[1,0].set_title('original image')
axes[1,1].imshow(usol,cmap='gray')
axes[1,1].set_title('Sol ISTA')
fig.tight_layout()





**Question**:
The FISTA algorithm is a variant of the ISTA algorithm that uses a sequence of intermediate vectors $y^{(k)}$ to accelerate the convergence.

Implement the FISTA algorithm that is given by:
- Initialization: Take $y^{(1)} = x^{(0)}\in\mathbb{R}^n$, and set $t^{(1)}=1$
- Step For $k\geq 1$: Compute
$$
x^{(k)} = \operatorname{prox}_{\frac{1}{L}g}\left( y^{(k)} - \frac{1}{L}\nabla f(y^{(k)}) \right),
$$
$$
t^{(k+1)} = \frac{1+\sqrt{1+4 (t^{(k)})^2}}{2},
$$
$$
y^{(k+1)} = x^{(k)} + \left(\frac{t^{(k)}-1}{t^{(k+1)}}\right)(x^{(k)} - x^{(k-1)}).
$$
- Return the last $x^{(k)}$ value.

In [None]:
def FISTA(x0, f, gradf, g, proxg, L, n_iter=50):
    
    #TODO
    
    return(x, Fhist)

In [None]:
usolfista, Fhistfista = FISTA(ub, f, gradf, g, proxg, L, n_iter=400)


In [None]:

# display
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 20))
axes[0,0].imshow(u,cmap='gray')
axes[0,0].set_title('original image')
axes[0,1].imshow(ub,cmap='gray')
axes[0,1].set_title('noisy image')
axes[1,0].imshow(usolfista,cmap='gray')
axes[1,0].set_title('Sol FISTA')
axes[1,1].imshow(usol,cmap='gray')
axes[1,1].set_title('Sol ISTA')
fig.tight_layout()

**Question:**
- Plot the sequence of objective values for ISTA and FISTA using log scale for the y-axis for 400 iterations (or more).
- What is the FISTA iteration that matches the last ISTA objective value?

In [None]:
# plot the objective values of FISTA and ISTA in logscale

#TODO
