class EmptyPatchFetcher:
"""
The idea is to fetch empty patches so that real content can be replaced with this.
"""
def __init__(self, idx_manager, patch_size, data_frames, max_val_threshold=None):
self._frames = data_frames
self._idx_manager = idx_manager
self._max_val_threshold = max_val_threshold
self._idx_list = []
self._patch_size = patch_size
self._grid_size = 1
self.set_empty_idx()
print(f"[{self.__class__.__name__}] MaxVal:{self._max_val_threshold}")
def compute_max(self, window):
"""
Rolling compute.
"""
N, H, W = self._frames.shape
randnum = -954321
assert self._grid_size == 1
max_data = np.zeros((N, H - window, W - window)) * randnum
for h in tqdm(range(H - window)):
for w in range(W - window):
max_data[:, h, w] = self._frames[:, h : h + window, w : w + window].max(
axis=(1, 2)
)
assert (max_data != 954321).any()
return max_data
def set_empty_idx(self):
max_data = self.compute_max(self._patch_size)
empty_loc = np.where(
np.logical_and(max_data >= 0, max_data < self._max_val_threshold)
)
# print(max_data.shape, len(empty_loc))
self._idx_list = []
for idx in range(len(empty_loc[0])):
n_idx = empty_loc[0][idx]
h_start = empty_loc[1][idx]
w_start = empty_loc[2][idx]
# print(n_idx,h_start,w_start)
# channel_idx = self._idx_manager.get_location_from_dataset_idx(0)[-1]
loc = (n_idx, h_start, w_start, 0)
idx = self._idx_manager.get_dataset_idx_from_location(loc)
t, h, w, _ = self._idx_manager.get_location_from_dataset_idx(idx)
assert h == h_start, f"{h} != {h_start}"
assert w == w_start, f"{w} != {w_start}"
assert t == n_idx, f"{t} != {n_idx}"
self._idx_list.append(idx)
self._idx_list = np.array(self._idx_list)
assert len(self._idx_list) > 0
def sample(self):
return (np.random.choice(self._idx_list), self._grid_size)