import numpy as np
from casatools import msmetadata, table
from .mwa_utils import get_bad_chans
##################################
# Imaging related
##################################
[docs]
def is_fft_good(n):
"""
Whether this number is good for FFTW or not
"""
for p in [2, 3, 5, 7]:
while n % p == 0:
n //= p
return n == 1
[docs]
def get_fft_size(n):
"""
Give the best number larger than the given number for best FFT performance
"""
while True:
if is_fft_good(n):
if n < 128:
if n <= 1:
n = 1
else:
n = 1 << n.bit_length()
return n + (n % 2)
n += 1
[docs]
def calc_sun_dia(freqMHz):
"""
Function to calculate the diameter of the Sun at a given frequency (White 2016)
Parameters
----------
freq : float
Frequency in MHz
Returns
-------
float
Diameter of the Sun in arcmin
"""
freqGHz = freqMHz / 10**3 # Convert in GHz
dia = 32 + (2.2 * (freqGHz) ** (-0.6))
return round(dia, 2)
[docs]
def calc_maxuv(msname, chan_number=-1):
"""
Calculate maximum UV
Parameters
----------
msname : str
Name of the measurement set
chan_number : int, optional
Channel number
Returns
-------
float
Maximum UV in meter
float
Maximum UV in wavelength
"""
msmd = msmetadata()
msmd.open(msname)
freq = msmd.chanfreqs(0)[chan_number]
wavelength = 299792458.0 / (freq)
msmd.close()
msmd.done()
tb = table()
tb.open(msname)
uvw = tb.getcol("UVW")
tb.close()
u, v, w = [uvw[i, :] for i in range(3)]
uv = np.sqrt(u**2 + v**2)
uv[uv == 0] = np.nan
maxuv = np.nanmax(uv)
return round(float(maxuv), 2), round(float(maxuv / wavelength), 2)
[docs]
def calc_minuv(msname, chan_number=-1):
"""
Calculate minimum UV
Parameters
----------
msname : str
Name of the measurement set
chan_number : int, optional
Channel number
Returns
-------
float
Minimum UV in meter
float
Minimum UV in wavelength
"""
msmd = msmetadata()
msmd.open(msname)
freq = msmd.chanfreqs(0)[chan_number]
wavelength = 299792458.0 / (freq)
msmd.close()
msmd.done()
tb = table()
tb.open(msname)
uvw = tb.getcol("UVW")
tb.close()
u, v, w = [uvw[i, :] for i in range(3)]
uv = np.sqrt(u**2 + v**2)
uv[uv == 0] = np.nan
minuv = np.nanmin(uv)
return round(float(minuv), 2), round(float(minuv / wavelength), 2)
[docs]
def calc_uvtaper(msname):
"""
Calculate UV-taper
Parameters
----------
msname : str
Measurement set
Returns
-------
float
UV-taper in lambda at highest frequency
"""
tb = table()
tb.open(msname)
u, v, w = tb.getcol("UVW")
tb.close()
msmd = msmetadata()
msmd.open(msname)
freqs = msmd.chanfreqs(0)
max_freq = np.nanmax(freqs)
wavelength = (3 * 10**8) / max_freq
msmd.close()
sun_dia = np.deg2rad(calc_sun_dia(max_freq / 10**6) / 60.0)
bin_size_lambda = 1.22 / sun_dia
bin_size = (bin_size_lambda * wavelength) / 2.0
r = np.sqrt(u**2 + v**2)
n_bins = int(max(r) / bin_size)
r_bins = np.linspace(r.min(), r.max(), n_bins)
counts, edges = np.histogram(r, bins=r_bins)
max_counts = np.nanmax(counts)
pos = np.where(counts < 0.01 * max_counts)[0][0]
uvtaper = edges[pos] / wavelength
return round(uvtaper, 0)
[docs]
def calc_field_of_view(msname, FWHM=True):
"""
Calculate optimum field of view in arcsec.
Parameters
----------
msname : str
Measurement set name
FWHM : bool, optional
Upto FWHM, otherwise upto first null
Returns
-------
float
Field of view in arcsec
"""
msmd = msmetadata()
msmd.open(msname)
freq = msmd.chanfreqs(0)[0]
msmd.close()
tb = table()
tb.open(msname + "/ANTENNA")
dish_dia = np.nanmin(tb.getcol("DISH_DIAMETER"))
tb.close()
wavelength = 299792458.0 / freq
if FWHM:
FOV = 1.22 * wavelength / dish_dia
else:
FOV = 2.04 * wavelength / dish_dia
fov_arcsec = np.rad2deg(FOV) * 3600 # In arcsecs
return round(float(fov_arcsec), 2)
[docs]
def get_optimal_image_interval(
msname,
temporal_tol_factor=0.1,
spectral_tol_factor=0.1,
chan_range="",
timestamp_range="",
flag_central_chan=False,
max_nchan=-1,
max_ntime=-1,
):
"""
Get optimal image spectral temporal interval such that total flux max-median in each chunk is within tolerance limit
Parameters
----------
msname : str
Name of the measurement set
temporal_tol_factor : float, optional
Tolerance factor for temporal variation (default : 0.1, 10%)
spectral_tol_factor : float, optional
Tolerance factor for spectral variation (default : 0.1, 10%)
chan_range : str, optional
Channel range
timestamp_range : str, optional
Timestamp range
flag_central_chan : bool, optional
Flag central channel
max_nchan : int, optional
Maxmium number of spectral chunk
max_ntime : int, optional
Maximum number of temporal chunk
Returns
-------
int
Number of time intervals to average
int
Number of channels to averages
"""
from casatools import ms as casamstool
def is_valid_chunk(chunk, tolerance):
mean_flux = np.nanmedian(chunk)
if mean_flux == 0:
return False
return (np.nanmax(chunk) - np.nanmin(chunk)) / mean_flux <= tolerance
def find_max_valid_chunk_length(fluxes, tolerance):
n = len(fluxes)
for window in range(n, 1, -1): # Try from largest to smallest
valid = True
for start in range(0, n, window):
end = min(start + window, n)
chunk = fluxes[start:end]
if len(chunk) < window: # Optionally require full window
valid = False
break
if not is_valid_chunk(chunk, tolerance):
valid = False
break
if valid:
return window # Return the largest valid window
return 1 # Minimum chunk size is 1 if nothing else is valid
tb = table()
mstool = casamstool()
msmd = msmetadata()
msmd.open(msname)
msmd.nchan(0)
times = msmd.timesforspws(0)
len(times)
del times
msmd.close()
tb.open(msname)
u, v, w = tb.getcol("UVW")
tb.close()
bad_chans = get_bad_chans(msname, flag_central_chan=flag_central_chan)
if bad_chans != "":
bad_chan_blocks = [i for i in bad_chans.split("0:")[1].split(";")]
bad_chan_list = []
for bad_chan in bad_chan_blocks:
if "~" not in bad_chan:
bad_chan_list.append(int(bad_chan))
else:
start_bad_chan = int(bad_chan.split("~")[0])
end_bad_chan = int(bad_chan.split("~")[-1])
for b in range(start_bad_chan, end_bad_chan + 1):
bad_chan_list.append(int(b))
else:
bad_chan_list = []
uvdist = np.sort(np.unique(np.sqrt(u**2 + v**2)))
mstool.open(msname)
if uvdist[0] == 0.0:
mstool.select({"uvdist": [0.0, 0.0]})
else:
mstool.select({"antenna1": 0, "antenna2": 1})
data = mstool.getdata(["DATA"], ifraxis=True)["data"]
mstool.close()
if len(bad_chan_list) > 0:
for bad_chan in bad_chan_list:
data[:, bad_chan, ...] = np.nan
data_shape = data.shape
if len(data_shape) == 4:
if chan_range != "":
start_chan = int(chan_range.split(",")[0])
end_chan = int(chan_range.split(",")[-1])
spectra = np.nanmedian(data[:, start_chan:end_chan, ...], axis=(0, 2, 3))
else:
spectra = np.nanmedian(data, axis=(0, 2, 3))
if timestamp_range != "":
t_start = int(timestamp_range.split(",")[0])
t_end = int(timestamp_range.split(",")[-1])
t_series = np.nanmedian(data[..., t_start:t_end], axis=(0, 1, 2))
else:
t_series = np.nanmedian(data, axis=(0, 1, 2))
t_series = t_series[t_series != 0]
t_chunksize = find_max_valid_chunk_length(t_series, temporal_tol_factor)
n_time_interval = int(len(t_series) / t_chunksize)
if max_ntime > 0 and n_time_interval > max_ntime:
n_time_interval = max_ntime
else:
if chan_range != "":
start_chan = int(chan_range.split(",")[0])
end_chan = int(chan_range.split(",")[-1])
spectra = np.nanmedian(data[:, start_chan:end_chan, ...], axis=(0, 2))
else:
spectra = np.nanmedian(data, axis=(0, 2))
n_time_interval = 1
spectra = spectra[spectra != 0]
f_chunksize = find_max_valid_chunk_length(spectra, spectral_tol_factor)
n_spectral_interval = int(len(spectra) / f_chunksize)
if max_nchan > 0 and n_spectral_interval > max_nchan:
n_spectral_interval = max_nchan
return n_time_interval, n_spectral_interval
[docs]
def calc_psf(msname, chan_number=-1):
"""
Function to calculate PSF size in arcsec
Parameters
----------
msname : str
Name of the measurement set
chan_number : int, optional
Channel number
Returns
-------
float
PSF size in arcsec
"""
maxuv_m, maxuv_l = calc_maxuv(msname, chan_number=chan_number)
psf = np.rad2deg(1.2 / maxuv_l) * 3600.0 # In arcsec
return round(float(psf), 2)
[docs]
def calc_npix_in_psf(weight, robust=0.0):
"""
Calculate number of pixels in a PSF (could be fractional)
Parameters
----------
weight : str
Image weighting scheme
robust : float, optional
Briggs weighting robust parameter (-1 to +1)
Returns
-------
float
Number of pixels in a PSF
"""
weight = weight.upper()
if weight == "NATURAL":
npix = 5.0
elif weight == "UNIFORM":
npix = 3.0
else:
# robust: -1 (uniform) → 3, +1 (natural) → 5
robust = np.clip(robust, -1.0, 1.0)
npix = 3.0 + ((robust + 1.0) / 2.0) * (5.0 - 3.0)
return round(npix, 0)
[docs]
def calc_cellsize(msname, num_pixel_in_psf):
"""
Calculate pixel size in arcsec
Parameters
----------
msname : str
Name of the measurement set
num_pixel_in_psf : float
Number of pixels in one PSF
Returns
-------
int
Pixel size in arcsec
"""
psf = calc_psf(msname)
pixel = round(psf / num_pixel_in_psf, 1)
return pixel
[docs]
def calc_multiscale_scales(msname, num_pixel_in_psf, chan_number=-1, max_scale=16):
"""
Calculate multiscale scales
Parameters
----------
msname : str
Name of the measurement set
num_pixel_in_psf : float
Number of pixels in one PSF
max_scale : float, optional
Maximum scale in arcmin
Returns
-------
list
Multiscale scales in pixel units
"""
psf = calc_psf(msname, chan_number=chan_number)
minuv, minuv_l = calc_minuv(msname, chan_number=chan_number)
max_interferometric_scale = (
0.5 * np.rad2deg(1.0 / minuv_l) * 60.0
) # In arcmin, half of maximum scale
max_interferometric_scale = min(max_scale, max_interferometric_scale)
max_scale_pixel = int((max_interferometric_scale * 60.0) / (psf / num_pixel_in_psf))
current_scale = num_pixel_in_psf
multiscale_scales = [0, current_scale]
while True:
current_scale = current_scale * 2
if current_scale >= max_scale_pixel:
current_scale = max_scale_pixel
multiscale_scales.append(current_scale)
break
multiscale_scales.append(current_scale)
return multiscale_scales
[docs]
def get_multiscale_bias(freq, bias_min=0.6, bias_max=0.9, minfreq=100, maxfreq=200):
"""
Get frequency dependent multiscale bias
Parameters
----------
freq : float
Frequency in MHz
bias_min : float, optional
Minimum bias at minimum L-band frequency
bias_max : float, optional
Maximum bias at maximum L-band frequency
minfreq : float, optional
Minimum frequency range in MHz
maxfreq : float, optional
Maximum frequency range in MHz
Returns
-------
float
Multiscale bias patrameter
"""
if freq <= minfreq:
return bias_min
elif freq >= maxfreq:
return bias_max
else:
freq_min = 100
freq_max = 200
logf = np.log10(freq)
logf_min = np.log10(freq_min)
logf_max = np.log10(freq_max)
frac = (logf - logf_min) / (logf_max - logf_min)
return round(
np.clip(bias_min + frac * (bias_max - bias_min), bias_min, bias_max), 3
)