86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212 | class SubtomogramGenerator:
"""
A class for generating subtomograms from a parent tomogram.
Attributes:
tomogram (Tomogram): The parent tomogram to sample from.
annotations (List[Annotation]): The annotations from the parent
tomogram.
vol_shape (Tuple[int, int, int]): The shape of the volumes to be
generated.
pads (Tuple[int, int, int]): The padding to apply to the boundaries.
gen (np.random.Generator): Random number generator for sampling.
"""
def __init__(self, tomogram: 'Tomogram') -> None:
"""
Initializes a SubtomogramGenerator instance.
Args:
tomogram (Tomogram): The parent tomogram to sample from.
"""
self.tomogram = tomogram
self.tomogram.load()
self.annotations = self.tomogram.annotations
self.vol_shape = (64, 256, 256)
self.pads = (8, 32, 32)
self.gen = np.random.default_rng()
def set_vol_shape(self, new_vol_shape: tuple[int, int, int]):
"""
Sets a new volume shape for the generator.
Args:
new_vol_shape (tuple[int, int, int]): The new volume shape.
"""
self.vol_shape = new_vol_shape
def positive_sample(self, point: Optional[np.ndarray] = None) -> Subtomogram:
"""
Returns a random subtomogram containing the specified point.
The point will not be closer than `pads` voxels to the respective
borders. If no point is given, a random annotation point from
self.tomogram's annotations is selected.
Args:
point (Optional[np.ndarray]): The point to include in the
subtomogram. Defaults to None.
Returns:
The newly created subtomogram.
"""
if point is None:
# Pick a random annotation point from self.tomogram's annotations
annotation = self.gen.choice(self.annotations)
point = self.gen.choice(annotation.points)
possible_lower_bounds = [np.linspace(
max(0, pt - vs + pad),
min(ts - vs, pt - pad),
endpoint=False,
dtype=int
)
for (ts, vs, pt, pad) in zip(self.tomogram.shape, self.vol_shape, point, self.pads)]
lower_bounds = [self.gen.choice(lb, shuffle=False) for lb in possible_lower_bounds]
# Construct a new Tomogram with modified annotations
return Subtomogram(self.tomogram, lower_bounds, self.vol_shape)
def negative_sample(self) -> Subtomogram:
"""
Returns a random subtomogram that does not contain any points from the
annotations.
This process continues until a valid subtomogram is found or the maximum
iterations are reached.
Returns:
The newly created subtomogram.
Raises:
Exception: If unable to find a valid subtomogram without annotation
points after 1000 attempts.
"""
# Generate completely random bounds until one has no annotations
maxiter = 1000
for iter in range(maxiter):
possible_lower_bounds = [np.linspace(
0,
ts - vs,
endpoint=False,
dtype=int
)
for (ts, vs) in zip(self.tomogram.shape, self.vol_shape)]
lower_bounds = [self.gen.choice(lb, shuffle=False)
for lb in possible_lower_bounds]
# Check if this volume contains any annotation points
contains_annotation = False
for point in self.tomogram.annotation_points():
new_point = point - lower_bounds
if _in_bounds(self.vol_shape, new_point):
contains_annotation = True
break
if not contains_annotation:
return Subtomogram(self.tomogram, lower_bounds, self.vol_shape)
raise Exception("Failed to find a volume without an annotation")
def find_annotation_points(self) -> List[np.ndarray]:
"""
Returns a list of points that are present in the annotations.
Returns:
A list of annotation points.
"""
points: List[np.ndarray] = []
for annotation in self.annotations:
points += annotation.points
return points
|