import os
import warnings
import numexpr as ne
import numpy as np
import subprocess
from casatools import table as casatable
from .basic_utils import suppress_output, average_with_padding, filter_outliers
from .flagging import get_chans_flag
warnings.filterwarnings("ignore")
[docs]
def create_blank_table(msname, caltable):
"""
Create a blank bandpass table
Parameters
----------
msname : str
Name of the measurement set
caltable : str
Caltable name
Returns
-------
str
Blank caltable name
"""
from casatools import calibrater
if os.path.exists(caltable):
os.system(f"rm -rf {caltable}")
cb = calibrater()
cb.open(msname)
cb.createcaltable(caltable, "Complex", "B Jones", False)
cb.close()
return caltable
[docs]
def create_crossphase_table(msname, caltable, freqs, crossphase, flags):
"""
Create cross phase CASA caltable
Parameters
----------
msname : str
Measurement set
caltable : str
Caltable name
freqs : numpy.array
Frequency list
crossphase : numpy.array
Crossphase array
flags : numpy.array
Flags
Returns
-------
str
Caltable name
"""
nchan = len(freqs)
caltable = create_blank_table(msname, caltable)
cmd = [
"run-mwa-fill-caltable",
"--msname",
msname,
"--caltable",
caltable,
"--nchan",
str(nchan),
]
subprocess.run(
cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
freqres = freqs[1] - freqs[0]
if os.path.exists(caltable) is not True:
print("Caltable is not made.")
return
with suppress_output():
tb = casatable()
tb.open(msname)
mean_time = np.nanmean(tb.getcol("TIME"))
tb.close()
del tb
tb = casatable()
tb.open(caltable + "/SPECTRAL_WINDOW", nomodify=False)
freqs = np.array(freqs)[:, np.newaxis]
freqres_array = np.ones(freqs.shape) * freqres
tb.putcol("CHAN_FREQ", freqs)
tb.putcol("NUM_CHAN", nchan)
tb.putcol("REF_FREQUENCY", np.nanmean(freqs))
tb.putcol("CHAN_WIDTH", freqres_array)
tb.putcol("EFFECTIVE_BW", freqres_array)
tb.putcol("RESOLUTION", freqres_array)
tb.close()
tb = casatable()
tb.open(caltable, nomodify=False)
ant = tb.getcol("ANTENNA1")
gain = tb.getcol("CPARAM")
cross_phase_gain_X = np.repeat(
np.exp(1j * np.deg2rad(crossphase))[..., np.newaxis], len(ant), axis=-1
)
gain[0, ...] = cross_phase_gain_X
gain[1, ...] = cross_phase_gain_X * 0 + 1
gain[np.isnan(gain)] = 1.0
tb.putcol("CPARAM", gain)
times = np.array([mean_time] * len(ant))
flags = flags[np.newaxis, :, np.newaxis]
flags = np.repeat(np.repeat(flags, len(ant), axis=2), 2, axis=0)
tb.putcol("FLAG", flags)
tb.putcol("TIME", times)
tb.close()
return caltable
[docs]
def fitted_crossphase(freqs, crossphase):
"""
Fit cos/sin components of crossphase vs frequency with increasing-degree polynomials,
choose the best by residual std, and reconstruct phase.
Parameters
----------
freqs : array-like
Frequency values (same length as crossphase).
crossphase : array-like
Cross-phase in degrees (can include NaNs).
Returns
-------
np.ndarray
Fitted cross-phase in degrees (NaNs outside valid span).
"""
gain_r = np.cos(np.radians(crossphase))
gain_i = np.sin(np.radians(crossphase))
gains = [gain_r, gain_i]
for i in range(len(gains)):
gain = filter_outliers(gains[i])
valid = ~np.isnan(gain)
best_fit, best_std = None, np.inf
for deg in range(3, 9):
coeffs = np.polyfit(np.asarray(freqs)[valid], gain[valid], deg)
interp_func = np.poly1d(coeffs)
interp_gain = interp_func(freqs)
residuals = gains[i] - interp_gain
new_std = np.nanstd(residuals[~np.isnan(gains[i])])
if new_std >= best_std:
break
best_std = new_std
best_fit = interp_gain
# Limit interpolation to valid frequency range
nanpos = np.where(~np.isnan(gains[i]))[0]
minpos, maxpos = np.nanmin(nanpos), np.nanmax(nanpos)
best_fit[:minpos] = np.nan
best_fit[maxpos:] = np.nan
gains[i] = best_fit
crossphase = np.angle(gains[0] + 1j * gains[1], deg=True)
return crossphase
[docs]
def crossphasecal(
msname,
caltable,
uvrange="",
gaintable="",
chanwidth=1,
n_threads=-1,
):
"""
Function to calculate MWA cross hand phase
Parameters
----------
msname : str
Name of the measurement set
caltable : str
Name of the caltable
uvrange : str, optional
UV-range for calibration
gaintable : str, optional
Previous gaintable
chanwidth : int, optional
Channels to average
n_threads : int, optional
Number of CPU threads to use
Returns
-------
str
Name of the caltable
"""
n_threads = max(1, n_threads)
ne.set_num_threads(n_threads)
if caltable == "":
caltable = msname.split(".ms")[0] + ".kcross"
#######################
with suppress_output():
tb = casatable()
tb.open(msname + "/SPECTRAL_WINDOW")
freqs = tb.getcol("CHAN_FREQ").flatten()
cent_freq = tb.getcol("REF_FREQUENCY")[0]
wavelength = (3 * 10**8) / cent_freq
tb.close()
del tb
# with suppress_output():
tb = casatable()
tb.open(msname)
ant1 = tb.getcol("ANTENNA1")
ant2 = tb.getcol("ANTENNA2")
data = tb.getcol("DATA")
model_data = tb.getcol("MODEL_DATA")
flag = tb.getcol("FLAG")
uvw = tb.getcol("UVW")
weight = tb.getcol("WEIGHT")
# Col shape, corrs, chans, baselines
weight = np.repeat(weight[0, np.newaxis, :], model_data.shape[1], axis=0)
tb.close()
if gaintable == "":
gaintable_supplied = False
else:
gaintable_supplied = True
with suppress_output():
tb = casatable()
tb.open(gaintable)
if isinstance(gaintable, list):
gaintable = gaintable[0]
gain = tb.getcol("CPARAM")
tb.close()
del tb
if uvrange != "":
uvdist = np.sqrt(uvw[0, :] ** 2 + uvw[1, :] ** 2)
if "~" in uvrange:
minuv_m = float(uvrange.split("lambda")[0].split("~")[0]) * wavelength
maxuv_m = float(uvrange.split("lambda")[0].split("~")[-1]) * wavelength
elif ">" in uvrange:
minuv_m = float(uvrange.split("lambda")[0].split(">")[-1]) * wavelength
maxuv_m = np.nanmax(uvdist)
else:
minuv_m = 0.1
maxuv_m = float(uvrange.split("lambda")[0].split("<")[-1]) * wavelength
uv_filter = (uvdist >= minuv_m) & (uvdist <= maxuv_m)
# Filter data based on uv_filter
data = data[..., uv_filter]
model_data = model_data[..., uv_filter]
flag = flag[..., uv_filter]
weight = weight[..., uv_filter]
ant1 = ant1[uv_filter]
ant2 = ant2[uv_filter]
#######################
data[flag] = np.nan
model_data[flag] = np.nan
xy_data = data[1, ...]
yx_data = data[2, ...]
xy_model = model_data[1, ...]
yx_model = model_data[2, ...]
if gaintable_supplied:
gainX1 = gain[0, :, ant1].T
gainY1 = gain[-1, :, ant1].T
gainX2 = gain[0, :, ant2].T
gainY2 = gain[-1, :, ant2].T
del gain
del data, model_data, uvw, flag
if chanwidth > 1:
xy_data = average_with_padding(xy_data, chanwidth, axis=1, pad_value=np.nan)
yx_data = average_with_padding(yx_data, chanwidth, axis=1, pad_value=np.nan)
xy_model = average_with_padding(xy_model, chanwidth, axis=1, pad_value=np.nan)
yx_model = average_with_padding(yx_model, chanwidth, axis=1, pad_value=np.nan)
if gaintable_supplied:
gainX1 = average_with_padding(gainX1, chanwidth, axis=1, pad_value=np.nan)
gainX2 = average_with_padding(gainX2, chanwidth, axis=1, pad_value=np.nan)
gainY1 = average_with_padding(gainY1, chanwidth, axis=1, pad_value=np.nan)
gainY2 = average_with_padding(gainY2, chanwidth, axis=1, pad_value=np.nan)
weight = average_with_padding(weight, chanwidth, axis=1, pad_value=np.nan)
if gaintable_supplied:
argument = ne.evaluate(
"weight * xy_data * conj(xy_model * gainX1) * gainY2 + weight * yx_model * gainY1 * conj(gainX2 * yx_data)"
)
else:
argument = ne.evaluate(
"weight * xy_data * conj(xy_model) + weight * yx_model * conj(yx_data)"
)
crossphase = np.angle(np.nansum(argument, axis=-1), deg=True)
freqs = average_with_padding(freqs, chanwidth, axis=0, pad_value=np.nan)
if chanwidth > 1:
chan_flags = np.array([False] * len(crossphase))
else:
unflag_chans, flag_chans = get_chans_flag(msname)
chan_flags = np.array([False] * len(crossphase))
chan_flags[flag_chans] = True
crossphase[flag_chans] = np.nan
if len(freqs) > 8:
freqres = freqs[1] - freqs[0]
if freqres <= 40 * 10**3:
crossphase = fitted_crossphase(freqs, crossphase)
create_crossphase_table(msname, caltable, freqs, crossphase, chan_flags)
return caltable