import logging
import numpy as np
import argparse
import time
import sys
import os
from casatools import msmetadata
from dask import delayed
from astropy.io import fits
from paircars.utils.basic_utils import print_banner, capture_all_output
from paircars.utils.casatasks import single_mstransform
from paircars.utils.logger_utils import (
SmartDefaultsHelpFormatter,
clean_shutdown,
init_logger,
get_logger_safe,
)
from paircars.utils.ms_metadata import get_timeranges
from paircars.utils.mwa_utils import (
get_MWA_coarse_bands,
get_MWA_coarse_chan,
)
from paircars.utils.proc_manage_utils import (
scale_worker_and_wait,
get_local_dask_cluster,
)
from paircars.utils.resource_utils import drop_cache
logging.getLogger("distributed").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.CRITICAL)
[docs]
def chanlist_to_str(lst):
lst = sorted(lst)
ranges = []
start = lst[0]
for i in range(1, len(lst)):
if lst[i] != lst[i - 1] + 1:
if lst[i - 1] > start:
ranges.append(f"{start}~{lst[i - 1]}")
elif lst[i - 1] == start:
ranges.append(f"{start}")
start = lst[i]
if lst[-1] > start:
ranges.append(f"{start}~{lst[-1]}")
elif lst[-1] == start:
ranges.append(f"{start}")
return ";".join(ranges)
[docs]
def split_target_scans(
mslist,
metafits,
dask_client,
workdir,
timeres,
freqres,
datacolumn,
split_coarse_chans=[],
scan=1,
prefix="targets",
time_interval=-1,
time_window=-1,
quack_timestamps=-1,
force_split=False,
only_disk=False,
n_threads=-1,
logger=None,
):
"""
Split target scans
Parameters
----------
mslist : list
Measurement set list
metafits : str
Metafits file
dask_client : dask.client
Dask client
workdir : str
Work directory
timeres : float
Time resolution in seconds
freqres : float
Frequency resolution in MHz
datacolumn : str
Data column to split
split_coarse_chans : list, optional
Split coarse channels
scan : int, optional
Scan to split
prefix : str, optional
Splited ms prefix
time_interval : float
Time interval in seconds
time_window : float
Time window in seconds
quack_timestamps : int, optional
Number of timestamps ignored at the start and end of each scan
force_split : bool, optional
Force split
only_disk : bool, optional
Split only disk
n_threads : int, optional
Number of threads to use
Returns
-------
int
Success message
list
Splited ms list
"""
if logger is None:
logger = get_logger_safe()
n_threads = max(1, n_threads)
if len(mslist) == 0:
logger.critical("Please provide a valid measurement set list.")
return 1, []
try:
os.chdir(workdir)
logger.debug(f"Current working directory: {os.getcwd()}")
#######################################
# Extracting time frequency information
#######################################
header = fits.getheader(metafits)
obsid = header["GPSTIME"]
mode = header["MODE"]
if "MWAX" in mode:
flag_central_chan = False
else:
flag_central_chan = True
logger.debug(f"Flag central channel: {flag_central_chan} for {mode}")
tasks = []
splited_ms_list = []
for msname in mslist:
msmd = msmetadata()
msmd.open(msname)
chanres = msmd.chanres(0, unit="MHz")[0]
msmd.close()
if freqres > 0: # Image resolution is in MHz
chanwidth = int(freqres / chanres)
if chanwidth < 1:
chanwidth = 1
else:
chanwidth = 1
if timeres > 0: # Image resolution is in seconds
timebin = str(timeres) + "s"
else:
timebin = ""
#############################
# Making spectral chunks
#############################
coarse_channel_bands = get_MWA_coarse_bands(
msname, flag_central_chan=flag_central_chan
)
coarse_chans = get_MWA_coarse_chan(msname)
logger.debug(f"Coarse channels for {msname} are: {coarse_chans}")
if len(split_coarse_chans) == 0:
use_coarse_chans = coarse_chans
else:
use_coarse_chans = split_coarse_chans
logger.debug(f"Using coarse channels for {msname} are: {use_coarse_chans}")
coarse_chlist = []
good_spwlist = []
for c in range(len(coarse_channel_bands)):
coarse_chan = coarse_chans[c]
if coarse_chan in use_coarse_chans:
chan = coarse_channel_bands[c]
good_chans = chan[2]
good_chans = [f"{i}" for i in good_chans]
good_spwlist.append(f"0:{';'.join(good_chans)}")
coarse_chlist.append(f"{coarse_chan}")
timerange_list = get_timeranges(
msname,
time_interval,
time_window,
only_disk=only_disk,
quack_timestamps=quack_timestamps,
)
timerange = ",".join(timerange_list)
for i in range(len(coarse_chlist)):
good_spw = good_spwlist[i]
coarse_chan = coarse_chlist[i]
outputvis = f"{workdir}/{prefix}_{obsid}_ch_{coarse_chan}.ms"
if os.path.exists(f"{outputvis}/.splited") and force_split is False:
logger.info(f"{outputvis} is already splited successfully.")
splited_ms_list.append(outputvis)
else:
if os.path.exists(outputvis):
logger.debug(f"Deleteing pre-existing output ms: {outputvis}")
os.system(f"rm -rf {outputvis}")
if os.path.exists(f"{outputvis}.flagversions"):
logger.debug(
f"Deleteing pre-existing output ms flags: {outputvis}.flagversions"
)
os.system(f"rm -rf {outputvis}.flagversions")
logger.debug("Spliting parameters:")
logger.debug(
f"Channel width: {chanwidth}, timebin: {timebin}, datacolumn: {datacolumn}, spectral window: {good_spw}, time range: {timerange}"
)
tasks.append(
delayed(single_mstransform_wrapper)(
msname=msname,
outputms=outputvis,
width=chanwidth,
timebin=timebin,
datacolumn=datacolumn,
spw=good_spw,
corr="",
timerange=timerange,
n_threads=n_threads,
)
)
future = dask_client.compute(tasks)
result_wrapper = dask_client.gather(future)
result = []
for r in result_wrapper:
result.append(r[0])
logger.debug("================")
logger.debug(f"Worker log for: {os.path.basename(r[0])}")
logger.debug("================")
for line in r[1].splitlines():
logger.debug(line)
for line in r[2].splitlines():
logger.debug(line)
splited_ms_list = splited_ms_list + result
if len(splited_ms_list) == 0:
logger.error(f"Spliting of measurement set: {msname} is unsuccessful.")
return 1, []
else:
logger.info(f"Spliting of measurement set: {msname} is done successfully.")
for splited_ms in splited_ms_list:
drop_cache(splited_ms)
return 0, splited_ms_list
except Exception:
logger.exception(
f"Spliting of measurement set: {msname} is unsuccessful.", exc_info=True
)
return 1, []
[docs]
def main(
mslist,
metafits,
workdir="",
datacolumn="data",
split_coarse_chans=[],
scan=1,
time_window=-1,
time_interval=-1,
quack_timestamps=-1,
freqres=-1,
timeres=-1,
prefix="targets",
force_split=False,
only_disk=False,
cpu_frac=0.8,
mem_frac=0.8,
logfile=None,
jobid=0,
start_remote_log=False,
verbose=False,
dask_client=None,
):
"""
Split target scans from a measurement set into smaller chunks for parallel processing.
Parameters
----------
mslist : str
Measurement sets (comma separated).
metafits : str
Metafits file
workdir : str, optional
Working directory for intermediate and output products. If empty, defaults to `<msname>/workdir`.
datacolumn : str, optional
Column of the MS to use for splitting (e.g., "DATA", "CORRECTED"). Default is "data".
split_coarse_chans : list, optional
Split coarse channels
scan : int, optional
Scan numbers to split.
time_window : float, optional
Time window in seconds for a single time chunk. Set -1 to disable. Default is -1.
time_interval : float, optional
Time interval in seconds between two time chunks. Set -1 to disable. Default is -1.
quack_timestamps : int, optional
Number of timestamps to flag at the beginning and end of each scan ("quack"). -1 to disable. Default is -1.
freqres : float, optional
Frequency resolution in MHz for spectral averaging. Set -1 to disable. Default is -1.
timeres : float, optional
Time resolution in seconds for time averaging. Set -1 to disable. Default is -1.
prefix : str, optional
Prefix for the output split MS files. Default is "targets".
force_split : bool, optional
Force to split
only_disk : bool, optional
Split only disk visible times
cpu_frac : float, optional
Fraction of available CPUs to allocate per task. Default is 0.8.
mem_frac : float, optional
Fraction of available memory to allocate per task. Default is 0.8.
logfile : str or None, optional
Path to log file. If None, logging to file is disabled. Default is None.
jobid : int, optional
Job identifier for tracking and PID storage. Default is 0.
start_remote_log : bool, optional
If True, enables remote logging using credentials stored in workdir. Default is False.
verbose : bool, optional
Verbose logs
dask_client : dask.client, optional
Dask client
Returns
-------
int
Success message
int
Expected splited ms
int
Succeeded splited ms
"""
logger = get_logger_safe()
if verbose:
logger.setLevel(logging.DEBUG)
cpu_frac = min(0.8, abs(cpu_frac))
mem_frac = min(0.8, abs(mem_frac))
mslist = mslist.split(",")
if workdir == "":
workdir = os.path.dirname(os.path.abspath(mslist[0])) + "/workdir"
os.makedirs(workdir, exist_ok=True)
os.chdir(workdir)
logger.debug(f"Current working directory: {os.getcwd()}")
############
# Logger
############
observer = None
if (
start_remote_log
and os.path.exists(f"{workdir}/.jobname_password.npy")
and logfile is not None
):
time.sleep(5)
jobname, password = np.load(
f"{workdir}/.jobname_password.npy", allow_pickle=True
)
if os.path.exists(logfile):
observer = init_logger(
"do_target_split", logfile, jobname=jobname, password=password
)
if observer is None:
logger.info(
"Remote link or jobname is blank. Not transmiting to remote logger."
)
if len(mslist) == 0:
logger.critical("Please provide a valid measurement set list.")
return 1, 0, 0
else:
total_ncoarse = 0
for msname in mslist:
ms_coarse_chans = get_MWA_coarse_chan(msname)
if len(split_coarse_chans) > 0:
ms_coarse_chans = list(set(ms_coarse_chans) & set(split_coarse_chans))
ncoarse = len(ms_coarse_chans)
total_ncoarse += ncoarse
total_ncoarse = max(1, total_ncoarse)
logger.debug(f"Total usable coarse channels: {total_ncoarse}")
expected = total_ncoarse
succeed = 0
dask_cluster = None
if dask_client is None:
dask_client, dask_cluster, dask_dir, nworker = get_local_dask_cluster(
workdir,
cpu_frac=cpu_frac,
mem_frac=mem_frac,
max_worker=total_ncoarse + 1,
)
if dask_client is None:
logger.critical("Error occured in creating local cluster.")
return 1, expected, succeed
scale_worker_and_wait(dask_cluster, dask_client, nworker)
try:
for banner in print_banner(
"Starting spliting measurement sets.", no_print=True
).splitlines():
logger.info(banner)
##################################
# Parallel spliting
##################################
client_info = dask_client.scheduler_info()["workers"]
njobs = len(client_info)
worker_mem_list = []
for addr, w in client_info.items():
worker_mem_list.append(w["memory_limit"] / 1024**3)
if len(worker_mem_list) > 0:
mem_limit = round(min(worker_mem_list), 3)
else:
mem_limit = 1
n_threads = os.environ.get("OMP_NUM_THREADS")
if n_threads is not None:
n_threads = int(n_threads)
else:
n_threads = 1
logger.info("#################################")
logger.info(f"Total dask worker: {njobs}")
logger.info(f"CPU per worker: {n_threads}")
logger.info(f"Memory per worker: {mem_limit} GB")
logger.info("#################################")
msg, splited_mslist = split_target_scans(
mslist,
metafits,
dask_client,
workdir,
float(timeres),
float(freqres),
datacolumn,
split_coarse_chans=split_coarse_chans,
time_window=float(time_window),
time_interval=float(time_interval),
quack_timestamps=int(quack_timestamps),
force_split=force_split,
scan=scan,
prefix=prefix,
only_disk=only_disk,
n_threads=n_threads,
logger=logger,
)
succeed = len(splited_mslist)
logger.info(f"Total measurement sets: {len(mslist)}")
logger.info(f"Total expected splited ms: {total_ncoarse}")
logger.info(f"Total splited ms: {succeed}")
if len(splited_mslist) == 0:
logger.debug("No splited measurement sets.")
msg = 1
else:
logger.debug("List of splited measurement sets:")
logger.debug(f"{splited_mslist}")
msg = 0
except Exception:
logger.exception("Exception", exc_info=True)
msg = 1
finally:
time.sleep(5)
clean_shutdown(observer)
for msname in mslist:
drop_cache(msname)
if dask_cluster is not None:
dask_client.shutdown()
dask_client.close()
dask_cluster.close()
drop_cache(workdir)
os.system(f"rm -rf {dask_dir}")
if msg == 0:
logger.info("All measurement sets are splited successfully.")
else:
logger.error("Error occured in spliting measurement sets.")
return msg, expected, succeed
[docs]
def cli():
parser = argparse.ArgumentParser(
description="Split measurement set into coarse channels",
formatter_class=SmartDefaultsHelpFormatter,
)
# Essential parameters
basic_args = parser.add_argument_group(
"###################\nEssential parameters\n###################"
)
basic_args.add_argument(
"mslist",
type=str,
help="Name of measurement sets (required positional argument)",
)
basic_args.add_argument(
"metafits",
type=str,
help="Metafits file (required positional argument)",
)
basic_args.add_argument(
"--workdir",
type=str,
default="",
help="Name of work directory",
)
# Advanced parameters
adv_args = parser.add_argument_group(
"###################\nAdvanced parameters\n###################"
)
adv_args.add_argument(
"--datacolumn",
type=str,
default="data",
help="Data column to split",
)
adv_args.add_argument(
"--scan",
type=int,
default=1,
help="Target scan to split",
)
adv_args.add_argument(
"--time_window",
type=float,
default=-1,
help="Time window in seconds of a single time chunk",
)
adv_args.add_argument(
"--time_interval",
type=float,
default=-1,
help="Time interval in seconds between two time chunks",
)
adv_args.add_argument(
"--quack_timestamps",
type=int,
default=-1,
help="Time stamps to ignore at the start and end of the each scan",
)
adv_args.add_argument(
"--freqres",
type=float,
default=-1,
help="Frequency to average in MHz",
metavar="Float",
)
adv_args.add_argument(
"--timeres",
type=float,
default=-1,
help="Time bin to average in seconds",
metavar="Float",
)
adv_args.add_argument(
"--prefix",
type=str,
default="targets",
help="Splited ms prefix name",
)
adv_args.add_argument("--force_split", action="store_true", help="Force to split")
adv_args.add_argument("--verbose", action="store_true", help="Verbose logs")
adv_args.add_argument("--jobid", type=int, default=0, help="Job ID")
# Resource management parameters
hard_args = parser.add_argument_group(
"###################\nHardware resource management parameters\n###################"
)
hard_args.add_argument(
"--cpu_frac",
type=float,
default=0.8,
help="CPU fraction to use",
metavar="Float",
)
hard_args.add_argument(
"--mem_frac",
type=float,
default=0.8,
help="Memory fraction to use",
metavar="Float",
)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
return 1
args = parser.parse_args()
msg, _, _ = main(
args.mslist,
args.metafits,
workdir=args.workdir,
datacolumn=args.datacolumn,
scan=args.scan,
time_window=args.time_window,
time_interval=args.time_interval,
quack_timestamps=args.quack_timestamps,
force_split=args.force_split,
freqres=args.freqres,
timeres=args.timeres,
prefix=args.prefix,
cpu_frac=args.cpu_frac,
mem_frac=args.mem_frac,
jobid=args.jobid,
verbose=args.verbose,
)
return msg