We demonstrate multigrid to solve a discrete Laplace equation: \(-\Delta_h u = f\) on interior grid point, \(u=0\) on boundary grid points.

First, define a function v = mg5pt(f, u) which implements one iteration of multigrid for the 5-point Laplacian with zero Dirichlet boundary conditions.

The grid size is \(h=1/n\) where \(n\) is assumed to be a power of \(2\). Grid functions are stored as numpy arrays with shape \((n+1, n+1)\), although only the interior \((n-1)\times (n-1)\) values are true unknowns, the others being zero. The inputs f and u are grid functions, with u giving the current iterate. The output is the next iterate. This routine can be called repeatedly as u = mg5pt(f, u) to perform multigrid iteration and, hopefully, converge to the solution.

In [1]:
import numpy as np

def gridfn(n):
    # create a zero grid function with shape (n+1, n+1)
    return np.zeros((n+1, n+1))
    
def mg5pt(f, u):
    n = f.shape[0] - 1  # deduce n from the shape of f
    if n == 2:
        # exact solve when there is only 1 unknown
        v = gridfn(2)
        v[1, 1] = f[1, 1]/16  # h = 1/2, so 4/h^2 = 16
    else:
        v = u.copy()
        h = 1./n
        n2 = n/2
        # presmooth with Gauss-Seidel
        for i in range(1, n):
            for j in range(1, n):
                v[i, j] = (v[i+1, j] + v[i-1, j] + v[i, j+1] + v[i, j-1] + h**2 * f[i, j])/4.
        # compute residual
        r = gridfn(n)
        r[1:n, 1:n] = f[1:n, 1:n] - (4.*v[1:n, 1:n] - v[2:n+1, 1:n] - v[0:n-1, 1:n] - v[1:n, 2:n+1] - v[1:n, 0:n-1])/h**2
        # transfer to next coarser mesh
        r2 = gridfn(n2)
        r2[1:n2, 1:n2] = ( 4.*r[2:n-1:2, 2:n-1:2] \
                     + 2.*(r[1:n-2:2, 2:n-1:2] + r[3:n:2, 2:n-1:2] + r[2:n-1:2, 1:n-2:2] + r[2:n-1:2, 3:n:2]) \
                     +    r[1:n-2:2, 1:n-2:2] + r[3:n:2, 1:n-2:2] + r[1:n-2:2, 3:n:2] + r[3:n:2, 3:n:2] ) / 16.
        # recursively apply multigrid iteration on coarse mesh with initial guess zero to estimate error
        e2 = mg5pt(r2, gridfn(n2))
        # transfer error to fine mesh
        e = gridfn(n)
        e[2:n-1:2, 2:n-1:2] = e2[1:n2, 1:n2]
        e[1:n:2, 2:n-1:2] = (e2[0:n2, 1:n2] + e2[1:n2+1, 1:n2])/2.
        e[2:n-1:2, 1:n:2] = (e2[1:n2, 0:n2] + e2[1:n2, 1:n2+1])/2.
        e[1:n:2, 1:n:2] = (e2[0:n2, 0:n2] + e2[1:n2+1, 0:n2] + e2[0:n2, 1:n2+1] + e2[1:n2+1, 1:n2+1])/4.
        # correct
        v = u + e
        # postsmooth with Gauss-Seidel in reverse order
        for i in range(n-1, 0, -1):
            for j in range(n-1, 0, -1):
                v[i, j] = (v[i+1, j] + v[i-1, j] + v[i, j+1] + v[i, j-1] + h**2 * f[i, j])/4.
     
    return v

Next we try it out. We define \(n\) (a power of 2), set up the grid, and define the rhs grid function f.

In [2]:
# set up the grid
n = 2**5
h = 1./n
x = np.linspace(0, 1, n+1)
y = np.linspace(0, 1, n+1)
[X, Y] = np.meshgrid(x, y, indexing='ij')      # (X(i,j), Y(i,j)) = (i h, j h)

# specify the function f(x,y)
def ff(x, y):
    return -190*x**2 * (1-y) + 26.4*np.exp((x-.25)**2 + (y-.25)**2)
# evaluate it to define the grid function f
f = ff(X, Y)

For comparison purposes only, we form the matrix and compute the solution with a direct solve.

In [3]:
from scipy.sparse import *
from scipy.sparse.linalg import spsolve
I = eye(n-1) # (n-1) x (n-1) identity matrix
e = np.ones(n-1) # vector (1, 1, ..., 1) of length n-1
e0 = np.ones(n-2) # vector (1, 1, ..., 1) of length n-2
A1 = diags([e0, -4*e, e0], [-1, 0, 1]) # (n-1) x (n-1) tridiagonal matrix
A2 = diags([e0, e0], [-1, 1]) # same with zeros on diagonal
A = -kronsum(A1, A2, format='csr')
b = h**2 * np.reshape(f[1:n, 1:n], (n-1)**2) # right hand side
U = spsolve(A, b)
uex = gridfn(n)
uex[1:n, 1:n] = np.reshape(U, (n-1, n-1))

Start plotting the output. Create two axes, one for direct solve and one for multigrid iterates. Plot the solution just obtained by direct solve.

In [4]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.ion()
fig = plt.figure(figsize=(13, 6))
ax1 = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122, projection='3d')
ax1.set_xlim(0., 1.)
ax1.set_ylim(0., 1.)
ax1.set_zlim(-1., 1.)
ax1.grid(False)
ax2.set_xlim(0., 1.)
ax2.set_ylim(0., 1.)
ax2.set_zlim(-1., 1.)
ax2.grid(False)
plt.show()
ax1.plot_surface(X, Y, uex, rstride=1, cstride=1, color='yellow', linewidth=0.5, antialiased=True, shade=True)
ax1.set_title('Direct solve')
Out[4]:
<matplotlib.text.Text at 0xaf40feec>

Now plot the multigrid iterates.

In [5]:
# set and plot initial guess, here random
u = gridfn(n)  # zero grid function
u[1:n, 1:n] = np.random.uniform(-1., 1., (n-1, n-1))
ax2.plot_surface(X, Y, u, rstride=1, cstride=1, color='yellow', linewidth=0.5, antialiased=True, shade=True)
ax2.set_title('Multigrid iteration {}'.format(0))
plt.draw()
niters = 10
for niter in range(niters):
    plt.waitforbuttonpress() # wait for input to continue
    u = mg5pt(f, u)
    ax2.clear()
    ax2.set_xlim(0., 1.)
    ax2.set_ylim(0., 1.)
    ax2.set_zlim(-1., 1.)
    ax2.grid(False)
    ax2.plot_surface(X, Y, u, rstride=1, cstride=1, color='yellow', linewidth=0.5, antialiased=True, shade=True)
    ax2.set_title('Multigrid iteration {}'.format(niter+1))
    plt.draw()
---------------------------------------------------------------------------
TclError                                  Traceback (most recent call last)
<ipython-input-5-8280484d1cfd> in <module>()
      7 niters = 10
      8 for niter in range(niters):
----> 9     plt.waitforbuttonpress() # wait for input to continue
     10     u = mg5pt(f, u)
     11     ax2.clear()

/usr/lib/pymodules/python2.7/matplotlib/pyplot.pyc in waitforbuttonpress(*args, **kwargs)
    586     If *timeout* is negative, does not timeout.
    587     """
--> 588     return gcf().waitforbuttonpress(*args, **kwargs)
    589 
    590 

/usr/lib/pymodules/python2.7/matplotlib/figure.pyc in waitforbuttonpress(self, timeout)
   1537 
   1538         blocking_input = BlockingKeyMouseInput(self)
-> 1539         return blocking_input(timeout=timeout)
   1540 
   1541     def get_default_bbox_extra_artists(self):

/usr/lib/pymodules/python2.7/matplotlib/blocking_input.pyc in __call__(self, timeout)
    367         """
    368         self.keyormouse = None
--> 369         BlockingInput.__call__(self, n=1, timeout=timeout)
    370 
    371         return self.keyormouse

/usr/lib/pymodules/python2.7/matplotlib/blocking_input.pyc in __call__(self, n, timeout)
    109         try:
    110             # Start event loop
--> 111             self.fig.canvas.start_event_loop(timeout=timeout)
    112         finally:  # Run even on exception like ctrl-c
    113             # Disconnect the callbacks

/usr/lib/pymodules/python2.7/matplotlib/backends/backend_tkagg.pyc in start_event_loop(self, timeout)
    500 
    501     def start_event_loop(self,timeout):
--> 502         FigureCanvasBase.start_event_loop_default(self,timeout)
    503     start_event_loop.__doc__=FigureCanvasBase.start_event_loop_default.__doc__
    504 

/usr/lib/pymodules/python2.7/matplotlib/backend_bases.pyc in start_event_loop_default(self, timeout)
   2413         self._looping = True
   2414         while self._looping and counter * timestep < timeout:
-> 2415             self.flush_events()
   2416             time.sleep(timestep)
   2417             counter += 1

/usr/lib/pymodules/python2.7/matplotlib/backends/backend_tkagg.pyc in flush_events(self)
    497 
    498     def flush_events(self):
--> 499         self._master.update()
    500 
    501     def start_event_loop(self,timeout):

/usr/lib/python2.7/lib-tk/Tkinter.pyc in update(self)
    965     def update(self):
    966         """Enter event loop until all pending events have been processed by Tcl."""
--> 967         self.tk.call('update')
    968     def update_idletasks(self):
    969         """Enter event loop until all idle callbacks have been called. This

TclError: can't invoke "update" command: application has been destroyed
/usr/lib/pymodules/python2.7/matplotlib/backend_bases.py:2407: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
  warnings.warn(str, mplDeprecation)

In []: