Source code for phrosty.pipeline

# Imports STANDARD
import sys
import argparse
import cupy as cp
from functools import partial
import logging
import matplotlib.pyplot as plt
from multiprocessing import Pool
import numpy as np
import nvtx
import pathlib
import re
import shutil
import tracemalloc
import uuid

# Imports ASTRO
from astropy.coordinates import SkyCoord
from astropy.io import fits
import astropy.table
from astropy.table import Table
import astropy.units as u
from astropy.wcs import WCS
from astropy.wcs.utils import skycoord_to_pixel
from galsim import roman

# Imports INTERNAL
from phrosty.imagesubtraction import sky_subtract, stampmaker
from phrosty.photometry import ap_phot, psfmodel, psf_phot
from phrosty.utils import get_exptime, read_truth_txt
from sfft.SpaceSFFTCupyFlow import SpaceSFFT_CupyFlow
from snappl.image import OpenUniverse2024FITSImage
from snappl.psf import PSF
from snpit_utils.config import Config
from snpit_utils.logger import SNLogger


[docs] class PipelineImage: """Holds a snappl.image.Image, with some other stuff the pipeline needs.""" def __init__( self, imagepath, pointing, sca, pipeline ): """Create a PipelineImage Parameters: ----------- imagepath : str or Path A path to the image. This will be passed on to an Image subclass constructor; which subclass depends on the config option photometry.phrosty.image_type pointing : str or int An identifier of the pointing of this image. Used e.g. to pull PSFs. pipeline : phrosty.pipeline.Pipeline The pipeline that owns this image. """ # self.psf is a object of a subclass of snappl.psf.PSF self.config = Config.get() self.temp_dir = pipeline.temp_dir self.keep_intermediate = self.config.value( 'photometry.phrosty.keep_intermediate' ) if self.keep_intermediate: self.save_dir = pathlib.Path( self.config.value( 'photometry.phrosty.paths.scratch_dir' ) ) elif not self.keep_intermediate: self.save_dir = self.temp_dir if self.config.value( 'photometry.phrosty.image_type' ) == 'ou2024fits': self.image = OpenUniverse2024FITSImage( imagepath, None, sca ) else: raise RuntimeError( "At the moment, phrosty only works with ou2024fits images. " "We hope this will change soon." ) self.pointing = pointing self.band = pipeline.band # Intermediate files if self.keep_intermediate: # Set to None. The path gets defined later on. # They have to be defined here in __init__ so that they exist # and are accessible in later functions. self.skysub_path = None self.detmask_path = None self.input_sci_psf_path = None self.input_templ_psf_path = None self.aligned_templ_img_path = None self.aligned_templ_var_path = None self.aligned_templ_psf_path = None self.crossconv_sci_path = None self.crossconv_templ_path = None self.diff_path = None self.decorr_kernel_path = None # Always save and output these self.decorr_psf_path = {} self.decorr_zptimg_path = {} self.decorr_diff_path = {} self.zpt_stamp_path = {} self.diff_var_path = {} self.diff_var_stamp_path = {} self.diff_stamp_path = {} # Held in memory self.skyrms = None self.psfobj = None self.psf_data = None def run_sky_subtract( self, mp=True ): # Eventually, we may not want to save the sky subtracted image, but keep # it in memory. (Reduce I/O.) Will require SFFT changes. # (This may not be practical, as it will increase memory usage a *lot*. # We may still need to write files.) try: imname = self.image.name # HACK ALERT : we're stripping the .gz off of the end of filenames # if they have them, and making sure filenames end in .fits, # because that's what SFFT needs. This can go away if we # refactor to pass data. if imname[-3:] == '.gz': imname = imname[:-3] if imname[-5:] != '.fits': imname = f'{imname}.fits' # Hey Rob--the below is broken. The pipeline runs with mp if I use the # example science image file from examples/perlmutter, but not if I use # my own file. If I run with one process, my own file runs. # if mp: # SNLogger.multiprocessing_replace() SNLogger.debug( f"run_sky_subtract on {imname}" ) self.skysub_path = self.save_dir / f"skysub_{imname}" self.detmask_path = self.save_dir / f"detmask_{imname}" self.skyrms = sky_subtract( self.image.path, self.skysub_path, self.detmask_path, temp_dir=self.save_dir, force=self.config.value( 'photometry.phrosty.force_sky_subtract' ) ) SNLogger.debug( f"...done running sky subtraction on {self.image.name}" ) return ( self.skysub_path, self.detmask_path, self.skyrms ) except Exception as ex: SNLogger.exception( ex ) raise def save_sky_subtract_info( self, info ): SNLogger.debug( f"Saving sky_subtract info for path {info[0]}" ) self.skysub_path = info[0] self.detmask_path = info[1] self.skyrms = info[2]
[docs] def get_psf( self, ra, dec ): """Get the at the right spot on the image. Parameters ---------- ra, dec : float The coordinates in decimal degrees where we want the PSF. """ # TODO: right now snappl.psf.PSF.get_psf_object just # passes the keyword arguments on to whatever makes # the psf... and it's different for each type of # PSF. We need to fix that... somehow.... wcs = self.image.get_wcs() x, y = wcs.world_to_pixel( ra, dec ) if self.psfobj is None: psftype = self.config.value( 'photometry.phrosty.psf.type' ) self.psfobj = PSF.get_psf_object( psftype, band=self.band, pointing=self.pointing, sca=self.image.sca, x=x, y=y ) stamp = self.psfobj.get_stamp( x, y ) if self.keep_intermediate: outfile = self.save_dir / f"psf_{self.image.name}.fits" fits.writeto( outfile, stamp, overwrite=True ) return stamp
def keep_psf_data( self, psf_data ): self.psf_data = psf_data
[docs] def free( self ): """Try to free memory. More might be done here.""" self.image.free()
[docs] class Pipeline: """Phrosty's top-level pipeline""" def __init__( self, object_id, ra, dec, band, science_images, template_images, nprocs=1, nwrite=5, verbose=False ): """Create the a pipeline object. Parameters ---------- object_id: int ra, dec: float Position of transient in decimal degrees band: str One of R062, Z087, Y106, J129, H158, F184, K213 science_images: list of tuple ( path_to_image, pointing, sca ) template_images: list of tuple ( path_to_image, pointing, sca ) nprocs: int, default 1 Number of cpus for the CPU multiprocessing segments of the pipeline. (GPU segments will run a single process.) nwrite: int, default 5 Number of asynchronous FITS writer processes. """ SNLogger.setLevel( logging.DEBUG if verbose else logging.INFO ) self.config = Config.get() self.tds_base_dir = pathlib.Path( self.config.value( 'ou24.tds_base' ) ) self.image_base_dir = self.tds_base_dir / 'images' self.dia_out_dir = pathlib.Path( self.config.value( 'photometry.phrosty.paths.dia_out_dir' ) ) self.scratch_dir = pathlib.Path( self.config.value( 'photometry.phrosty.paths.scratch_dir' ) ) self.temp_dir_parent = pathlib.Path( self.config.value( 'photometry.phrosty.paths.temp_dir' ) ) self.temp_dir = self.temp_dir_parent / str(uuid.uuid1()) self.temp_dir.mkdir() self.ltcv_dir = pathlib.Path( self.config.value( 'photometry.phrosty.paths.ltcv_dir' ) ) self.object_id = object_id self.ra = ra self.dec = dec self.band = band self.science_images = ( [ PipelineImage( self.image_base_dir / ppsmb[0], ppsmb[1], ppsmb[2], self ) for ppsmb in science_images if ppsmb[4] == self.band ] ) self.template_images = ( [ PipelineImage( self.image_base_dir / ppsmb[0], ppsmb[1], ppsmb[2], self ) for ppsmb in template_images if ppsmb[4] == self.band ] ) self.nprocs = nprocs self.nwrite = nwrite self.keep_intermediate = self.config.value( 'photometry.phrosty.keep_intermediate' ) self.remove_temp_dir = self.config.value( 'photometry.phrosty.remove_temp_dir' ) self.mem_trace = self.config.value( 'photometry.phrosty.mem_trace' ) def sky_sub_all_images( self ): # Currently, this writes out a bunch of FITS files. Further refactoring needed # to support more general image types. all_imgs = self.science_images.copy() # shallow copy all_imgs.extend( self.template_images ) def log_error( img, x ): SNLogger.error( f"Sky subtraction subprocess failure: {x} for image {img.image.path}" ) if self.nprocs > 1: with Pool( self.nprocs ) as pool: for img in all_imgs: pool.apply_async( img.run_sky_subtract, (), {}, callback=img.save_sky_subtract_info, error_callback=partial(log_error, img) ) pool.close() pool.join() else: for img in all_imgs: img.save_sky_subtract_info( img.run_sky_subtract( mp=False ) ) def get_psfs( self ): all_imgs = self.science_images.copy() # shallow copy all_imgs.extend( self.template_images ) if self.nprocs > 1: with Pool( self.nprocs ) as pool: for img in all_imgs: # callback_partial = partial( img.save_psf_path, all_imgs ) pool.apply_async( img.get_psf, (self.ra, self.dec), {}, img.keep_psf_data, lambda x: SNLogger.error( f"get_psf subprocess failure: {x}" ) ) pool.close() pool.join() else: for img in all_imgs: img.keep_psf_data( img.get_psf(self.ra, self.dec) )
[docs] def align_and_pre_convolve(self, templ_image, sci_image ): """Align and pre convolve a single template/science pair. Parameters ---------- sci_image: phrosty.Image The science (new) image. templ_image: phrosty.Image The template (ref) image that will be subtracted from sci_image. """ with fits.open( sci_image.skysub_path ) as hdul: hdr_sci = hdul[0].header data_sci = cp.array( np.ascontiguousarray(hdul[0].data.T), dtype=cp.float64 ) with fits.open( templ_image.skysub_path ) as hdul: hdr_templ = hdul[0].header data_templ = cp.array( np.ascontiguousarray(hdul[0].data.T), dtype=cp.float64 ) sci_psf = cp.ascontiguousarray( cp.array( sci_image.psf_data.T, dtype=cp.float64 ) ) templ_psf = cp.ascontiguousarray( cp.array( templ_image.psf_data.T, dtype=cp.float64 ) ) with fits.open( sci_image.detmask_path ) as hdul: sci_detmask = cp.array( np.ascontiguousarray( hdul[0].data.T ) ) with fits.open( templ_image.detmask_path ) as hdul: templ_detmask = cp.array( np.ascontiguousarray( hdul[0].data.T ) ) sfftifier = SpaceSFFT_CupyFlow( hdr_sci, hdr_templ, sci_image.skyrms, templ_image.skyrms, data_sci, data_templ, cp.array( sci_image.image.noise ), cp.array( templ_image.image.noise ), sci_detmask, templ_detmask, sci_psf, templ_psf, KerPolyOrder=Config.get().value('photometry.phrosty.kerpolyorder') ) sfftifier.resampling_image_mask_psf() sfftifier.cross_convolution() return sfftifier
[docs] def phot_at_coords( self, img, err, psf, pxcoords=(50, 50), ap_r=4 ): """Do photometry at forced set of pixel coordinates.""" forcecoords = Table([[float(pxcoords[0])], [float(pxcoords[1])]], names=["x", "y"]) init = ap_phot(img, forcecoords, ap_r=ap_r) init['flux_init'] = init['aperture_sum'] final = psf_phot(img, err, psf, init, forced_phot=True) flux = final['flux_fit'][0] flux_err = final['flux_err'][0] mag = -2.5 * np.log10(final["flux_fit"][0]) mag_err = (2.5 / np.log(10)) * np.abs(final["flux_err"][0] / final["flux_fit"][0]) results_dict = { 'aperture_sum': init['aperture_sum'][0], 'flux_fit': flux, 'flux_fit_err': flux_err, 'mag_fit': mag, 'mag_fit_err': mag_err } return results_dict
[docs] def get_stars(self, truthpath, nx=4088, ny=4088, transform=False, wcs=None): """Get the stars in the science images. Optional to transform to another WCS. """ truth_tab = read_truth_txt(path=truthpath) truth_tab['mag'].name = 'mag_truth' truth_tab['flux'].name = 'flux_truth' if transform: assert wcs is not None, 'You need to provide a WCS to transform to!' truth_tab['x'].name, truth_tab['y'].name = 'x_orig', 'y_orig' worldcoords = SkyCoord(ra=truth_tab['ra'] * u.deg, dec=truth_tab['dec'] * u.deg) x, y = skycoord_to_pixel(worldcoords, wcs) truth_tab['x'] = x truth_tab['y'] = y if not transform: truth_tab['x'] -= 1 truth_tab['y'] -= 1 idx = np.where(truth_tab['obj_type'] == 'star')[0] stars = truth_tab[idx] stars = stars[np.logical_and(stars["x"] < nx, stars["x"] > 0)] stars = stars[np.logical_and(stars["y"] < ny, stars["y"] > 0)] return stars
def get_galsim_values(self): exptime = get_exptime(self.band) area_eff = roman.collecting_area gs_zpt = roman.getBandpasses()[self.band].zeropoint return {'exptime': exptime, 'area_eff': area_eff, 'gs_zpt': gs_zpt}
[docs] def get_zpt(self, zptimg, err, psf, band, stars, ap_r=4, zpt_plot=None, oid=None, sci_pointing=None, sci_sca=None): # TODO : Need to move this code all over into snappl Image. It sounds like # for Roman images we may have to do our own zeropoints (which is what # is happening here), but for actual Roman we're going to use # calibration information we get from elsewhere, so we don't want to bake # doing the calibration into the pipeline as we do here. # # Also Issue #70 """Get the zeropoint based on the stars.""" # First, need to do photometry on the stars. init_params = ap_phot(zptimg, stars, ap_r=ap_r) init_params['flux_init'] = init_params['aperture_sum'] final_params = psf_phot(zptimg, err, psf, init_params, forced_phot=True) # Do not need to cross match. Can just merge tables because they # will be in the same order. photres = astropy.table.join(stars, init_params, keys=['object_id', 'ra', 'dec', 'realized_flux', 'flux_truth', 'mag_truth', 'obj_type']) photres = astropy.table.join(photres, final_params, keys=['id']) # Get the zero point. galsim_vals = self.get_galsim_values() star_ap_mags = -2.5 * np.log10(photres['aperture_sum']) star_fit_mags = -2.5 * np.log10(photres['flux_fit']) star_truth_mags = ( -2.5 * np.log10(photres['flux_truth']) + galsim_vals['gs_zpt'] + 2.5 * np.log10(galsim_vals['exptime'] * galsim_vals['area_eff']) ) # Eventually, this should be a S/N cut, not a mag cut. zpt_mask = np.logical_and(star_truth_mags > 19, star_truth_mags < 21.5) zpt = np.nanmedian(star_truth_mags[zpt_mask] - star_fit_mags[zpt_mask]) ap_zpt = np.nanmedian(star_truth_mags[zpt_mask] - star_ap_mags[zpt_mask]) if zpt_plot is not None: assert oid is not None, 'If zpt_plot=True, oid must be provided.' assert sci_pointing is not None, 'If zpt_plot=True, sci_pointing must be provided.' assert sci_sca is not None, 'If zpt_plot=True, sci_sca must be provided.' savedir = self.ltcv_dir / f'figs/{oid}/zpt_plots' savedir.mkdir(parents=True, exist_ok=True) savepath = savedir / f'zpt_stars_{band}_{sci_pointing}_{sci_sca}.png' plt.figure(figsize=(8, 8)) yaxis = star_fit_mags + zpt - star_truth_mags plt.plot(star_truth_mags, yaxis, marker='o', linestyle='') plt.axhline(0, linestyle='--', color='k') plt.xlabel('Truth mag') plt.ylabel('Fit mag - zpt + truth mag') plt.title(f'{band} {sci_pointing} {sci_sca}') plt.savefig(savepath, dpi=300, bbox_inches='tight') plt.close() SNLogger.info(f'zpt debug plot saved to {savepath}') # savepath = os.path.join(savedir, f'hist_truth-fit_{band}_{sci_pointing}_{sci_sca}.png') # plt.hist(star_truth_mags[zpt_mask] - star_fit_mags[zpt_mask]) # plt.title(f'{band} {sci_pointing} {sci_sca}') # plt.xlabel('star_truth_mags[zpt_mask] - star_fit_mags[zpt_mask]') # plt.savefig(savepath, dpi=300, bbox_inches='tight') # plt.close() return zpt, ap_zpt
def make_phot_info_dict( self, sci_image, templ_image, ap_r=4 ): # Do photometry on stamp because it will read faster diff_img_stamp_path = sci_image.diff_stamp_path[ templ_image.image.name ] diff_img_var_stamp_path = sci_image.diff_var_stamp_path[ templ_image.image.name ] results_dict = {} results_dict['sci_name'] = sci_image.image.name results_dict['templ_name'] = templ_image.image.name results_dict['success'] = False results_dict['ra'] = self.ra results_dict['dec'] = self.dec results_dict['mjd'] = sci_image.image.mjd results_dict['filter'] = self.band results_dict['pointing'] = sci_image.pointing results_dict['sca'] = sci_image.image.sca results_dict['template_pointing'] = templ_image.pointing results_dict['template_sca'] = templ_image.image.sca if diff_img_stamp_path.is_file(): # Load in the difference image stamp. with fits.open(diff_img_stamp_path) as diff_hdu: diffimg = diff_hdu[0].data wcs = WCS(diff_hdu[0].header) with fits.open( diff_img_var_stamp_path ) as var_hdu: err = np.sqrt( var_hdu[0].data ) # Load in the decorrelated PSF. psfpath = sci_image.decorr_psf_path[ templ_image.image.name ] with fits.open( psfpath ) as hdu: psf = psfmodel( hdu[0].data ) coord = SkyCoord(ra=self.ra * u.deg, dec=self.dec * u.deg) pxcoords = skycoord_to_pixel(coord, wcs) results_dict.update( self.phot_at_coords(diffimg, err, psf, pxcoords=pxcoords, ap_r=ap_r) ) # Get the zero point from the decorrelated, convolved science image. # First, get the table of known stars. # TODO -- take this galsim-specific code out, move it to a separate module. Define a general # zeropointing interface, of which the galsim-speicifc one will be one instance truthpath = str( self.tds_base_dir / f'truth/{self.band}/{sci_image.pointing}/' f'Roman_TDS_index_{self.band}_{sci_image.pointing}_{sci_image.image.sca}.txt' ) stars = self.get_stars(truthpath) # Now, calculate the zero point based on those stars. zptimg_path = sci_image.decorr_zptimg_path[ templ_image.image.name ] with fits.open(zptimg_path) as hdu: zptimg = hdu[0].data zpt, ap_zpt = self.get_zpt(zptimg, sci_image.image.noise, psf, self.band, stars, oid=self.object_id, sci_pointing=sci_image.pointing, sci_sca=sci_image.image.sca) # Add additional info to the results dictionary so it can be merged into a nice file later. results_dict['zpt'] = zpt results_dict['ap_zpt'] = ap_zpt results_dict['success'] = True else: SNLogger.warning( f"Post-processed image files for " f"{self.band}_{sci_image.pointing}_{sci_image.image.sca}-" f"{self.band}_{templ_image.pointing}_{templ_image.image.sca} " f"do not exist. Skipping." ) results_dict['zpt'] = np.nan results_dict['ap_zpt'] = np.nan results_dict['aperture_sum'] = np.nan results_dict['flux_fit'] = np.nan results_dict['flux_fit_err'] = np.nan results_dict['mag_fit'] = np.nan results_dict['mag_fit_err'] = np.nan return results_dict def add_to_results_dict( self, one_pair ): for key, arr in self.results_dict.items(): arr.append( one_pair[ key ] ) def save_stamp_paths( self, sci_image, templ_image, paths ): sci_image.zpt_stamp_path[ templ_image.image.name ] = paths[0] sci_image.diff_stamp_path[ templ_image.image.name ] = paths[1] sci_image.diff_var_stamp_path[ templ_image.image.name ] = paths[2] def do_stamps( self, sci_image, templ_image ): zptname = sci_image.decorr_zptimg_path[ templ_image.image.name ] zpt_stampname = stampmaker( self.ra, self.dec, np.array([100, 100]), zptname, savedir=self.dia_out_dir, savename=f"stamp_{zptname.name}" ) diffname = sci_image.decorr_diff_path[ templ_image.image.name ] diff_stampname = stampmaker( self.ra, self.dec, np.array([100, 100]), diffname, savedir=self.dia_out_dir, savename=f"stamp_{diffname.name}" ) diffvarname = sci_image.diff_var_path[ templ_image.image.name ] diffvar_stampname = stampmaker( self.ra, self.dec, np.array([100, 100]), diffvarname, savedir=self.dia_out_dir, savename=f"stamp_{diffvarname.name}" ) SNLogger.info(f"Decorrelated diff stamp path: {pathlib.Path( diff_stampname )}") SNLogger.info(f"Zpt image stamp path: {pathlib.Path( zpt_stampname )}") SNLogger.info(f"Decorrelated diff variance stamp path: {pathlib.Path( diffvar_stampname )}") return pathlib.Path( zpt_stampname ), pathlib.Path( diff_stampname ), pathlib.Path( diffvar_stampname ) def make_lightcurve( self ): SNLogger.info( "Making lightcurve." ) self.results_dict = { 'ra': [], 'dec': [], 'mjd': [], 'filter': [], 'pointing': [], 'sca': [], 'template_pointing': [], 'template_sca': [], 'zpt': [], 'ap_zpt': [], 'aperture_sum': [], 'flux_fit': [], 'flux_fit_err': [], 'mag_fit': [], 'mag_fit_err': [], } if self.nprocs > 1: with Pool( self.nprocs ) as pool: for sci_image in self.science_images: for templ_image in self.template_images: pool.apply_async( self.make_phot_info_dict, (sci_image, templ_image), {}, self.add_to_results_dict, error_callback=lambda x: SNLogger.error( f"make_phot_info_dict " f"subprocess failure: {x}" ) ) pool.close() pool.join() else: for i, sci_image in enumerate( self.science_images ): SNLogger.debug( f"Doing science image {i} of {len(self.science_images)}" ) for templ_image in self.template_images: self.add_to_results_dict( self.make_phot_info_dict( sci_image, templ_image ) ) results_tab = Table(self.results_dict) results_tab.sort('mjd') results_savedir = self.ltcv_dir / 'data' / str(self.object_id) results_savedir.mkdir( exist_ok=True, parents=True ) results_savepath = results_savedir / f'{self.object_id}_{self.band}_all.csv' results_tab.write(results_savepath, format='csv', overwrite=True) SNLogger.info(f'Results saved to {results_savepath}') def write_fits_file( self, data, header, savepath ): fits.writeto( savepath, data, header=header, overwrite=True ) def clear_contents( self, directory ): for f in directory.iterdir(): try: if f.is_dir(): shutil.rmtree( f ) else: f.unlink() except Exception as e: print( f'Oops! Deleting {f} from {directory} did not work.\nReason: {e}' ) def __call__( self, through_step=None ): if self.mem_trace: tracemalloc.start() tracemalloc.reset_peak() if through_step is None: through_step = 'make_lightcurve' steps = [ 'sky_subtract', 'get_psfs', 'align_and_preconvolve', 'subtract', 'find_decorrelation', 'apply_decorrelation', 'make_stamps', 'make_lightcurve' ] stepdex = steps.index( through_step ) if stepdex < 0: raise ValueError( f"Unknown step {through_step}" ) steps = steps[:stepdex+1] if 'sky_subtract' in steps: SNLogger.info( "Running sky subtraction" ) with nvtx.annotate( "skysub", color=0xff8888 ): self.sky_sub_all_images() if self.mem_trace: SNLogger.info( f"After sky_subtract, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" ) if 'get_psfs' in steps: SNLogger.info( "Getting PSFs" ) with nvtx.annotate( "getpsfs", color=0xff8888 ): self.get_psfs() if self.mem_trace: SNLogger.info( f"After get_psfs, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" ) # Create a process pool to write fits files with Pool( self.nwrite ) as fits_writer_pool: def log_fits_write_error( savepath, x ): SNLogger.error( f"Exception writing FITS file {savepath}: {x}" ) # raise? # Do the hardcore processing for templ_image in self.template_images: for sci_image in self.science_images: SNLogger.info( f"Processing {sci_image.image.name} minus {templ_image.image.name}" ) sfftifier = None if 'align_and_preconvolve' in steps: SNLogger.info( "...align_and_preconvolve" ) with nvtx.annotate( "align_and_pre_convolve", color=0x8888ff ): sfftifier = self.align_and_pre_convolve( templ_image, sci_image ) if 'subtract' in steps: SNLogger.info( "...subtract" ) with nvtx.annotate( "subtraction", color=0x44ccff ): sfftifier.sfft_subtraction() if 'find_decorrelation' in steps: SNLogger.info( "...find_decorrelation" ) with nvtx.annotate( "find_decor", color=0xcc44ff ): sfftifier.find_decorrelation() SNLogger.info( "...generate variance image" ) with nvtx.annotate( "variance", color=0x44ccff ): diff_var = sfftifier.create_variance_image() mess = ( f"{self.band}_{sci_image.pointing}_{sci_image.image.sca}_-" f"_{self.band}_{templ_image.pointing}_{templ_image.image.sca}.fits" ) diff_var_path = self.dia_out_dir / f"diff_var_{mess}" if 'apply_decorrelation' in steps: mess = ( f"{self.band}_{sci_image.pointing}_{sci_image.image.sca}_-" f"_{self.band}_{templ_image.pointing}_{templ_image.image.sca}.fits" ) decorr_psf_path = self.dia_out_dir / f"decorr_psf_{mess}" decorr_zptimg_path = self.dia_out_dir / f"decorr_zptimg_{mess}" decorr_diff_path = self.dia_out_dir / f"decorr_diff_{mess}" images = [ sfftifier.PixA_DIFF_GPU, diff_var, sfftifier.PixA_Ctarget_GPU, sfftifier.PSF_target_GPU ] savepaths = [ decorr_diff_path, diff_var_path, decorr_zptimg_path, decorr_psf_path ] headers = [ sfftifier.hdr_target, sfftifier.hdr_target, sfftifier.hdr_target, None ] for img, savepath, hdr in zip( images, savepaths, headers ): with nvtx.annotate( "apply_decor", color=0xccccff ): SNLogger.info( f"...apply_decor to {savepath}" ) decorimg = sfftifier.apply_decorrelation( img ) with nvtx.annotate( "submit writefits", color=0xff8888 ): SNLogger.info( f"...writefits {savepath}" ) fits_writer_pool.apply_async( self.write_fits_file, ( cp.asnumpy( decorimg ).T, hdr, savepath ), {}, error_callback=partial(log_fits_write_error, savepath) ) sci_image.decorr_psf_path[ templ_image.image.name ] = decorr_psf_path sci_image.decorr_zptimg_path[ templ_image.image.name ] = decorr_zptimg_path sci_image.decorr_diff_path[ templ_image.image.name ] = decorr_diff_path sci_image.diff_var_path[ templ_image.image.name ] = diff_var_path if self.keep_intermediate: # Each key is the file prefix addition. # Each list has [descriptive filetype, image file name, data, header]. # TODO: Include multiprocessing. # In the future, we may want to write these things right after they happen # instead of saving it all for the end of the SFFT stuff. sci_filepathpart = f'{sci_image.band}_{sci_image.pointing}_{sci_image.image.sca}' templ_filepathpart = f'{templ_image.band}_{templ_image.pointing}_{templ_image.image.sca}' write_filepaths = {'aligned': [['img', f'{templ_filepathpart}_-_{sci_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_resamp_object_GPU.T), sfftifier.hdr_target], ['var', f'{templ_filepathpart}_-_{sci_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_resamp_objectVar_GPU.T), sfftifier.hdr_target], ['psf', f'{templ_filepathpart}_-_{sci_filepathpart}.fits', cp.asnumpy(sfftifier.PSF_resamp_object_GPU.T), sfftifier.hdr_target], ['detmask', f'{sci_filepathpart}_-_{templ_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_resamp_object_DMASK_GPU.T), sfftifier.hdr_target] ], 'convolved': [['img', f'{sci_filepathpart}_-_{templ_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_Ctarget_GPU.T), sfftifier.hdr_target], ['img', f'{templ_filepathpart}_-_{sci_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_Cresamp_object_GPU.T), sfftifier.hdr_target] ], 'diff': [['img', f'{sci_filepathpart}_-_{templ_filepathpart}.fits', cp.asnumpy(sfftifier.PixA_DIFF_GPU.T), sfftifier.hdr_target] ], 'decorr': [['kernel', f'{sci_filepathpart}_-_{templ_filepathpart}.fits', cp.asnumpy(sfftifier.FKDECO_GPU.T), sfftifier.hdr_target] ] } # Write the aligned images for key in write_filepaths.keys(): for (imgtype, name, data, header) in write_filepaths[key]: savepath = self.scratch_dir / f'{key}_{imgtype}_{name}' self.write_fits_file( data, header, savepath=savepath ) SNLogger.info( f"DONE processing {sci_image.image.name} minus {templ_image.image.name}" ) if self.mem_trace: SNLogger.info( f"After preprocessing, subtracting, and postprocessing \ a science image, memory usage = \ {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" ) sci_image.free() SNLogger.info( f"DONE with all science images for template {templ_image.image.name}" ) templ_image.free() SNLogger.info( "Waiting for FITS writer processes to finish" ) with nvtx.annotate( "fits_write_wait", color=0xff8888 ): fits_writer_pool.close() fits_writer_pool.join() SNLogger.info( "...FITS writer processes done." ) if 'make_stamps' in steps: SNLogger.info( "Starting to make stamps..." ) with nvtx.annotate( "make stamps", color=0xff8888 ): partialstamp = partial(stampmaker, self.ra, self.dec, np.array([100, 100])) # template path, savedir, savename templstamp_args = ( (ti.image.path, self.dia_out_dir, f'stamp_{str(ti.image.name)}') for ti in self.template_images ) if self.nwrite > 1: with Pool( self.nwrite ) as templ_stamp_pool: templ_stamp_pool.starmap_async( partialstamp, templstamp_args ) templ_stamp_pool.close() templ_stamp_pool.join() with Pool( self.nwrite ) as sci_stamp_pool: for sci_image in self.science_images: for templ_image in self.template_images: pair = (sci_image, templ_image) sci_stamp_pool.apply_async( self.do_stamps, pair, {}, callback = partial(self.save_stamp_paths, sci_image, templ_image), error_callback=partial(SNLogger.error, "do_stamps subprocess failure: {x}") ) sci_stamp_pool.close() sci_stamp_pool.join() else: for tsargs in templstamp_args: partialstamp(*tsargs) for sci_image in self.science_images: for templ_image in self.template_images: stamp_paths = self.do_stamps( sci_image, templ_image) self.save_stamp_paths( sci_image, templ_image, stamp_paths ) SNLogger.info('...finished making stamps.') if self.mem_trace: SNLogger.info( f"After make_stamps, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" ) if 'make_lightcurve' in steps: SNLogger.info( "Making lightcurve" ) with nvtx.annotate( "make_lightcurve", color=0xff8888 ): self.make_lightcurve() if self.mem_trace: SNLogger.info( f"After make_lightcurve, memory usage = \ {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" ) if self.remove_temp_dir: self.clear_contents( self.temp_dir )
# ====================================================================== def main(): # Run one arg pass just to get the config file, so we can augment # the full arg parser later with config options configparser = argparse.ArgumentParser( add_help=False ) configparser.add_argument( '-c', '--config-file', default=None, help=( "Location of the .yaml config file; defaults to the value of the " "SNPIT_CONFIG environment varaible." ) ) args, leftovers = configparser.parse_known_args() try: cfg = Config.get( args.config_file, setdefault=True ) except RuntimeError as e: if str(e) == 'No default config defined yet; run Config.init(configfile)': sys.stderr.write( "Error, no configuration file defined.\n" "Either run phrosty with -c <configfile>\n" "or set the SNPIT_CONFIG environment varaible.\n" ) sys.exit(1) else: raise parser = argparse.ArgumentParser() # Put in the config_file argument, even though it will never be found, so it shows up in help parser.add_argument( '-c', '--config-file', help="Location of the .yaml config file" ) parser.add_argument( '--oid', type=int, required=True, help="Object ID" ) parser.add_argument( '-r', '--ra', type=float, required=True, help="Object RA" ) parser.add_argument( '-d', '--dec', type=float, required=True, help="Object Dec" ) parser.add_argument( '-b', '--band', type=str, required=True, help="Band: R062, Z087, Y106, J129, H158, F184, or K213" ) parser.add_argument( '-t', '--template-images', type=str, required=True, help="Path to file with, per line, ( path_to_image, pointing, sca )" ) parser.add_argument( '-s', '--science-images', type=str, required=True, help="Path to file with, per line, ( path_to_image, pointing, sca )" ) parser.add_argument( '-p', '--nprocs', type=int, default=1, help="Number of process for multiprocessing steps (e.g. skysub)" ) parser.add_argument( '-w', '--nwrite', type=int, default=5, help="Number of parallel FITS writing processes" ) parser.add_argument( '-v', '--verbose', action='store_true', default=False, help="Show debug log info" ) parser.add_argument( '--through-step', default='make_lightcurve', help="Stop after this step; one of (see above)" ) cfg.augment_argparse( parser ) args = parser.parse_args( leftovers ) cfg.parse_args( args ) science_images = [] template_images = [] for infile, imlist in zip( [ args.science_images, args.template_images ], [ science_images, template_images ] ): with open( infile ) as ifp: hdrline = ifp.readline() if not re.search( r"^\s*path\s+pointing\s+sca\s+mjd\s+band\s*$", hdrline ): raise ValueError( f"First line of list file {infile} didn't match what was expected." ) for line in ifp: img, point, sca, mjd, band = line.split() imlist.append( ( pathlib.Path(img), int(point), int(sca), float(mjd), band ) ) pipeline = Pipeline( args.oid, args.ra, args.dec, args.band, science_images, template_images, nprocs=args.nprocs, nwrite=args.nwrite, verbose=args.verbose ) pipeline( args.through_step ) # ====================================================================== if __name__ == "__main__": main()