#!/usr/bin/env python # coding: utf-8 # # Continuous Linear Optimal Transport Transform (CLOT) # # This tutorial will demonstrate: how to use the forward and inverse operations of the CLOT in the the *PyTransKit* package. # ## Class:: CLOT # # Continuous Linear Optimal Transport Transform. # # Parameters # ---------- # lr : float (default=0.01) # Learning rate. # momentum : float (default=0.) # Nesterov accelerated gradient descent momentum. # decay : float (default=0.) # Learning rate decay over each update. # max_iter : int (default=300) # Maximum number of iterations. # tol : float (default=0.001) # Stop iterating when change in cost function is below this threshold. # verbose : int (default=1) # Verbosity during optimization. 0=no output, 1=print cost, # 2=print all metrics. # # Attributes # ----------- # displacements_ : array, shape (2, height, width) # Displacements u. First index denotes direction: displacements_[0] is # y-displacements, and displacements_[1] is x-displacements. # transport_map_ : array, shape (2, height, width) # Transport map f. First index denotes direction: transport_map_[0] is # y-map, and transport_map_[1] is x-map. # displacements_initial_ : array, shape (2, height, width) # Initial displacements computed using the method by Haker et al. # transport_map_initial_ : array, shape (2, height, width) # Initial transport map computed using the method by Haker et al. # cost_ : list of float # Value of cost function at each iteration. # curl_ : list of float # Curl at each iteration. # # References # ---------- # [A continuous linear optimal transport approach for pattern analysis in # image datasets] # (https://www.sciencedirect.com/science/article/pii/S0031320315003507) # [Optimal mass transport for registration and warping] # (https://link.springer.com/article/10.1023/B:VISI.0000036836.66311.97) # # Functions: # -------------- # 1. Forward transform: # lot = forward(sig0, sig1) # # Inputs: # ---------------- # sig0 : array, shape (height, width) # Reference image. # sig1 : array, shape (height, width) # Signal to transform. # # Outputs: # ---------------- # lot : array, shape (2, height, width) # LOT transform of input image sig1. First index denotes direction: # lot[0] is y-LOT, and lot[1] is x-LOT. # # 2. Apply forward transport map: # sig0_recon = apply_forward_map(transport_map, sig1) # # Inputs: # ---------------- # transport_map : array, shape (2, height, width) # Forward transport map. # sig1 : array, shape (height, width) # Signal to transform. # # Outputs: # ---------------- # sig0_recon : array, shape (height, width) # Reconstructed reference signal sig0. # # 3. Apply inverse transport map: # sig1_recon = inverse(transport_map, sig0) # # Inputs: # ---------------- # transport_map : array, shape (2, height, width) # Forward transport map. Inverse is computed in this function. # sig0 : array, shape (height, width) # Reference signal. # # Outputs: # ---------------- # sig1_recon : array, shape (height, width) # Reconstructed signal sig1. # # ## Definition # The Continuous Linear Optimal Transport (CLOT) transform $\widehat s$ of a density function $s(\mathbf x)$ is defined as the optimal transport map from a reference density $s_0(\mathbf x)$ to $s(\mathbf x)$. Specifically, let $s_0(\mathbf x), s(\mathbf x)$ be positive functions defined on domains $\Omega_{s_0}, \Omega_{s}\subseteq \mathbb R^d$ respectively and such that $$\int_{\Omega_{s_0}}s_0(\mathbf x) d\mathbf x = \int_{\Omega_{s}}s(\mathbf x) d\mathbf x =1 \quad \text{(normalized)}.$$
# # Assuming that the density functions $s_0, s$ have finite second moments, there is an unique solution # to the Monge optimal transport problem:
# \begin{align} \label{Monge} # &\text{min Monge }(T)=\int_{R^d}\big|x-T(x)\big|^2s_0(x)dx, \qquad (1)\\ # &\text{s.t.}\quad \int_{B}s(\mathbf y)d\mathbf y= \int_{T^{-1}(B)} s_0(\mathbf x) d\mathbf x, \quad \text{for all open} \ B\subseteq \mathbb R^d. \qquad (2) \label{mass-preserving} # \end{align} # Any map $T$ satisfying constraint in (2) is called a transport (mass-preserving) map between $s_0$ and $s$. In particular, when $T$ is bijective and continuously differentiable, the mass-preserving constraint in (2) becomes
# \begin{equation} \label{MassCons} # s_0(x)= \big|\det \big(\nabla T(x)\big)\big|s\big(T(x)\big). # \end{equation} # # The minimizer to the above Monge problem is called an optimal transport map. Given a fixed reference density $s_0$, the LOT transform $\widehat s$ of a density function $s$ is defined to the unique optimal transport map from $s_0$ to $s$. # Moreover Brenier [1] shows that any optimal transport map can be written as the gradient of a convex function, i.e., $\widehat s = \nabla \phi$ where $\phi$ is a convex function. Following the generic approach described in [2], Kolouri et al. [3] employed an iterative algorithm minimizing (1) with constraint (2) via the gradient descent idea. # # ### References # [1] Y. Brenier. Polar factorization and monotone rearrangement of vector-valuedfunctions.Commun. Pure Appl. Math., 44(4):375–417, 1991.1
# [2] S. Haker, L. Zhu, A. Tannenbaum, and S. Angenent. Optimal mass transport forregistration and warping.Int. J. Comput. Vis., 60(4):225–240, 2004.
# [3] S. Kolouri, A. Tosun, J. Ozolek, and G. Rohde. A continuous linear optimal trans-port approach for pattern analysis in image datasets.Pattern Recognit., 51:453–462, 2016. # ## CLOT Demo # The examples will cover the following operations: # * Forward operation of the CLOT # * Apply forward map to transport $I_1$ to $I_0$ # * Apply inverse map to reconstruct $I_1$ from $I_0$ # ## Forward CLOT # ### Import necessary python packages # In[1]: import numpy as np import matplotlib.pyplot as plt # ### Read and normalize two images $I_0$ and $I_1$. # In[2]: import matplotlib.image as mpimg import sys sys.path.append('../') from pytranskit.optrans.utils import signal_to_pdf I0 = mpimg.imread('images/I0.bmp') I1 = mpimg.imread('images/I1.bmp') # Convert images to PDFs img0 = signal_to_pdf(I0, sigma=1., total=100.) img1 = signal_to_pdf(I1, sigma=1., total=100.) fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(5,10)) ax[0].imshow(img0,cmap='gray') ax[1].imshow(img1,cmap='gray') ax[0].set_title('$I_0$') ax[1].set_title('$I_1$') ax[0].axis('off') ax[1].axis('off') plt.show() # ### Compute CLOT and apply forward map # In[7]: from pytranskit.optrans.continuous.clot import CLOT from pytranskit.optrans.utils import plot_displacements2d clot = CLOT(max_iter=500, lr=1e-6, tol=1e-4,verbose=0) # calculate CLOT lot = clot.forward(img0, img1) # transport map and displacement map from I1 to I0 tmap10 = clot.transport_map_ disp = clot.displacements_ # apply forward map to transport I1 to I0 img0_recon = clot.apply_forward_map(tmap10, img1) fig, ax = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(10,20)) ax[0].imshow(img0, cmap='gray') ax[0].set_title('$I_0$') ax[1].imshow(img1, cmap='gray') ax[1].set_title('$I_1$') ax[2].imshow(img0_recon, cmap='gray') ax[2].set_title('$f^{\'}I_1\circ f$') plot_displacements2d(disp, ax=ax[3], count=20) ax[3].set_title('Displacement') plt.show() # ## Inverse CLOT # Apply inverse map on $I_0$ to reconstruct $I_1$ # In[4]: img1_recon = clot.apply_inverse_map(tmap10, img0) fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(8,15)) ax[0].imshow(img1, cmap='gray') ax[0].set_title('$I_1$') ax[1].imshow(img0, cmap='gray') ax[1].set_title('$I_0$') ax[2].imshow(img1_recon, cmap='gray') ax[2].set_title('$(f^{-1})\'I_0\circ f^{-1}$') ax[0].axis('off') ax[1].axis('off') ax[2].axis('off') plt.show() # ## Geodesic # Show points on the geodesic between $I_0$ and $I_1$ # In[5]: lot11 = clot.forward(img1, img1) tmap11 = clot.transport_map_ alpha = np.linspace(0,1,5) img_recon = [] fig, ax = plt.subplots(1, len(alpha), sharex=True, sharey=True, figsize=(10,5*len(alpha))) for i in range(len(alpha)): tmap = alpha[i]*tmap10 + (1-alpha[i])*tmap11 img_recon.append(clot.apply_forward_map(tmap, img1)) ax[i].imshow(img_recon[i],cmap='gray') ax[i].axis('off') plt.show # In[ ]: