Data preparation¶
NeuralGCM models take and produce data on defined on 37 pressure levels, including the following variables, provided in SI units and on the NeuralGCM model’s native grid:
Inputs/outputs (on pressure levels):
u_component_of_wind
,v_component_of_wind
,geopotential
,temperature
,specific_humidity
,specific_cloud_ice_water_content
,specific_cloud_liquid_water_content
.Forcings (surface level only):
sea_surface_temperature
,sea_ice_cover
Regridding data¶
Preparing a dataset stored on a different horizontal grid for NeuralGCM requires two steps:
Horizontal regridding to a Gaussian grid. For processing fine-resolution data conservative regridding is most appropriate (and is what we used to train NeuralGCM).
Filling in all missing values (NaN), to ensure all inputs are valid. Forcing fields like
sea_surface_temperature
are only defined over ocean in ERA5, and NeuralGCM’s surface model also includes a mask that ignores values over land, but we still need to fill all NaN values to them leaking into our model outputs.
Utilities for both of these operations are packaged as part of Dinosaur. We’ll show how to use them on the Zarr copy of ERA5 from the ARCO-ERA5 project:
import jax
import numpy as np
import gcsfs
import neuralgcm
import xarray
# load demo model
checkpoint = neuralgcm.demo.load_checkpoint_tl63_stochastic()
model = neuralgcm.PressureLevelModel.from_checkpoint(checkpoint)
# create a xarray.Dataset with required variables for NeuralGCM
gcs = gcsfs.GCSFileSystem(token='anon')
path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(path), chunks=None)
full_era5 = full_era5[model.input_variables + model.forcing_variables]
full_era5
<xarray.Dataset> Dimensions: (time: 1087704, level: 37, latitude: 721, longitude: 1440) Coordinates: * latitude (latitude) float32 90.0 89.75 ... -90.0 * level (level) int64 1 2 3 5 ... 950 975 1000 * longitude (longitude) float32 0.0 0.25 ... 359.8 * time (time) datetime64[ns] 1900-01-01 ...... Data variables: geopotential (time, level, latitude, longitude) float32 ... specific_humidity (time, level, latitude, longitude) float32 ... temperature (time, level, latitude, longitude) float32 ... u_component_of_wind (time, level, latitude, longitude) float32 ... v_component_of_wind (time, level, latitude, longitude) float32 ... specific_cloud_ice_water_content (time, level, latitude, longitude) float32 ... specific_cloud_liquid_water_content (time, level, latitude, longitude) float32 ... sea_ice_cover (time, latitude, longitude) float32 ... sea_surface_temperature (time, latitude, longitude) float32 ...
Based on this dataset and our model grid, we can build a Regridder
object:
from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils
full_era5_grid = spherical_harmonic.Grid(
latitude_nodes=full_era5.sizes['latitude'],
longitude_nodes=full_era5.sizes['longitude'],
latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)
# Other available regridders include BilinearRegridder and NearestRegridder.
regridder = horizontal_interpolation.ConservativeRegridder(
full_era5_grid, model.data_coords.horizontal, skipna=True
)
Note
skipna=True
in ConservativeRegridder
means grid cells with a mix of NaN/non-NaN values should be filled skipping NaN values. This ensures sea surface temperature and sea ice cover remains defined in coarse grid cells that overlap coastlines.
regridder
ConservativeRegridder(source_grid=Grid(longitude_wavenumbers=0, total_wavenumbers=0, longitude_nodes=1440, latitude_nodes=721, latitude_spacing='equiangular_with_poles', longitude_offset=0.0, radius=1.0, spherical_harmonics_impl=<class 'dinosaur.spherical_harmonic.RealSphericalHarmonics'>, spmd_mesh=None), target_grid=Grid(longitude_wavenumbers=64, total_wavenumbers=65, longitude_nodes=128, latitude_nodes=64, latitude_spacing='gauss', longitude_offset=0.0, radius=1.0, spherical_harmonics_impl=<class 'dinosaur.spherical_harmonic.RealSphericalHarmonicsWithZeroImag'>, spmd_mesh=None))
full_era5.nbytes / 1e12
1178.986912912812
Regridding requires the data to be first loaded into memory. Because this full dataset is gigantic (100s of TB) we’ll only regrid a single time point:
sliced_era5 = full_era5.sel(time='2020-01-01T00').compute()
sliced_era5
<xarray.Dataset> Dimensions: (level: 37, latitude: 721, longitude: 1440) Coordinates: * latitude (latitude) float32 90.0 89.75 ... -90.0 * level (level) int64 1 2 3 5 ... 950 975 1000 * longitude (longitude) float32 0.0 0.25 ... 359.8 time datetime64[ns] 2020-01-01 Data variables: geopotential (level, latitude, longitude) float32 ... specific_humidity (level, latitude, longitude) float32 ... temperature (level, latitude, longitude) float32 ... u_component_of_wind (level, latitude, longitude) float32 ... v_component_of_wind (level, latitude, longitude) float32 ... specific_cloud_ice_water_content (level, latitude, longitude) float32 ... specific_cloud_liquid_water_content (level, latitude, longitude) float32 ... sea_ice_cover (latitude, longitude) float32 1.0 ..... sea_surface_temperature (latitude, longitude) float32 271.5 ...
regridded = xarray_utils.regrid(sliced_era5, regridder)
regridded
<xarray.Dataset> Dimensions: (level: 37, longitude: 128, latitude: 64) Coordinates: * level (level) int64 1 2 3 5 ... 950 975 1000 time datetime64[ns] 2020-01-01 * longitude (longitude) float64 0.0 2.812 ... 357.2 * latitude (latitude) float64 -87.86 ... 87.86 Data variables: geopotential (level, longitude, latitude) float32 ... specific_humidity (level, longitude, latitude) float32 ... temperature (level, longitude, latitude) float32 ... u_component_of_wind (level, longitude, latitude) float32 ... v_component_of_wind (level, longitude, latitude) float32 ... specific_cloud_ice_water_content (level, longitude, latitude) float32 ... specific_cloud_liquid_water_content (level, longitude, latitude) float32 ... sea_ice_cover (longitude, latitude) float32 nan ..... sea_surface_temperature (longitude, latitude) float32 nan .....
Looking at the data, we see that sea surface temperature is now on a much coarser grid (roughly 2.8°).
sliced_era5.sea_surface_temperature.plot(x='longitude', y='latitude', aspect=2, size=2.5);
regridded.sea_surface_temperature.plot(x='longitude', y='latitude', aspect=2, size=2.5);
However, we still have missing values (NaN) in the locations shown in white over land. We’ll fill those with values from the nearest non-missing locations:
regridded_and_filled = xarray_utils.fill_nan_with_nearest(regridded)
regridded_and_filled.sea_surface_temperature.plot(x='longitude', y='latitude', aspect=2, size=2.5);
Now we have a dataset ready for feeding into NeuralGCM!
Converting to/from Xarray¶
NeuralGCM’s Learned methods transform data in the form of dictionaries of NumPy or JAX arrays:
Expected dict keys for
inputs
andforcings
are indicated bymodel.input_variables
andmodel.forcing_variables
.Dictionary values should be arrays with shape
(level, longitude, latitude)
or(time, level, longitude, latitude)
, depending upon whether the function expects or produces outputs with a time axis. Thelevel
axis is size 1 for surface variables (i.e., sea surface temperature and sea ice concentration).In addition to surface and 3D fields, a
sim_time
variable is used to calculate incident solar radiation. Unlike the other vairables,sim_time
needs to be provided already converted into the model’s internal non-dimensional units (JAX does not support NumPy’sdatetime64
dtype).
inputs_from_xarray()
and forcings_from_xarray()
convert xarray.Dataset
objects with the appropriate variables into dictionary of array format, either with or without a leading time dimension:
inputs = model.inputs_from_xarray(regridded_and_filled)
forcings = model.forcings_from_xarray(regridded_and_filled)
jax.tree.map(np.shape, inputs)
{'geopotential': (37, 128, 64),
'sim_time': (),
'specific_cloud_ice_water_content': (37, 128, 64),
'specific_cloud_liquid_water_content': (37, 128, 64),
'specific_humidity': (37, 128, 64),
'temperature': (37, 128, 64),
'u_component_of_wind': (37, 128, 64),
'v_component_of_wind': (37, 128, 64)}
jax.tree.map(np.shape, forcings)
{'sea_ice_cover': (1, 128, 64),
'sea_surface_temperature': (1, 128, 64),
'sim_time': ()}
Notice that sim_time
was calculated automatically. It can also be calculated explicitly from numpy.datetime64
arrays using datetime64_to_simtime()
:
inputs['sim_time']
array(188693.6256)
model.datetime64_to_sim_time(regridded_and_filled.time.data)
188693.6256
Outputs from NeuralGCM can be converted back into an xarray.Dataset
with appropriate coordinates via data_to_xarray()
. Pass in times=None
if there is no leading time-axis, or supply a 1D numpy array of time values:
model.data_to_xarray(
model.inputs_from_xarray(regridded_and_filled),
times=None, # times=None indicates no leading time-axis
)
<xarray.Dataset> Dimensions: (level: 37, longitude: 128, latitude: 64) Coordinates: * longitude (longitude) float64 0.0 2.812 ... 357.2 * latitude (latitude) float64 -87.86 ... 87.86 * level (level) int64 1 2 3 5 ... 950 975 1000 Data variables: u_component_of_wind (level, longitude, latitude) float32 ... specific_cloud_ice_water_content (level, longitude, latitude) float32 ... temperature (level, longitude, latitude) float32 ... v_component_of_wind (level, longitude, latitude) float32 ... specific_cloud_liquid_water_content (level, longitude, latitude) float32 ... sim_time float64 1.887e+05 geopotential (level, longitude, latitude) float32 ... specific_humidity (level, longitude, latitude) float32 ... Attributes: longitude_wavenumbers: 64 total_wavenumbers: 65 longitude_nodes: 128 latitude_nodes: 64 latitude_spacing: gauss longitude_offset: 0.0 radius: 1.0 spherical_harmonics_impl: RealSphericalHarmonicsWithZeroImag spmd_mesh: centers: [1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, ... horizontal_grid_type: Grid vertical_grid_type: PressureCoordinates
Tip
In principle, simtime_to_datetime64()
can calculate output times automatically (e.g., from the outputs of unroll()
), but this isn’t recommended. By default, JAX does math in float32 mode, which can result in significant rounding errors (e.g., up to a few minutes). As illustrated in the Forecasting quick start, we recommend calculating times directly with NumPy or pandas.
Time shifting¶
In the NeuralGCM paper, we used one other data preparation trick for forcing variables: we shifted them backwards in time (by one day), so we could not be accused of leaking data from the future into our weather forecasts.
This can be reproduced with the selective_temporal_shift()
utility, which acts lazily even on datasets that do not fit into memory:
xarray_utils.selective_temporal_shift(
dataset=full_era5,
variables=model.forcing_variables,
time_shift='24 hours',
)
<xarray.Dataset> Dimensions: (time: 1087680, level: 37, latitude: 721, longitude: 1440) Coordinates: * latitude (latitude) float32 90.0 89.75 ... -90.0 * level (level) int64 1 2 3 5 ... 950 975 1000 * longitude (longitude) float32 0.0 0.25 ... 359.8 * time (time) datetime64[ns] 1900-01-02 ...... Data variables: geopotential (time, level, latitude, longitude) float32 ... specific_humidity (time, level, latitude, longitude) float32 ... temperature (time, level, latitude, longitude) float32 ... u_component_of_wind (time, level, latitude, longitude) float32 ... v_component_of_wind (time, level, latitude, longitude) float32 ... specific_cloud_ice_water_content (time, level, latitude, longitude) float32 ... specific_cloud_liquid_water_content (time, level, latitude, longitude) float32 ... sea_ice_cover (time, latitude, longitude) float32 ... sea_surface_temperature (time, latitude, longitude) float32 ...