Lanczos Filter

Lanczos Filter#

此程式碼根據 liv0505/Lanczos-Filter 以及NCL的bandpass filter、filwgts_lanczos進行修改 (感謝台大大氣系廖建泓協助更新)。

首先建立自訂義函式:

#-------------------------filter function-------------------------------------------------#
def low_pass_weights(nwts, cutoff):

    import numpy as np

    """Calculate weights for a low pass Lanczos filter.
    Args:
    nwts: int  (Source: NCL)
        A scalar indicating the total number of weights (must be an odd number; nwt >= 3). 
        The more weights, the better the filter, but there is a greater loss of data.
    cutoff: float
        The cutoff frequency in inverse time steps.
    """
    w = np.zeros([nwts])
    n = nwts // 2
    w[n] = 2 * cutoff
    k = np.arange(1., n)
    sigma = np.sin(np.pi * k / n) * n / (np.pi * k)
    firstfactor = np.sin(2. * np.pi * cutoff * k) / (np.pi * k)
    w[n-1:0:-1] = firstfactor * sigma
    w[n+1:-1] = firstfactor * sigma
    return w[1:-1]

def high_pass_weights(nwts, cutoff):

    import numpy as np

    """Calculate weights for a high pass Lanczos filter.
    Args:
    nwts: int  (Source: NCL)
        A scalar indicating the total number of weights (must be an odd number; nwt >= 3). 
        The more weights, the better the filter, but there is a greater loss of data.
    cutoff: float
        The cutoff frequency in inverse time steps.
    """
    w = np.zeros([nwts])
    n = nwts // 2
    w[n] = 1 - 2 * cutoff            #w0
    k = np.arange(1., n)
    sigma = np.sin(np.pi * k / n) * n / (np.pi * k)
    firstfactor = np.sin(2. * np.pi * cutoff * k) / (np.pi * k)
    w[n-1:0:-1] = -firstfactor * sigma
    w[n+1:-1] = -firstfactor * sigma
    return w[1:-1]

def lanczos_hp_filter(data,nwts,fca,srate): 
    """" 
    Args: 
    nwts: int  (Source: NCL)
        A scalar indicating the total number of weights (must be an odd number; nwt >= 3). 
        The more weights, the better the filter, but there is a greater loss of data.

    fca: float
        A scalar indicating the cut-off frequency of the ideal high or low-pass filter: (0.0 < fca < 0.5).

    """

    import xarray as xr 
    # construct 3 days and 10 days low pass filters
    hfw = high_pass_weights(nwts, fca*(1/srate))
    weight_high  = xr.DataArray(hfw, dims = ['window'])

    # apply the filters using the rolling method with the weights
    highpass_hf = data.rolling(time = len(hfw), center = True).construct('window').dot(weight_high)

    # the bandpass is the difference of two lowpass filters.
    highpass = highpass_hf

    return highpass 

def lanczos_lp_filter(data,nwts,fca,srate): 
    """" 
    Args: 
    nwts: int  (Source: NCL)
        A scalar indicating the total number of weights (must be an odd number; nwt >= 3). 
        The more weights, the better the filter, but there is a greater loss of data.

    fca: float
        A scalar indicating the cut-off frequency of the ideal low-pass filter: (0.0 < fca < 0.5).
    """

    import xarray as xr 
    # construct 3 days and 10 days low pass filters
    lfw = low_pass_weights(nwts, fca*(1/srate))
    weight_low  = xr.DataArray(lfw, dims = ['window'])

    # apply the filters using the rolling method with the weights
    lowpass_lf = data.rolling(time = len(lfw), center = True).construct('window').dot(weight_low)

    # the bandpass is the difference of two lowpass filters.
    lowpass = lowpass_lf

    return lowpass 

def lanczos_bp_filter(data,nwts,fca,fcb,srate): 
    """" 
    Args: 
    nwts: int  (Source: NCL)
        A scalar indicating the total number of weights (must be an odd number; nwt >= 3). 
        The more weights, the better the filter, but there is a greater loss of data.

    fca: float
        A scalar indicating the cut-off frequency of the ideal high or low-pass filter: (0.0 < fca < 0.5).

    fcb: float
        A scalar used only when a band-pass filter is desired. It is the second cut-off frequency (fca < fcb < 0.5).
    """

    import xarray as xr 
    # construct 3 days and 10 days low pass filters
    hfw = low_pass_weights(nwts, fcb*(1/srate))
    lfw = low_pass_weights(nwts, fca*(1/srate))
    weight_high = xr.DataArray(hfw, dims = ['window'])
    weight_low  = xr.DataArray(lfw, dims = ['window'])

    # apply the filters using the rolling method with the weights
    lowpass_hf = data.rolling(time = len(hfw), center = True).construct('window').dot(weight_high)
    lowpass_lf = data.rolling(time = len(lfw), center = True).construct('window').dot(weight_low)

    # the bandpass is the difference of two lowpass filters.
    bandpass = lowpass_hf - lowpass_lf

    return bandpass 

接著我們只需要引用 lanczos_hp_filterlanczos_lp_filterlanczos_bp_filter函數,就可以分別計算高通、低通和帶通濾波。這個函式需要提供幾個引數:

  • data: 資料的DataArray。

  • nwts: 帶通濾波的權重 (weights),必須滿足 nwts \(\ge \) 3。數值越大時,濾波效果越好,但也會有越多資料損失 (NaN)。一般而言,對於Lanczos的濾波,nwts=201

  • fca, fcb: 帶通濾波的兩個截止頻率,必須滿足 0 < fca < fcb < 0.5。(如果是高通和低通濾波,只需要考慮fca)

  • srate: 資料的時間解析度,即一天有幾個資料點。

Example 1: 計算冬季 (DJF) OLR在10天以下、10-30天、30-60天、120天以上的variance佔total variance的比例。

import xarray as xr 
import numpy as np 
import cmaps
import matplotlib as mpl
from matplotlib import pyplot as plt
from cartopy import crs as ccrs   
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
mpl.rcParams['figure.dpi'] = 200

olr = (xr.open_dataset('data/olr.nc')
         .sel(lat=slice(-20,30),lon=slice(40,180)).olr) 
# 將OLR送進濾波函數
olr_10_hp    = lanczos_hp_filter(data=olr, nwts=201,fca=(1./10.)             ,srate=1)
olr_30_60_bp = lanczos_bp_filter(data=olr, nwts=201,fca=(1./60.),fcb=(1./30.),srate=1)
olr_10_30_bp = lanczos_bp_filter(data=olr, nwts=201,fca=(1./20.),fcb=(1./10.),srate=1)
olr_120_lp   = lanczos_lp_filter(data=olr, nwts=201,fca=(1./120.)            ,srate=1)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 3
      1 # 將OLR送進濾波函數
      2 olr_10_hp    = lanczos_hp_filter(data=olr, nwts=201,fca=(1./10.)             ,srate=1)
----> 3 olr_30_60_bp = lanczos_bp_filter(data=olr, nwts=201,fca=(1./60.),fcb=(1./30.),srate=1)
      4 olr_10_30_bp = lanczos_bp_filter(data=olr, nwts=201,fca=(1./20.),fcb=(1./10.),srate=1)
      5 olr_120_lp   = lanczos_lp_filter(data=olr, nwts=201,fca=(1./120.)            ,srate=1)

Cell In[1], line 117, in lanczos_bp_filter(data, nwts, fca, fcb, srate)
    114 weight_low  = xr.DataArray(lfw, dims = ['window'])
    116 # apply the filters using the rolling method with the weights
--> 117 lowpass_hf = data.rolling(time = len(hfw), center = True).construct('window').dot(weight_high)
    118 lowpass_lf = data.rolling(time = len(lfw), center = True).construct('window').dot(weight_low)
    120 # the bandpass is the difference of two lowpass filters.

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/util/deprecation_helpers.py:143, in deprecate_dims.<locals>.wrapper(*args, **kwargs)
    135     emit_user_level_warning(
    136         f"The `{old_name}` argument has been renamed to `dim`, and will be removed "
    137         "in the future. This renaming is taking place throughout xarray over the "
   (...)
    140         PendingDeprecationWarning,
    141     )
    142     kwargs["dim"] = kwargs.pop(old_name)
--> 143 return func(*args, **kwargs)

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/dataarray.py:5110, in DataArray.dot(self, other, dim)
   5107 if not isinstance(other, DataArray):
   5108     raise TypeError("dot only operates on DataArrays.")
-> 5110 return computation.dot(self, other, dim=dim)

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/util/deprecation_helpers.py:143, in deprecate_dims.<locals>.wrapper(*args, **kwargs)
    135     emit_user_level_warning(
    136         f"The `{old_name}` argument has been renamed to `dim`, and will be removed "
    137         "in the future. This renaming is taking place throughout xarray over the "
   (...)
    140         PendingDeprecationWarning,
    141     )
    142     kwargs["dim"] = kwargs.pop(old_name)
--> 143 return func(*args, **kwargs)

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/computation.py:1872, in dot(dim, *arrays, **kwargs)
   1869 # subscripts should be passed to np.einsum as arg, not as kwargs. We need
   1870 # to construct a partial function for apply_ufunc to work.
   1871 func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
-> 1872 result = apply_ufunc(
   1873     func,
   1874     *arrays,
   1875     input_core_dims=input_core_dims,
   1876     output_core_dims=output_core_dims,
   1877     join=join,
   1878     dask="allowed",
   1879 )
   1880 return result.transpose(*all_dims, missing_dims="ignore")

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/computation.py:1271, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
   1269 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1270 elif any(isinstance(a, DataArray) for a in args):
-> 1271     return apply_dataarray_vfunc(
   1272         variables_vfunc,
   1273         *args,
   1274         signature=signature,
   1275         join=join,
   1276         exclude_dims=exclude_dims,
   1277         keep_attrs=keep_attrs,
   1278     )
   1279 # feed Variables directly through apply_variable_ufunc
   1280 elif any(isinstance(a, Variable) for a in args):

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/computation.py:313, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    308 result_coords, result_indexes = build_output_coords_and_indexes(
    309     args, signature, exclude_dims, combine_attrs=keep_attrs
    310 )
    312 data_vars = [getattr(a, "variable", a) for a in args]
--> 313 result_var = func(*data_vars)
    315 out: tuple[DataArray, ...] | DataArray
    316 if signature.num_outputs > 1:

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/computation.py:824, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    819     if vectorize:
    820         func = _vectorize(
    821             func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
    822         )
--> 824 result_data = func(*input_data)
    826 if signature.num_outputs == 1:
    827     result_data = (result_data,)

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/xarray/core/duck_array_ops.py:59, in einsum(*args, **kwargs)
     57 else:
     58     xp = get_array_namespace(*args)
---> 59     return xp.einsum(*args, **kwargs)

File /data/wtsai/micromamba/p3t/lib/python3.10/site-packages/numpy/_core/einsumfunc.py:1429, in einsum(out, optimize, *operands, **kwargs)
   1427     if specified_out:
   1428         kwargs['out'] = out
-> 1429     return c_einsum(*operands, **kwargs)
   1431 # Check the kwargs to avoid a more cryptic error later, without having to
   1432 # repeat default values here
   1433 valid_einsum_kwargs = ['dtype', 'order', 'casting']

KeyboardInterrupt: 
olr_djf = olr.sel(time=olr.time.dt.month.isin([1,2,12]))

# 計算各波段變異量相對於總變異量的比例
olr_10_hp_djf = olr_10_hp.sel(time=olr_10_hp.time.dt.month.isin([1,2,12]))
djf_10hp_var_rt = ( olr_10_hp_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0) / 
                          olr_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0)) 

olr_30_60_bp_djf = olr_30_60_bp.sel(time=olr_30_60_bp.time.dt.month.isin([1,2,12]))
djf_30_60_var_rt = ( olr_30_60_bp_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0) / 
                              olr_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0)) 

olr_10_30_bp_djf = olr_10_30_bp.sel(time=olr_10_30_bp.time.dt.month.isin([1,2,12]))
djf_10_30_var_rt = ( olr_10_30_bp_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0) / 
                              olr_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0)) 

olr_120_lp_djf = olr_120_lp.sel(time=olr_120_lp.time.dt.month.isin([1,2,12]))
djf_120lp_var_rt = ( olr_120_lp_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0) / 
                            olr_djf.sel(time=slice('1999-01-01','2020-12-31')).var(axis=0)) 
proj = ccrs.PlateCarree()     
fig,axes = plt.subplots(2,2,figsize=(12,7),subplot_kw={'projection':proj})   
ax = axes.flatten()

# 繪圖
cmap = cmaps.precip3_16lev
cf_10d_hp = djf_10hp_var_rt.plot.contourf("lon","lat",  
                                            transform=proj,     
                                            ax=ax[0],              
                                            levels=np.arange(0.2,0.65,0.05),   
                                            cmap=cmap,   
                                            add_colorbar=True, 
                                            extend='both',
                                            cbar_kwargs={'orientation': 'horizontal', 'aspect': 30, 'label': ' '} 
                                            )
ax[0].set_title(' ')
ax[0].set_title('(a) 10-day high pass',loc='left')
cf_10_30d = djf_10_30_var_rt.plot.contourf("lon","lat",  
                                            transform=proj,     
                                            ax=ax[2],              
                                            levels=[0.05,0.1,0.15,0.2,0.25,0.3,0.4],   
                                            cmap=cmap,   
                                            add_colorbar=True, 
                                            extend='both',
                                            cbar_kwargs={'orientation': 'horizontal', 'aspect': 30, 'label': ' '} 
                                            )
ax[2].set_title(' ')
ax[2].set_title('(b) 10-30-day band pass',loc='left')
cf_30_60d = djf_30_60_var_rt.plot.contourf("lon","lat",  
                                            transform=proj,   
                                            ax=ax[1],         
                                            levels=np.arange(0.03,0.21,0.03),
                                            cmap=cmap,   
                                            add_colorbar=True, 
                                            extend='both',
                                            cbar_kwargs={'orientation': 'horizontal', 'aspect': 30, 'label': ' '}  
                                            )
ax[1].set_title(' ')
ax[1].set_title('(c) 30-60-day band pass',loc='left')
cf_120d_lp = djf_120lp_var_rt.plot.contourf("lon","lat",  
                                            transform=proj,     
                                            ax=ax[3],              
                                            levels=[0.05,0.1,0.15,0.2,0.25,0.3,0.4],   
                                            cmap=cmap,   
                                            add_colorbar=True, 
                                            extend='both',
                                            cbar_kwargs={'orientation': 'horizontal', 'aspect': 30, 'label': ' '} 
                                            )
ax[3].set_title(' ')
ax[3].set_title('(d) 120-day low pass',loc='left')

lon_formatter = LONGITUDE_FORMATTER
lat_formatter = LATITUDE_FORMATTER  
for i in range(0,4):
    ax[i].set_title('')
    ax[i].set_extent([40,180,-20,30],crs=proj)
    ax[i].set_xticks(np.arange(40,200,20), crs=proj)
    ax[i].set_yticks(np.arange(-20,40,10), crs=proj)   
    ax[i].xaxis.set_major_formatter(lon_formatter)
    ax[i].yaxis.set_major_formatter(lat_formatter) 
    ax[i].coastlines()                                    
    ax[i].set_ylabel(' ')   
    ax[i].set_xlabel(' ')

plt.show()
/Users/waynetsai/micromamba/envs/p3/lib/python3.10/site-packages/shapely/predicates.py:798: RuntimeWarning: invalid value encountered in intersects
  return lib.intersects(a, b, **kwargs)
_images/6c3222b41f433b198cb0ed29a8ab8734ff14805477ad3d2f0c3c7500d5c61ec3.png