Stable API¶
PressureLevelModel¶
The user-facing API for NeuralGCM models centers around PressureLevelModel
:
- class neuralgcm.PressureLevelModel(structure: WhirlModel, params: dict[str, dict[str, float | ndarray | Array]], gin_config: str)[source]¶
Inference-only API for models that predict dense data on pressure levels.
These models are trained on ECMWF ERA5 data on pressure-levels as stored in the Copernicus Data Store.
This class encapsulates the details of defining models (e.g., with Haiku) and hence should remain stable even for future NeuralGCM models.
Constructor¶
Use this class method to create a new model:
- classmethod PressureLevelModel.from_checkpoint(checkpoint: Any) PressureLevelModel [source]¶
Creates a PressureLevelModel from a checkpoint.
- Parameters:
checkpoint – dictionary with keys “model_config_str”, “aux_ds_dict” and “params” that specifies model gin configuration, supplemental xarray dataset with model-specific static features, and model parameters.
- Returns:
Instance of a PressureLevelModel with weights and configuration specified by the checkpoint.
Properties¶
These properties describe the coordinate system and variables for which a model is defined:
- property PressureLevelModel.timestep: timedelta64[source]¶
Spacing between internal model timesteps.
- property PressureLevelModel.data_coords: CoordinateSystem[source]¶
Coordinate system for input and output data.
- property PressureLevelModel.model_coords: CoordinateSystem[source]¶
Coordinate system for internal model state.
Learned methods¶
These method use trained model parameters to convert from input variables defined on data coordinates (i.e., pressure levels) to internal model state variables defined on model coordinates (i.e., sigma levels) and back.
advance
and unroll
allow for stepping forward in time.
- PressureLevelModel.encode(inputs: dict[str, float | ndarray | Array], forcings: dict[str, float | ndarray | Array], rng_key: Any | None = None) Any [source]¶
Encode from pressure-level inputs & forcings to model state.
- Parameters:
inputs – input data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords.
forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.
rng_key – optional JAX RNG key to use for encoding the state. Required if using stochastic models, otherwise ignored.
- Returns:
Dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords.
- PressureLevelModel.decode(state: Any, forcings: dict[str, float | ndarray | Array]) dict[str, Array] [source]¶
Decode from model state to pressure-level outputs.
- Parameters:
state – dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords.
forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.
- Returns:
Outputs on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords.
- PressureLevelModel.advance(state: Any, forcings: dict[str, float | ndarray | Array]) Any [source]¶
Advance model state one timestep forward.
- Parameters:
state – dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords
forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.
- Returns:
State advanced one time-step forward.
- PressureLevelModel.unroll(state: Any, forcings: dict[str, float | ndarray | Array], *, steps: int, timedelta: str | timedelta64 | Timestamp | timedelta | None = None, start_with_input: bool = False, post_process_fn: Callable[[Any], Any] | None = None) tuple[Any, dict[str, Array]] [source]¶
Unroll predictions over many time-steps.
Usage:
advanced_state, outputs = model.unroll(state, forcings, steps=N)
where
advanced_state
is the advanced model state afterN
steps andoutputs
is a trajectory of decoded states on pressure-levels with a leading dimension of sizeN
.- Parameters:
state – initial model state.
forcings – forcing data over the time-period spanned by the desired output trajectory. Should include a leading time-axis, but times can be at any desired granularity (e.g., it should be fine to supply daily forcing data, even if producing hourly outputs). The nearest forcing in time will be used for each internal
advance()
anddecode()
call.steps – number of time-steps to take.
timedelta – size of each time-step to take, which must be a multiple of the internal model timestep. By default uses the internal model timestep.
start_with_input – if
True
, outputs are at times[0, ..., (steps - 1) * timestep]
relative to the initial time; ifFalse
, outputs are at times[timestep, ..., steps * timestep]
.post_process_fn – optional function to apply to each advanced state and current forcings to create outputs like
post_process_fn(state, forcings)
, whereforcings
does not include a time axis. By default, usesmodel.decode
.
- Returns:
A tuple of the advanced state at time
steps * timestamp
, and outputs with a leadingtime
axis at the time-steps specified bysteps
,timedelta
andstart_with_input
.
Unit conversion¶
The internal state of NeuralGCM models uses non-dimensional units and
“simulation time,” instead of SI and numpy.datetime64
. These utilities allow
for converting arrays back and forth, including inside JAX code:
- PressureLevelModel.to_nondim_units(value: float | ndarray | Array | DataArray, units: str) float | ndarray | Array | DataArray [source]¶
Scale a value to the model’s internal non-dimensional units.
- PressureLevelModel.from_nondim_units(value: float | ndarray | Array | DataArray, units: str) float | ndarray | Array | DataArray [source]¶
Scale a value from the model’s internal non-dimensional units.
Xarray conversion¶
Xarray is convenient for data preparation and evaluation, but is not compatible
with JAX. Use these methods to convert between xarray.Dataset
objects and
inputs/outputs from learned methods:
- PressureLevelModel.inputs_from_xarray(dataset: Dataset) dict[str, ndarray] [source]¶
Extract inputs from an xarray.Dataset.
- PressureLevelModel.forcings_from_xarray(dataset: Dataset) dict[str, ndarray] [source]¶
Extract forcings from an xarray.Dataset.
- PressureLevelModel.data_to_xarray(data: dict[str, float | ndarray | Array], times: ndarray | None, decoded: bool = True) Dataset [source]¶
Converts decoded model predictions to xarray.Dataset format.
- Parameters:
data – dict of arrays with shapes matching input/outputs or encoded model state for this model, i.e., with shape ([time,] level, longitude, latitude), where [time,] indicates an optional leading time dimension.
times – either None indicating no leading time dimension on any variables, or a coordinate array of times with shape (time,).
decoded – if True, use self.data_coords to determine the output coordinates; otherwise use self.model_coords.
- Returns:
An xarray.Dataset with appropriate coordinates and dimensions.
Demo dataset & models¶
These constructors are useful for testing purposes, to avoid the need to load
large datasets from cloud storage. Instead, they rely on small test datasets
packaged with the neuralgcm
code.
For non-testing purposes, see the model checkpoints from the paper in the Forecasting quick start.