In [1]:
# TODO: document this notebook
# References: https://journals.ametsoc.org/doi/full/10.1175/BAMS-D-13-00255.1
In [2]:
%matplotlib inline
import warnings

warnings.filterwarnings("ignore")
# Silence dask.distributed logs
import logging

logger = logging.getLogger("distributed.utils_perf")
logger.setLevel(logging.ERROR)

import intake
import numpy as np
import pandas as pd
import xarray as xr

Create and Connect to Dask Distributed Cluster

In [4]:
from dask_kubernetes import KubeCluster
from dask.distributed import Client

cluster = KubeCluster()
cluster.adapt(minimum=2, maximum=100, wait_count=60)
client = Client(cluster)
cluster

☝️ Don't forget to click the link above to view the scheduler dashboard!

Load data into xarray from an intake catalog

In [5]:
cat = intake.Catalog(
    "https://raw.githubusercontent.com/NCAR/cesm-lens-aws/master/intake-catalogs/atmosphere/daily.yaml"
)

ds_20C = cat["reference_height_temperature_20C"].to_dask()
ds_rcp = cat["reference_height_temperature_RCP85"].to_dask()
t_20c = ds_20C["TREFHT"]
t_rcp = ds_rcp["TREFHT"]
In [6]:
t_ref = t_20c.sel(time=slice("1961", "1990"))
t_ref
Out[6]:
<xarray.DataArray 'TREFHT' (member_id: 40, time: 10950, lat: 192, lon: 288)>
dask.array<shape=(40, 10950, 192, 288), dtype=float32, chunksize=(2, 365, 192, 288)>
Coordinates:
  * lat        (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon        (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
  * member_id  (member_id) int64 1 2 3 4 5 6 7 8 ... 34 35 101 102 103 104 105
  * time       (time) object 1961-01-01 00:00:00 ... 1990-12-31 00:00:00
Attributes:
    cell_methods:  time: mean
    long_name:     Reference height temperature
    units:         K
In [7]:
areacella = ds_20C.area
total_area = areacella.sum()
areacella
Out[7]:
<xarray.DataArray 'area' (lat: 192, lon: 288)>
dask.array<shape=(192, 288), dtype=float32, chunksize=(192, 288)>
Coordinates:
  * lat      (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 87.17 88.12 89.06 90.0
  * lon      (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
Attributes:
    long_name:      Grid-Cell Area
    standard_name:  cell_area
    units:          m2

Compute Weighted Means

In [8]:
t_ref_ts = (
    (t_ref.resample(time="AS").mean("time") * areacella).sum(dim=("lat", "lon"))
    / total_area
).mean(dim=("time", "member_id"))

t_20c_ts = (
    (t_20c.resample(time="AS").mean("time") * areacella).sum(dim=("lat", "lon"))
) / total_area

t_rcp_ts = (
    (t_rcp.resample(time="AS").mean("time") * areacella).sum(dim=("lat", "lon"))
) / total_area
In [9]:
%%time
t_ref_mean = t_ref_ts.load()
t_ref_mean
CPU times: user 10.7 s, sys: 459 ms, total: 11.1 s
Wall time: 18.1 s
Out[9]:
<xarray.DataArray ()>
array(286.38766, dtype=float32)
In [10]:
t_20c_ts_df = t_20c_ts.to_series().unstack().T
t_20c_ts_df.head()
Out[10]:
member_id 1 2 3 4 5 6 7 8 9 10 ... 31 32 33 34 35 101 102 103 104 105
time
1920-01-01 00:00:00 286.311310 286.346710 286.283875 286.363983 286.328400 286.373444 286.386017 286.302185 286.374878 286.348358 ... 286.243469 286.283783 286.173859 286.309509 286.296234 286.341064 286.341187 286.376831 286.321167 286.254822
1921-01-01 00:00:00 286.250641 286.198181 286.287292 286.390564 286.309204 286.334229 286.311310 286.300232 286.315857 286.305603 ... 286.179413 286.315674 286.075104 286.295990 286.318085 286.375275 286.246063 286.356201 286.492523 286.224274
1922-01-01 00:00:00 286.293488 286.296356 286.265686 286.336517 286.293579 286.220093 286.010773 286.195099 286.205170 286.396545 ... 286.142365 286.316254 286.140167 286.293549 286.327972 286.142365 286.412598 286.369232 286.503418 286.282074
1923-01-01 00:00:00 286.329163 286.322662 286.251099 286.322723 286.237457 286.152069 286.066040 286.204498 286.271454 286.292236 ... 286.168762 286.300781 286.095490 286.116302 286.227905 286.226440 286.512909 286.381348 286.215302 286.396332
1924-01-01 00:00:00 286.307465 286.237366 286.148895 286.311890 286.361694 286.185974 286.248352 286.288177 286.330444 286.411835 ... 286.143066 286.287079 286.234100 286.199890 286.252777 286.322815 286.256165 286.221588 286.247437 286.422028

5 rows × 40 columns

In [11]:
t_rcp_ts_df = t_rcp_ts.to_series().unstack().T
t_rcp_ts_df.head()
Out[11]:
member_id 1 2 3 4 5 6 7 8 9 10 ... 31 32 33 34 35 101 102 103 104 105
time
2006-01-01 00:00:00 286.764832 286.960358 286.679230 286.793152 286.754547 287.022339 286.850464 287.089844 286.960022 286.775787 ... 286.866089 286.925049 286.663971 286.955414 286.712524 287.115601 286.863556 286.881683 287.308411 287.030334
2007-01-01 00:00:00 287.073792 286.908539 286.808746 286.998901 286.841675 286.993042 286.914124 286.938965 286.933563 286.675385 ... 286.804108 286.849548 286.628204 287.010529 286.811523 287.187225 286.862823 287.008240 287.222534 287.239044
2008-01-01 00:00:00 287.104095 286.815033 286.995056 287.081543 287.100708 286.960510 286.854706 286.878937 287.062927 286.702454 ... 286.825653 286.844086 286.811859 286.803741 286.956635 287.080994 286.930084 286.945801 287.087128 287.157745
2009-01-01 00:00:00 286.984497 287.059418 287.010498 287.144745 286.948700 287.092316 286.888458 287.050964 287.138428 286.890839 ... 286.785797 286.876556 286.953094 287.060364 287.056885 287.124908 287.005615 287.083984 287.254211 287.060730
2010-01-01 00:00:00 286.991821 287.102295 286.988159 286.875183 286.954407 287.121796 286.938843 287.116211 286.957245 287.049622 ... 286.937317 286.928284 286.980499 287.118713 287.178040 287.030212 287.114716 287.083038 287.256927 287.066528

5 rows × 40 columns

Get Observations (HadCRUT4; Morice et al. 2012)

In [12]:
ds = xr.open_dataset(
    "https://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/cru/hadcrut4/air.mon.anom.median.nc"
).load()
ds
Out[12]:
<xarray.Dataset>
Dimensions:    (lat: 36, lon: 72, nbnds: 2, time: 2036)
Coordinates:
  * lat        (lat) float32 87.5 82.5 77.5 72.5 ... -72.5 -77.5 -82.5 -87.5
  * lon        (lon) float32 -177.5 -172.5 -167.5 -162.5 ... 167.5 172.5 177.5
  * time       (time) datetime64[ns] 1850-01-01 1850-02-01 ... 2019-08-01
Dimensions without coordinates: nbnds
Data variables:
    time_bnds  (time, nbnds) datetime64[ns] 1850-01-01 1850-01-31 ... 2019-08-31
    air        (time, lat, lon) float32 nan nan nan nan nan ... nan nan nan nan
Attributes:
    platform:                        Surface
    title:                           HADCRUT4 Combined Air Temperature/SST An...
    history:                         Originally created at NOAA/ESRL PSD by C...
    Conventions:                     CF-1.0
    Comment:                         This dataset supersedes V3
    Source:                          Obtained from http://hadobs.metoffice.co...
    version:                         4.2.0
    dataset_title:                   HadCRUT4
    References:                      https://www.esrl.noaa.gov/psd/data/gridd...
    DODS_EXTRA.Unlimited_Dimension:  time
  • Obs mean: weight by days in each month
In [13]:
def weighted_temporal_mean(ds):
    time_bound_diff = ds.time_bnds.diff(dim="nbnds")[:, 0]
    wgts = time_bound_diff.groupby("time.year") / time_bound_diff.groupby(
        "time.year"
    ).sum(xr.ALL_DIMS)
    np.testing.assert_allclose(wgts.groupby("time.year").sum(xr.ALL_DIMS), 1.0)
    obs = ds["air"]
    cond = obs.isnull()
    ones = xr.where(cond, 0.0, 1.0)
    obs_sum = (obs * wgts).resample(time="AS").sum(dim="time")
    ones_out = (ones * wgts).resample(time="AS").sum(dim="time")
    obs_s = (obs_sum / ones_out).mean(("lat", "lon")).to_series()
    return obs_s
In [14]:
obs_s = weighted_temporal_mean(ds)
obs_s.head()
Out[14]:
time
1850-01-01   -0.338822
1851-01-01   -0.245482
1852-01-01   -0.291014
1853-01-01   -0.342457
1854-01-01   -0.276820
Freq: AS-JAN, dtype: float64
In [15]:
all_ts_anom = pd.concat([t_20c_ts_df, t_rcp_ts_df]) - t_ref_mean.data
years = [val.year for val in all_ts_anom.index]
  • Confirm that after using area weighted average, max temp increase is 5k
In [16]:
np.testing.assert_allclose(all_ts_anom.values.max(), 5.0, rtol=0.02)

Figure 2: Global surface temperature anomaly (1961-90 base period) for individual ensemble members, and observations

In [17]:
import matplotlib.pyplot as plt
In [18]:
ax = plt.axes()

ax.tick_params(right=True, top=True, direction="out", length=6, width=2, grid_alpha=0.5)
ax.plot(years, all_ts_anom, color="grey")
ax.plot(years, all_ts_anom[1], color="black")
ax.plot(obs_s.index.year.tolist(), obs_s, color="red")

ax.text(
    0.3,
    0.4,
    "observations",
    verticalalignment="bottom",
    horizontalalignment="left",
    transform=ax.transAxes,
    color="red",
    fontsize=10,
)
ax.text(
    0.3,
    0.3,
    "members 1-40",
    verticalalignment="bottom",
    horizontalalignment="left",
    transform=ax.transAxes,
    color="grey",
    fontsize=10,
)

ax.set_xticks([1850, 1920, 1950, 2000, 2050, 2100])
plt.ylim(-1, 5)
plt.xlim(1850, 2100)
plt.ylabel("Global Surface\nTemperature Anomaly (K)")
plt.show()

Compute Linear Trend for Winter Seasons

In [19]:
def linear_trend(da, dim="time"):
    da_chunk = da.chunk({dim: -1})
    trend = xr.apply_ufunc(
        calc_slope,
        da_chunk,
        vectorize=True,
        input_core_dims=[[dim]],
        output_core_dims=[[]],
        output_dtypes=[np.float],
        dask="parallelized",
    )
    return trend


def calc_slope(y):
    """ufunc to be used by linear_trend"""
    x = np.arange(len(y))
    return np.polyfit(x, y, 1)[0]
In [20]:
t = xr.concat([t_20c, t_rcp], dim="time")
seasons = t.sel(time=slice("1979", "2012")).resample(time="QS-DEC").mean("time")
# Include only full seasons from 1979 and 2012
seasons = seasons.sel(time=slice("1979", "2012")).load()
seasons
Out[20]:
<xarray.DataArray 'TREFHT' (member_id: 40, time: 136, lat: 192, lon: 288)>
array([[[[219.35846, ..., 219.44753],
         ...,
         [249.04382, ..., 249.0427 ]],

        ...,

        [[239.91365, ..., 240.39209],
         ...,
         [235.72987, ..., 235.72905]]],


       ...,


       [[[220.02869, ..., 220.16699],
         ...,
         [247.35379, ..., 247.3539 ]],

        ...,

        [[241.33533, ..., 242.03595],
         ...,
         [241.77695, ..., 241.7759 ]]]], dtype=float32)
Coordinates:
  * time       (time) object 1979-03-01 00:00:00 ... 2012-12-01 00:00:00
  * lat        (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon        (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
  * member_id  (member_id) int64 1 2 3 4 5 6 7 8 ... 34 35 101 102 103 104 105
In [21]:
winter_seasons = seasons.sel(
    time=seasons.time.where(seasons.time.dt.month == 12, drop=True)
)
winter_trends = linear_trend(
    winter_seasons.chunk({"lat": 20, "lon": 20, "time": -1})
).load() * len(winter_seasons.time)
In [22]:
# Make sure that we have 34 seasons
assert len(winter_seasons.time) == 34
In [23]:
!pip install git+https://github.com/hhuangwx/cmaps.git -q
In [24]:
import cmaps  # for NCL colormaps
import cartopy.crs as ccrs
import matplotlib.colors as colors

cmap = cmaps.amwg_blueyellowred
levels = [-7, -6, -5, -4, -3, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7]
norm = colors.Normalize(vmin=levels[0], vmax=levels[-1], clip=True)
In [25]:
fig = plt.figure(dpi=300)
fg = winter_trends.plot(
    col="member_id",
    col_wrap=4,
    transform=ccrs.PlateCarree(),
    subplot_kws={"projection": ccrs.Robinson(central_longitude=180)},
    add_colorbar=False,
    levels=levels,
    cmap=cmap,
    norm=norm,
    extend="neither",
)

for ax in fg.axes.flat:
    ax.coastlines(color="grey")

# TODO: move the subplot title to lower left corners
# TODO: Add obs panel and ensemble mean at the end

fg.add_colorbar(orientation="horizontal")
fg.cbar.set_label("1979-2012 DJF surface air temperature trends (K/34 years)")
fg.cbar.set_ticks(levels)
fg.cbar.set_ticklabels(levels)
<Figure size 1800x1200 with 0 Axes>