Source code for paircars.pipeline.do_target_split

import logging
import numpy as np
import argparse
import time
import sys
import os
from casatools import msmetadata
from dask import delayed
from astropy.io import fits
from paircars.utils.basic_utils import print_banner, capture_all_output
from paircars.utils.casatasks import single_mstransform
from paircars.utils.logger_utils import (
    SmartDefaultsHelpFormatter,
    clean_shutdown,
    init_logger,
    get_logger_safe,
)
from paircars.utils.ms_metadata import get_timeranges
from paircars.utils.mwa_utils import (
    get_MWA_coarse_bands,
    get_MWA_coarse_chan,
)
from paircars.utils.proc_manage_utils import (
    scale_worker_and_wait,
    get_local_dask_cluster,
)
from paircars.utils.resource_utils import drop_cache

logging.getLogger("distributed").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.CRITICAL)


[docs] def chanlist_to_str(lst): lst = sorted(lst) ranges = [] start = lst[0] for i in range(1, len(lst)): if lst[i] != lst[i - 1] + 1: if lst[i - 1] > start: ranges.append(f"{start}~{lst[i - 1]}") elif lst[i - 1] == start: ranges.append(f"{start}") start = lst[i] if lst[-1] > start: ranges.append(f"{start}~{lst[-1]}") elif lst[-1] == start: ranges.append(f"{start}") return ";".join(ranges)
[docs] def single_mstransform_wrapper(**kwargs): with capture_all_output() as (out, err): result = single_mstransform(**kwargs) return result, out.getvalue(), err.getvalue()
[docs] def split_target_scans( mslist, metafits, dask_client, workdir, timeres, freqres, datacolumn, split_coarse_chans=[], scan=1, prefix="targets", time_interval=-1, time_window=-1, quack_timestamps=-1, force_split=False, only_disk=False, n_threads=-1, logger=None, ): """ Split target scans Parameters ---------- mslist : list Measurement set list metafits : str Metafits file dask_client : dask.client Dask client workdir : str Work directory timeres : float Time resolution in seconds freqres : float Frequency resolution in MHz datacolumn : str Data column to split split_coarse_chans : list, optional Split coarse channels scan : int, optional Scan to split prefix : str, optional Splited ms prefix time_interval : float Time interval in seconds time_window : float Time window in seconds quack_timestamps : int, optional Number of timestamps ignored at the start and end of each scan force_split : bool, optional Force split only_disk : bool, optional Split only disk n_threads : int, optional Number of threads to use Returns ------- int Success message list Splited ms list """ if logger is None: logger = get_logger_safe() n_threads = max(1, n_threads) if len(mslist) == 0: logger.critical("Please provide a valid measurement set list.") return 1, [] try: os.chdir(workdir) logger.debug(f"Current working directory: {os.getcwd()}") ####################################### # Extracting time frequency information ####################################### header = fits.getheader(metafits) obsid = header["GPSTIME"] mode = header["MODE"] if "MWAX" in mode: flag_central_chan = False else: flag_central_chan = True logger.debug(f"Flag central channel: {flag_central_chan} for {mode}") tasks = [] splited_ms_list = [] for msname in mslist: msmd = msmetadata() msmd.open(msname) chanres = msmd.chanres(0, unit="MHz")[0] msmd.close() if freqres > 0: # Image resolution is in MHz chanwidth = int(freqres / chanres) if chanwidth < 1: chanwidth = 1 else: chanwidth = 1 if timeres > 0: # Image resolution is in seconds timebin = str(timeres) + "s" else: timebin = "" ############################# # Making spectral chunks ############################# coarse_channel_bands = get_MWA_coarse_bands( msname, flag_central_chan=flag_central_chan ) coarse_chans = get_MWA_coarse_chan(msname) logger.debug(f"Coarse channels for {msname} are: {coarse_chans}") if len(split_coarse_chans) == 0: use_coarse_chans = coarse_chans else: use_coarse_chans = split_coarse_chans logger.debug(f"Using coarse channels for {msname} are: {use_coarse_chans}") coarse_chlist = [] good_spwlist = [] for c in range(len(coarse_channel_bands)): coarse_chan = coarse_chans[c] if coarse_chan in use_coarse_chans: chan = coarse_channel_bands[c] good_chans = chan[2] good_chans = [f"{i}" for i in good_chans] good_spwlist.append(f"0:{';'.join(good_chans)}") coarse_chlist.append(f"{coarse_chan}") timerange_list = get_timeranges( msname, time_interval, time_window, only_disk=only_disk, quack_timestamps=quack_timestamps, ) timerange = ",".join(timerange_list) for i in range(len(coarse_chlist)): good_spw = good_spwlist[i] coarse_chan = coarse_chlist[i] outputvis = f"{workdir}/{prefix}_{obsid}_ch_{coarse_chan}.ms" if os.path.exists(f"{outputvis}/.splited") and force_split is False: logger.info(f"{outputvis} is already splited successfully.") splited_ms_list.append(outputvis) else: if os.path.exists(outputvis): logger.debug(f"Deleteing pre-existing output ms: {outputvis}") os.system(f"rm -rf {outputvis}") if os.path.exists(f"{outputvis}.flagversions"): logger.debug( f"Deleteing pre-existing output ms flags: {outputvis}.flagversions" ) os.system(f"rm -rf {outputvis}.flagversions") logger.debug("Spliting parameters:") logger.debug( f"Channel width: {chanwidth}, timebin: {timebin}, datacolumn: {datacolumn}, spectral window: {good_spw}, time range: {timerange}" ) tasks.append( delayed(single_mstransform_wrapper)( msname=msname, outputms=outputvis, width=chanwidth, timebin=timebin, datacolumn=datacolumn, spw=good_spw, corr="", timerange=timerange, n_threads=n_threads, ) ) future = dask_client.compute(tasks) result_wrapper = dask_client.gather(future) result = [] for r in result_wrapper: result.append(r[0]) logger.debug("================") logger.debug(f"Worker log for: {os.path.basename(r[0])}") logger.debug("================") for line in r[1].splitlines(): logger.debug(line) for line in r[2].splitlines(): logger.debug(line) splited_ms_list = splited_ms_list + result if len(splited_ms_list) == 0: logger.error(f"Spliting of measurement set: {msname} is unsuccessful.") return 1, [] else: logger.info(f"Spliting of measurement set: {msname} is done successfully.") for splited_ms in splited_ms_list: drop_cache(splited_ms) return 0, splited_ms_list except Exception: logger.exception( f"Spliting of measurement set: {msname} is unsuccessful.", exc_info=True ) return 1, []
[docs] def main( mslist, metafits, workdir="", datacolumn="data", split_coarse_chans=[], scan=1, time_window=-1, time_interval=-1, quack_timestamps=-1, freqres=-1, timeres=-1, prefix="targets", force_split=False, only_disk=False, cpu_frac=0.8, mem_frac=0.8, logfile=None, jobid=0, start_remote_log=False, verbose=False, dask_client=None, ): """ Split target scans from a measurement set into smaller chunks for parallel processing. Parameters ---------- mslist : str Measurement sets (comma separated). metafits : str Metafits file workdir : str, optional Working directory for intermediate and output products. If empty, defaults to `<msname>/workdir`. datacolumn : str, optional Column of the MS to use for splitting (e.g., "DATA", "CORRECTED"). Default is "data". split_coarse_chans : list, optional Split coarse channels scan : int, optional Scan numbers to split. time_window : float, optional Time window in seconds for a single time chunk. Set -1 to disable. Default is -1. time_interval : float, optional Time interval in seconds between two time chunks. Set -1 to disable. Default is -1. quack_timestamps : int, optional Number of timestamps to flag at the beginning and end of each scan ("quack"). -1 to disable. Default is -1. freqres : float, optional Frequency resolution in MHz for spectral averaging. Set -1 to disable. Default is -1. timeres : float, optional Time resolution in seconds for time averaging. Set -1 to disable. Default is -1. prefix : str, optional Prefix for the output split MS files. Default is "targets". force_split : bool, optional Force to split only_disk : bool, optional Split only disk visible times cpu_frac : float, optional Fraction of available CPUs to allocate per task. Default is 0.8. mem_frac : float, optional Fraction of available memory to allocate per task. Default is 0.8. logfile : str or None, optional Path to log file. If None, logging to file is disabled. Default is None. jobid : int, optional Job identifier for tracking and PID storage. Default is 0. start_remote_log : bool, optional If True, enables remote logging using credentials stored in workdir. Default is False. verbose : bool, optional Verbose logs dask_client : dask.client, optional Dask client Returns ------- int Success message int Expected splited ms int Succeeded splited ms """ logger = get_logger_safe() if verbose: logger.setLevel(logging.DEBUG) cpu_frac = min(0.8, abs(cpu_frac)) mem_frac = min(0.8, abs(mem_frac)) mslist = mslist.split(",") if workdir == "": workdir = os.path.dirname(os.path.abspath(mslist[0])) + "/workdir" os.makedirs(workdir, exist_ok=True) os.chdir(workdir) logger.debug(f"Current working directory: {os.getcwd()}") ############ # Logger ############ observer = None if ( start_remote_log and os.path.exists(f"{workdir}/.jobname_password.npy") and logfile is not None ): time.sleep(5) jobname, password = np.load( f"{workdir}/.jobname_password.npy", allow_pickle=True ) if os.path.exists(logfile): observer = init_logger( "do_target_split", logfile, jobname=jobname, password=password ) if observer is None: logger.info( "Remote link or jobname is blank. Not transmiting to remote logger." ) if len(mslist) == 0: logger.critical("Please provide a valid measurement set list.") return 1, 0, 0 else: total_ncoarse = 0 for msname in mslist: ms_coarse_chans = get_MWA_coarse_chan(msname) if len(split_coarse_chans) > 0: ms_coarse_chans = list(set(ms_coarse_chans) & set(split_coarse_chans)) ncoarse = len(ms_coarse_chans) total_ncoarse += ncoarse total_ncoarse = max(1, total_ncoarse) logger.debug(f"Total usable coarse channels: {total_ncoarse}") expected = total_ncoarse succeed = 0 dask_cluster = None if dask_client is None: dask_client, dask_cluster, dask_dir, nworker = get_local_dask_cluster( workdir, cpu_frac=cpu_frac, mem_frac=mem_frac, max_worker=total_ncoarse + 1, ) if dask_client is None: logger.critical("Error occured in creating local cluster.") return 1, expected, succeed scale_worker_and_wait(dask_cluster, dask_client, nworker) try: for banner in print_banner( "Starting spliting measurement sets.", no_print=True ).splitlines(): logger.info(banner) ################################## # Parallel spliting ################################## client_info = dask_client.scheduler_info()["workers"] njobs = len(client_info) worker_mem_list = [] for addr, w in client_info.items(): worker_mem_list.append(w["memory_limit"] / 1024**3) if len(worker_mem_list) > 0: mem_limit = round(min(worker_mem_list), 3) else: mem_limit = 1 n_threads = os.environ.get("OMP_NUM_THREADS") if n_threads is not None: n_threads = int(n_threads) else: n_threads = 1 logger.info("#################################") logger.info(f"Total dask worker: {njobs}") logger.info(f"CPU per worker: {n_threads}") logger.info(f"Memory per worker: {mem_limit} GB") logger.info("#################################") msg, splited_mslist = split_target_scans( mslist, metafits, dask_client, workdir, float(timeres), float(freqres), datacolumn, split_coarse_chans=split_coarse_chans, time_window=float(time_window), time_interval=float(time_interval), quack_timestamps=int(quack_timestamps), force_split=force_split, scan=scan, prefix=prefix, only_disk=only_disk, n_threads=n_threads, logger=logger, ) succeed = len(splited_mslist) logger.info(f"Total measurement sets: {len(mslist)}") logger.info(f"Total expected splited ms: {total_ncoarse}") logger.info(f"Total splited ms: {succeed}") if len(splited_mslist) == 0: logger.debug("No splited measurement sets.") msg = 1 else: logger.debug("List of splited measurement sets:") logger.debug(f"{splited_mslist}") msg = 0 except Exception: logger.exception("Exception occured in spliting.", exc_info=True) msg = 1 finally: time.sleep(5) clean_shutdown(observer) for msname in mslist: drop_cache(msname) if dask_cluster is not None: dask_client.shutdown() dask_client.close() dask_cluster.close() drop_cache(workdir) os.system(f"rm -rf {dask_dir}") if msg == 0: logger.info("All measurement sets are splited successfully.") else: logger.error("Error occured in spliting measurement sets.") return msg, expected, succeed
[docs] def cli(): parser = argparse.ArgumentParser( description="Split measurement set into coarse channels", formatter_class=SmartDefaultsHelpFormatter, ) # Essential parameters basic_args = parser.add_argument_group( "###################\nEssential parameters\n###################" ) basic_args.add_argument( "mslist", type=str, help="Name of measurement sets (required positional argument)", ) basic_args.add_argument( "metafits", type=str, help="Metafits file (required positional argument)", ) basic_args.add_argument( "--workdir", type=str, default="", help="Name of work directory", ) # Advanced parameters adv_args = parser.add_argument_group( "###################\nAdvanced parameters\n###################" ) adv_args.add_argument( "--datacolumn", type=str, default="data", help="Data column to split", ) adv_args.add_argument( "--scan", type=int, default=1, help="Target scan to split", ) adv_args.add_argument( "--time_window", type=float, default=-1, help="Time window in seconds of a single time chunk", ) adv_args.add_argument( "--time_interval", type=float, default=-1, help="Time interval in seconds between two time chunks", ) adv_args.add_argument( "--quack_timestamps", type=int, default=-1, help="Time stamps to ignore at the start and end of the each scan", ) adv_args.add_argument( "--freqres", type=float, default=-1, help="Frequency to average in MHz", metavar="Float", ) adv_args.add_argument( "--timeres", type=float, default=-1, help="Time bin to average in seconds", metavar="Float", ) adv_args.add_argument( "--prefix", type=str, default="targets", help="Splited ms prefix name", ) adv_args.add_argument("--force_split", action="store_true", help="Force to split") adv_args.add_argument("--verbose", action="store_true", help="Verbose logs") adv_args.add_argument("--jobid", type=int, default=0, help="Job ID") # Resource management parameters hard_args = parser.add_argument_group( "###################\nHardware resource management parameters\n###################" ) hard_args.add_argument( "--cpu_frac", type=float, default=0.8, help="CPU fraction to use", metavar="Float", ) hard_args.add_argument( "--mem_frac", type=float, default=0.8, help="Memory fraction to use", metavar="Float", ) if len(sys.argv) == 1: parser.print_help(sys.stderr) return 1 args = parser.parse_args() msg, _, _ = main( args.mslist, args.metafits, workdir=args.workdir, datacolumn=args.datacolumn, scan=args.scan, time_window=args.time_window, time_interval=args.time_interval, quack_timestamps=args.quack_timestamps, force_split=args.force_split, freqres=args.freqres, timeres=args.timeres, prefix=args.prefix, cpu_frac=args.cpu_frac, mem_frac=args.mem_frac, jobid=args.jobid, verbose=args.verbose, ) return msg