Stable API

PressureLevelModel

The user-facing API for NeuralGCM models centers around PressureLevelModel:

class neuralgcm.PressureLevelModel(structure: WhirlModel, params: dict[str, dict[str, float | ndarray | Array]], gin_config: str)[source]

Inference-only API for models that predict dense data on pressure levels.

These models are trained on ECMWF ERA5 data on pressure-levels as stored in the Copernicus Data Store.

This class encapsulates the details of defining models (e.g., with Haiku) and hence should remain stable even for future NeuralGCM models.

Constructor

Use this class method to create a new model:

classmethod PressureLevelModel.from_checkpoint(checkpoint: Any) PressureLevelModel[source]

Creates a PressureLevelModel from a checkpoint.

Parameters:

checkpoint – dictionary with keys “model_config_str”, “aux_ds_dict” and “params” that specifies model gin configuration, supplemental xarray dataset with model-specific static features, and model parameters.

Returns:

Instance of a PressureLevelModel with weights and configuration specified by the checkpoint.

Properties

These properties describe the coordinate system and variables for which a model is defined:

property PressureLevelModel.timestep: timedelta64[source]

Spacing between internal model timesteps.

property PressureLevelModel.data_coords: CoordinateSystem[source]

Coordinate system for input and output data.

property PressureLevelModel.model_coords: CoordinateSystem[source]

Coordinate system for internal model state.

property PressureLevelModel.input_variables: list[str][source]

List of variable names required in inputs by this model.

property PressureLevelModel.forcing_variables: list[str][source]

List of variable names required in forcings by this model.

Learned methods

These method use trained model parameters to convert from input variables defined on data coordinates (i.e., pressure levels) to internal model state variables defined on model coordinates (i.e., sigma levels) and back.

advance and unroll allow for stepping forward in time.

PressureLevelModel.encode(inputs: dict[str, float | ndarray | Array], forcings: dict[str, float | ndarray | Array], rng_key: Any | None = None) Any[source]

Encode from pressure-level inputs & forcings to model state.

Parameters:
  • inputs – input data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords.

  • forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.

  • rng_key – optional JAX RNG key to use for encoding the state. Required if using stochastic models, otherwise ignored.

Returns:

Dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords.

PressureLevelModel.decode(state: Any, forcings: dict[str, float | ndarray | Array]) dict[str, Array][source]

Decode from model state to pressure-level outputs.

Parameters:
  • state – dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords.

  • forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.

Returns:

Outputs on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords.

PressureLevelModel.advance(state: Any, forcings: dict[str, float | ndarray | Array]) Any[source]

Advance model state one timestep forward.

Parameters:
  • state – dynamical core state on sigma levels, where all arrays have dimensions [level, zonal_wavenumber, total_wavenumber] matching model_coords

  • forcings – forcing data on pressure-levels, as a dict where each entry is an array with shape [level, longitude, latitude] matching data_coords. Single level data (e.g., sea surface temperature) should have a level dimension of size 1.

Returns:

State advanced one time-step forward.

PressureLevelModel.unroll(state: Any, forcings: dict[str, float | ndarray | Array], *, steps: int, timedelta: str | timedelta64 | Timestamp | timedelta | None = None, start_with_input: bool = False, post_process_fn: Callable[[Any], Any] | None = None) tuple[Any, dict[str, Array]][source]

Unroll predictions over many time-steps.

Usage:

advanced_state, outputs = model.unroll(state, forcings, steps=N)

where advanced_state is the advanced model state after N steps and outputs is a trajectory of decoded states on pressure-levels with a leading dimension of size N.

Parameters:
  • state – initial model state.

  • forcings – forcing data over the time-period spanned by the desired output trajectory. Should include a leading time-axis, but times can be at any desired granularity (e.g., it should be fine to supply daily forcing data, even if producing hourly outputs). The nearest forcing in time will be used for each internal advance() and decode() call.

  • steps – number of time-steps to take.

  • timedelta – size of each time-step to take, which must be a multiple of the internal model timestep. By default uses the internal model timestep.

  • start_with_input – if True, outputs are at times [0, ..., (steps - 1) * timestep] relative to the initial time; if False, outputs are at times [timestep, ..., steps * timestep].

  • post_process_fn – optional function to apply to each advanced state and current forcings to create outputs like post_process_fn(state, forcings), where forcings does not include a time axis. By default, uses model.decode.

Returns:

A tuple of the advanced state at time steps * timestamp, and outputs with a leading time axis at the time-steps specified by steps, timedelta and start_with_input.

Unit conversion

The internal state of NeuralGCM models uses non-dimensional units and “simulation time,” instead of SI and numpy.datetime64. These utilities allow for converting arrays back and forth, including inside JAX code:

PressureLevelModel.to_nondim_units(value: float | ndarray | Array | DataArray, units: str) float | ndarray | Array | DataArray[source]

Scale a value to the model’s internal non-dimensional units.

PressureLevelModel.from_nondim_units(value: float | ndarray | Array | DataArray, units: str) float | ndarray | Array | DataArray[source]

Scale a value from the model’s internal non-dimensional units.

PressureLevelModel.datetime64_to_sim_time(datetime64: ndarray) ndarray[source]

Converts a datetime64 array to sim_time.

PressureLevelModel.sim_time_to_datetime64(sim_time: ndarray) ndarray[source]

Converts a sim_time array to datetime64.

Xarray conversion

Xarray is convenient for data preparation and evaluation, but is not compatible with JAX. Use these methods to convert between xarray.Dataset objects and inputs/outputs from learned methods:

PressureLevelModel.inputs_from_xarray(dataset: Dataset) dict[str, ndarray][source]

Extract inputs from an xarray.Dataset.

PressureLevelModel.forcings_from_xarray(dataset: Dataset) dict[str, ndarray][source]

Extract forcings from an xarray.Dataset.

PressureLevelModel.data_to_xarray(data: dict[str, float | ndarray | Array], times: ndarray | None, decoded: bool = True) Dataset[source]

Converts decoded model predictions to xarray.Dataset format.

Parameters:
  • data – dict of arrays with shapes matching input/outputs or encoded model state for this model, i.e., with shape ([time,] level, longitude, latitude), where [time,] indicates an optional leading time dimension.

  • times – either None indicating no leading time dimension on any variables, or a coordinate array of times with shape (time,).

  • decoded – if True, use self.data_coords to determine the output coordinates; otherwise use self.model_coords.

Returns:

An xarray.Dataset with appropriate coordinates and dimensions.

Demo dataset & models

These constructors are useful for testing purposes, to avoid the need to load large datasets from cloud storage. Instead, they rely on small test datasets packaged with the neuralgcm code.

For non-testing purposes, see the model checkpoints from the paper in the Forecasting quick start.

neuralgcm.demo.load_data(coords: CoordinateSystem) Dataset[source]

Load demo data for the given coordinate system.

neuralgcm.demo.load_checkpoint_tl63_stochastic()[source]

Load a checkpoint for a toy TL63 stochastic model.