Skip to content

Commit

Permalink
Merge pull request #142 from karllark/usability_updates
Browse files Browse the repository at this point in the history
Usability updates
  • Loading branch information
karllark authored Feb 27, 2025
2 parents 7de4bd1 + 6d19482 commit 32c78f4
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 62 deletions.
2 changes: 1 addition & 1 deletion measure_extinction/extdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def trans_elv_alav(self, av=None, akav=0.112):
-------
Updates self.(exts, uncs)
"""
if (self.type_rel_band != "V") and (self.type_rel_band != 0.55 * u.micron):
if (self.type_rel_band != "V") and (self.type_rel_band != 0.55 * u.micron) and (self.type_rel_band != 5500. * u.angstrom):
warnings.warn(
"attempt to normalize a non-E(lambda-V) curve with A(V)", UserWarning
)
Expand Down
99 changes: 72 additions & 27 deletions measure_extinction/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self, modinfo=None, obsdata=None):
if obsdata is not None:
self.logf = {}
for cspec in obsdata.data.keys():
self.logf[cspec] = -1.0
self.logf[cspec] = MEParameter(value=-3.0, bounds=(-9.0, 9.0))

def pprint_parameters(self):
"""
Expand Down Expand Up @@ -165,7 +165,7 @@ def pprint_parameters(self):
tline = ""
for cname in self.logf.keys():
hline += f"{cname} "
tline += f"{self.logf[cname]:.2f} "
tline += f"{self.logf[cname].value:.2f} "
print(f"{tline[:-1]} ({hline[:-1]})")

def parameters(self):
Expand All @@ -182,7 +182,7 @@ def parameters(self):
vals.append(getattr(self, cname).value)
if hasattr(self, "logf"):
for ckey in self.logf.keys():
vals.append(self.logf[ckey])
vals.append(self.logf[ckey].value)
return np.array(vals)

def parameters_to_fit(self):
Expand All @@ -200,7 +200,7 @@ def parameters_to_fit(self):
vals.append(getattr(self, cname).value)
if hasattr(self, "logf"):
for ckey in self.logf.keys():
vals.append(self.logf[ckey])
vals.append(self.logf[ckey].value)
return np.array(vals)

def fit_to_parameters(self, fit_params, uncs=None):
Expand All @@ -223,7 +223,7 @@ def fit_to_parameters(self, fit_params, uncs=None):
i += 1
if hasattr(self, "logf"):
for ckey in self.logf.keys():
self.logf[ckey] = fit_params[i]
self.logf[ckey].value = fit_params[i]
i += 1

def get_nonfixed_paramnames(self):
Expand All @@ -235,6 +235,10 @@ def get_nonfixed_paramnames(self):
cparam = getattr(self, cname)
if not cparam.fixed:
names.append(cname)
if hasattr(self, "logf"):
for ckey in self.logf.keys():
if not self.logf[ckey].fixed:
names.append(f"logf[{ckey}]")

return names

Expand All @@ -253,6 +257,18 @@ def check_param_limits(self):
raise ValueError(
f"{cname} = {pval} is above the bounds ({pbounds[0]}, {pbounds[1]})"
)
if hasattr(self, "logf"):
for ckey in self.logf.keys():
pval = self.logf[ckey].value
pbounds = self.logf[ckey].bounds
if (pbounds[0] is not None) and (pval < pbounds[0]):
raise ValueError(
f"logf[{ckey}] = {pval} is below the bounds ({pbounds[0]}, {pbounds[1]})"
)
elif (pbounds[1] is not None) and (pval > pbounds[1]):
raise ValueError(
f"logf[{cname}] = {pval} is above the bounds ({pbounds[0]}, {pbounds[1]})"
)

def add_exclude_region(self, exreg):
"""
Expand Down Expand Up @@ -285,7 +301,11 @@ def fit_weights(self, obsdata):
for cspec in list(obsdata.data.keys()):
# base weights
self.weights[cspec] = np.full(len(obsdata.data[cspec].fluxes), 0.0)
gvals = (obsdata.data[cspec].npts > 0) & np.isfinite(obsdata.data[cspec].fluxes)
gvals = (
(obsdata.data[cspec].npts > 0)
& np.isfinite(obsdata.data[cspec].fluxes)
& (obsdata.data[cspec].uncs.value > 0.0)
)
self.weights[cspec][gvals] = 1.0 / obsdata.data[cspec].uncs[gvals].value

x = 1.0 / obsdata.data[cspec].waves
Expand Down Expand Up @@ -583,15 +603,16 @@ def lnlike(self, obsdata, modeldata):

if hasattr(self, "logf"):
unc = 1.0 / self.weights[cspec][gvals]
unc2 = unc**2 + model**2 + np.exp(2.0 * self.logf[cspec])
unc2 = unc**2 + model**2 + np.exp(2.0 * self.logf[cspec].value)
weights = 1.0 / np.sqrt(unc2)
lnextra = np.log(unc2)
else:
weights = self.weights[cspec][gvals]
lnextra = 0.0

chiarr = np.square(
((obsdata.data[cspec].fluxes[gvals].value - model) * weights + lnextra)
chiarr = (
np.square(((obsdata.data[cspec].fluxes[gvals].value - model) * weights))
+ lnextra
)
lnl += -0.5 * np.sum(chiarr)

Expand Down Expand Up @@ -623,6 +644,15 @@ def lnprior(self):
if pprior is not None:
lnp += -0.5 * ((pval - pprior[0]) / pprior[1]) ** 2

if hasattr(self, "logf"):
for ckey in self.logf.keys():
pval = self.logf[ckey].value
pbounds = self.logf[ckey].bounds
if (pbounds[0] is not None) and (pval < pbounds[0]):
return self.lnp_bignum
elif (pbounds[1] is not None) and (pval > pbounds[1]):
return self.lnp_bignum

return lnp

def lnprob(self, obsdata, modeldata):
Expand Down Expand Up @@ -700,7 +730,15 @@ def nll(params, memodel, *args):

return (outmod, result)

def fit_sampler(self, obsdata, modinfo, nsteps=1000, burnfrac=0.1, multiproc=False):
def fit_sampler(
self,
obsdata,
modinfo,
nsteps=1000,
burnfrac=0.1,
save_samples=None,
multiproc=False,
):
"""
Run a samplier (specifically emcee) to find the detailed
parameters including uncertainties.
Expand All @@ -719,6 +757,9 @@ def fit_sampler(self, obsdata, modinfo, nsteps=1000, burnfrac=0.1, multiproc=Fal
burnfrac : float
fraction of nsteps to discard as the burn in [default=0.1]
save_samples : filename
name of hd5 file to save the MCMC samples
multiproc : boolean
set to run the emcee in parallel (does not speed up things much) [default=False]
Expand All @@ -740,21 +781,21 @@ def fit_sampler(self, obsdata, modinfo, nsteps=1000, burnfrac=0.1, multiproc=Fal
if not np.isfinite(outmod.lnprior()):
raise ValueError("ln(prior) is not finite")

# simple function to turn the log(likelihood) into the chisqr
# required as op.minimize function searches for the minimum chisqr (not max likelihood like MCMC algorithms)
# def lnprob(params, memodel, *args):
# memodel.fit_to_parameters(params)
# return memodel.lnprob(*args)

# get the non-fixed initial parameters
p0 = outmod.parameters_to_fit()

# setup the sampliers
ndim = len(p0)
nwalkers = 2 * ndim

# setting up the walkers to start "near" the inital guess
p = [p0 * (1 + 0.01 * np.random.normal(0, 1.0, ndim)) for k in range(nwalkers)]

if save_samples:
# Don't forget to clear it in case the file already exists
save_backend = emcee.backends.HDFBackend(save_samples)
save_backend.reset(nwalkers, ndim)

# setup and run the sampler
if multiproc:
with Pool() as pool:
Expand All @@ -772,6 +813,7 @@ def fit_sampler(self, obsdata, modinfo, nsteps=1000, burnfrac=0.1, multiproc=Fal
ndim,
_lnprob,
args=(outmod, obsdata, modinfo),
backend=save_backend,
)
sampler.run_mcmc(p, nsteps, progress=True)

Expand Down Expand Up @@ -835,6 +877,7 @@ def plot(self, obsdata, modinfo):
hi_ext_modsed = self.hi_abs_sed(modinfo, ext_modsed)

ax = axes[0]
yrange = [100.0, -100.0]
for cspec in obsdata.data.keys():
if cspec == "BAND":
ptype = "o"
Expand Down Expand Up @@ -879,19 +922,21 @@ def plot(self, obsdata, modinfo):
alpha=calpha,
)

# info for y limits of plot - make sure not not include Ly-alpha
gvals = np.logical_or(
modinfo.waves[cspec] > 0.125 * u.micron,
modinfo.waves[cspec] < 0.118 * u.micron,
)
gvals = np.logical_and(gvals, modinfo.waves[cspec] > 0.11 * u.micron)
multval = self.norm.value * np.power(modinfo.waves[cspec][gvals], 4.0)
mflux = (hi_ext_modsed[cspec][gvals] * multval).value
tyrange = np.log10([np.nanmin(mflux), np.nanmax(mflux)])
yrange[0] = np.min([tyrange[0], yrange[0]])
yrange[1] = np.max([tyrange[1], yrange[1]])

ax.set_xscale("log")
ax.set_yscale("log")

# get a reasonable y range
cspec = "MODEL_FULL_LOWRES"
gvals = np.logical_or(
modinfo.waves[cspec] > 0.125 * u.micron,
modinfo.waves[cspec] < 0.118 * u.micron,
)
gvals = np.logical_and(gvals, modinfo.waves[cspec] > 0.11 * u.micron)
multval = self.norm.value * np.power(modinfo.waves[cspec][gvals], 4.0)
mflux = (hi_ext_modsed[cspec][gvals] * multval).value
yrange = np.log10([np.nanmin(mflux), np.nanmax(mflux)])
ydelt = yrange[1] - yrange[0]
yrange[0] = 10 ** (yrange[0] - 0.1 * ydelt)
yrange[1] = 10 ** (yrange[1] + 0.1 * ydelt)
Expand All @@ -903,7 +948,7 @@ def plot(self, obsdata, modinfo):
ax.tick_params("both", length=10, width=2, which="major")
ax.tick_params("both", length=5, width=1, which="minor")
axes[1].set_ylim(-10.0, 10.0)
axes[1].plot([0.1, 2.5], [0.0, 0.0], "k:")
axes[1].axhline(0.0, color="k", linestyle=":")

k = 0
for cname in self.paramnames:
Expand Down
10 changes: 9 additions & 1 deletion measure_extinction/modeldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,15 @@ def __init__(
bp_cam = "UVIS1"
bp = STS.band(f"WFC3,{bp_cam},{bp_info[1]}")
else:
band_filename = f"John{cband}.dat"
if (
("WISE" in cband)
or ("IRAC" in cband)
or ("MIPS" in cband)
):
estr = ""
else:
estr = "John"
band_filename = f"{estr}{cband}.dat"
bp = SpectralElement.from_file(
f"{band_resp_path}/{band_filename}"
)
Expand Down
30 changes: 17 additions & 13 deletions measure_extinction/plotting/plot_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def plot_multi_extinction(
fitmodel=False,
HI_lines=False,
range=None,
spread=False,
spread=0.0,
exclude=[],
log=False,
text_offsets=[],
Expand Down Expand Up @@ -454,8 +454,8 @@ def plot_multi_extinction(
range : list of 2 floats [default=None]
Wavelength range to be plotted (in micron) - [min,max]
spread : boolean [default=False]
Whether or not to spread the extinction curves out by adding a vertical offset to each curve
spread : float [default=0]
Amount to addiatively spread the curves
exclude : list of strings [default=[]]
List of data type(s) to exclude from the plot (e.g., "IRS", "IRAC1")
Expand Down Expand Up @@ -515,12 +515,17 @@ def plot_multi_extinction(
text_angles = np.full(len(starpair_list), 10)

for i, starpair in enumerate(starpair_list):
if ".fits" not in starpair:
fname = "%s%s_ext.fits" % (path, starpair.lower())
else:
fname = starpair

# read in the extinction curve data
extdata = ExtData("%s%s_ext.fits" % (path, starpair.lower()))
extdata = ExtData(fname)

# spread out the curves if requested
if spread:
yoffset = 0.25 * i
if spread != 0.0:
yoffset = spread * i
else:
yoffset = 0.0

Expand Down Expand Up @@ -759,7 +764,7 @@ def plot_extinction(
ax.set_title(starpair, fontsize=50)
else:
ax.set_title(extdata.red_file.replace(".dat", ""), fontsize=50)
if extdata.comp_file != "":
if (extdata.comp_file is not None) and (extdata.comp_file != ""):
ax.text(
0.99,
0.95,
Expand Down Expand Up @@ -795,14 +800,12 @@ def main():
# commandline parser
parser = argparse.ArgumentParser()
parser.add_argument(
"starpair_list",
nargs="+",
help='pairs of star names for which to plot the extinction curve, in the format "reddenedstarname_comparisonstarname", without spaces',
"starpair_list", nargs="+", help="filenames of extinction curves"
)
parser.add_argument(
"--path",
help="path to the data files",
default=get_datapath(),
default="./",
)
parser.add_argument("--alax", help="plot A(lambda)/A(X)", action="store_true")
parser.add_argument(
Expand All @@ -827,8 +830,9 @@ def main():
)
parser.add_argument(
"--spread",
help="spread the curves out over the figure; can only be used in combination with --onefig",
action="store_true",
help="spread the curves out over the figure by this amount; can only be used in combination with --onefig",
default=0.0,
type=float,
)
parser.add_argument(
"--exclude",
Expand Down
3 changes: 1 addition & 2 deletions measure_extinction/plotting/plot_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import astropy.units as u
import pandas as pd

from measure_extinction.utils.helpers import get_datapath
from measure_extinction.stardata import StarData


Expand Down Expand Up @@ -475,7 +474,7 @@ def main():
parser.add_argument(
"--path",
help="path to the data files",
default=get_datapath(),
default="./",
)
parser.add_argument("--mlam4", help="plot lambda^4*F(lambda)", action="store_true")
parser.add_argument("--HI_lines", help="indicate the HI-lines", action="store_true")
Expand Down
Loading

0 comments on commit 32c78f4

Please sign in to comment.