Source code for paircars.utils.image_utils

import numpy as np
import traceback
import warnings
import copy
import os
from collections import defaultdict
from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs import FITSFixedWarning
from .basic_utils import average_timestamp, timestamp_to_mjdsec
from .udocker_utils import run_wsclean

warnings.simplefilter("ignore", category=FITSFixedWarning)


##########################
# Image analysis related
##########################
[docs] def create_circular_mask(msname, cellsize, imsize, mask_radius=20): """ Create fits solar mask Parameters ---------- msname : str Name of the measurement set cellsize : float Cell size in arcsec imsize : int Imsize in number of pixels mask_radius : float Mask radius in arcmin Returns ------- str Fits mask file name """ try: msname = msname.rstrip("/") imagename_prefix = ( os.path.dirname(os.path.abspath(msname)) + "/" + os.path.basename(msname).split(".ms")[0] + "_solar" ) wsclean_args = [ "-quiet", "-scale " + str(cellsize) + "asec", "-size " + str(imsize) + " " + str(imsize), "-nwlayers 1", "-niter 0 -name " + imagename_prefix, "-channel-range 0 1", "-interval 0 1", ] wsclean_cmd = "wsclean " + " ".join(wsclean_args) + " " + msname msg = run_wsclean(wsclean_cmd, "paircarswsclean", verbose=False) if msg == 0: center = (int(imsize / 2), int(imsize / 2)) radius = mask_radius * 60 / cellsize Y, X = np.ogrid[:imsize, :imsize] dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) mask = dist_from_center <= radius os.system( "cp -r " + imagename_prefix + "-image.fits mask-" + os.path.basename(imagename_prefix) + ".fits" ) os.system("rm -rf " + imagename_prefix + "*") data = fits.getdata("mask-" + os.path.basename(imagename_prefix) + ".fits") header = fits.getheader( "mask-" + os.path.basename(imagename_prefix) + ".fits" ) data[0, 0, ...][mask] = 1.0 data[0, 0, ...][~mask] = 0.0 fits.writeto( imagename_prefix + "-mask.fits", data=data, header=header, overwrite=True, ) os.system("rm -rf mask-" + os.path.basename(imagename_prefix) + ".fits") if os.path.exists(imagename_prefix + "-mask.fits"): return imagename_prefix + "-mask.fits" else: print("Circular mask could not be created.") return else: print("Circular mask could not be created.") return except Exception: traceback.print_exc() return
[docs] def create_circular_mask_array(data, radius, center_x=None, center_y=None): """ Creating circular mask of a Numpy array Parameters ---------- data : numpy.array 2D numpy array radius : int Radius in pixels Returns ------- numpy.array Mask array """ shape = data.shape center = (shape[0] // 2, shape[1] // 2) if center_x is None: center_x = center[1] if center_y is None: center_y = center[0] Y, X = np.ogrid[: shape[0], : shape[1]] dist_from_center = (X - center_x) ** 2 + (Y - center_y) ** 2 mask = dist_from_center <= radius**2 return mask
[docs] def calc_solar_image_stat(imagename, disc_size=50): """ Calculate solar image dynamic range Parameters ---------- imagename : str Fits image name disc_size : float, optional Solar disc size in arcmin (default : 50) Returns ------- float Maximum value float Minimum value float RMS values float Total value float Mean value float Median value float RMS dynamic range float Min-max dynamic range """ data = fits.getdata(imagename) header = fits.getheader(imagename) total_pix = int(header["NAXIS1"]) pix_size = abs(header["CDELT1"]) * 3600.0 # In arcsec radius = int((disc_size * 60) / pix_size) if radius > total_pix: radius = total_pix / 4.0 if data.ndim == 4: data = data[0, 0, ...] elif data.ndim == 3: data = data[0, ...] else: data = data mask = create_circular_mask_array(data, radius) masked_data = copy.deepcopy(data) masked_data[mask] = np.nan unmasked_data = copy.deepcopy(data) unmasked_data[~mask] = np.nan maxval = float(np.nanmax(unmasked_data)) minval = float(np.nanmin(masked_data)) rms = float(np.nanstd(masked_data)) total_val = float(np.nansum(unmasked_data)) if rms != 0: rms_dyn = float(maxval / rms) else: rms_dyn = np.nan if abs(minval) != 0: minmax_dyn = float(maxval / abs(minval)) else: minmax_dyn = np.nan mean_val = float(np.nanmean(unmasked_data)) median_val = float(np.nanmedian(unmasked_data)) del data, mask, unmasked_data, masked_data return ( round(maxval, 2), round(minval, 2), round(rms, 2), round(total_val, 2), round(mean_val, 2), round(median_val, 2), round(rms_dyn, 2), round(minmax_dyn, 2), )
[docs] def calc_dyn_range(imagename, modelname, residualname, fits_mask=""): """ Calculate dynamic ranges. Parameters ---------- imagename : list or str Image FITS file(s) modelname : list or str Model FITS file(s) residualname : list ot str Residual FITS file(s) fits_mask : str, optional FITS file mask Returns ------- model_flux : float Total model flux. dyn_range_rms : float Max/RMS dynamic range. rms : float RMS of the image """ def load_data(name): return fits.getdata(name) def to_list(x): return [x] if isinstance(x, str) else x imagename = to_list(imagename) modelname = to_list(modelname) residualname = to_list(residualname) use_mask = bool(fits_mask and os.path.exists(fits_mask)) mask_data = fits.getdata(fits_mask).astype(bool) if use_mask else None if mask_data is not None: mask_data = mask_data[0, 0, ...] model_flux, dr1, rmsvalue = 0, 0, 0 for i in range(len(imagename)): img = imagename[i] res = residualname[i] image = load_data(img) residual = load_data(res) rms = np.nanstd(residual) image = image[0, 0, ...] residual = residual[0, 0, ...] if use_mask: maxval = np.nanmax(image[mask_data]) else: maxval = np.nanmax(image) dr1 += maxval / rms if rms else 0 rmsvalue += rms for mod in modelname: model = load_data(mod) model = model[0, 0, ...] model_flux += np.nansum(model[mask_data] if use_mask else model) rmsvalue = rmsvalue / np.sqrt(len(residualname)) return float(model_flux), round(float(dr1), 2), round(float(rmsvalue), 2)
[docs] def generate_tb_map(imagename, outfile=""): """ Function to generate brightness temperature map Parameters ---------- imagename : str Name of the flux calibrated image outfile : str, optional Output brightess temperature image name Returns ------- str Output image name """ if outfile == "": outfile = imagename.split(".fits")[0] + "_TB.fits" image_header = fits.getheader(imagename) image_data = fits.getdata(imagename) major = float(image_header["BMAJ"]) * 3600.0 # In arcsec minor = float(image_header["BMIN"]) * 3600.0 # In arcsec if image_header["CTYPE3"] == "FREQ": freq = image_header["CRVAL3"] / 10**9 # In GHz elif image_header["CTYPE4"] == "FREQ": freq = image_header["CRVAL4"] / 10**9 # In GHz else: print(f"No frequency information is present in header for {imagename}.") return TB_conv_factor = (1.222e6) / ((freq**2) * major * minor) TB_data = image_data * TB_conv_factor image_header["BUNIT"] = "K" fits.writeto(outfile, data=TB_data, header=image_header, overwrite=True) return outfile
[docs] def cutout_image(fits_file, output_file, x_deg=2): """ Cutout central part of the image Parameters ---------- fits_file : str Input fits file output_file : str Output fits file name (If same as input, input image will be overwritten) x_deg : float, optional Size of the output image in degree Returns ------- str Output image name """ hdu = fits.open(fits_file)[0] data = hdu.data # shape: (nfreq, nstokes, ny, nx) header = hdu.header WCS(header) _, _, ny, nx = data.shape center_x, center_y = nx // 2, ny // 2 # Get pixel scale (deg/pixel) pix_scale_deg = np.abs(header["CDELT1"]) x_pix = int((x_deg / pix_scale_deg) / 2) # Adjust if cutout size exceeds image size max_half_x = nx // 2 max_half_y = ny // 2 x_pix = min(x_pix, max_half_x) y_pix = min(x_pix, max_half_y) # Assume square pixels # Define slice indices x0 = center_x - x_pix x1 = center_x + x_pix y0 = center_y - y_pix y1 = center_y + y_pix # Slice data cutout_data = data[:, :, y0:y1, x0:x1] # Update header new_header = header.copy() new_header["NAXIS1"] = x1 - x0 new_header["NAXIS2"] = y1 - y0 new_header["CRPIX1"] -= x0 new_header["CRPIX2"] -= y0 # Save fits.writeto(output_file, cutout_data, header=new_header, overwrite=True) return output_file
[docs] def make_timeavg_image(wsclean_images, outfile_name, keep_wsclean_images=True): """ Convert WSClean images into a time averaged image Parameters ---------- wsclean_images : list List of WSClean images. outfile_name : str Name of the output file. keep_wsclean_images : bool, optional Whether to retain the original WSClean images (default: True). Returns ------- str Output image name. """ timestamps = [] data = [] for i in range(len(wsclean_images)): image = wsclean_images[i] image_data = fits.getdata(image) if len(data) == 0: data.append(image_data) else: last_data = data[-1] if image_data.shape == last_data.shape: data.append(image_data) else: print( f"Image data shape: {image_data.shape} does not match with last data: {last_data.shape}" ) timestamps.append(fits.getheader(image)["DATE-OBS"]) data = np.array(data) data = np.nanmean(data, axis=0) avg_timestamp = average_timestamp(timestamps) header = fits.getheader(wsclean_images[0]) header["DATE-OBS"] = avg_timestamp fits.writeto(outfile_name, data=data, header=header, overwrite=True) if not keep_wsclean_images: for img in wsclean_images: os.system(f"rm -rf {img}") return outfile_name
[docs] def make_freqavg_image(wsclean_images, outfile_name, keep_wsclean_images=True): """ Convert WSClean images into a frequency averaged image Parameters ---------- wsclean_images : list List of WSClean images. outfile_name : str Name of the output file. keep_wsclean_images : bool, optional Whether to retain the original WSClean images (default: True). Returns ------- str Output image name. """ freqs = [] for i in range(len(wsclean_images)): image = wsclean_images[i] if i == 0: data = fits.getdata(image) else: data += fits.getdata(image) header = fits.getheader(image) if header["CTYPE3"] == "FREQ": freqs.append(float(header["CRVAL3"])) freqaxis = 3 elif header["CTYPE4"] == "FREQ": freqs.append(float(header["CRVAL4"])) freqaxis = 4 data /= len(wsclean_images) if len(freqs) > 0: mean_freq = np.nanmean(freqs) width = max(freqs) - min(freqs) header = fits.getheader(wsclean_images[0]) if freqaxis == 3: header["CRAVL3"] = mean_freq header["CDELT3"] = width elif freqaxis == 4: header["CRAVL4"] = mean_freq header["CDELT4"] = width fits.writeto(outfile_name, data=data, header=header, overwrite=True) if not keep_wsclean_images: for img in wsclean_images: os.system(f"rm -rf {img}") return outfile_name
[docs] def make_stokes_wsclean_imagecube( wsclean_images, outfile_name, keep_wsclean_images=True ): """ Convert WSClean images into a Stokes cube image. Parameters ---------- wsclean_images : list List of WSClean images. outfile_name : str Name of the output file. keep_wsclean_images : bool, optional Whether to retain the original WSClean images (default: True). Returns ------- str Output image name. """ stokes = sorted( set( ( os.path.basename(i).split(".fits")[0].split(" - ")[-2] if " - " in i else "I" ) for i in wsclean_images ) ) valid_stokes = [ {"I"}, {"I", "V"}, {"I", "Q", "U", "V"}, {"XX", "YY"}, {"LL", "RR"}, {"Q", "U"}, {"I", "Q"}, ] if set(stokes) not in valid_stokes: print("Invalid Stokes combination.") return imagename_prefix = "temp_" + os.path.basename(wsclean_images[0]).split(" - I")[0] imagename_prefix + ".image" data, header = fits.getdata(wsclean_images[0]), fits.getheader(wsclean_images[0]) for img in wsclean_images[1:]: data = np.append(data, fits.getdata(img), axis=0) header.update( {"NAXIS4": len(stokes), "CRVAL4": 1 if "I" in stokes else -5, "CDELT4": 1} ) imagename_prefix + ".fits" fits.writeto(outfile_name, data=data, header=header, overwrite=True) if not keep_wsclean_images: for img in wsclean_images: os.system(f"rm -rf {img}") return outfile_name
[docs] def filter_images(imagelist, min_time_sep=60.0): """ Select images with maximum bandwidth, then for each frequency keep images separated by at least `min_time_sep` seconds. Parameters ---------- imagelist : list Image list min_time_sep : float, optional Minimum time seperation in seconds Returns ------- list Filtered image list """ image_info = [] try: for image in imagelist: header = fits.getheader(image) bw = -1 freq = -1 if header.get("CTYPE3") == "FREQ": bw = abs(float(header.get("CDELT3", -1))) / 1e6 freq = float(header.get("CRVAL3", -1)) elif header.get("CTYPE4") == "FREQ": bw = abs(float(header.get("CDELT4", -1))) / 1e6 freq = float(header.get("CRVAL4", -1)) bw = round(bw, 2) timeobs = header["DATE-OBS"].split(".")[0] mjdsec = timestamp_to_mjdsec(timeobs, date_format=1) image_info.append( {"image": image, "bw": bw, "freq": freq, "mjdsec": mjdsec} ) bws = np.array([info["bw"] for info in image_info]) max_bw = np.max(bws) image_info = [info for info in image_info if info["bw"] == max_bw] freq_groups = defaultdict(list) for info in image_info: freq_groups[info["freq"]].append(info) final_images = [] for freq, group in freq_groups.items(): group = sorted(group, key=lambda x: x["mjdsec"]) last_time = -np.inf if min_time_sep > 0: for info in group: if info["mjdsec"] - last_time >= min_time_sep: final_images.append(info["image"]) last_time = info["mjdsec"] else: info = group[int(len(group) / 2)] final_images.append(info["image"]) return sorted(final_images) except Exception: print("Error in filtering out images.") traceback.print_exc() return []