Skip to content

DM-49008: Add options to use cell-based coadds #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
50 changes: 45 additions & 5 deletions python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lsst.pipe.base import (Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections)
import lsst.pipe.base.connectionTypes as cT

from lsst.pex.config import ConfigurableField
from lsst.pex.config import ConfigurableField, Field
from lsst.meas.base import SkyMapIdGeneratorConfig
from lsst.meas.deblender import SourceDeblendTask
from lsst.meas.extensions.scarlet import ScarletDeblendTask
Expand Down Expand Up @@ -120,6 +120,20 @@ class DeblendCoaddSourcesMultiConnections(PipelineTaskConnections,
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
coadds_cell = cT.Input(
doc="Exposure on which to run deblending",
name="{inputCoaddName}CoaddCell",
storageClass="MultipleCellCoadd",
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
backgrounds = cT.Input(
doc="Background model to subtract from the cell-based coadd",
name="{inputCoaddName}Coadd_calexp_background",
storageClass="Background",
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
outputSchema = cT.InitOutput(
doc="Output of the schema used in deblending task",
name="{outputCoaddName}Coadd_deblendedFlux_schema",
Expand Down Expand Up @@ -161,9 +175,20 @@ def __init__(self, *, config=None):
del self.fluxCatalogs
del self.templateCatalogs

if config:
if config.useCellCoadds:
del self.coadds
else:
del self.coadds_cell
del self.backgrounds


class DeblendCoaddSourcesMultiConfig(PipelineTaskConfig,
pipelineConnections=DeblendCoaddSourcesMultiConnections):
useCellCoadds = Field[bool](
doc="Use cell-based coadds instead of regular coadds?",
default=False,
)
multibandDeblend = ConfigurableField(
target=ScarletDeblendTask,
doc="Task to deblend an images in multiple bands"
Expand Down Expand Up @@ -246,13 +271,28 @@ def __init__(self, initInputs, **kwargs):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Obtain the list of bands, sort them (alphabetically), then reorder
# all input lists to match this band order.
bandOrder = [dRef.dataId["band"] for dRef in inputRefs.coadds]
coaddRefs = inputRefs.coadds_cell if self.config.useCellCoadds else inputRefs.coadds
bandOrder = [dRef.dataId["band"] for dRef in coaddRefs]
bandOrder.sort()
inputRefs = reorderRefs(inputRefs, bandOrder, dataIdKey="band")
inputs = butlerQC.get(inputRefs)
inputs["idFactory"] = self.config.idGenerator.apply(butlerQC.quantum.dataId).make_table_id_factory()
inputs["bands"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]
outputs = self.run(**inputs)
bands = [dRef.dataId["band"] for dRef in coaddRefs]
mergedDetections = inputs.pop("mergedDetections")
if self.config.useCellCoadds:
exposures = [mcc.stitch().asExposure() for mcc in inputs.pop("coadds_cell")]
backgrounds = inputs.pop("backgrounds")
for exposure, background in zip(exposures, backgrounds):
exposure.image -= background.getImage()
coadds = exposures
else:
coadds = inputs.pop("coadds")
assert not inputs, "runQuantum got extra inputs"
outputs = self.run(
coadds=coadds,
bands=bands,
mergedDetections=mergedDetections,
idFactory=self.config.idGenerator.apply(butlerQC.quantum.dataId).make_table_id_factory(),
)
butlerQC.put(outputs, outputRefs)

def run(self, coadds, bands, mergedDetections, idFactory):
Expand Down
58 changes: 49 additions & 9 deletions python/lsst/pipe/tasks/fit_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class CoaddMultibandFitInputConnections(
dimensions=("tract", "patch", "band", "skymap"),
multiple=True,
)
coadds_cell = cT.Input(
doc="Cell-coadd exposures on which to run fits",
name="{name_coadd}CoaddCell",
storageClass="MultipleCellCoadd",
dimensions=("tract", "patch", "band", "skymap"),
multiple=True,
)
backgrounds = cT.Input(
doc="Background models to subtract from the coadds_cell",
name="{name_coadd}Coadd_calexp_background",
storageClass="Background",
dimensions=("tract", "patch", "band", "skymap"),
multiple=True,
)
models_psf = cT.Input(
doc="Input PSF model parameter catalog",
# Consider allowing independent psf fit method
Expand Down Expand Up @@ -198,9 +212,19 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
return adjusted_inputs, {}

def __init__(self, *, config=None):
super().__init__(config=config)
if config is None:
return

if config.drop_psf_connection:
del self.models_psf

if config.use_cell_coadds:
del self.coadds
else:
del self.coadds_cell
del self.backgrounds


class CoaddMultibandFitConnections(CoaddMultibandFitInputConnections):
cat_output = cT.Output(
Expand Down Expand Up @@ -295,6 +319,10 @@ class CoaddMultibandFitBaseConfig(
target=CoaddMultibandFitSubTask,
doc="Task to fit sources using multiple bands",
)
use_cell_coadds = pexConfig.Field[bool](
doc="Use cell coadd images for object fitting?",
default=False,
)
idGenerator = SkyMapIdGeneratorConfig.make_field()

def get_band_sets(self):
Expand Down Expand Up @@ -332,18 +360,30 @@ class CoaddMultibandFitBase:
def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInputs]:
id_tp = self.config.idGenerator.apply(butlerQC.quantum.dataId).catalog_id
# This is a roundabout way of ensuring all inputs get sorted and matched
keys = ["cats_meas", "coadds"]
if self.config.use_cell_coadds:
keys = ["cats_meas", "coadds_cell", "backgrounds"]
else:
keys = ["cats_meas", "coadds"]
has_psf_models = "models_psf" in inputs
if has_psf_models:
keys.append("models_psf")
input_refs_objs = ((getattr(inputRefs, key), inputs[key]) for key in keys)
inputs_sorted = tuple(
{dRef.dataId: obj for dRef, obj in zip(refs, objs)}
for refs, objs in input_refs_objs
)
cats = inputs_sorted[0]
exps = inputs_sorted[1]
models_psf = inputs_sorted[2] if has_psf_models else None
input_refs_objs = {key: (getattr(inputRefs, key), inputs[key]) for key in keys}
inputs_sorted = {
key: {dRef.dataId: obj for dRef, obj in zip(refs, objs, strict=True)}
for key, (refs, objs) in input_refs_objs.items()
}
cats = inputs_sorted["cats_meas"]
if self.config.use_cell_coadds:
exps = {}
for data_id, background in inputs_sorted["backgrounds"].items():
mcc = inputs_sorted["coadds_cell"][data_id]
stitched_coadd = mcc.stitch()
exposure = stitched_coadd.asExposure()
exposure.image -= background.getImage()
exps[data_id] = exposure
else:
exps = inputs_sorted["coadds"]
models_psf = inputs_sorted["models_psf"] if has_psf_models else None
dataIds = set(cats).union(set(exps))
models_scarlet = inputs["models_scarlet"]
catexp_dict = {}
Expand Down
45 changes: 43 additions & 2 deletions python/lsst/pipe/tasks/fit_coadd_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ class CoaddPsfFitConnections(
storageClass="ExposureF",
dimensions=("tract", "patch", "band", "skymap"),
)
coadd_cell = cT.Input(
doc="Cell-coadd image to fit a PSF model to",
name="{name_coadd}CoaddCell",
storageClass="MultipleCellCoadd",
dimensions=("tract", "patch", "band", "skymap"),
)
background = cT.Input(
doc="Background model to subtract from the coadd_cell",
name="{name_coadd}Coadd_calexp_background",
storageClass="Background",
dimensions=("tract", "patch", "band", "skymap"),
)
cat_meas = cT.Input(
doc="Deblended single-band source catalog",
name="{name_coadd}Coadd_meas",
Expand All @@ -76,6 +88,17 @@ class CoaddPsfFitConnections(
dimensions=("tract", "patch", "band", "skymap"),
)

def __init__(self, *, config=None):
super().__init__(config=config)
if config is None:
return

if config.use_cell_coadds:
del self.coadd
else:
del self.coadd_cell
del self.background


class CoaddPsfFitSubConfig(pexConfig.Config):
"""Base config class for the CoaddPsfFitTask.
Expand Down Expand Up @@ -135,6 +158,11 @@ class CoaddPsfFitConfig(
):
"""Configure a CoaddPsfFitTask, including a configurable fitting subtask.
"""
use_cell_coadds = pexConfig.Field(
dtype=bool,
default=False,
doc="Use cell coadd images for PSF fitting",
)
fit_coadd_psf = pexConfig.ConfigurableField(
target=CoaddPsfFitSubTask,
doc="Task to fit PSF models for a single coadd",
Expand Down Expand Up @@ -162,13 +190,26 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
id_tp = self.config.idGenerator.apply(butlerQC.quantum.dataId).catalog_id
dataId = inputRefs.cat_meas.dataId
for dataRef in (inputRefs.coadd,):

if self.config.use_cell_coadds:
coaddDataRef = inputRefs.coadd_cell
multiple_cell_coadd = inputs.pop('coadd_cell')
background = inputs.pop('background')
exposure = multiple_cell_coadd.stitch().asExposure()
exposure.image -= background.getImage()
else:
coaddDataRef = inputRefs.coadd
exposure = inputs.pop('coadd')

for dataRef in (coaddDataRef,):
if dataRef.dataId != dataId:
raise RuntimeError(f'{dataRef=}.dataId != {inputRefs.cat_meas.dataId=}')

catalog = inputs.pop('cat_meas')
catexp = CatalogExposurePsf(
catalog=inputs['cat_meas'], exposure=inputs['coadd'], dataId=dataId, id_tract_patch=id_tp,
catalog=catalog, exposure=exposure, dataId=dataId, id_tract_patch=id_tp,
)
assert not inputs, "runQuantum got more inputs than expected"
outputs = self.run(catexp=catexp)
butlerQC.put(outputs, outputRefs)

Expand Down
Loading