| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from tempfile import TemporaryFile |
| import matplotlib as mpl |
| import matplotlib.pyplot as plt |
| from mpl_toolkits.basemap import Basemap |
| from mpl_toolkits.axes_grid1 import ImageGrid |
| import scipy.stats.mstats as mstats |
| import numpy as np |
| import numpy.ma as ma |
| |
| # Set the default colormap to coolwarm |
| mpl.rc('image', cmap='coolwarm') |
| |
| def set_cmap(name): |
| ''' |
| Sets the default colormap (eg when setting cmap=None in a function) |
| See: http://matplotlib.org/examples/pylab_examples/show_colormaps.html |
| for a list of possible colormaps. |
| Appending '_r' to a matplotlib colormap name will give you a reversed |
| version of it. |
| |
| :param name: The name of the colormap. |
| :type name: str |
| ''' |
| # The first line is redundant but it prevents the user from setting |
| # the cmap rc value improperly |
| cmap = plt.get_cmap(name) |
| mpl.rc('image', cmap=cmap.name) |
| |
| def _nice_intervals(data, nlevs): |
| ''' |
| Purpose:: |
| Calculates nice intervals between each color level for colorbars |
| and contour plots. The target minimum and maximum color levels are |
| calculated by taking the minimum and maximum of the distribution |
| after cutting off the tails to remove outliers. |
| |
| Input:: |
| data - an array of data to be plotted |
| nlevs - an int giving the target number of intervals |
| |
| Output:: |
| clevs - A list of floats for the resultant colorbar levels |
| ''' |
| # Find the min and max levels by cutting off the tails of the distribution |
| # This mitigates the influence of outliers |
| data = data.ravel() |
| mnlvl = mstats.scoreatpercentile(data, 5) |
| mxlvl = mstats.scoreatpercentile(data, 95) |
| locator = mpl.ticker.MaxNLocator(nlevs) |
| clevs = locator.tick_values(mnlvl, mxlvl) |
| |
| # Make sure the bounds of clevs are reasonable since sometimes |
| # MaxNLocator gives values outside the domain of the input data |
| clevs = clevs[(clevs >= mnlvl) & (clevs <= mxlvl)] |
| return clevs |
| |
| def _best_grid_shape(nplots, oldshape): |
| ''' |
| Purpose:: |
| Calculate a better grid shape in case the user enters more columns |
| and rows than needed to fit a given number of subplots. |
| |
| Input:: |
| nplots - an int giving the number of plots that will be made |
| oldshape - a tuple denoting the desired grid shape (nrows, ncols) for arranging |
| the subplots originally requested by the user. |
| |
| Output:: |
| newshape - the smallest possible subplot grid shape needed to fit nplots |
| ''' |
| nrows, ncols = oldshape |
| size = nrows * ncols |
| diff = size - nplots |
| if diff < 0: |
| raise ValueError('gridshape=(%d, %d): Cannot fit enough subplots for data' %(nrows, ncols)) |
| else: |
| # If the user enters an excessively large number of |
| # rows and columns for gridshape, automatically |
| # correct it so that it fits only as many plots |
| # as needed |
| while diff >= ncols: |
| nrows -= 1 |
| size = nrows * ncols |
| diff = size - nplots |
| |
| # Don't forget to remove unnecessary columns too |
| if nrows == 1: |
| ncols = nplots |
| |
| newshape = nrows, ncols |
| return newshape |
| |
| def _fig_size(gridshape, aspect=None): |
| ''' |
| Purpose:: |
| Calculates the figure dimensions from a subplot gridshape |
| |
| Input:: |
| gridshape - Tuple denoting the subplot gridshape |
| aspect - Float denoting approximate aspect ratio of each subplot |
| (width / height). Default is 8.5 / 5.5 |
| |
| Output:: |
| width - float for width of the figure in inches |
| height - float for height of the figure in inches |
| ''' |
| if aspect is None: |
| aspect = 8.5 / 5.5 |
| |
| # Default figure size is 8.5" x 5.5" if nrows == ncols |
| # and then adjusted by given aspect ratio |
| nrows, ncols = gridshape |
| if nrows >= ncols: |
| # If more rows keep width constant |
| width, height = (aspect * 5.5), 5.5 * (nrows / ncols) |
| else: |
| # If more columns keep height constant |
| width, height = (aspect * 5.5) * (ncols / nrows), 5.5 |
| |
| return width, height |
| |
| def draw_taylor_diagram(results, names, refname, fname, fmt='png', |
| gridshape=(1,1), ptitle='', subtitles=None, |
| pos='upper right', frameon=True, radmax=1.5): |
| ''' |
| Purpose:: |
| Draws a Taylor diagram |
| |
| Input:: |
| results - an Nx2 array containing normalized standard deviations, |
| correlation coefficients, and names of evaluation results |
| names - list of names for each evaluated dataset |
| refname - The name of the reference dataset |
| fname - a string specifying the filename of the plot |
| fmt - an optional string specifying the filetype, default is .png |
| gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging |
| the subplots. |
| ptitle - an optional string specifying the plot title |
| subtitles - an optional list of strings specifying the title for each subplot |
| pos - an optional string or tuple of float for determining |
| the position of the legend |
| frameon - an optional boolean that determines whether to draw a frame |
| around the legend box |
| radmax - an optional float to adjust the extent of the axes in terms of |
| standard deviation. |
| ''' |
| # Handle the single plot case. |
| if results.ndim == 2: |
| results = results.reshape(1, *results.shape) |
| |
| # Make sure gridshape is compatible with input data |
| nplots = results.shape[0] |
| gridshape = _best_grid_shape(nplots, gridshape) |
| |
| # Set up the figure |
| fig = plt.figure() |
| fig.set_size_inches((8.5, 11)) |
| fig.dpi = 300 |
| for i, data in enumerate(results): |
| rect = gridshape + (i + 1,) |
| # Convert rect to string form as expected by TaylorDiagram constructor |
| rect = str(rect[0]) + str(rect[1]) + str(rect[2]) |
| |
| # Create Taylor Diagram object |
| dia = TaylorDiagram(1, fig=fig, rect=rect, label=refname, radmax=radmax) |
| for i, (stddev, corrcoef) in enumerate(data): |
| dia.add_sample(stddev, corrcoef, marker='$%d$' % (i + 1), ms=6, |
| label=names[i]) |
| if subtitles is not None: |
| dia._ax.set_title(subtitles[i]) |
| |
| # Add legend |
| legend = fig.legend(dia.samplePoints, |
| [p.get_label() for p in dia.samplePoints], |
| handlelength=0., prop={'size': 10}, numpoints=1, |
| loc=pos) |
| legend.draw_frame(frameon) |
| plt.subplots_adjust(wspace=0) |
| |
| # Add title and save the figure |
| fig.suptitle(ptitle) |
| plt.tight_layout(.05, .05) |
| fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi) |
| fig.clf() |
| |
| def draw_subregions(subregions, lats, lons, fname, fmt='png', ptitle='', |
| parallels=None, meridians=None, subregion_masks=None): |
| ''' |
| Purpose:: |
| Function to draw subregion domain(s) on a map |
| |
| Input:: |
| subregions - a list of subRegion objects |
| lats - array of latitudes |
| lons - array of longitudes |
| fname - a string specifying the filename of the plot |
| fmt - an optional string specifying the filetype, default is .png |
| ptitle - an optional string specifying plot title |
| parallels - an optional list of ints or floats for the parallels to be drawn |
| meridians - an optional list of ints or floats for the meridians to be drawn |
| subregion_masks - optional dictionary of boolean arrays for each subRegion |
| for giving finer control of the domain to be drawn, by default |
| the entire domain is drawn. |
| ''' |
| # Set up the figure |
| fig = plt.figure() |
| fig.set_size_inches((8.5, 11.)) |
| fig.dpi = 300 |
| ax = fig.add_subplot(111) |
| |
| # Determine the map boundaries and construct a Basemap object |
| lonmin = lons.min() |
| lonmax = lons.max() |
| latmin = lats.min() |
| latmax = lats.max() |
| m = Basemap(projection='cyl', llcrnrlat=latmin, urcrnrlat=latmax, |
| llcrnrlon=lonmin, urcrnrlon=lonmax, resolution='l', ax=ax) |
| |
| # Draw the borders for coastlines and countries |
| m.drawcoastlines(linewidth=1) |
| m.drawcountries(linewidth=.75) |
| m.drawstates() |
| |
| # Create default meridians and parallels. The interval between |
| # them should be 1, 5, 10, 20, 30, or 40 depending on the size |
| # of the domain |
| length = max((latmax - latmin), (lonmax - lonmin)) / 5 |
| if length <= 1: |
| dlatlon = 1 |
| elif length <= 5: |
| dlatlon = 5 |
| else: |
| dlatlon = np.round(length, decimals=-1) |
| |
| if meridians is None: |
| meridians = np.r_[np.arange(0, -180, -dlatlon)[::-1], np.arange(0, 180, dlatlon)] |
| if parallels is None: |
| parallels = np.r_[np.arange(0, -90, -dlatlon)[::-1], np.arange(0, 90, dlatlon)] |
| |
| # Draw parallels / meridians |
| m.drawmeridians(meridians, labels=[0, 0, 0, 1], linewidth=.75, fontsize=10) |
| m.drawparallels(parallels, labels=[1, 0, 0, 1], linewidth=.75, fontsize=10) |
| |
| # Set up the color scaling |
| cmap = plt.cm.rainbow |
| norm = mpl.colors.BoundaryNorm(np.arange(1, len(subregions) + 3), cmap.N) |
| |
| # Process the subregions |
| for i, reg in enumerate(subregions): |
| if subregion_masks is not None and reg.name in subregion_masks.keys(): |
| domain = (i + 1) * subregion_masks[reg.name] |
| else: |
| domain = (i + 1) * np.ones((2, 2)) |
| |
| nlats, nlons = domain.shape |
| domain = ma.masked_equal(domain, 0) |
| reglats = np.linspace(reg.latmin, reg.latmax, nlats) |
| reglons = np.linspace(reg.lonmin, reg.lonmax, nlons) |
| reglons, reglats = np.meshgrid(reglons, reglats) |
| |
| # Convert to to projection coordinates. Not really necessary |
| # for cylindrical projections but keeping it here in case we need |
| # support for other projections. |
| x, y = m(reglons, reglats) |
| |
| # Draw the subregion domain |
| m.pcolormesh(x, y, domain, cmap=cmap, norm=norm, alpha=.5) |
| |
| # Label the subregion |
| xm, ym = x.mean(), y.mean() |
| m.plot(xm, ym, marker='$%s$' %(reg.name), markersize=12, color='k') |
| |
| # Add the title |
| ax.set_title(ptitle) |
| |
| # Save the figure |
| fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi) |
| fig.clf() |
| |
| def draw_time_series(results, times, labels, fname, fmt='png', gridshape=(1, 1), |
| xlabel='', ylabel='', ptitle='', subtitles=None, |
| label_month=False, yscale='linear', aspect=None): |
| ''' |
| Purpose:: |
| Function to draw a time series plot |
| |
| Input:: |
| results - a 3d array of time series |
| times - a list of python datetime objects |
| labels - a list of strings with the names of each set of data |
| fname - a string specifying the filename of the plot |
| fmt - an optional string specifying the output filetype |
| gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging |
| the subplots. |
| xlabel - a string specifying the x-axis title |
| ylabel - a string specifying the y-axis title |
| ptitle - a string specifying the plot title |
| subtitles - an optional list of strings specifying the title for each subplot |
| label_month - optional bool to toggle drawing month labels |
| yscale - optional string for setting the y-axis scale, 'linear' for linear |
| and 'log' for log base 10. |
| aspect - Float denoting approximate aspect ratio of each subplot |
| (width / height). Default is 8.5 / 5.5 |
| ''' |
| # Handle the single plot case. |
| if results.ndim == 2: |
| results = results.reshape(1, *results.shape) |
| |
| # Make sure gridshape is compatible with input data |
| nplots = results.shape[0] |
| gridshape = _best_grid_shape(nplots, gridshape) |
| |
| # Set up the figure |
| width, height = _fig_size(gridshape) |
| fig = plt.figure() |
| fig.set_size_inches((width, height)) |
| fig.dpi = 300 |
| |
| # Make the subplot grid |
| grid = ImageGrid(fig, 111, |
| nrows_ncols=gridshape, |
| axes_pad=0.3, |
| share_all=True, |
| add_all=True, |
| ngrids=nplots, |
| label_mode='L', |
| aspect=False, |
| cbar_mode='single', |
| cbar_location='bottom', |
| cbar_size=.05, |
| cbar_pad=.20 |
| ) |
| |
| # Make the plots |
| for i, ax in enumerate(grid): |
| data = results[i] |
| if label_month: |
| xfmt = mpl.dates.DateFormatter('%b') |
| xloc = mpl.dates.MonthLocator() |
| ax.xaxis.set_major_formatter(xfmt) |
| ax.xaxis.set_major_locator(xloc) |
| |
| # Set the y-axis scale |
| ax.set_yscale(yscale) |
| |
| # Set up list of lines for legend |
| lines = [] |
| ymin, ymax = 0, 0 |
| |
| # Plot each line |
| for tSeries in data: |
| line = ax.plot_date(times, tSeries, '') |
| lines.extend(line) |
| cmin, cmax = tSeries.min(), tSeries.max() |
| ymin = min(ymin, cmin) |
| ymax = max(ymax, cmax) |
| |
| # Add a bit of padding so lines don't touch bottom and top of the plot |
| ymin = ymin - ((ymax - ymin) * 0.1) |
| ymax = ymax + ((ymax - ymin) * 0.1) |
| ax.set_ylim((ymin, ymax)) |
| |
| # Set the subplot title if desired |
| if subtitles is not None: |
| ax.set_title(subtitles[i], fontsize='small') |
| |
| # Create a master axes rectangle for figure wide labels |
| fax = fig.add_subplot(111, frameon=False) |
| fax.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off') |
| fax.set_ylabel(ylabel) |
| fax.set_title(ptitle, fontsize=16) |
| fax.title.set_y(1.04) |
| |
| # Create the legend using a 'fake' colorbar axes. This lets us have a nice |
| # legend that is in sync with the subplot grid |
| cax = ax.cax |
| cax.set_frame_on(False) |
| cax.set_xticks([]) |
| cax.set_yticks([]) |
| cax.legend((lines), labels, loc='upper center', ncol=10, fontsize='small', |
| mode='expand', frameon=False) |
| |
| # Note that due to weird behavior by axes_grid, it is more convenient to |
| # place the x-axis label relative to the colorbar axes instead of the |
| # master axes rectangle. |
| cax.set_title(xlabel, fontsize=12) |
| cax.title.set_y(-1.5) |
| |
| # Rotate the x-axis tick labels |
| for ax in grid: |
| for xtick in ax.get_xticklabels(): |
| xtick.set_ha('right') |
| xtick.set_rotation(30) |
| |
| # Save the figure |
| fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi) |
| fig.clf() |
| |
| def draw_contour_map(dataset, lats, lons, fname, fmt='png', gridshape=(1, 1), |
| clabel='', ptitle='', subtitles=None, cmap=None, |
| clevs=None, nlevs=10, parallels=None, meridians=None, |
| extend='neither', aspect=8.5/2.5): |
| ''' |
| Purpose:: |
| Create a multiple panel contour map plot. |
| |
| Input:: |
| dataset - 3d array of the field to be plotted with shape (nT, nLon, nLat) |
| lats - array of latitudes |
| lons - array of longitudes |
| fname - a string specifying the filename of the plot |
| fmt - an optional string specifying the filetype, default is .png |
| gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging |
| the subplots. |
| clabel - an optional string specifying the colorbar title |
| ptitle - an optional string specifying plot title |
| subtitles - an optional list of strings specifying the title for each subplot |
| cmap - an string or optional matplotlib.colors.LinearSegmentedColormap instance |
| denoting the colormap |
| clevs - an optional list of ints or floats specifying contour levels |
| nlevs - an optional integer specifying the target number of contour levels if |
| clevs is None |
| parallels - an optional list of ints or floats for the parallels to be drawn |
| meridians - an optional list of ints or floats for the meridians to be drawn |
| extend - an optional string to toggle whether to place arrows at the colorbar |
| boundaries. Default is 'neither', but can also be 'min', 'max', or |
| 'both'. Will be automatically set to 'both' if clevs is None. |
| ''' |
| # Handle the single plot case. Meridians and Parallels are not labeled for |
| # multiple plots to save space. |
| if dataset.ndim == 2 or (dataset.ndim == 3 and dataset.shape[0] == 1): |
| if dataset.ndim == 2: |
| dataset = dataset.reshape(1, *dataset.shape) |
| mlabels = [0, 0, 0, 1] |
| plabels = [1, 0, 0, 1] |
| else: |
| mlabels = [0, 0, 0, 0] |
| plabels = [0, 0, 0, 0] |
| |
| # Make sure gridshape is compatible with input data |
| nplots = dataset.shape[0] |
| gridshape = _best_grid_shape(nplots, gridshape) |
| |
| # Set up the figure |
| fig = plt.figure() |
| fig.set_size_inches((8.5, 11.)) |
| fig.dpi = 300 |
| |
| # Make the subplot grid |
| grid = ImageGrid(fig, 111, |
| nrows_ncols=gridshape, |
| axes_pad=0.3, |
| share_all=True, |
| add_all=True, |
| ngrids=nplots, |
| label_mode='L', |
| cbar_mode='single', |
| cbar_location='bottom', |
| cbar_size=.15, |
| cbar_pad='0%' |
| ) |
| |
| # Determine the map boundaries and construct a Basemap object |
| lonmin = lons.min() |
| lonmax = lons.max() |
| latmin = lats.min() |
| latmax = lats.max() |
| m = Basemap(projection = 'cyl', llcrnrlat = latmin, urcrnrlat = latmax, |
| llcrnrlon = lonmin, urcrnrlon = lonmax, resolution = 'l') |
| |
| # Convert lats and lons to projection coordinates |
| if lats.ndim == 1 and lons.ndim == 1: |
| lons, lats = np.meshgrid(lons, lats) |
| |
| # Calculate contour levels if not given |
| if clevs is None: |
| # Cut off the tails of the distribution |
| # for more representative contour levels |
| clevs = _nice_intervals(dataset, nlevs) |
| extend = 'both' |
| |
| cmap = plt.get_cmap(cmap) |
| |
| # Create default meridians and parallels. The interval between |
| # them should be 1, 5, 10, 20, 30, or 40 depending on the size |
| # of the domain |
| length = max((latmax - latmin), (lonmax - lonmin)) / 5 |
| if length <= 1: |
| dlatlon = 1 |
| elif length <= 5: |
| dlatlon = 5 |
| else: |
| dlatlon = np.round(length, decimals = -1) |
| if meridians is None: |
| meridians = np.r_[np.arange(0, -180, -dlatlon)[::-1], np.arange(0, 180, dlatlon)] |
| if parallels is None: |
| parallels = np.r_[np.arange(0, -90, -dlatlon)[::-1], np.arange(0, 90, dlatlon)] |
| |
| x, y = m(lons, lats) |
| for i, ax in enumerate(grid): |
| # Load the data to be plotted |
| data = dataset[i] |
| m.ax = ax |
| |
| # Draw the borders for coastlines and countries |
| m.drawcoastlines(linewidth=1) |
| m.drawcountries(linewidth=.75) |
| |
| # Draw parallels / meridians |
| m.drawmeridians(meridians, labels=mlabels, linewidth=.75, fontsize=10) |
| m.drawparallels(parallels, labels=plabels, linewidth=.75, fontsize=10) |
| |
| # Draw filled contours |
| cs = m.contourf(x, y, data, cmap=cmap, levels=clevs, extend=extend) |
| |
| # Add title |
| if subtitles is not None: |
| ax.set_title(subtitles[i], fontsize='small') |
| |
| # Add colorbar |
| cbar = fig.colorbar(cs, cax=ax.cax, drawedges=True, orientation='horizontal', |
| extendfrac='auto') |
| cbar.set_label(clabel) |
| cbar.set_ticks(clevs) |
| cbar.ax.xaxis.set_ticks_position('none') |
| cbar.ax.yaxis.set_ticks_position('none') |
| |
| # This is an ugly hack to make the title show up at the correct height. |
| # Basically save the figure once to achieve tight layout and calculate |
| # the adjusted heights of the axes, then draw the title slightly above |
| # that height and save the figure again |
| fig.savefig(TemporaryFile(), bbox_inches='tight', dpi=fig.dpi) |
| ymax = 0 |
| for ax in grid: |
| bbox = ax.get_position() |
| ymax = max(ymax, bbox.ymax) |
| |
| # Add figure title |
| fig.suptitle(ptitle, y=ymax + .06, fontsize=16) |
| fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi) |
| fig.clf() |
| |
| def draw_portrait_diagram(results, rowlabels, collabels, fname, fmt='png', |
| gridshape=(1, 1), xlabel='', ylabel='', clabel='', |
| ptitle='', subtitles=None, cmap=None, clevs=None, |
| nlevs=10, extend='neither', aspect=None): |
| ''' |
| Purpose:: |
| Makes a portrait diagram plot. |
| |
| Input:: |
| results - 3d array of the field to be plotted. The second dimension |
| should correspond to the number of rows in the diagram and the |
| third should correspond to the number of columns. |
| rowlabels - a list of strings denoting labels for each row |
| collabels - a list of strings denoting labels for each column |
| fname - a string specifying the filename of the plot |
| fmt - an optional string specifying the output filetype |
| gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging |
| the subplots. |
| xlabel - an optional string specifying the x-axis title |
| ylabel - an optional string specifying the y-axis title |
| clabel - an optional string specifying the colorbar title |
| ptitle - a string specifying the plot title |
| subtitles - an optional list of strings specifying the title for each subplot |
| cmap - an optional string or matplotlib.colors.LinearSegmentedColormap instance |
| denoting the colormap |
| clevs - an optional list of ints or floats specifying colorbar levels |
| nlevs - an optional integer specifying the target number of contour levels if |
| clevs is None |
| extend - an optional string to toggle whether to place arrows at the colorbar |
| boundaries. Default is 'neither', but can also be 'min', 'max', or |
| 'both'. Will be automatically set to 'both' if clevs is None. |
| aspect - Float denoting approximate aspect ratio of each subplot |
| (width / height). Default is 8.5 / 5.5 |
| ''' |
| # Handle the single plot case. |
| if results.ndim == 2: |
| results = results.reshape(1, *results.shape) |
| |
| nplots = results.shape[0] |
| |
| # Make sure gridshape is compatible with input data |
| gridshape = _best_grid_shape(nplots, gridshape) |
| |
| # Row and Column labels must be consistent with the shape of |
| # the input data too |
| prows, pcols = results.shape[1:] |
| if len(rowlabels) != prows or len(collabels) != pcols: |
| raise ValueError('rowlabels and collabels must have %d and %d elements respectively' %(prows, pcols)) |
| |
| # Set up the figure |
| width, height = _fig_size(gridshape) |
| fig = plt.figure() |
| fig.set_size_inches((width, height)) |
| fig.dpi = 300 |
| |
| # Make the subplot grid |
| grid = ImageGrid(fig, 111, |
| nrows_ncols=gridshape, |
| axes_pad=0.4, |
| share_all=True, |
| aspect=False, |
| add_all=True, |
| ngrids=nplots, |
| label_mode='all', |
| cbar_mode='single', |
| cbar_location='bottom', |
| cbar_size=.15, |
| cbar_pad='3%' |
| ) |
| |
| # Calculate colorbar levels if not given |
| if clevs is None: |
| # Cut off the tails of the distribution |
| # for more representative colorbar levels |
| clevs = _nice_intervals(results, nlevs) |
| extend = 'both' |
| |
| cmap = plt.get_cmap(cmap) |
| norm = mpl.colors.BoundaryNorm(clevs, cmap.N) |
| |
| # Do the plotting |
| for i, ax in enumerate(grid): |
| data = results[i] |
| cs = ax.matshow(data, cmap=cmap, aspect='auto', origin='lower', norm=norm) |
| |
| # Add grid lines |
| ax.xaxis.set_ticks(np.arange(data.shape[1] + 1)) |
| ax.yaxis.set_ticks(np.arange(data.shape[0] + 1)) |
| x = (ax.xaxis.get_majorticklocs() - .5) |
| y = (ax.yaxis.get_majorticklocs() - .5) |
| ax.vlines(x, y.min(), y.max()) |
| ax.hlines(y, x.min(), x.max()) |
| |
| # Configure ticks |
| ax.xaxis.tick_bottom() |
| ax.xaxis.set_ticks_position('none') |
| ax.yaxis.set_ticks_position('none') |
| ax.set_xticklabels(collabels, fontsize='xx-small') |
| ax.set_yticklabels(rowlabels, fontsize='xx-small') |
| |
| # Add axes title |
| if subtitles is not None: |
| ax.text(0.5, 1.04, subtitles[i], va='center', ha='center', |
| transform = ax.transAxes, fontsize='small') |
| |
| # Create a master axes rectangle for figure wide labels |
| fax = fig.add_subplot(111, frameon=False) |
| fax.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off') |
| fax.set_ylabel(ylabel) |
| fax.set_title(ptitle, fontsize=16) |
| fax.title.set_y(1.04) |
| |
| # Add colorbar |
| cax = ax.cax |
| cbar = fig.colorbar(cs, cax=cax, norm=norm, boundaries=clevs, drawedges=True, |
| extend=extend, orientation='horizontal', extendfrac='auto') |
| cbar.set_label(clabel) |
| cbar.set_ticks(clevs) |
| cbar.ax.xaxis.set_ticks_position('none') |
| cbar.ax.yaxis.set_ticks_position('none') |
| |
| # Note that due to weird behavior by axes_grid, it is more convenient to |
| # place the x-axis label relative to the colorbar axes instead of the |
| # master axes rectangle. |
| cax.set_title(xlabel, fontsize=12) |
| cax.title.set_y(1.5) |
| |
| # Save the figure |
| fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi) |
| fig.clf() |
| |
| class TaylorDiagram(object): |
| """ Taylor diagram helper class |
| |
| Plot model standard deviation and correlation to reference (data) |
| sample in a single-quadrant polar plot, with r=stddev and |
| theta=arccos(correlation). |
| """ |
| |
| def __init__(self, refstd, radmax=1.5, fig=None, rect=111, label='_'): |
| """Set up Taylor diagram axes, i.e. single quadrant polar |
| plot, using mpl_toolkits.axisartist.floating_axes. refstd is |
| the reference standard deviation to be compared to. |
| """ |
| |
| from matplotlib.projections import PolarAxes |
| import mpl_toolkits.axisartist.floating_axes as FA |
| import mpl_toolkits.axisartist.grid_finder as GF |
| |
| self.refstd = refstd # Reference standard deviation |
| |
| tr = PolarAxes.PolarTransform() |
| |
| # Correlation labels |
| rlocs = np.concatenate((np.arange(10)/10.,[0.95,0.99])) |
| tlocs = np.arccos(rlocs) # Conversion to polar angles |
| gl1 = GF.FixedLocator(tlocs) # Positions |
| tf1 = GF.DictFormatter(dict(zip(tlocs, map(str,rlocs)))) |
| |
| # Standard deviation axis extent |
| self.smin = 0 |
| self.smax = radmax*self.refstd |
| |
| ghelper = FA.GridHelperCurveLinear(tr, |
| extremes=(0,np.pi/2, # 1st quadrant |
| self.smin,self.smax), |
| grid_locator1=gl1, |
| tick_formatter1=tf1, |
| ) |
| |
| if fig is None: |
| fig = plt.figure() |
| |
| ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper) |
| fig.add_subplot(ax) |
| |
| # Adjust axes |
| ax.axis["top"].set_axis_direction("bottom") # "Angle axis" |
| ax.axis["top"].toggle(ticklabels=True, label=True) |
| ax.axis["top"].major_ticklabels.set_axis_direction("top") |
| ax.axis["top"].label.set_axis_direction("top") |
| ax.axis["top"].label.set_text("Correlation") |
| |
| ax.axis["left"].set_axis_direction("bottom") # "X axis" |
| ax.axis["left"].label.set_text("Standard deviation") |
| |
| ax.axis["right"].set_axis_direction("top") # "Y axis" |
| ax.axis["right"].toggle(ticklabels=True) |
| ax.axis["right"].major_ticklabels.set_axis_direction("left") |
| |
| ax.axis["bottom"].set_visible(False) # Useless |
| |
| # Contours along standard deviations |
| ax.grid(False) |
| |
| self._ax = ax # Graphical axes |
| self.ax = ax.get_aux_axes(tr) # Polar coordinates |
| |
| # Add reference point and stddev contour |
| # print "Reference std:", self.refstd |
| l, = self.ax.plot([0], self.refstd, 'k*', |
| ls='', ms=10, label=label) |
| t = np.linspace(0, np.pi/2) |
| r = np.zeros_like(t) + self.refstd |
| self.ax.plot(t,r, 'k--', label='_') |
| |
| # Collect sample points for latter use (e.g. legend) |
| self.samplePoints = [l] |
| |
| def add_sample(self, stddev, corrcoef, *args, **kwargs): |
| """Add sample (stddev,corrcoeff) to the Taylor diagram. args |
| and kwargs are directly propagated to the Figure.plot |
| command.""" |
| |
| l, = self.ax.plot(np.arccos(corrcoef), stddev, |
| *args, **kwargs) # (theta,radius) |
| self.samplePoints.append(l) |
| |
| return l |
| |
| def add_rms_contours(self, levels=5, **kwargs): |
| """Add constant centered RMS difference contours.""" |
| |
| rs,ts = np.meshgrid(np.linspace(self.smin,self.smax), |
| np.linspace(0,np.pi/2)) |
| # Compute centered RMS difference |
| rms = np.sqrt(self.refstd**2 + rs**2 - 2*self.refstd*rs*np.cos(ts)) |
| |
| contours = self.ax.contour(ts, rs, rms, levels, **kwargs) |
| |
| def add_stddev_contours(self, std, corr1, corr2, **kwargs): |
| """Add a curved line with a radius of std between two points |
| [std, corr1] and [std, corr2]""" |
| |
| t = np.linspace(np.arccos(corr1), np.arccos(corr2)) |
| r = np.zeros_like(t) + std |
| return self.ax.plot(t,r,'red', linewidth=2) |
| |
| def add_contours(self,std1,corr1,std2,corr2, **kwargs): |
| """Add a line between two points |
| [std1, corr1] and [std2, corr2]""" |
| |
| t = np.linspace(np.arccos(corr1), np.arccos(corr2)) |
| r = np.linspace(std1, std2) |
| |
| return self.ax.plot(t,r,'red',linewidth=2) |