import numpy as np
import traceback
import os
import zarr
import dask
import warnings
from casatools import msmetadata, table
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr
from numpy.linalg import inv
from .basic_utils import (
get_datadir,
ceil_to_multiple,
timestamp_to_mjdsec,
mjdsec_to_timestamp,
)
from .imaging import calc_sun_dia, calc_maxuv, calc_field_of_view
#####################################
# Calibration related
#####################################
[docs]
def fill_nan_gains(x, data):
"""
Interpolate nan gains across frequency
Parameters
----------
x : numpy.array
1D array of freqs
data : numpy.array
1D array of complex gains
Returns
-------
numpy.array
1D array of nan filled interpolated gains
"""
from scipy.interpolate import interp1d
x = np.asarray(x)
data = np.asarray(data, dtype=float) # ensure NaNs work
nans = np.isnan(data)
if np.sum(~nans) < 3:
return data
sort_idx = np.argsort(x)
x = x[sort_idx]
data = data[sort_idx]
nans = np.isnan(data)
interp_func = interp1d(
x[~nans],
data[~nans],
kind="linear",
bounds_error=False,
fill_value="extrapolate",
)
interpolated_data = interp_func(x)
return interpolated_data
[docs]
def fluxcal_caltable(caltable, attn=10):
"""
Function to scale scale MWA bandpass table for attenuation (Digital gain corrections should already been applied)
Parameters
----------
caltable : str
Name of the caltable
attn : float, optional
Attenuation in dB
Returns
-------
str
Flux calibrated caltable
"""
datadir = get_datadir()
tb = table()
tb.open(f"{caltable}/SPECTRAL_WINDOW")
freqlist = tb.getcol("CHAN_FREQ") / 10**6 # In MHz
tb.close()
fluxscale_poly = np.poly1d(
np.load(f"{datadir}/Ref_mean_bandpass_final.npy", allow_pickle=True)[0]
)
gain_scale = fluxscale_poly(freqlist)
att_scaling = 10 ** (-(attn - 1) / 10.0)
gain_scale_att = gain_scale * np.sqrt(att_scaling)
tb.open(caltable, nomodify=False)
gain = tb.getcol("CPARAM")
for i in range(gain.shape[1]):
gain[:, i, :] *= gain_scale_att[i]
tb.putcol("CPARAM", gain)
tb.flush()
tb.close()
return caltable
[docs]
def merge_caltables(caltables, merged_caltable, append=False, keepcopy=False):
"""
Merge multiple same type of caltables
Parameters
----------
caltables : list
Caltable list
merged_caltable : str
Merged caltable name
append : bool, optional
Append with exisiting caltable
keepcopy : bool, opitonal
Keep input caltables or not
Returns
-------
str
Merged caltable
"""
if not isinstance(caltables, list) or len(caltables) == 0:
print("Please provide a list of caltable.")
return
if os.path.exists(merged_caltable) and append:
pass
else:
if os.path.exists(merged_caltable):
os.system("rm -rf " + merged_caltable)
if keepcopy:
os.system("cp -r " + caltables[0] + " " + merged_caltable)
else:
os.system("mv " + caltables[0] + " " + merged_caltable)
caltables.remove(caltables[0])
if len(caltables) > 0:
tb = table()
for caltable in caltables:
if os.path.exists(caltable):
tb.open(caltable)
tb.copyrows(merged_caltable)
tb.close()
if not keepcopy:
os.system("rm -rf " + caltable)
return merged_caltable
[docs]
def interpolate_bpass(caltables, overwrite=False):
"""
Interpolate bandpass/crossphase tables for missing frequency solutions
Parameters
----------
caltables : list
List of bandpass/crossphase tables
overwrite : bool, optional
Overwrite the input tables or not
Returns
-------
list
Output bandpass/crossphase tables
"""
if len(caltables) < 2:
return caltables
tb = table()
all_freqs = []
all_gains = []
for cal in caltables:
tb.open(cal + "/SPECTRAL_WINDOW")
freq = tb.getcol("CHAN_FREQ")
tb.close()
tb.open(cal)
gain = tb.getcol("CPARAM")
flag = tb.getcol("FLAG")
gain[flag] = np.nan + 1j * np.nan
tb.close()
all_freqs.append(freq)
all_gains.append(gain)
all_freqs = np.concatenate(all_freqs, axis=0)
all_gains = np.concatenate(all_gains, axis=1)
all_freqs = all_freqs.flatten()
pos = np.argsort(all_freqs)
all_freqs_sorted = all_freqs[pos]
all_gains_sorted = all_gains[:, pos, :]
interpolated_gains = np.ones(all_gains_sorted.shape, dtype="complex")
npol = all_gains_sorted.shape[0]
nant = all_gains_sorted.shape[-1]
for p in range(npol):
for a in range(nant):
interp_re = fill_nan_gains(
all_freqs_sorted, np.real(all_gains_sorted[p, :, a])
)
interp_im = fill_nan_gains(
all_freqs_sorted, np.imag(all_gains_sorted[p, :, a])
)
interp_gain = interp_re + 1j * interp_im
nans = np.isnan(interp_gain)
interp_gain[nans] = 1.0 + 1j * 0.0
interpolated_gains[p, :, a] = interp_gain
del interp_gain
outlist = []
for cal in caltables:
cal = cal.rstrip("/")
if overwrite is False:
outcal = f"{cal}.interp"
if os.path.exists(outcal):
os.system(f"rm -rf {outcal}")
os.system(f"cp -r {cal} {outcal}")
else:
outcal = cal
tb.open(f"{outcal}/SPECTRAL_WINDOW")
freqs = tb.getcol("CHAN_FREQ")
tb.close()
pos = np.searchsorted(all_freqs_sorted, freqs.flatten())
interp_gain_out = interpolated_gains[:, pos, :]
tb.open(outcal, nomodify=False)
flags = tb.getcol("FLAG")
gains = tb.getcol("CPARAM")
gains[flags] = interp_gain_out[flags]
tb.putcol("CPARAM", gains)
flags *= False
pos = np.where(np.isnan(np.abs(gains)))
flags[pos] = True
tb.putcol("FLAG", flags)
tb.flush()
tb.close()
outlist.append(outcal)
return outlist
[docs]
def interpolate_quartical(caltables, overwrite=False):
"""
Function to interpolate quartical caltable
Parameters
----------
caltables : list
Name of the full Jones QuartiCal caltable caltables
overwrite : bool, optional
Overwrite the input caltable (if not, a new caltable will be written)
Returns
-------
str
New caltable name
"""
all_freqs = []
all_gains = []
for caltable in caltables:
caltable = caltable.rstrip("/")
soltypes = get_quartical_soltype(caltable)
if len(soltypes) == 0:
print("No solution is present. Not performing interpolation.")
pass
else:
soltype = soltypes[0]
gains = xds_from_zarr(f"{caltable}::{soltype}")
freqs = gains[0].gain_freq.to_numpy()
gain_data = gains[
0
].gains.to_numpy() # Shape: ntime, nchan, nant, ndir, npol
gain_flag = gains[0].gain_flags.to_numpy()
gain_flag = gains[0].gain_flags.values.astype(bool)
gain_data[gain_flag, :] = np.nan
all_freqs.append(freqs)
all_gains.append(gain_data)
all_freqs = np.concatenate(all_freqs, axis=0)
all_gains = np.concatenate(all_gains, axis=1)
all_freqs = all_freqs.flatten()
pos = np.argsort(all_freqs)
all_freqs_sorted = all_freqs[pos]
all_gains_sorted = all_gains[:, pos, ...]
interpolated_gains = np.ones(all_gains_sorted.shape, dtype="complex")
interpolated_gains[..., 1] *= 0.0
interpolated_gains[..., 2] *= 0.0
ntime = all_gains_sorted.shape[0]
npol = all_gains_sorted.shape[-1]
nant = all_gains_sorted.shape[2]
for t in range(ntime):
for p in range(npol):
for a in range(nant):
interp_re = fill_nan_gains(
all_freqs_sorted, np.real(all_gains_sorted[t, :, a, 0, p])
)
interp_im = fill_nan_gains(
all_freqs_sorted, np.imag(all_gains_sorted[t, :, a, 0, p])
)
interp_gain = interp_re + 1j * interp_im
nans = np.isnan(interp_gain)
if p == 0 or p == 3:
interp_gain[nans] = 1.0 + 1j * 0.0
else:
interp_gain[nans] = 0.0 + 1j * 0.0
interpolated_gains[t, :, a, 0, p] = interp_gain
del interp_gain
outlist = []
for caltable in caltables:
caltable = caltable.rstrip("/")
soltypes = get_quartical_soltype(caltable)
if len(soltypes) == 0:
print("No solution is present. Not performing interpolation.")
pass
else:
soltype = soltypes[0]
gains = xds_from_zarr(f"{caltable}::{soltype}")
gain_data = gains[
0
].gains.to_numpy() # Shape: ntime, nchan, nant, ndir, npol
gain_flag = gains[0].gain_flags.to_numpy()
bool_gain_flag = gains[0].gain_flags.values.astype(bool)
freqs = gains[0].gain_freq.to_numpy()
pos = np.searchsorted(all_freqs_sorted, freqs.flatten())
interp_gain_out = interpolated_gains[:, pos, ...]
gain_data[bool_gain_flag, :] = interp_gain_out[bool_gain_flag, :]
gains[0].update(
{
"gain_flags": (
["gain_time", "gain_freq", "antenna", "direction"],
gain_flag,
)
}
)
gains[0].update(
{
"gains": (
[
"gain_time",
"gain_freq",
"antenna",
"direction",
"correlation",
],
gain_data,
)
}
)
if overwrite:
output_name = caltable
else:
output_name = f"{caltable}.interp"
if overwrite:
os.system(f"rm -rf {caltable}*")
write_xds_list = xds_to_zarr(gains, f"{output_name}::{soltype}")
dask.compute(write_xds_list)
outlist.append(output_name)
return outlist
[docs]
def get_cal_flag_info(caltable):
"""
Get flag information of the caltable
Parameters
----------
caltable : str
Caltable
Returns
-------
list
Flagged channel list
list
Flagged antenna list
list
Flagged timestamp list
float
Total flag fraction
float
Channel flag fraction
float
Antenna flag fraction
float
Time flag fraction
"""
tb = table()
tb.open(caltable)
flags = tb.getcol("FLAG")
times = tb.getcol("TIME")
unique_times = np.unique(times)
ntime = len(unique_times)
tb.close()
shape = flags.shape
flags = flags.reshape(shape[0], shape[1], ntime, shape[2] // ntime)
shape = flags.shape
npol = shape[0]
nchan = shape[1]
ntime = shape[2]
nant = shape[3]
chans_flags = np.nansum(flags, axis=(0, 2, 3))
ant_flags = np.nansum(flags, axis=(0, 1, 2))
time_flags = np.nansum(flags, axis=(0, 1, 3))
flag_chans = np.where(chans_flags == npol * nant * ntime)[0]
flag_ants = np.where(ant_flags == npol * nchan * ntime)[0]
flag_times = np.where(time_flags == npol * nchan * nant)[0]
flag_frac = np.nansum(flags) / np.size(flags)
chan_flag_frac = len(flag_chans) / nchan
ant_flag_frac = len(flag_ants) / nant
time_flag_frac = len(flag_times) / ntime
return (
flag_chans.tolist(),
flag_ants.tolist(),
flag_times.tolist(),
flag_frac,
chan_flag_frac,
ant_flag_frac,
time_flag_frac,
)
[docs]
def get_psf_size(msname, chan_number=-1):
"""
Function to calculate PSF size in arcsec
Parameters
----------
msname : str
Name of the measurement set
chan_number : int, optional
Channel number
Returns
-------
float
PSF size in arcsec
"""
maxuv_m, maxuv_l = calc_maxuv(msname, chan_number=chan_number)
psf = np.rad2deg(1.2 / maxuv_l) * 3600.0 # In arcsec
return round(psf, 2)
[docs]
def calc_bw_smearing_freqwidth(msname, full_FoV=False, FWHM=True):
"""
Function to calculate spectral width to produce bandwidth smearing
Parameters
----------
msname : str
Name of the measurement set
full_FoV : bool, optional
Consider smearing within solar disc or full FoV
FWHM : bool, optional
If using full FoV, consider upto FWHM or first null
Returns
-------
float
Spectral width in MHz
"""
tb = table()
tb.open(f"{msname}/SPECTRAL_WINDOW")
freq = float(tb.getcol("REF_FREQUENCY")[0]) / 10**6
freqres = float(tb.getcol("CHAN_WIDTH")[0][0]) / 10**6
tb.close()
R = 0.9
if full_FoV:
fov = calc_field_of_view(msname, FWHM=FWHM) # In arcsec
else:
fov = 2 * calc_sun_dia(np.nanmean(freq)) * 60 # 2 times sun size
psf = get_psf_size(msname)
delta_nu = np.sqrt((1 / R**2) - 1) * (psf / fov) * freq
delta_nu = ceil_to_multiple(delta_nu, freqres)
return round(delta_nu, 2)
[docs]
def calc_time_smearing_timewidth(msname, full_FoV=False, FWHM=True):
"""
Calculate maximum time averaging to avoid time smearing over full FoV.
Parameters
----------
msname : str
Measurement set name
full_FoV : bool, optional
Consider smearing within solar disc or full FoV
FWHM : bool, optional
If using full FoV, consider upto FWHM or first null
Returns
-------
delta_t_max : float
Maximum allowable time averaging in seconds.
"""
msmd = msmetadata()
msmd.open(msname)
freq_Hz = msmd.chanfreqs(0)[0]
times = msmd.timesforspws(0)
msmd.close()
timeres = times[1] - times[0]
c = 299792458.0 # speed of light in m/s
omega_E = 7.2921159e-5 # Earth rotation rate in rad/s
lam = c / freq_Hz # wavelength in meters
freq = freq_Hz / 10**6
if full_FoV:
fov = calc_field_of_view(msname, FWHM=FWHM) # In arcsec
else:
fov = 2 * calc_sun_dia(np.nanmean(freq)) * 60 # 2 times sun size
fov_deg = fov / 3600.0
fov_rad = np.deg2rad(fov_deg)
uv, uvlambda = calc_maxuv(msname)
# Approximate maximum allowable time to avoid >10% amplitude loss
delta_t_max = lam / (2 * np.pi * uv * omega_E * fov_rad)
delta_t_max = ceil_to_multiple(delta_t_max, timeres)
return round(delta_t_max, 2)
[docs]
def max_time_solar_smearing(msname):
"""
Max allowable time averaging to avoid solar motion smearing.
Parameters
----------
msname : str
Measurement set name
Returns
-------
t_max : float
Maximum time averaging in seconds.
"""
omega_sun = 2.5 / (60.0) # solar apparent motion (2.5 arcsec/min to arcsec/sec)
psf = get_psf_size(msname)
t_max = 0.5 * (psf / omega_sun) # seconds
return round(t_max, 2)
[docs]
def get_nearest_bandpass_table(caltable_list, freq):
"""
Function to get nearest bandpass table of a given frequency
Parameters
----------
caltable_list : list
List of bandpass table
freq : float
Frequency in MHz
Returns
-------
str
Name of the nearest bandpass table
"""
if len(caltable_list) == 0:
print("No caltable is provided.")
return
if freq is None:
print("No frequency information is given.")
return
caltable_list = np.array(caltable_list)
freq_list = []
for caltable in caltable_list:
result = get_caltable_metadata(caltable)
freq_list.append(float(result["Central channel frequency (MHz)"]))
freq_list = np.array(freq_list)
pos = np.argmin(np.abs(freq - freq_list))
nearest_caltable = caltable_list[pos]
return nearest_caltable
[docs]
def get_nearest_gaincal_table(caltable_list, timestamp):
"""
Function to get nearest gaincal table of a given time
Parameters
----------
caltable_list : list
List of gaincal table
timestamp : str
Timestamp (format : 'YYYY/MM/DD/hh:mm:ss')
Returns
-------
str
Name of the nearest gaincal table
"""
if len(caltable_list) == 0:
print("No caltable is provided.\n")
return None
if timestamp is None:
print("No time information is given.\n")
return None
try:
caltable_list = np.array(caltable_list)
time_list = []
for caltable in caltable_list:
result = get_caltable_metadata(caltable)
starttime = result["Start time"]
endtime = result["End time"]
startime_mjd = timestamp_to_mjdsec(starttime, date_format=1)
endtime_mjd = timestamp_to_mjdsec(endtime, date_format=1)
time_list.append((startime_mjd + endtime_mjd) / 2.0)
time_list = np.array(time_list)
time_mjd = timestamp_to_mjdsec(timestamp, date_format=0)
pos = np.argmin(np.abs(time_mjd - time_list))
nearest_caltable = caltable_list[pos]
return nearest_caltable
except Exception:
traceback.print_exc()
print("Nearest caltable could not be found.\n")
return None
[docs]
def uvrange_casa_to_quartical(msname, uvrange=""):
"""
Get quartical uv-range from CASA format uv-range
Parameters
----------
msname : str
Measurement set
uvrange : str
UV-range in CASA format
Returns
-------
float
Minimum UV in meter
float
Maximum UV in meter
"""
if uvrange == "":
return [0.0, 0.0]
msmd = msmetadata()
msmd.open(msname)
freq = msmd.meanfreq(0)
msmd.close()
wavelength = (3 * 10**8) / freq
uvrange = uvrange.rstrip("lambda")
if "~" in uvrange:
minuv_l = float(uvrange.split("~")[0])
maxuv_l = float(uvrange.split("~")[-1])
elif ">" in uvrange:
minuv_l = float(uvrange.split(">")[-1])
maxuv_l = 0.0
elif "<" in uvrange:
minuv_l = 0.0
maxuv_l = float(uvrange.split("<")[0])
else:
minuv_l = 0.0
maxuv_l = 0.0
return round(float(minuv_l * wavelength), 1), round(float(maxuv_l * wavelength), 1)
[docs]
def solint_in_float(solint):
"""
Convert solution interval to seconds
Parameters
----------
solint : str
Solution interval
Returns
-------
float
Solution interval in seconds
"""
if solint.endswith("s"):
solint = float(solint.rstrip("s"))
elif solint.endswith("min"):
solint = float(solint.rstrip("min")) * 60.0
else:
try:
solint = float(solint)
except Exception:
solint = None
return solint
[docs]
def get_quartical_soltype(quartical_table):
"""
Get quartical solution types
Parameters
----------
quartical_table : str
Quartical table
Returns
-------
list
Solutions types
"""
try:
z = zarr.open_group(quartical_table, mode="r")
soltypes = [k for k in z.group_keys() if not k.startswith(".")]
if len(soltypes) == 0:
return []
return soltypes
except Exception:
print("Could not read QuartiCal caltable.")
return []
[docs]
def quartical_matrix_normalize(caltable, overwrite=False):
"""
Function to make matrix normalization (Normalization of full Jones solutions)
Note : for mathematical expression, look at equation 21 of Kansabanik et al. 2022, ApJ, 932:110
Parameters
----------
caltable : str
Name of the full Jones QuartiCal caltable
overwrite : bool, optional
Overwrite the input caltable (if not, a new caltable will be written)
Returns
-------
str
New caltable name
"""
caltable = caltable.rstrip("/")
soltypes = get_quartical_soltype(caltable)
if len(soltypes) == 0:
print("No solution is present. Not performing any normalisation.")
return caltable
soltype = soltypes[0]
gains = xds_from_zarr(f"{caltable}::{soltype}")
gain_data = gains[0].gains.to_numpy() # Shape: ntime, nchan, nant, ndir, npol
gain_flag = gains[0].gain_flags.to_numpy()
gain_flag = gains[0].gain_flags.values.astype(bool)
gain_data[gain_flag, :] = np.nan
gain_data = gain_data.reshape(*gain_data.shape[:-1], 2, 2)
for t in range(gain_data.shape[0]):
for f in range(gain_data.shape[1]):
for d in range(gain_data.shape[3]):
g = gain_data[t, f, :, d, ...]
if np.abs(np.nansum(g)) != 0:
gH = g.conj().transpose(0, 2, 1)
gH_dot_g_sum_inv = inv(np.nansum(np.matmul(gH, g), axis=0))
X = inv(np.matmul(gH_dot_g_sum_inv, np.nansum(gH, axis=0)))
gain_data[t, f, :, d, ...] = np.matmul(g, inv(X))
gain_data = gain_data.reshape(*gain_data.shape[:-2], 4)
gain_flag = gain_flag.astype("int")
gains[0].update(
{
"gain_flags": (
["gain_time", "gain_freq", "antenna", "direction"],
gain_flag,
)
}
)
gains[0].update(
{
"gains": (
["gain_time", "gain_freq", "antenna", "direction", "correlation"],
gain_data,
)
}
)
output_name = caltable if overwrite else f"{caltable}.poldist"
if overwrite:
os.system(f"rm -rf {caltable}*")
write_xds_list = xds_to_zarr(gains, f"{output_name}::{soltype}")
dask.compute(write_xds_list)
return output_name
[docs]
def scale_bandpass(bandpass_table, cal_attn, target_attn):
"""
Scale a bandpass calibration table using attenuation data.
Parameters
----------
bandpass_table : str
Input bandpass calibration table.
cal_attn : float
Calibrator attenuation
target_attn : float
Target attenuation
Returns
-------
str
Name of the output table.
"""
warnings.filterwarnings("ignore", category=RuntimeWarning)
bandpass_table = bandpass_table.rstrip("/")
output_table = f"{bandpass_table}.att{target_attn}"
if os.path.exists(output_table):
os.system(f"rm -rf {output_table}")
os.system(f"cp -r {bandpass_table} {output_table}")
tb = table()
tb.open(output_table, nomodify=False)
gain = tb.getcol("CPARAM")
tb.getcol("FLAG")
if cal_attn == target_attn:
scaling = 1.0
else:
scaling = 10 ** (-(target_attn - cal_attn) / 20.0)
gain *= scaling
tb.putcol("CPARAM", gain)
tb.flush()
tb.close()
return output_table