import numpy as np
import os
import traceback
import time
from casatools import table, msmetadata
from .basic_utils import suppress_output
from .resource_utils import limit_threads
#############################
# General CASA tasks
#############################
[docs]
def check_scan_in_caltable(caltable, scan):
"""
Check scan number available in caltable or not
Parameters
----------
caltable : str
Name of the caltable
scan : int
Scan number
Returns
-------
bool
Whether scan is present in the caltable or not
"""
tb = table()
tb.open(caltable)
scans = tb.getcol("SCAN_NUMBER")
tb.close()
if int(scan) in scans:
return True
else:
return False
[docs]
def reset_weights_and_flags(
msname="",
restore_flag=True,
force_reset=False,
n_threads=-1,
):
"""
Reset weights and flags for the ms
Parameters
----------
msname : str
Measurement set
restore_flag : bool, optional
Restore flags or not
force_reset : bool, optional
Force reset
"""
n_threads = max(1, n_threads)
limit_threads(n_threads=n_threads)
from casatasks import flagdata
msname = msname.rstrip("/")
if not os.path.exists(f"{msname}/.reset") or force_reset:
mspath = os.path.dirname(os.path.abspath(msname))
os.chdir(mspath)
if restore_flag:
print(f"Restoring flags of measurement set : {msname}")
if os.path.exists(msname + ".flagversions"):
os.system("rm -rf " + msname + ".flagversions")
flagdata(vis=msname, mode="unflag", flagbackup=False)
print(f"Resetting previous weights of the measurement set: {msname}")
msmd = msmetadata()
msmd.open(msname)
npol = msmd.ncorrforpol()[0]
msmd.close()
tb = table()
tb.open(msname, nomodify=False)
colnames = tb.colnames()
nrows = tb.nrows()
if "WEIGHT" in colnames:
print(f"Resetting weight column to ones of measurement set : {msname}.")
weight = np.ones((npol, nrows))
tb.putcol("WEIGHT", weight)
if "SIGMA" in colnames:
print(f"Resetting sigma column to ones of measurement set: {msname}.")
sigma = np.ones((npol, nrows))
tb.putcol("SIGMA", sigma)
if "WEIGHT_SPECTRUM" in colnames:
print(f"Removing weight spectrum of measurement set: {msname}.")
tb.removecols("WEIGHT_SPECTRUM")
if "SIGMA_SPECTRUM" in colnames:
print(f"Removing sigma spectrum of measurement set: {msname}.")
tb.removecols("SIGMA_SPECTRUM")
tb.flush()
tb.close()
os.system(f"touch {msname}/.reset")
return
[docs]
def calc_normzlized_crosscorr(data, flag, ant1, ant2, time):
"""
Calculate normalized cross correlation
Parameters
----------
data : numpy.array
Data array
flag : numpy.array
Flag array
ant1 : numpy.array
Antenna 1 array
ant2 : numpy.array
Antenna 2 array
time : numpy.array
Time array
Returns
-------
numpy.array
Normalized data (only if writeto_file is False)
numpy.array
New flag (only if writeto_file is False)
"""
nrow = data.shape[-1]
# Identify autocorrelation rows
auto_mask = ant1 == ant2
auto_rows = np.where(auto_mask)[0]
auto_ant = ant1[auto_mask]
auto_time = time[auto_mask]
# Build lookup (vectorized via sorting)
# Combine (time, antenna) into structured array
auto_keys = np.core.records.fromarrays([auto_time, auto_ant], names="time,ant")
cross_keys_1 = np.core.records.fromarrays([time, ant1], names="time,ant")
cross_keys_2 = np.core.records.fromarrays([time, ant2], names="time,ant")
sort_idx = np.argsort(auto_keys)
auto_keys_sorted = auto_keys[sort_idx]
auto_rows_sorted = auto_rows[sort_idx]
# Find matching autos using searchsorted
def match_keys(query_keys):
idx = np.searchsorted(auto_keys_sorted, query_keys)
idx = np.clip(idx, 0, len(auto_keys_sorted) - 1)
matched = auto_keys_sorted[idx] == query_keys
out = np.full(nrow, -1, dtype=int)
out[matched] = auto_rows_sorted[idx][matched]
return out
idx1 = match_keys(cross_keys_1)
idx2 = match_keys(cross_keys_2)
# valid rows where both autos exist and not autocorr
valid = (idx1 >= 0) & (idx2 >= 0) & (ant1 != ant2)
# Vectorized normalization
norm = np.zeros_like(data, dtype=np.complex64)
auto1_xx = np.abs(data[0, :, idx1[valid]])
auto2_xx = np.abs(data[0, :, idx2[valid]])
auto1_yy = np.abs(data[-1, :, idx1[valid]])
auto2_yy = np.abs(data[-1, :, idx2[valid]])
npol = data.shape[0]
with suppress_output():
if npol == 2:
for p in range(npol):
if p == 0:
denom = np.sqrt(auto1_xx * auto2_xx)
else:
denom = np.sqrt(auto1_yy * auto2_yy)
norm[p, :, valid] = data[p, :, valid] / denom
else:
for p in range(npol):
if p == 0:
denom = np.sqrt(auto1_xx * auto2_xx)
elif p == 1:
denom = np.sqrt(auto1_xx * auto2_yy)
elif p == 2:
denom = np.sqrt(auto1_yy * auto2_xx)
else:
denom = np.sqrt(auto1_yy * auto2_yy)
norm[p, :, valid] = data[p, :, valid] / denom
# Clean up
flag[np.isnan(norm)] = True
norm[np.isnan(norm)] = 0.0 + 0.0j
return norm, flag
[docs]
def normalized_crosscorr_ms(msname, datacolumn="DATA"):
"""
Perform normalized cross-correlation of the measurement set
Parameters
----------
msname : str
Measurement set
datacolumn : str, optional
Data column to normalize
Returns
-------
str
Normalized measurement set
"""
try:
msname = msname.rstrip("/")
outfile = f"{msname}.norm"
if os.path.exists(outfile):
os.system(f"rm -rf {outfile}")
os.system(f"cp -r {msname} {outfile}")
with suppress_output():
tb = table()
tb.open(outfile, nomodify=False)
datacolumn = datacolumn.upper()
if datacolumn == "CORRECTED":
datacolumn = "CORRECTED_DATA"
if datacolumn == "MODEL":
datacolumn = "MODEL_DATA"
if datacolumn not in tb.colnames():
datacolumn = "DATA"
data = tb.getcol(datacolumn) # (npol, nchan, nrow)
flag = tb.getcol("FLAG")
ant1 = tb.getcol("ANTENNA1")
ant2 = tb.getcol("ANTENNA2")
time = tb.getcol("TIME")
norm, flag = calc_normzlized_crosscorr(data, flag, ant1, ant2, time)
tb.putcol(datacolumn, norm)
tb.putcol("FLAG", flag)
tb.flush()
tb.close()
return outfile
except Exception:
traceback.print_exc()
return msname