{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "***(F)ISTA algorithmfor image deblurring using wavelets***" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pywt\n", "from pywt._doc_utils import wavedec2_keys, draw_2d_wp_basis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## to embed figures in the notebook\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# image creation\n", "# input image\n", "u = plt.imread('simpson512.png')\n", "u = u[:,:,1] # extract green channel for black and white version\n", "\n", "# start with small images for your experimentations\n", "n = 128 # up to 512 \n", "i = 80 # 0 for top left corner\n", "u = u[i:i+n,i:i+n]\n", "nr,nc = u.shape\n", "\n", "# add noise\n", "sig = 0.05 # noise standard deviation\n", "noise = np.random.rand(nr,nc)\n", "ub = u + sig*noise\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# display\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))\n", "axes[0].imshow(u,cmap='gray')\n", "axes[0].set_title('original image')\n", "axes[1].imshow(ub,cmap='gray')\n", "axes[1].set_title('noisy image')\n", "fig.tight_layout()\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Wavelet transform: Documentation:\n", "\n", "x= ub\n", "shape = x.shape\n", "\n", "max_lev = 3 # how many levels of decomposition to draw\n", "label_levels = 3 # how many levels to explicitly label on the plots\n", "\n", "fig, axes = plt.subplots(2, 4, figsize=[14, 8])\n", "for level in range(0, max_lev + 1):\n", " if level == 0:\n", " # show the original image before decomposition\n", " axes[0, 0].set_axis_off()\n", " axes[1, 0].imshow(x, cmap=plt.cm.gray)\n", " axes[1, 0].set_title('Image')\n", " axes[1, 0].set_axis_off()\n", " continue\n", "\n", " # plot subband boundaries of a standard DWT basis\n", " draw_2d_wp_basis(shape, wavedec2_keys(level), ax=axes[0, level],\n", " label_levels=label_levels)\n", " axes[0, level].set_title('{} level\\ndecomposition'.format(level))\n", "\n", " # compute the 2D DWT\n", " c = pywt.wavedec2(x, 'haar', mode='periodization', level=level)\n", " # normalize each coefficient array independently for better visibility\n", " c[0] /= np.abs(c[0]).max()\n", " for detail_level in range(level):\n", " c[detail_level + 1] = [np.abs(d)/np.abs(d).max() for d in c[detail_level + 1]]\n", " # show the normalized coefficients\n", " arr, slices = pywt.coeffs_to_array(c)\n", " axes[1, level].imshow(arr, cmap=plt.cm.gray)\n", " axes[1, level].set_title('Coefficients\\n({} level)'.format(level))\n", " axes[1, level].set_axis_off()\n", "\n", "plt.tight_layout()\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "- Code a soft_thresholding(x,lmbd) function that applies the soft-thresholding operator to an array:\n", "$$\n", "(\\operatorname{prox}_{\\lambda \\|\\cdot\\|_1}(x))_i = \n", "\\begin{cases}\n", "0 & \\text{si $x_i\\in[-\\lambda,\\lambda]$},\\\\\n", "x_i+\\lambda & \\text{si $x_i<\\lambda$},\\\\\n", "x_i-\\lambda & \\text{si $x_i>\\lambda$}.\n", "\\end{cases}\n", "$$\n", "- Plot the soft-thresholding function in 1D or in 2D" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TODO\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "- Let $W\\in\\mathcal{O}_n(\\mathbb{R})$ be an orthogonal matrix and let $g(x) = f(Wx)$. Show that\n", "$$\n", "\\operatorname{prox}_{g}(x) = W^T\\operatorname{prox}_{f}(Wx)\n", "$$\n", "- What is computed by the function prox_l1_wavelet below?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "level = 3\n", "def prox_l1_wavelet(x,lmbd):\n", " # Wavelet decomposition\n", " c = pywt.wavedec2(x, 'haar', mode='periodization', level=level)\n", " # Apply soft_thresholding to all wavelet coefficients\n", " c[0] = soft_thresholding(c[0],lmbd)\n", " for detail_level in range(level):\n", " c[detail_level + 1] = [soft_thresholding(d,lmbd) for d in c[detail_level + 1]]\n", " # Wavelet reconstruction (inverse of wavelet decompositon)\n", " y = pywt.waverec2(c, 'haar', mode='periodization')\n", " #print(y.shape)\n", " return(y)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "- Observe that the soft-thresholding has a tendency to remove the noise. Explain why.\n", "- Vary the value of lmbd to visualize extreme cases\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = prox_l1_wavelet(ub,0.02)\n", "# display\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))\n", "axes[0].imshow(ub,cmap='gray')\n", "axes[0].set_title('noisy image')\n", "axes[1].imshow(v,cmap='gray')\n", "axes[1].set_title('soft thresholded')\n", "fig.tight_layout()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "In what follows we want to apply the ISTA algorithm do a deblurring problem: $\\min_{x\\in\\mathbb{R}^n} F(x)$ with\n", "$$\n", "F(x) = f(x) + g(x) = \\frac{1}{2} \\| A x - u_b\\|^2 + \\lambda \\|W x\\|_1 \n", "$$\n", "where $x\\mapsto Ax$ is a blurring operator (obtained by a convolution with a positive kernel), and $x\\mapsto W x$ is the wavelet transform. $g$ has been studied above, we now turn to the term $f(x) = \\frac{1}{2} \\| A x - u_b\\|^2$.\n", "\n", "- Read and test the code below.\n", "- What is the $\\ell_1$-norm of the image \"gausskernel\" that we denote $u_g$?\n", "- Recall the $L^p$-$L^1$ convolution theorem and justify that $\\|A u \\| = \\|u \\ast u_g\\| \\leq \\|u\\|$. Deduce that $\\|A\\|=1$.\n", "- Give the expression of $x\\mapsto \\nabla f(x)$ and justify that the Lipshitz constant of $x\\mapsto \\nabla f(x)$ is $L=1$.\n", "- Check that $L=1$ experimentally by implementing the power method with 50 iterations to the proper linear operator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Blurring\n", "\n", "convol = lambda a,b: np.real(np.fft.ifft2(np.fft.fft2(a)*np.fft.fft2(b)))\n", "s = 5\n", "sigblur = 4\n", "kernel = np.zeros(2*s+1)\n", "for t in np.arange(-s,s+1):\n", " kernel[t] = np.exp(-t**2/(2*sigblur**2))\n", "kernel /= sum(kernel)\n", "\n", "gausskernel = np.zeros(ub.shape)\n", "for t1 in np.arange(-s,s+1):\n", " for t2 in np.arange(-s,s+1):\n", " gausskernel[t1,t2] = kernel[t1]*kernel[t2]\n", " \n", "def gaussianblur(x):\n", " return(convol(x,gausskernel))\n", "\n", "v = gaussianblur(u)\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))\n", "axes[0].imshow(u,cmap='gray')\n", "axes[0].set_title('original image')\n", "axes[1].imshow(v,cmap='gray')\n", "axes[1].set_title('burred image')\n", "fig.tight_layout()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Power method to check the Lipschitz constant\n", "#TODO\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Deblurring + denoising: Simulation of data\n", "# data: blur + noise of original image\n", "sig = 0.001 # noise standard deviation\n", "ub = gaussianblur(u) + sig*noise\n", "\n", "# display\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))\n", "axes[0].imshow(u,cmap='gray')\n", "axes[0].set_title('original image')\n", "axes[1].imshow(ub,cmap='gray')\n", "axes[1].set_title('noisy image')\n", "fig.tight_layout()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "- Implement ISTA algorithm below to minimize F(x) = f(x) + g(x) with the iterations:\n", "$$\n", "x^{(k+1)} = \\operatorname{prox}_{\\frac{1}{L}g}\\left( x^{(k)} - \\frac{1}{L}\\nabla f(x^{(k)}) \\right)\n", "$$\n", "It is assumed that proxg(x) returns $\\operatorname{prox}_{\\frac{1}{L}g}(x)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def ISTA(x0, f, gradf, g, proxg, L, n_iter=50):\n", " x = x0.copy()\n", " Fx = f(x)+g(x)\n", " Fhist = []\n", " Fhist.append(Fx)\n", " k=0\n", " print([0, Fx])\n", " while(k