Source code for paircars.utils.flagging

import numpy as np
import traceback
import os
import dask
from datetime import datetime as dt, timezone
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr
from .basic_utils import suppress_output
from .calibration import get_quartical_soltype
from .resource_utils import limit_threads
from .imaging import calc_maxuv


###############################
# Flagging related functions
################################
[docs] def flagsummary(msname, summary_file): """ Save flag summary Parameters ---------- msname : str Measurement set name summary_file : str Summary file name Returns ------- str Summary file """ from casatasks import flagdata with suppress_output(): s = flagdata(vis=msname, mode="summary") allkeys = s.keys() with open(summary_file, "w") as f: f.write(f"Flag summary of: {msname}\n") for x in allkeys: try: for y in s[x].keys(): try: flagged_percent = 100.0 * ( s[x][y]["flagged"] / s[x][y]["total"] ) logstring = f"{x} {y} {flagged_percent}\n" f.write(logstring) except Exception: pass except Exception: pass return summary_file
[docs] def do_flag_backup(msname, flagtype="flagdata"): """ Take a flag backup Parameters ---------- msname : str Measurement set name flagtype : str, optional Flag type """ from casatools import agentflagger af = agentflagger() af.open(msname) versionlist = af.getflagversionlist() if len(versionlist) != 0: for version_name in versionlist: if flagtype in version_name: try: version_num = ( int(version_name.split(":")[0].split(" ")[0].split("_")[-1]) + 1 ) except BaseException: version_num = 1 else: version_num = 1 else: version_num = 1 dt_string = dt.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") af.saveflagversion( flagtype + "_" + str(version_num), "Flags autosave on " + dt_string ) af.done()
[docs] def uvbin_flag( msname, uvbin_size=10, datacolumn="data", mode="rflag", threshold=10.0, flagbackup=True, ): """ Perform uv-bin flag Parameters ---------- msname : str Measurement set uvbin_size : float, optional UV-bin size in wavelength datacolumn : str, optional Data column mode : str, optional Flag mode (rflag or tfcrop) threshold : float, optional Flagging threshold flagbackup : bool, optional Flag backup """ from casatasks import flagdata, flagmanager try: maxuv_m, maxuv_l = calc_maxuv(msname) if flagbackup: do_flag_backup(msname, flagtype="uvbin_flagdata") maxuv_l = int(maxuv_l) uvbin_size = int(uvbin_size) for i in range(0, maxuv_l, uvbin_size): try: with suppress_output(): if mode == "rflag": flagdata( vis=msname, mode=mode, datacolumn=datacolumn, uvrange=f"{i}~{i+uvbin_size}lambda", timedevscale=threshold, freqdevscale=threshold, ntime="2s", flagbackup=False, ) else: flagdata( vis=msname, mode=mode, datacolumn=datacolumn, uvrange=f"{i}~{i+uvbin_size}lambda", timecutoff=threshold, freqcutoff=threshold, flagbackup=False, ) except Exception: pass return 0 except Exception: traceback.print_exc() if flagbackup: with suppress_output(): flagmanager(vis=msname, mode="restore", versionname="uvbin_flagdata_1") flagmanager(vis=msname, mode="delete", versionname="uvbin_flagdata_1") return 1
[docs] def get_chans_flag_per_time(msname): """ Get flagged channel fraction per time Parameters ---------- msname : str Measurement set Returns ------- list Timestamps (in MJD seconds) list Channel flag fractions """ from casatools import ms as casamstool, msmetadata msmd = msmetadata() msmd.open(msname) times = msmd.timesforspws(0) msmd.close() mstool = casamstool() mstool.open(msname) flag = mstool.getdata("FLAG", ifraxis=True)["flag"] mstool.close() shape = flag.shape npol = shape[0] nchan = shape[1] nbaseline = shape[2] if len(times) > 1: flag_frac = np.nansum(flag, axis=(0, 1, 2)) / (npol * nchan * nbaseline) else: flag_frac = np.array([np.nansum(flag) / np.size(flag)]) flag_frac = flag_frac.tolist() times = times.tolist() return times, flag_frac
[docs] def get_unflagged_antennas( msname="", scan="", n_threads=-1, ): """ Get unflagged antennas of a scan Parameters ---------- msname : str Name of the measurement set scan : str Scans Returns ------- numpy.array Unflagged antenna names numpy.array Flag fraction list """ n_threads = max(1, n_threads) limit_threads(n_threads=n_threads) from casatasks import flagdata msname = msname.rstrip("/") mspath = os.path.dirname(os.path.abspath(msname)) os.chdir(mspath) with suppress_output(): flag_summary = flagdata(vis=msname, scan=str(scan), mode="summary") antenna_flags = flag_summary["antenna"] unflagged_antenna_names = [] flag_frac_list = [] for ant in antenna_flags.keys(): flag_frac = antenna_flags[ant]["flagged"] / antenna_flags[ant]["total"] if flag_frac < 1.0: unflagged_antenna_names.append(ant) flag_frac_list.append(flag_frac) unflagged_antenna_names = np.array(unflagged_antenna_names) flag_frac_list = np.array(flag_frac_list) return unflagged_antenna_names, flag_frac_list
[docs] def get_chans_flag( msname="", field="", n_threads=-1, ): """ Get flag/unflag channel list Parameters ---------- msname : str Measurement set name field : str, optional Field name or ID Returns ------- list Unflag channel list list Flag channel list """ n_threads = max(1, n_threads) limit_threads(n_threads=n_threads) from casatasks import flagdata msname = msname.rstrip("/") mspath = os.path.dirname(os.path.abspath(msname)) os.chdir(mspath) with suppress_output(): summary = flagdata(vis=msname, field=field, mode="summary", spwchan=True) unflag_chans = [] flag_chans = [] for chan in summary["spw:channel"]: r = summary["spw:channel"][chan] chan_number = int(chan.split("0:")[-1]) flag_frac = r["flagged"] / r["total"] if flag_frac == 1: flag_chans.append(chan_number) else: unflag_chans.append(chan_number) return unflag_chans, flag_chans
[docs] def calc_flag_fraction( msname="", field="", scan="", n_threads=-1, ): """ Function to calculate the fraction of total data flagged. Parameters ---------- msname : str Name of the measurement set field : str, optional Field names scan : str, optional Scan names Returns ------- float Fraction of the total data flagged """ n_threads = max(1, n_threads) limit_threads(n_threads=n_threads) from casatasks import flagdata msname = msname.rstrip("/") mspath = os.path.dirname(os.path.abspath(msname)) os.chdir(mspath) with suppress_output(): summary = flagdata(vis=msname, field=field, scan=scan, mode="summary") flagged_fraction = summary["flagged"] / summary["total"] return flagged_fraction
[docs] def flag_outside_uvrange( vis, uvrange, flagbackup=True, n_threads=-1, ): """ Flag outside the given uv range Parameters ---------- vis : str Measurement set name uvrange : str UV-range flagbackup : bool, optional Flag backup """ n_threads = max(1, n_threads) limit_threads(n_threads=n_threads) from casatasks import flagdata try: if "lambda" in uvrange: islambda = True uvrange = uvrange.replace("lambda", "") else: islambda = False if "~" in uvrange: low, high = uvrange.split("~") if islambda: low = f"{low}lambda" high = f"{high}lambda" cmds = [ {"mode": "manual", "uvrange": f"<{low}", "flagbackup": flagbackup}, {"mode": "manual", "uvrange": f">{high}", "flagbackup": flagbackup}, ] elif ">" in uvrange: low = uvrange.split(">")[-1] if islambda: low = f"{low}lambda" cmds = [ {"mode": "manual", "uvrange": f"<{low}", "flagbackup": flagbackup}, ] elif "<" in uvrange: high = uvrange.split("<")[-1] if islambda: high = f"{high}lambda" cmds = [ {"mode": "manual", "uvrange": f">{high}", "flagbackup": flagbackup}, ] else: cmds = [] if len(cmds) > 0: for cmd in cmds: print(f"Flagging command: {cmd}") flagdata(vis=vis, **cmd) return 0 except Exception: traceback.print_exc() return 1
[docs] def flag_quartical_table(caltable, threshold=10.0): """ Flag quartical caltable Parameters ---------- caltable : str Caltable name threshold : float Flagging threshold Returns ------- str Flagged caltable name """ caltable = caltable.rstrip("/") soltypes = get_quartical_soltype(caltable) if len(soltypes) == 0: print("No solution is present. Not performing any flagging.") return caltable soltype = soltypes[0] gains = xds_from_zarr(f"{caltable}::{soltype}") gain_data = gains[0].gains.to_numpy() # Shape: ntime, nchan, nant, ndir, npol gain_flag = gains[0].gain_flags.to_numpy() np.nansum(gain_flag) gain_flag = gain_flag.astype("bool") gain_data[gain_flag] = np.nan g1 = gain_data[..., 0] g2 = gain_data[..., 3] d1 = gain_data[..., 1] d2 = gain_data[..., 2] g1_real = np.real(g1) - np.nanmedian(np.real(g1)) g1_imag = np.imag(g1) - np.nanmedian(np.imag(g1)) g2_real = np.real(g2) - np.nanmedian(np.real(g2)) g2_imag = np.imag(g2) - np.nanmedian(np.imag(g2)) d1_real = np.real(d1) - np.nanmedian(np.real(d1)) d1_imag = np.imag(d1) - np.nanmedian(np.imag(d1)) d2_real = np.real(d2) - np.nanmedian(np.real(d2)) d2_imag = np.imag(d2) - np.nanmedian(np.imag(d2)) g1_real_std = np.nanstd(g1_real) g1_imag_std = np.nanstd(g1_imag) g2_real_std = np.nanstd(g2_real) g2_imag_std = np.nanstd(g2_imag) d1_real_std = np.nanstd(d1_real) d1_imag_std = np.nanstd(d1_imag) d2_real_std = np.nanstd(d2_real) d2_imag_std = np.nanstd(d2_imag) pos = np.where( (d1_real > threshold * d1_real_std) | (d1_imag > threshold * d1_imag_std) | (d2_real > threshold * d2_real_std) | (d2_imag > threshold * d2_imag_std) | (g1_real > threshold * g1_real_std) | (g1_imag > threshold * g1_imag_std) | (g2_real > threshold * g2_real_std) | (g2_imag > threshold * g2_imag_std) ) gain_data[np.isnan(gain_data)][..., 0] = 1.0 gain_data[np.isnan(gain_data)][..., 1] = 0.0 gain_data[np.isnan(gain_data)][..., 2] = 0.0 gain_data[np.isnan(gain_data)][..., 3] = 1.0 gain_flag[pos] = True gain_data[gain_flag][..., 0] = 1.0 gain_data[gain_flag][..., 1] = 0.0 gain_data[gain_flag][..., 2] = 0.0 gain_data[gain_flag][..., 3] = 1.0 shape = gain_flag.shape ntime = shape[0] nchan = shape[1] shape[2] ndir = shape[3] ant_frac = np.nansum(gain_flag, axis=(0, 1, 3)) / (ntime * nchan * ndir) pos = np.where(ant_frac > 0.5)[0] gain_flag[:, :, pos, :] = True gain_flag = gain_flag.astype("int") gains[0].update( { "gain_flags": ( ["gain_time", "gain_freq", "antenna", "direction"], gain_flag, ) } ) gains[0].update( { "gains": ( ["gain_time", "gain_freq", "antenna", "direction", "correlation"], gain_data, ) } ) output_path = f"{caltable}::{soltype}" os.system(f"rm -rf {caltable}") write_xds_list = xds_to_zarr(gains, output_path) dask.compute(write_xds_list) return caltable