class LCMultiChDloader(MultiChDloader):
def __init__(
self,
data_config: DatasetConfig,
fpath: str,
load_data_fn: Callable,
val_fraction=None,
test_fraction=None,
):
self._padding_kwargs = (
data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
)
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
super().__init__(
data_config,
fpath,
load_data_fn=load_data_fn,
val_fraction=val_fraction,
test_fraction=test_fraction,
)
if data_config.overlapping_padding_kwargs is not None:
assert (
self._padding_kwargs == data_config.overlapping_padding_kwargs
), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
It should be so since we just use overlapping_padding_kwargs when it is not None"
else:
self._overlapping_padding_kwargs = data_config.padding_kwargs
self.multiscale_lowres_count = data_config.multiscale_lowres_count
assert self.multiscale_lowres_count is not None
self._scaled_data = [self._data]
self._scaled_noise_data = [self._noise_data]
assert (
isinstance(self.multiscale_lowres_count, int)
and self.multiscale_lowres_count >= 1
)
assert isinstance(self._padding_kwargs, dict)
assert "mode" in self._padding_kwargs
for _ in range(1, self.multiscale_lowres_count):
shape = self._scaled_data[-1].shape
assert len(shape) == 4
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
ds_data = resize(
self._scaled_data[-1].astype(np.float32), new_shape
).astype(self._scaled_data[-1].dtype)
# NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
assert (
ds_data.max() / self._scaled_data[-1].max() < 5
), "Downsampled image should not have very different values"
assert (
ds_data.max() / self._scaled_data[-1].max() > 0.2
), "Downsampled image should not have very different values"
self._scaled_data.append(ds_data)
# do the same for noise
if self._noise_data is not None:
noise_data = resize(self._scaled_noise_data[-1], new_shape)
self._scaled_noise_data.append(noise_data)
def reduce_data(
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
):
assert t_list is not None
assert h_start is None
assert h_end is None
assert w_start is None
assert w_end is None
self._data = self._data[t_list].copy()
self._scaled_data = [
self._scaled_data[i][t_list].copy() for i in range(len(self._scaled_data))
]
if self._noise_data is not None:
self._noise_data = self._noise_data[t_list].copy()
self._scaled_noise_data = [
self._scaled_noise_data[i][t_list].copy()
for i in range(len(self._scaled_noise_data))
]
self.N = len(t_list)
# TODO where tf is self._img_sz defined?
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
print(
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
)
def _init_msg(self):
msg = super()._init_msg()
msg += f" Pad:{self._padding_kwargs}"
if self._uncorrelated_channels:
msg += f" UncorrChProbab:{self._uncorrelated_channel_probab}"
return msg
def _load_scaled_img(
self, scaled_index, index: Union[int, Tuple[int, int]]
) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(index, int):
idx = index
else:
idx, _ = index
# tidx = self.idx_manager.get_t(idx)
patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
nidx = patch_loc_list[0]
imgs = self._scaled_data[scaled_index][nidx]
imgs = tuple([imgs[None, ..., i] for i in range(imgs.shape[-1])])
if self._noise_data is not None:
noisedata = self._scaled_noise_data[scaled_index][nidx]
noise = tuple([noisedata[None, ..., i] for i in range(noisedata.shape[-1])])
factor = np.sqrt(2) if self._input_is_sum else 1.0
imgs = tuple([img + noise[0] * factor for img in imgs])
return imgs
def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
"""
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
the cropped image will be smaller than self._img_sz * self._img_sz
"""
max_len_vals = list(self.idx_manager.data_shape[1:-1])
max_len_vals[-2:] = img.shape[-2:]
return self._crop_img_with_padding(
img, patch_start_loc, max_len_vals=max_len_vals
)
def _get_img(self, index: int):
"""
Returns the primary patch along with low resolution patches centered on the primary patch.
"""
# Noise_tuples is populated when there is synthetic noise in training
# Should have similar type of noise with the noise model
# Starting with microsplit, dump the noise, use it instead as an augmentation if nessesary
img_tuples, noise_tuples = self._load_img(index)
assert self._img_sz is not None
h, w = img_tuples[0].shape[-2:]
if self._enable_random_cropping:
patch_start_loc = self._get_random_hw(h, w)
if self._5Ddata:
patch_start_loc = (
np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
) + patch_start_loc
else:
patch_start_loc = self._get_deterministic_loc(index)
# LC logic is located here, the function crops the image of the highest resolution
cropped_img_tuples = [
self._crop_flip_img(img, patch_start_loc, False, False)
for img in img_tuples
]
cropped_noise_tuples = [
self._crop_flip_img(noise, patch_start_loc, False, False)
for noise in noise_tuples
]
patch_start_loc = list(patch_start_loc)
h_start, w_start = patch_start_loc[-2], patch_start_loc[-1]
h_center = h_start + self._img_sz // 2
w_center = w_start + self._img_sz // 2
allres_versions = {
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
}
for scale_idx in range(1, self.multiscale_lowres_count):
# Returning the image of the lower resolution
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
h_center = h_center // 2
w_center = w_center // 2
h_start = h_center - self._img_sz // 2
w_start = w_center - self._img_sz // 2
patch_start_loc[-2:] = [h_start, w_start]
scaled_cropped_img_tuples = [
self._crop_flip_img(img, patch_start_loc, False, False)
for img in scaled_img_tuples
]
for ch_idx in range(len(img_tuples)):
allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
output_img_tuples = tuple(
[
np.concatenate(allres_versions[ch_idx])
for ch_idx in range(len(img_tuples))
]
)
return output_img_tuples, cropped_noise_tuples
def __getitem__(self, index: Union[int, Tuple[int, int]]):
img_tuples, noise_tuples = self._get_img(index)
if self._uncorrelated_channels:
assert (
self._input_idx is None
), "Uncorrelated channels is not implemented when there is a separate input channel."
if np.random.rand() < self._uncorrelated_channel_probab:
img_tuples_new = [None] * len(img_tuples)
img_tuples_new[0] = img_tuples[0]
for i in range(1, len(img_tuples)):
new_index = np.random.randint(len(self))
img_tuples_tmp, _ = self._get_img(new_index)
img_tuples_new[i] = img_tuples_tmp[i]
img_tuples = img_tuples_new
if self._is_train:
if self._empty_patch_replacement_enabled:
if np.random.rand() < self._empty_patch_replacement_probab:
img_tuples = self.replace_with_empty_patch(img_tuples)
if self._enable_rotation:
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
# add noise to input, if noise is present combine it with the image
# factor is for the compute input not to have too much noise because the average of two gaussians
if len(noise_tuples) > 0:
factor = np.sqrt(2) if self._input_is_sum else 1.0
input_tuples = []
for x in img_tuples:
x = (
x.copy()
) # to avoid changing the original image since it is later used for target
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
x[0] = x[0] + noise_tuples[0] * factor
input_tuples.append(x)
else:
input_tuples = img_tuples
# Compute the input by sum / average the channels
# Alpha is an amount of weight which is applied to the channels when combining them
# How to sample alpha is still under research
inp, alpha = self._compute_input(input_tuples)
target_tuples = [img[:1] for img in img_tuples]
# add noise to target.
if len(noise_tuples) >= 1:
target_tuples = [
x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
]
target = self._compute_target(target_tuples, alpha)
norm_target = self.normalize_target(target)
output = [inp, norm_target]
if self._return_alpha:
output.append(alpha)
if isinstance(index, int):
return tuple(output)
_, grid_size = index
output.append(grid_size)
return tuple(output)