Skip to content

_projection_utils

Utility functions and models for intensity projection tasks.

DaskProjectionMethod

Bases: Enum

Registration method selection.

Choose which method to use for intensity projection along the Z axis.

ATTRIBUTE DESCRIPTION
MIP

Maximum intensity projection

MINIP

Minimum intensity projection

MEANIP

Mean intensity projection

SUMIP

Sum intensity projection

Source code in fractal_tasks_core/_projection_utils.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class DaskProjectionMethod(Enum):
    """Registration method selection.

    Choose which method to use for intensity projection along the Z axis.

    Attributes:
        MIP: Maximum intensity projection
        MINIP: Minimum intensity projection
        MEANIP: Mean intensity projection
        SUMIP: Sum intensity projection
    """

    MIP = "Maximum intensity projection"
    MINIP = "Minimum intensity projection"
    MEANIP = "Mean intensity projection"
    SUMIP = "Sum intensity projection"

    def apply(self, dask_array: da.Array, axis: int = 0) -> da.Array:
        """Apply the selected projection method to the given Dask array.

        Args:
            dask_array (dask.array.Array): The Dask array to project.
            axis (int): The axis along which to apply the projection.

        Returns:
            dask.array.Array: The resulting Dask array after applying the
                projection.
        """
        # Map the Enum values to the actual Dask array methods
        method_map = {
            DaskProjectionMethod.MIP: max_wrapper,
            DaskProjectionMethod.MINIP: min_wrapper,
            DaskProjectionMethod.MEANIP: mean_wrapper,
            DaskProjectionMethod.SUMIP: safe_sum,
        }
        # Call the appropriate method, passing in the dask_array explicitly
        return method_map[self](dask_array, axis=axis)

    @property
    def abbreviation(self) -> str:
        """Get the abbreviation of the projection method.

        Returns:
            str: The abbreviation of the projection method.
        """
        abbrev_map = {
            DaskProjectionMethod.MIP: "mip",
            DaskProjectionMethod.MINIP: "minip",
            DaskProjectionMethod.MEANIP: "meanip",
            DaskProjectionMethod.SUMIP: "sumip",
        }
        return abbrev_map[self]

abbreviation: str property

Get the abbreviation of the projection method.

RETURNS DESCRIPTION
str

The abbreviation of the projection method.

TYPE: str

apply(dask_array, axis=0)

Apply the selected projection method to the given Dask array.

PARAMETER DESCRIPTION
dask_array

The Dask array to project.

TYPE: Array

axis

The axis along which to apply the projection.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
Array

dask.array.Array: The resulting Dask array after applying the projection.

Source code in fractal_tasks_core/_projection_utils.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def apply(self, dask_array: da.Array, axis: int = 0) -> da.Array:
    """Apply the selected projection method to the given Dask array.

    Args:
        dask_array (dask.array.Array): The Dask array to project.
        axis (int): The axis along which to apply the projection.

    Returns:
        dask.array.Array: The resulting Dask array after applying the
            projection.
    """
    # Map the Enum values to the actual Dask array methods
    method_map = {
        DaskProjectionMethod.MIP: max_wrapper,
        DaskProjectionMethod.MINIP: min_wrapper,
        DaskProjectionMethod.MEANIP: mean_wrapper,
        DaskProjectionMethod.SUMIP: safe_sum,
    }
    # Call the appropriate method, passing in the dask_array explicitly
    return method_map[self](dask_array, axis=axis)

InitArgsMIP

Bases: BaseModel

Init Args for MIP task.

ATTRIBUTE DESCRIPTION
origin_url

Path to the zarr_url with the 3D data

TYPE: str

method

Projection method to be used. See DaskProjectionMethod

TYPE: DaskProjectionMethod

overwrite

If True, overwrite the task output.

TYPE: bool

new_plate_name

Name of the new OME-Zarr HCS plate

TYPE: str

Source code in fractal_tasks_core/_projection_utils.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class InitArgsMIP(BaseModel):
    """Init Args for MIP task.

    Attributes:
        origin_url: Path to the zarr_url with the 3D data
        method: Projection method to be used. See `DaskProjectionMethod`
        overwrite: If `True`, overwrite the task output.
        new_plate_name: Name of the new OME-Zarr HCS plate
    """

    origin_url: str
    method: DaskProjectionMethod = DaskProjectionMethod.MIP
    overwrite: bool
    new_plate_name: str

_compute_new_shape(source_image)

Compute the new shape of the image after the projection.

The new shape is the same as the original one, except for the z-axis, which is set to 1.

RETURNS DESCRIPTION
tuple[int, ...]
  • new shape of the image
int
  • index of the z-axis in the original image
Source code in fractal_tasks_core/_projection_utils.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _compute_new_shape(source_image: Image) -> tuple[tuple[int, ...], int]:
    """Compute the new shape of the image after the projection.

    The new shape is the same as the original one,
    except for the z-axis, which is set to 1.

    Returns:
        - new shape of the image
        - index of the z-axis in the original image
    """
    on_disk_shape = source_image.shape
    logger.info(f"Source {on_disk_shape=}")

    on_disk_z_index = source_image.axes_handler.get_index("z")
    if on_disk_z_index is None:
        raise ValueError(
            "The input image does not contain a z-axis, "
            "projection is only supported for 3D images with a z-axis."
        )

    dest_on_disk_shape = list(on_disk_shape)
    dest_on_disk_shape[on_disk_z_index] = 1
    logger.info(f"Destination {dest_on_disk_shape=}")
    return tuple(dest_on_disk_shape), on_disk_z_index

max_wrapper(dask_array, axis=0)

Perform a da.max on the dask_array.

PARAMETER DESCRIPTION
dask_array

The input Dask array.

TYPE: Array

axis

The axis along which to max the array. Defaults to 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
Array

dask.array.Array: The result of the max

Source code in fractal_tasks_core/_projection_utils.py
79
80
81
82
83
84
85
86
87
88
89
90
def max_wrapper(dask_array: da.Array, axis: int = 0) -> da.Array:
    """Perform a da.max on the dask_array.

    Args:
        dask_array (dask.array.Array): The input Dask array.
        axis (int, optional): The axis along which to max the array.
            Defaults to 0.

    Returns:
        dask.array.Array: The result of the max
    """
    return dask_array.max(axis=axis)

mean_wrapper(dask_array, axis=0)

Perform a da.mean on the dask_array & cast it to its original dtype.

Without casting, the result can change dtype to e.g. float64

PARAMETER DESCRIPTION
dask_array

The input Dask array.

TYPE: Array

axis

The axis along which to mean the array. Defaults to 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
Array

dask.array.Array: The result of the mean, cast back to the original dtype.

Source code in fractal_tasks_core/_projection_utils.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def mean_wrapper(dask_array: da.Array, axis: int = 0) -> da.Array:
    """Perform a da.mean on the dask_array & cast it to its original dtype.

    Without casting, the result can change dtype to e.g. float64

    Args:
        dask_array (dask.array.Array): The input Dask array.
        axis (int, optional): The axis along which to mean the array.
            Defaults to 0.

    Returns:
        dask.array.Array: The result of the mean, cast back to the original
            dtype.
    """
    # Determine the original dtype
    original_dtype = dask_array.dtype

    # Perform the mean
    result = da.mean(dask_array, axis=axis)

    # Cast back to the original dtype
    result = result.astype(original_dtype)

    return result

min_wrapper(dask_array, axis=0)

Perform a da.min on the dask_array.

PARAMETER DESCRIPTION
dask_array

The input Dask array.

TYPE: Array

axis

The axis along which to min the array. Defaults to 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
Array

dask.array.Array: The result of the min

Source code in fractal_tasks_core/_projection_utils.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def min_wrapper(dask_array: da.Array, axis: int = 0) -> da.Array:
    """Perform a da.min on the dask_array.

    Args:
        dask_array (dask.array.Array): The input Dask array.
        axis (int, optional): The axis along which to min the array.
            Defaults to 0.

    Returns:
        dask.array.Array: The result of the min
    """
    return dask_array.min(axis=axis)

projection_core(*, input_zarr_url, output_zarr_url, method=DaskProjectionMethod.MIP, overwrite=False, attributes=None)

Perform intensity projection along Z axis with a chosen method.

Note: this task stores the output in a new zarr file.

PARAMETER DESCRIPTION
input_zarr_url

Path or url to the individual OME-Zarr image to be processed.

TYPE: str

output_zarr_url

Path or url to the output OME-Zarr image.

TYPE: str

method

Projection method to be used. See DaskProjectionMethod

TYPE: DaskProjectionMethod DEFAULT: MIP

overwrite

If True, overwrite the task output.

TYPE: bool DEFAULT: False

attributes

Additional attributes to be added to the output image.

TYPE: dict[str, Any] | None DEFAULT: None

Source code in fractal_tasks_core/_projection_utils.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def projection_core(
    *,
    input_zarr_url: str,
    output_zarr_url: str,
    method: DaskProjectionMethod = DaskProjectionMethod.MIP,
    overwrite: bool = False,
    attributes: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """Perform intensity projection along Z axis with a chosen method.

    Note: this task stores the output in a new zarr file.

    Args:
        input_zarr_url: Path or url to the individual OME-Zarr image to be processed.
        output_zarr_url: Path or url to the output OME-Zarr image.
        method: Projection method to be used. See `DaskProjectionMethod`
        overwrite: If `True`, overwrite the task output.
        attributes: Additional attributes to be added to the output image.
    """
    logger.info(f"{input_zarr_url=}")
    logger.info(f"{output_zarr_url=}")
    logger.info(f"{method=}")

    # Read image metadata
    original_ome_zarr = open_ome_zarr_container(input_zarr_url)
    original_image = original_ome_zarr.get_image()

    if original_image.is_2d or original_image.is_2d_time_series:
        raise ValueError(
            "The input image is 2D, projection is only supported for 3D images."
        )

    # Compute the new shape and pixel size
    dest_on_disk_shape, z_axis_index = _compute_new_shape(original_image)
    logger.info(f"New shape: {dest_on_disk_shape=}")

    # Create the new empty image
    ome_zarr_mip = original_ome_zarr.derive_image(
        store=output_zarr_url,
        name=method.value.upper(),
        shape=dest_on_disk_shape,
        pixelsize=original_image.pixel_size.yx,
        z_spacing=1.0,
        time_spacing=original_image.pixel_size.t,
        overwrite=overwrite,
        copy_labels=False,
        copy_tables=True,
    )
    logger.info(f"New Projection image created - {ome_zarr_mip=}")
    proj_image = ome_zarr_mip.get_image()

    # Process the image
    source_dask = original_image.get_as_dask()
    dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index)
    dest_dask = da.expand_dims(dest_dask, axis=z_axis_index)
    proj_image.set_array(dest_dask)
    proj_image.consolidate()
    # Ends

    # Edit the roi tables
    for roi_table_name in ome_zarr_mip.list_roi_tables():
        table = ome_zarr_mip.get_generic_roi_table(roi_table_name)

        for roi in table.rois():
            old_z_slice = roi.get("z")
            if old_z_slice is not None:
                roi = roi.update_slice("z", (0, 1))
            table.add(roi, overwrite=True)

        table.consolidate()
        logger.info(f"Table {roi_table_name} Projection done")

    # Generate image_list_updates
    attributes = attributes or {}
    image_list_update_dict = {
        "image_list_updates": [
            {
                "zarr_url": output_zarr_url,
                "origin": input_zarr_url,
                "attributes": attributes,
                "types": {"is_3D": False},
            }
        ]
    }
    return image_list_update_dict

safe_sum(dask_array, axis=0)

Perform a safe sum on a Dask array to avoid overflow.

Clips the result of da.sum & casts it to its original dtype. Dask.array already correctly handles promotion to uin32 or uint64 when necessary internally, but we want to ensure we clip the result.

PARAMETER DESCRIPTION
dask_array

The input Dask array.

TYPE: Array

axis

The axis along which to sum the array. Defaults to 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
Array

dask.array.Array: The result of the sum, safely clipped and cast back to the original dtype.

Source code in fractal_tasks_core/_projection_utils.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def safe_sum(dask_array: da.Array, axis: int = 0) -> da.Array:
    """Perform a safe sum on a Dask array to avoid overflow.

    Clips the result of da.sum & casts it to its original dtype.
    Dask.array already correctly handles promotion to uin32 or uint64 when
    necessary internally, but we want to ensure we clip the result.

    Args:
        dask_array (dask.array.Array): The input Dask array.
        axis (int, optional): The axis along which to sum the array.
            Defaults to 0.

    Returns:
        dask.array.Array: The result of the sum, safely clipped and cast
            back to the original dtype.
    """
    # Determine the original dtype
    original_dtype = dask_array.dtype
    if not np.issubdtype(original_dtype, np.integer):
        raise ValueError(
            f"safe_sum only supports integer dtypes, got {original_dtype}. "
            "Use a different projection method for float arrays."
        )
    max_value = np.iinfo(original_dtype).max

    # Perform the sum
    result = da.sum(dask_array, axis=axis)

    # Clip the values to the maximum possible value for the original dtype
    result = da.clip(result, 0, max_value)

    # Cast back to the original dtype
    result = result.astype(original_dtype)

    return result