import psutil
import numpy as np
import os
from casatools import msmetadata, ms as casamstool, table, measures
from .basic_utils import timestamp_to_mjdsec, mjdsec_to_timestamp
from .selfcal_utils import determine_disk_visibility
from .resource_utils import limit_threads
##########################
# Measurement set metadata
##########################
[docs]
def get_phasecenter(msname, fieldID=0):
"""
Get phasecenter of the measurement set
Parameters
----------
msname : str
Name of the measurement set
fieldID : int, optional
Zero based field ID
Returns
-------
float
RA in degree
float
DEC in degree
"""
msmd = msmetadata()
msmd.open(msname)
phasecenter = msmd.phasecenter(fieldID)
msmd.close()
msmd.done()
radeg = np.rad2deg(phasecenter["m0"]["value"])
radeg = radeg % 360
decdeg = np.rad2deg(phasecenter["m1"]["value"])
decdeg = decdeg % 360
return round(radeg, 5), round(decdeg, 5)
[docs]
def get_timeranges(
msname,
time_interval,
time_window,
quack_timestamps=-1,
only_disk=False,
):
"""
Get time ranges for a scan with certain time intervals
Parameters
----------
msname : str
Name of the measurement set
time_interval : float
Time interval in seconds between two time chunks
time_window : float
Time window in seconds of a single time chunk
quack_timestamps : int, optional
Number of timestamps ignored at the start and end of each scan
only_disk : bool, optional
Whether select timestamps with disk visibilties
Returns
-------
list
List of time ranges
"""
msmd = msmetadata()
msmd.open(msname)
times = msmd.timesforspws(0)
msmd.close()
msmd.done()
time_ranges = []
if len(times) == 1:
time_ranges.append(mjdsec_to_timestamp(times[0], str_format=1))
return time_ranges
if (
quack_timestamps > 0 and len(times) > 2 * quack_timestamps + 3
): # At least 3 timestamps remain after quack flagging
quack_timestamps += 1
times = times[quack_timestamps:-quack_timestamps]
if only_disk:
_, _, disk_timestamps = determine_disk_visibility(msname)
if len(disk_timestamps) == 0:
print(f"No timestamp with disk visibility for ms: {msname}.")
filtered_timestamps = times
elif len(disk_timestamps) >= len(times):
print(f"All timestamps have disk visibilties for ms: {msname}.")
filtered_timestamps = times
else:
filtered_timestamps = []
for t in range(len(times)):
if t in disk_timestamps:
filtered_timestamps.append(times[t])
print(
f"{len(filtered_timestamps)} number of timestamps among {len(times)} have disk visibiltiies for ms: {msname}."
)
if len(filtered_timestamps) == 0:
filtered_timestamps = times
start_time = min(times)
end_time = max(times)
if only_disk is False:
if time_interval < 0 or time_window < 0 or time_interval <= time_window:
t = (
mjdsec_to_timestamp(start_time, str_format=1)
+ "~"
+ mjdsec_to_timestamp(end_time, str_format=1)
)
time_ranges.append(t)
return time_ranges
timeres = times[1] - times[0]
ntime_chunk = max(1, int(time_interval / timeres))
ntime = int(time_window / timeres)
for i in range(0, len(times), ntime_chunk):
try:
start_time = times[i]
except Exception:
if ntime > 0:
start_time = times[-ntime]
else:
start_time = times[-1]
if start_time not in filtered_timestamps:
nearpos = np.argmin(abs(start_time - filtered_timestamps))
start_time = filtered_timestamps[nearpos]
try:
end_time = times[i + ntime]
except Exception:
end_time = times[-1]
if end_time not in filtered_timestamps:
nearpos = np.argmin(abs(end_time - filtered_timestamps))
end_time = filtered_timestamps[nearpos]
if end_time > start_time:
time_ranges.append(
f"{mjdsec_to_timestamp(start_time, str_format=1)}~{mjdsec_to_timestamp(end_time, str_format=1)}"
)
elif start_time > end_time:
time_ranges.append(
f"{mjdsec_to_timestamp(end_time, str_format=1)}~{mjdsec_to_timestamp(start_time, str_format=1)}"
)
else:
time_ranges.append(f"{mjdsec_to_timestamp(start_time, str_format=1)}")
return time_ranges
[docs]
def calc_fractional_bandwidth(msname):
"""
Calculate fractional bandwidh
Parameters
----------
msname : str
Name of measurement set
Returns
-------
float
Fraction bandwidth in percentage
"""
msmd = msmetadata()
msmd.open(msname)
freqs = msmd.chanfreqs(0)
bw = max(freqs) - min(freqs)
frac_bandwidth = bw / msmd.meanfreq(0)
msmd.close()
return frac_bandwidth * 100.0
[docs]
def baseline_names(msname):
"""
Get baseline names
Parameters
----------
msname : str
Measurement set name
Returns
-------
list
Baseline names list
"""
mstool = casamstool()
mstool.open(msname)
ants = mstool.getdata(["antenna1", "antenna2"])
mstool.close()
baseline_ids = set(zip(ants["antenna1"], ants["antenna2"]))
baseline_names = []
for ant1, ant2 in sorted(baseline_ids):
baseline_names.append(str(ant1) + "&&" + str(ant2))
return baseline_names
[docs]
def get_ms_size(msname, only_autocorr=False):
"""
Get measurement set total size on-disk
(Note: it could be smaller than actual data size, because of data compression)
Parameters
----------
msname : str
Measurement set name
only_autocorr : bool, optional
Only auto-correlation
Returns
-------
float
Size in GB
"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(msname):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
if only_autocorr:
msmd = msmetadata()
msmd.open(msname)
nant = msmd.nantennas()
msmd.close()
all_baselines = (nant * nant) / 2
total_size /= all_baselines
total_size *= nant
return total_size / (1024**3) # in GB
[docs]
def get_column_size(msname, only_autocorr=False):
"""
Get datacolumn size
(Note: this is true datasize in memory)
Parameters
----------
msname : str
Measurement set
only_autocorr : bool, optional
Only auto-correlations
Returns
-------
float
A single datacolumn data size in GB
"""
msmd = msmetadata()
msmd.open(msname)
nrow = int(msmd.nrows())
nchan = msmd.nchan(0)
npol = msmd.ncorrforpol()[0]
nant = msmd.nantennas()
msmd.close()
datasize = nrow * nchan * npol * 16 / (1024.0**3)
if only_autocorr:
all_baselines = (nant * nant) / 2
datasize /= all_baselines
datasize *= nant
return datasize
[docs]
def get_ms_scan_size(msname, scan, only_autocorr=False):
"""
Get measurement set scan size
Parameters
----------
msname : str
Measurement set
scan : int
Scan number
only_autocorr : bool, optional
Only for auto-correlations
Returns
-------
float
Size in GB
"""
tb = table()
tb.open(msname)
nrow = tb.nrows()
tb.close()
mstool = casamstool()
mstool.open(msname)
mstool.select({"scan_number": int(scan)})
scan_nrow = mstool.nrow(True)
mstool.close()
ms_size = get_column_size(msname, only_autocorr=only_autocorr)
scan_size = scan_nrow * (ms_size / nrow)
return scan_size
[docs]
def get_chunk_size(msname, mem_limit=-1, only_autocorr=False):
"""
Get time chunk size for a memory limit
Parameters
----------
msname : str
Measurement set
mem_limit : int, optional
Memory limit
only_autocorr : bool, optional
Only aut-correlation
Returns
-------
int
Number of chunks
"""
if mem_limit == -1:
mem_limit = psutil.virtual_memory().available / 1024**3 # In GB
col_size = get_column_size(msname, only_autocorr=only_autocorr)
nchunk = int(col_size / mem_limit)
if nchunk < 1:
nchunk = 1
return nchunk
[docs]
def check_datacolumn_valid(msname, datacolumn="DATA"):
"""
Check whether a data column exists and valid
Parameters
----------
msname : str
Measurement set
datacolumn : str, optional
Data column string in table (e.g.,DATA, CORRECTED_DATA', MODEL_DATA, FLAG, WEIGHT, WEIGHT_SPECTRUM, SIGMA, SIGMA_SPECTRUM)
Returns
-------
bool
Whether valid data column is present or not
"""
tb = table()
msname = msname.rstrip("/")
msname = os.path.abspath(msname)
try:
tb.open(msname)
colnames = tb.colnames()
if datacolumn not in colnames:
return False
try:
model_data = tb.getcol(datacolumn, startrow=0, nrow=1)
if model_data is None or model_data.size == 0:
return False
elif (model_data == 0).all():
return False
else:
return True
except Exception:
return False
except Exception:
return False
finally:
try:
tb.close()
except Exception:
pass
[docs]
def get_bad_ants(msname="", fieldnames=[], n_threads=-1):
"""
Get bad antennas
Parameters
----------
msname : str
Name of the ms
fieldnames : list, optional
Fluxcal field names
Returns
-------
list
Bad antenna list
str
Bad antenna string
"""
n_threads = max(1, n_threads)
limit_threads(n_threads=n_threads)
from casatasks import visstat
if len(fieldnames) == 0:
print("Provide field names.")
return [], ""
msname = msname.rstrip("/")
mspath = os.path.dirname(os.path.abspath(msname))
os.chdir(mspath)
msmd = msmetadata()
all_field_bad_ants = []
msmd.open(msname)
nant = msmd.nantennas()
msmd.close()
msmd.done()
for field in fieldnames:
ant_medians = []
bad_ants = []
for ant in range(nant):
stat_median = visstat(
vis=msname,
field=str(field),
uvrange="0lambda",
antenna=str(ant) + "&&" + str(ant),
useflags=False,
)["DATA_DESC_ID=0"]["median"]
ant_medians.append(stat_median)
ant_medians = np.array(ant_medians)
all_ant_median = np.nanmean(ant_medians)
all_ant_std = np.nanstd(ant_medians)
pos = np.where(ant_medians < all_ant_median - (5 * all_ant_std))[0]
if len(pos) > 0:
for b_ant in pos:
bad_ants.append(b_ant)
all_field_bad_ants.append(bad_ants)
bad_ants = [set(sublist) for sublist in all_field_bad_ants]
common_elements = set.intersection(*bad_ants)
bad_ants = list(common_elements)
if len(bad_ants) > 0:
bad_ants_str = ",".join([str(i) for i in bad_ants])
else:
bad_ants_str = ""
return bad_ants, bad_ants_str
[docs]
def get_common_spw(spw1, spw2):
"""
Return common spectral windows in merged CASA string format.
Parameters
----------
spw1 : str
First spectral window (0:xx~yy)
spw2 : str
Second spectral window (0:xx1~yy1)
Returns
-------
str
Merged spectral window
"""
from itertools import groupby
from collections import defaultdict
def to_set(s):
out, cur = set(), None
for part in s.split(";"):
if ":" in part:
cur, rng = part.split(":")
else:
rng = part
cur = int(cur)
a, *b = map(int, rng.split("~"))
out.update((cur, i) for i in range(a, (b[0] if b else a) + 1))
return out
def to_str(pairs):
spw_dict = defaultdict(list)
for spw, ch in sorted(pairs):
spw_dict[spw].append(ch)
result = []
for spw, chans in spw_dict.items():
chans.sort()
for _, g in groupby(enumerate(chans), lambda x: x[1] - x[0]):
grp = list(g)
a, b = grp[0][1], grp[-1][1]
result.append(f"{a}" if a == b else f"{a}~{b}")
if len(result) > 0:
return "0:" + ";".join(result)
else:
return ""
return to_str(to_set(spw1) & to_set(spw2))
[docs]
def scans_in_timerange(msname="", timerange=""):
"""
Get scans in the given timerange
Parameters
----------
msname : str
Measurement set
timerange : str
Time range with date and time
Returns
-------
dict
Scan dict for timerange
"""
from casatools import ms, quanta
msname = msname.rstrip("/")
mspath = os.path.dirname(os.path.abspath(msname))
os.chdir(mspath)
qa = quanta()
ms_tool = ms()
ms_tool.open(msname)
# Get scan summary
scan_summary = ms_tool.getscansummary()
# Convert input timerange to MJD seconds
timerange_list = timerange.split(",")
valid_scans = {}
for timerange in timerange_list:
tr_start_str, tr_end_str = timerange.split("~")
# Try parsing as date string
tr_start = timestamp_to_mjdsec(tr_start_str)
tr_end = timestamp_to_mjdsec(tr_end_str)
for scan_id, scan_info in scan_summary.items():
t0_str = scan_info["0"]["BeginTime"]
t1_str = scan_info["0"]["EndTime"]
scan_start = qa.convert(qa.quantity(t0_str, "d"), "s")["value"]
scan_end = qa.convert(qa.quantity(t1_str, "d"), "s")["value"]
# Check overlap
if scan_end >= tr_start and scan_start <= tr_end:
if tr_end >= scan_end:
e = scan_end
else:
e = tr_end
if tr_start <= scan_start:
s = scan_start
else:
s = tr_start
if scan_id in valid_scans.keys():
old_t = valid_scans[scan_id].split("~")
old_s = timestamp_to_mjdsec(old_t[0])
old_e = timestamp_to_mjdsec(old_t[-1])
if s > old_s:
s = old_s
if e < old_e:
e = old_e
valid_scans[int(scan_id)] = (
mjdsec_to_timestamp(s, str_format=1)
+ "~"
+ mjdsec_to_timestamp(e, str_format=1)
)
ms_tool.close()
return valid_scans
[docs]
def get_refant(
msname="",
field="",
n_threads=-1,
):
"""
Get reference antenna
Parameters
----------
msname : str
Name of the measurement set
field : str, optional
Field name
Returns
-------
str
Reference antenna
"""
n_threads = max(1, n_threads)
limit_threads(n_threads=n_threads)
from casatasks import visstat, casalog
msname = msname.rstrip("/")
mspath = os.path.dirname(os.path.abspath(msname))
os.chdir(mspath)
casalog.filter("SEVERE")
msmd = msmetadata()
msmd.open(msname)
nant = msmd.nantennas()
msmd.close()
msmd.done()
antamp = []
antrms = []
selected_nant = min(10, int(0.1 * nant))
selected_nant = min(selected_nant, nant)
for ant in range(selected_nant):
ant = str(ant)
t = visstat(
vis=msname,
field=field,
antenna=ant,
timeaverage=True,
timebin="500min",
timespan="state,scan",
reportingaxes="field",
)
item = str(list(t.keys())[0])
amp = float(t[item]["median"])
rms = float(t[item]["rms"])
antamp.append(amp)
antrms.append(rms)
antamp = np.array(antamp)
antrms = np.array(antrms)
medamp = np.median(antamp)
np.median(antrms)
goodrms = []
goodamp = []
goodant = []
for i in range(len(antamp)):
if antamp[i] > medamp:
goodant.append(i)
goodamp.append(antamp[i])
goodrms.append(antrms[i])
goodrms = np.array(goodrms)
referenceant = np.argmin(goodrms)
return str(referenceant)
[docs]
def get_uvrange_exclude(uvrange):
"""
Get uv-range(s) excluding the given uv-range
Parameters
----------
uvrange : str
UV-range in CASA format
Returns
-------
list
List of uvranges excluding the given uv-range
"""
uvrange = uvrange.strip().lower()
if "lambda" not in uvrange:
raise ValueError("uvrange must contain 'lambda' units")
if uvrange.startswith(">"):
val = uvrange[1:].replace("lambda", "").strip()
return [f"<{val}lambda"]
elif uvrange.startswith("<"):
val = uvrange[1:].replace("lambda", "").strip()
return [f">{val}lambda"]
elif "~" in uvrange:
parts = uvrange.replace("lambda", "").split("~")
if len(parts) != 2:
raise ValueError("Invalid uvrange format with '~'")
low, high = parts[0].strip(), parts[1].strip()
try:
low_val = float(low)
high_val = float(high)
except ValueError:
raise ValueError("uvrange bounds must be numeric")
if low_val > high_val:
raise ValueError(
f"Lower bound {low_val} > upper bound {high_val} in uvrange"
)
return [f"<{low}lambda", f">{high}lambda"]
else:
raise ValueError(f"Unsupported uvrange format: '{uvrange}'")
[docs]
def get_ms_scans(msname):
"""
Get scans of the measurement set
Parameters
----------
msname : str
Measurement set
Returns
-------
list
Scan list
"""
msmd = msmetadata()
msmd.open(msname)
scans = msmd.scannumbers().tolist()
msmd.close()
return scans
[docs]
def get_observatory_name(msname):
"""
Get observatory name
Parameters
----------
msname : str
Measurement set
Returns
-------
str
Observatory name in all upper case
"""
observatory = ""
try:
msmd = msmetadata()
msmd.open(msname)
observatory = msmd.observatorynames()[0].upper()
msmd.close()
except Exception:
pass
return observatory
[docs]
def get_observatory_coord(msname):
"""
Get observatory coordinate
Parameters
----------
msname : str
Measurement set
Returns
-------
float
Latitude in degrees
float
Longitude in degrees
float
Height in meters
"""
msmd = msmetadata()
msmd.open(msname)
msmd.observatoryposition()
me = measures()
obs_pos = me.observatory(msmd.observatorynames()[0])
lon = obs_pos["m0"]["value"] * (180.0 / 3.141592653589793)
lat = obs_pos["m1"]["value"] * (180.0 / 3.141592653589793)
height = obs_pos["m2"]["value"]
msmd.close()
return round(lat, 3), round(lon, 3), round(height, 3)
[docs]
def get_pol_names(msname, fullpol=True):
"""
Get correlation names
Parameters
----------
msname : str
Measurement set
fullpol : bool, optional
Full polarization products or not
Returns
-------
list
List of cross correlation product names
"""
CASA_POL_PRODUCTS = {
1: "I",
2: "Q",
3: "U",
4: "V",
5: "RR",
6: "RL",
7: "LR",
8: "LL",
9: "XX",
10: "XY",
11: "YX",
12: "YY",
}
msmd = msmetadata()
msmd.open(msname)
pols = msmd.corrtypesforpol(0)
msmd.close()
pol_names = []
for p in pols:
pol_name = CASA_POL_PRODUCTS[int(p)]
if fullpol is True:
pol_names.append(pol_name)
else:
if pol_name in ["XX", "YY", "RR", "LL", "I"]:
pol_names.append(pol_name)
else:
pass
return pol_names