import logging
import warnings
import numpy as np
from stdatamodels.jwst import datamodels
from jwst.datamodels import ModelContainer
from jwst.datamodels.utils.flat_multispec import expand_flat_spec
from jwst.extract_1d.spec_wcs import create_spectral_wcs
log = logging.getLogger(__name__)
# attributes that we want to copy unmodified from input to output spectra
SPECMETA_ATTRIBUTES = [
"source_id",
"dispersion_direction",
"source_type",
"source_ra",
"source_dec",
]
__all__ = [
"InputSpectrumModel",
"OutputSpectrumModel",
"count_input",
"compute_output_wl",
"check_exptime",
"combine_1d_spectra",
]
[docs]
class OutputSpectrumModel:
"""
Model an output spectrum.
Attributes
----------
wavelength : ndarray
Output wavelength.
flux : ndarray
Output flux.
flux_error : ndarray
Output error on the flux.
surf_bright : ndarray
Output surface brightness.
sb_error : ndarray
Output error on the surface brightness.
dq : ndarray
Output DQ array.
weight : ndarray
Weight value for each spectral element.
count : ndarray
Input value count for each output spectral element.
wcs : gwcs.wcs.WCS
Output spectral WCS.
normalized : bool
Flag to indicate data has been combined (sums are normalized).
"""
def __init__(self):
self.wavelength = None
self.flux = None
self.flux_error = None
self.surf_bright = None
self.sb_error = None
self.dq = None
self.weight = None
self.count = None
self.wcs = None
self.normalized = False
self.source_id = None
self.flux_unit = None
self.sb_unit = None
[docs]
def assign_wavelengths(self, input_spectra):
"""
Create an array of wavelengths to use for the output spectrum.
Take the union of all input wavelengths, then call :func:`compute_output_wl`
to bin wavelengths in groups of the number of overlapping spectra.
Parameters
----------
input_spectra : list of `~jwst.combine_1d.combine1d.InputSpectrumModel`
List of input spectra.
"""
(wl, n_input_spectra) = count_input(input_spectra)
self.wavelength = np.sort(compute_output_wl(wl, n_input_spectra))
self.wcs = create_spectral_wcs(
input_spectra[0].right_ascension[0], input_spectra[0].declination[0], self.wavelength
)
[docs]
def accumulate_sums(self, input_spectra, sigma_clip=None):
"""
Compute a weighted sum of all the input spectra.
Each pixel of each input spectrum will be added to one pixel of
the output spectrum. The wavelength spacing of the input and
output should be comparable, so each input pixel will usually
correspond to one output pixel (i.e., have the same wavelength,
to within half a pixel). However, for any given input spectrum,
there can be output pixels that are incremented by more than one
input pixel, and there can be output pixels that are not incremented.
If there are several input spectra, such gaps will hopefully be
filled in.
Pixels in an input spectrum that are flagged (via the DQ array)
as DO_NOT_USE will not be used, i.e., they will simply be ignored
when incrementing output arrays from an input spectrum. If an
input pixel is flagged with a non-zero value other than DO_NOT_USE,
the value will be propagated to the output DQ array via bitwise OR.
Parameters
----------
input_spectra : list of `~jwst.combine_1d.combine1d.InputSpectrumModel`
List of input spectra.
sigma_clip : float, optional
Factor for clipping outliers in spectral combination.
"""
# This is the data type for the output spectrum. We'll use double
# precision for accumulating sums for most columns, but for the DQ
# array, use the correct output data type.
cmb_dtype = datamodels.CombinedSpecModel().spec_table.dtype
dq_dtype = cmb_dtype.fields["DQ"][0]
nelem = self.wavelength.shape[0]
nspec = len(input_spectra)
dq = np.zeros(nelem, dtype=dq_dtype)
flux = np.zeros((nspec, nelem), dtype=np.float64)
flux_error = np.zeros((nspec, nelem), dtype=np.float64)
surf_bright = np.zeros((nspec, nelem), dtype=np.float64)
sb_error = np.zeros((nspec, nelem), dtype=np.float64)
weight = np.zeros((nspec, nelem), dtype=np.float64)
count = np.zeros((nspec, nelem), dtype=np.float64)
self.flux_unit = input_spectra[0].flux_unit
self.sb_unit = input_spectra[0].sb_unit
n_nan = 0 # initial value
ninputs = 0
for s, in_spec in enumerate(input_spectra):
ninputs += 1
if in_spec.name is not None:
slit_name = f"{ninputs}, slit {in_spec.name}"
else:
slit_name = ninputs
log.info(f"Accumulating data from input spectrum {slit_name}")
# Get the pixel numbers in the output corresponding to the
# wavelengths of the current input spectrum.
out_pixel = self.wcs.invert(
in_spec.right_ascension, in_spec.declination, in_spec.wavelength
)
# i is a pixel number in the current input spectrum, and
# k is the corresponding pixel number in the output spectrum.
nan_flag = np.isnan(out_pixel)
n_nan += nan_flag.sum()
for i in range(len(out_pixel)):
# Need to check on dq and nan flux because dq is not set for some x1d
if (in_spec.dq[i] & datamodels.dqflags.pixel["DO_NOT_USE"] > 0) | np.isnan(
in_spec.flux[i]
):
continue
# Round to the nearest pixel.
if nan_flag[i]: # skip if the pixel number is NaN
continue
k = round(float(out_pixel[i]))
if k < 0 or k >= nelem:
continue
dq[k] |= in_spec.dq[i]
flux[s, k] = in_spec.flux[i]
flux_error[s, k] = in_spec.flux_error[i]
surf_bright[s, k] = in_spec.surf_bright[i]
sb_error[s, k] = in_spec.sb_error[i]
weight[s, k] = in_spec.weight[i]
count[s, k] = 1.0
(flux, flux_error, surf_bright, sb_error, weight, count) = self.combine_spectra(
flux, flux_error, surf_bright, sb_error, weight, count, sigma_clip=sigma_clip
)
if n_nan > 0:
log.warning(f"{int(n_nan)} output pixel numbers were NaN")
self.flux = flux
self.flux_error = flux_error
self.surf_bright = surf_bright
self.sb_error = sb_error
self.dq = dq
self.weight = weight
self.count = count
# Since the output wavelengths will not usually be exactly the same
# as the input wavelengths, it's possible that there will be output
# pixels for which there is no corresponding pixel in any of the
# input spectra. Check for this case.
index = np.where(self.count > 0.0)
n_good = len(index[0])
if nelem > n_good:
log.warning(
f"{int(nelem - n_good)} elements of output had no corresponding input data;"
)
log.warning(" these elements will be omitted.")
self.wavelength = self.wavelength[index]
self.flux = self.flux[index]
self.flux_error = self.flux_error[index]
self.surf_bright = self.surf_bright[index]
self.sb_error = self.sb_error[index]
self.dq = self.dq[index]
self.weight = self.weight[index]
self.count = self.count[index]
del index
[docs]
def combine_spectra(
self, flux, flux_error, surf_bright, sb_error, weight, count, sigma_clip=None
):
"""
Combine accumulated spectra.
Parameters
----------
flux : ndarray, 2-D
Tabulated fluxes for the input spectra in the format
``[N spectra, M wavelengths]``.
flux_error : ndarray, 2-D
Flux errors of input spectra.
surf_bright : ndarray, 2-D
Surface brightnesses of input spectra.
sb_error : ndarray, 2-D
Surface brightness errors for input spectra.
weight : ndarray, 2-D
Pixel weights for input spectra
count : ndarray, 2-D
Count of how many values at each index in the input arrays.
sigma_clip : float, optional
Factor for clipping outliers. Compares input spectra to the
median and medaian absolute devaition, by default None.
Returns
-------
flux : ndarray, 1-D
Combined 1-D fluxes.
flux_error : ndarray, 1-D
Combined 1-D flux errors.
surf_bright : ndarray, 1-D
Combined 1-D surface brightnesses.
sb_error : ndarray, 1-D
Combined 1-D surface brightness errors.
weight : ndarray, 1-D
Total, per wavelength weights.
count : ndarray, 1-D
Total count of spectra contributing to each wavelength.
"""
# Catch warnings for all NaN slices in an array.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if sigma_clip is not None:
# Copy the fluxes for modifying
flux_2d = np.array(flux * weight)
# Mask any missing pixels in the input spectra
missing = (count < 1) | (flux_2d == 0.0)
flux_2d[missing] = np.nan
# Calculate median and median absolute deviation
med_flux = np.nanmedian(flux_2d, axis=0)
mad = np.nanmedian(np.abs(flux_2d - med_flux))
# Clip any outlier pixels in the input spectra
clipped = np.abs(flux * weight - med_flux) > sigma_clip * mad
flux[clipped] = np.nan
flux_error[clipped] = np.nan
surf_bright[clipped] = np.nan
sb_error[clipped] = np.nan
count[clipped] = 0
weight[clipped] = 0
# Perform a weighted sum of the input spectra
sum_weight = np.nansum(weight, axis=0)
sum_weight_nonzero = np.where(sum_weight > 0.0, sum_weight, 1.0)
flux = np.nansum(flux * weight, axis=0) / sum_weight_nonzero
flux_error = np.sqrt(np.nansum((flux_error * weight) ** 2, axis=0)) / sum_weight_nonzero
surf_bright = np.nansum(surf_bright * weight, axis=0) / sum_weight_nonzero
sb_error = np.sqrt(np.nansum((sb_error * weight) ** 2, axis=0)) / sum_weight_nonzero
count = np.nansum(count, axis=0)
self.normalized = True
return flux, flux_error, surf_bright, sb_error, sum_weight, count
[docs]
def create_output_data(self):
"""
Create the output data.
Returns
-------
output_model : `~stdatamodels.jwst.datamodels.CombinedSpecModel`
A table of combined spectral data.
"""
if not self.normalized:
log.warning("Data have not been divided by the sum of the weights.")
cmb_dtype = datamodels.CombinedSpecModel().spec_table.dtype
# Note that these arrays have to be in the right order.
data = np.array(
list(
zip(
self.wavelength,
self.flux,
self.flux_error,
self.surf_bright,
self.sb_error,
self.dq,
self.weight,
self.count,
strict=False,
)
),
dtype=cmb_dtype,
)
output_model = datamodels.CombinedSpecModel(spec_table=data)
output_model.spec_table.columns["wavelength"].unit = "um"
output_model.spec_table.columns["flux"].unit = self.flux_unit
output_model.spec_table.columns["error"].unit = self.flux_unit
output_model.spec_table.columns["surf_bright"].unit = self.sb_unit
output_model.spec_table.columns["sb_error"].unit = self.sb_unit
return output_model
[docs]
def close(self):
"""Set data attributes to null values."""
self.wavelength = None
self.flux = None
self.flux_error = None
self.surf_bright = None
self.sb_error = None
self.dq = None
self.weight = None
self.count = None
self.wcs = None
self.normalized = False
self.source_id = None
[docs]
def compute_output_wl(wl, n_input_spectra):
"""
Compute output wavelengths.
In summary, the output wavelengths are computed by binning the
input wavelengths in groups of the number of overlapping spectra.
The merged and sorted input wavelengths are in ``wl``.
If the wavelength arrays of the input spectra are nearly all the
same, the values in wl will be in clumps of N nearly identical
values (for N input files). We want to average the wavelengths
in those clumps to get the output wavelengths. However, if the
wavelength scales of the input files differ slightly, the values
in ``wl`` may clump in some regions but not in others. The algorithm
below tries to find clumps; a clump is identified by having a
small standard deviation of all the wavelengths in a slice of ``wl``,
compared with neighboring slices. It is intended that this should
not find clumps where they're not present, i.e., in regions where
the wavelengths of the input spectra are not well aligned with each
other. These regions will be gaps in the initial determination
of output wavelengths based on clumps. In such regions, output
wavelengths will then be computed by just averaging groups of
input wavelengths to fill in the gaps.
Parameters
----------
wl : 1-D array
An array containing all the wavelengths from all the input
spectra, sorted in increasing order.
n_input_spectra : 1-D array
An integer array of the same length as ``wl``. For a given
array index ``k`` (for example), ``n_input_spectra[k]`` is the number of
input spectra that cover wavelength ``wl[k]``.
Returns
-------
wavelength : 1-D array
Array of wavelengths for the output spectrum.
"""
nwl = len(wl)
# sigma is an array of the standard deviation at each element
# of wl, over n_input_spectra elements. A small value implies that
# there's a clump, i.e. several elements of wl with nearly the
# same wavelength.
sigma = np.zeros(nwl, dtype=np.float64) + 9999.0
# mean_wl is the mean wavelength over the same slice of wl that we
# used to compute sigma. If sigma is small enough that it looks
# as if there's a clump, we'll copy the mean_wl value to temp_wl
# to be one element of the output wavelengths.
mean_wl = np.zeros(nwl, dtype=np.float64) - 99.0
# temp_wl has the same number of elements as wl, but we expect the
# array of output wavelengths to be significantly smaller, so
# temp_wl is initialized to a negative value as a flag. Positive
# elements will be copied to the array of output wavelengths.
temp_wl = np.zeros(nwl, dtype=np.float64) - 99.0
for k in range(nwl):
n = n_input_spectra[k]
if n == 1:
sigma[k] = 0.0
mean_wl[k] = wl[k]
temp_wl[k] = mean_wl[k]
else:
k1 = k + n // 2 + 1
k0 = k1 - n
if k0 >= 0 and k1 <= nwl:
sigma[k] = wl[k0:k1].std()
mean_wl[k] = wl[k0:k1].mean()
if sigma[k] == 0.0:
temp_wl[k] = mean_wl[k]
cutoff = 0.8
for k in range(nwl):
# If sigma[k] equals 0, temp_wl has already been assigned.
if sigma[k] > 0.0:
if k == 0:
if sigma[k] < cutoff * sigma[1]:
temp_wl[k] = mean_wl[k]
elif k == nwl - 1:
if sigma[k] < cutoff * sigma[nwl - 2]:
temp_wl[k] = mean_wl[k]
else:
if sigma[k] < cutoff * (sigma[k - 1] + sigma[k + 1]) / 2.0:
temp_wl[k] = mean_wl[k]
# Fill gaps in the output wavelengths by taking averages of the
# input wavelengths. If there are n overlapping input spectra,
# average a block of n elements of `wl`.
done = False
i = 0
while not done:
if i >= nwl:
done = True
else:
n = n_input_spectra[i]
middle = n // 2
if i + middle < nwl:
n = n_input_spectra[i + middle]
if i + n < nwl:
# The nominal range is i - 1 to i + n (inclusive) for
# checking whether an output wavelength has already
# been assigned.
low = max(i - 1, 0)
high = min(i + n + 1, nwl - 1)
if temp_wl[low:high].max() <= 0.0:
temp_wl[i] = wl[i : i + n].mean()
i += n
return temp_wl[np.where(temp_wl > 0.0)].copy()
[docs]
def check_exptime(exptime_key):
"""
Check exptime_key for validity.
This function checks ``exptime_key``. If it is valid, the corresponding
value used by the metadata interface will be returned. This will be
either "integration_time" or "exposure_time". If it is invalid,
"unit weight" will be returned (meaning that a weight of 1 will be
used when averaging spectra), and a warning will be logged.
Parameters
----------
exptime_key : str
A keyword or string indicating what value (integration time or
exposure time) should be used as a weight when combing spectra.
Returns
-------
exptime_key : str
The value will be either "integration_time", "exposure_time",
or "unit_weight".
"""
exptime_lwr = exptime_key.lower()
if exptime_lwr.startswith("integration") or exptime_lwr == "effinttm":
exptime_key = "integration_time"
log.info("Using integration time as the weight.")
elif exptime_lwr.startswith("exposure") or exptime_lwr == "effexptm":
exptime_key = "exposure_time"
log.info("Using exposure time as the weight.")
elif exptime_lwr == "unit_weight" or exptime_lwr == "unit weight":
exptime_key = "unit_weight"
log.info("Using weight = 1.")
else:
log.warning(f"Don't understand exptime_key = '{exptime_key}'; using unit weight.")
log.info("The options for exptime_key are:")
log.info(" integration_time, effinttm, exposure_time, effexptm, unit_weight, unit weight")
exptime_key = "unit_weight"
return exptime_key
def _read_input_spectra(input_model, exptime_key, input_spectra):
"""
Read input spectra from a datamodel.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.MultiSpecModel`, \
`~stdatamodels.jwst.datamodels.TSOMultiSpecModel`, or \
`~stdatamodels.jwst.datamodels.MRSSpecModel`
A datamodel with a ``spec`` attribute, containing spectra.
If `~stdatamodels.jwst.datamodels.TSOMultiSpecModel`,
integrations in the spectral table rows
are expanded into separate spectra.
exptime_key : str
Exposure time key to use for weighting.
input_spectra : dict
Dictionary to hold input spectra, keyed by spectral order;
Updated in place.
Returns
-------
input_spectra : dict
The updated dictionary, holding all spectra in the input model.
Raises
------
TypeError
If the input datamodel does not have a ``spec`` attribute.
"""
if not hasattr(input_model, "spec"):
raise TypeError(f"Invalid input datamodel: {type(input_model)}")
if isinstance(input_model, datamodels.TSOMultiSpecModel):
spectra = expand_flat_spec(input_model).spec
else:
spectra = input_model.spec
for in_spec in spectra:
if not np.any(np.isfinite(in_spec.spec_table.field("flux"))):
if in_spec.meta.hasattr("group_id"):
msg = (
f"Input spectrum {in_spec.source_id} order {in_spec.spectral_order} "
f"from group_id {in_spec.meta.group_id} has no valid flux values; skipping."
)
else:
msg = (
f"Input spectrum {in_spec.source_id} order {in_spec.spectral_order} "
"has no valid flux values; skipping."
)
log.warning(msg)
continue
spectral_order = in_spec.spectral_order
if spectral_order not in input_spectra:
input_spectra[spectral_order] = []
input_spectra[spectral_order].append(InputSpectrumModel(input_model, in_spec, exptime_key))
return input_spectra
[docs]
def combine_1d_spectra(input_model, exptime_key, sigma_clip=None):
"""
Combine the input spectra.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
The input spectra. This will likely be a
`~jwst.datamodels.container.ModelContainer` object,
but may also be a multi-spectrum model, such as
`~stdatamodels.jwst.datamodels.MultiSpecModel` or
`~stdatamodels.jwst.datamodels.TSOMultiSpecModel`.
Input spectra may have different spectral orders
or wavelengths but should all share the same target.
May be updated in place if processing is skipped.
exptime_key : str
A string identifying which keyword to use to get the exposure time,
which is used as a weight when combining spectra. The value should
be one of: "exposure_time" (the default), "integration_time",
or "unit_weight".
Returns
-------
output_model : `~stdatamodels.jwst.datamodels.MultiCombinedSpecModel`
A combined spectra datamodel.
"""
log.debug(f"Using exptime_key = {exptime_key}.")
exptime_key = check_exptime(exptime_key)
input_spectra = {}
output_spectra = {}
if isinstance(input_model, ModelContainer):
for ms in input_model:
_read_input_spectra(ms, exptime_key, input_spectra)
else:
_read_input_spectra(input_model, exptime_key, input_spectra)
if len(input_spectra) == 0:
log.error("No valid input spectra found for source. Skipping.")
input_model.meta.cal_step.combine_1d = "SKIPPED"
return input_model
for order in input_spectra:
output_spectra[order] = OutputSpectrumModel()
output_spectra[order].assign_wavelengths(input_spectra[order])
output_spectra[order].accumulate_sums(input_spectra[order], sigma_clip=sigma_clip)
output_model = datamodels.MultiCombinedSpecModel()
for order in output_spectra:
output_order = output_spectra[order].create_output_data()
output_order.spectral_order = order
for attr in SPECMETA_ATTRIBUTES:
setattr(output_order, attr, getattr(input_spectra[order][0], attr))
output_model.spec.append(output_order)
# Copy one of the input headers to output.
if isinstance(input_model, ModelContainer):
output_model.update(input_model[0], only="PRIMARY")
else:
output_model.update(input_model, only="PRIMARY")
# Looks clunky, but need an output_spec instance to copy wcs
output_model.meta.wcs = output_spectra[list(output_spectra)[0]].wcs
output_model.meta.cal_step.combine_1d = "COMPLETE"
for order in input_spectra:
for in_spec in input_spectra[order]:
in_spec.close()
for order in output_spectra:
output_spectra[order].close()
return output_model