Skip to content

masked_loading

Functions to use masked loading of ROIs before/after processing.

_postprocess_output(*, modified_array, original_array, background)

Postprocess cellpose output, mainly to restore its original background.

NOTE: The pre/post-processing functions and the masked_loading_wrapper are currently meant to work as part of the cellpose_segmentation task, with the plan of then making them more flexible; see https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/340.

PARAMETER DESCRIPTION
modified_array

The 3D (ZYX) array with the correct object data and wrong background data.

TYPE: ndarray

original_array

The 3D (ZYX) array with the wrong object data and correct background data.

TYPE: ndarray

background

The 3D (ZYX) boolean array that defines the background.

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

The postprocessed array.

Source code in fractal_tasks_core/masked_loading.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def _postprocess_output(
    *,
    modified_array: np.ndarray,
    original_array: np.ndarray,
    background: np.ndarray,
) -> np.ndarray:
    """
    Postprocess cellpose output, mainly to restore its original background.

    **NOTE**: The pre/post-processing functions and the
    masked_loading_wrapper are currently meant to work as part of the
    cellpose_segmentation task, with the plan of then making them more
    flexible; see
    https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/340.

    Args:
        modified_array: The 3D (ZYX) array with the correct object data and
            wrong background data.
        original_array: The 3D (ZYX) array with the wrong object data and
            correct background data.
        background: The 3D (ZYX) boolean array that defines the background.

    Returns:
        The postprocessed array.
    """
    # Restore background
    modified_array[background] = original_array[background]
    return modified_array

_preprocess_input(image_array, *, region, current_label_path, ROI_table_path, ROI_positional_index)

Preprocess a four-dimensional cellpose input.

This involves :

  • Loading the masking label array for the appropriate ROI;
  • Extracting the appropriate label value from the ROI_table.obs dataframe;
  • Constructing the background mask, where the masking label matches with a specific label value;
  • Setting the background of image_array to 0;
  • Loading the array which will be needed in postprocessing to restore background.

NOTE 1: This function relies on V1 of the Fractal table specifications, see https://fractal-analytics-platform.github.io/fractal-tasks-core/tables/.

NOTE 2: The pre/post-processing functions and the masked_loading_wrapper are currently meant to work as part of the cellpose_segmentation task, with the plan of then making them more flexible; see https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/340.

Naming of variables refers to a two-steps labeling, as in "first identify organoids, then look for nuclei inside each organoid") :

  • "masking" refers to the labels that are used to identify the object vs background (e.g. the organoid labels); these labels already exist.
  • "current" refers to the labels that are currently being computed in the cellpose_segmentation task, e.g. the nuclear labels.
PARAMETER DESCRIPTION
image_array

The 4D CZYX array with image data for a specific ROI.

TYPE: ndarray

region

The ZYX indices of the ROI, in a form like (slice(0, 1), slice(1000, 2000), slice(1000, 2000)).

TYPE: tuple[slice, ...]

current_label_path

Path to the image used as current label, in a form like /somewhere/plate.zarr/A/01/0/labels/nuclei_in_organoids/0.

TYPE: str

ROI_table_path

Path of the AnnData table for the masking-label ROIs; this is used (together with ROI_positional_index) to extract label_value.

TYPE: str

ROI_positional_index

Index of the current ROI, which is used to extract label_value from ROI_table_obs.

TYPE: int

Returns: A tuple with three arrays: the preprocessed image array, the background mask, the current label.

Source code in fractal_tasks_core/masked_loading.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def _preprocess_input(
    image_array: np.ndarray,
    *,
    region: tuple[slice, ...],
    current_label_path: str,
    ROI_table_path: str,
    ROI_positional_index: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Preprocess a four-dimensional cellpose input.

    This involves :

    - Loading the masking label array for the appropriate ROI;
    - Extracting the appropriate label value from the `ROI_table.obs`
      dataframe;
    - Constructing the background mask, where the masking label matches with a
      specific label value;
    - Setting the background of `image_array` to `0`;
    - Loading the array which will be needed in postprocessing to restore
      background.

    **NOTE 1**: This function relies on V1 of the Fractal table specifications,
    see
    https://fractal-analytics-platform.github.io/fractal-tasks-core/tables/.

    **NOTE 2**: The pre/post-processing functions and the
    masked_loading_wrapper are currently meant to work as part of the
    cellpose_segmentation task, with the plan of then making them more
    flexible; see
    https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/340.

    Naming of variables refers to a two-steps labeling, as in "first identify
    organoids, then look for nuclei inside each organoid") :

    - `"masking"` refers to the labels that are used to identify the object
      vs background (e.g. the organoid labels); these labels already exist.
    - `"current"` refers to the labels that are currently being computed in
      the `cellpose_segmentation` task, e.g. the nuclear labels.

    Args:
        image_array: The 4D CZYX array with image data for a specific ROI.
        region: The ZYX indices of the ROI, in a form like
            `(slice(0, 1), slice(1000, 2000), slice(1000, 2000))`.
        current_label_path: Path to the image used as current label, in a form
            like `/somewhere/plate.zarr/A/01/0/labels/nuclei_in_organoids/0`.
        ROI_table_path: Path of the AnnData table for the masking-label ROIs;
            this is used (together with `ROI_positional_index`) to extract
            `label_value`.
        ROI_positional_index: Index of the current ROI, which is used to
            extract `label_value` from `ROI_table_obs`.
    Returns:
        A tuple with three arrays: the preprocessed image array, the background
            mask, the current label.
    """

    logger.info(f"[_preprocess_input] {image_array.shape=}")
    logger.info(f"[_preprocess_input] {region=}")

    # Check that image data are 4D (CZYX) - FIXME issue 340
    if not image_array.ndim == 4:
        raise ValueError(
            "_preprocess_input requires a 4D "
            f"image_array argument, but {image_array.shape=}"
        )

    # Load the ROI table and its metadata attributes
    ROI_table = ad.read_zarr(ROI_table_path)
    attrs = zarr.group(ROI_table_path).attrs
    logger.info(f"[_preprocess_input] {ROI_table_path=}")
    logger.info(f"[_preprocess_input] {attrs.asdict()=}")
    MaskingROITableAttrs(**attrs.asdict())
    label_relative_path = attrs["region"]["path"]
    column_name = attrs["instance_key"]

    # Check that ROI_table.obs has the right column and extract label_value
    if column_name not in ROI_table.obs.columns:
        raise ValueError(
            'In _preprocess_input, "{column_name}" '
            f" missing in {ROI_table.obs.columns=}"
        )
    label_value = int(
        float(ROI_table.obs[column_name].iloc[ROI_positional_index])
    )

    # Load masking-label array (lazily)
    masking_label_path = str(
        Path(ROI_table_path).parent / label_relative_path / "0"
    )
    logger.info(f"{masking_label_path=}")
    masking_label_array = da.from_zarr(masking_label_path)
    logger.info(
        f"[_preprocess_input] {masking_label_path=}, "
        f"{masking_label_array.shape=}"
    )

    # Load current-label array (lazily)
    current_label_array = da.from_zarr(current_label_path)
    logger.info(
        f"[_preprocess_input] {current_label_path=}, "
        f"{current_label_array.shape=}"
    )

    # Load ROI data for current label array
    current_label_region = current_label_array[region].compute()

    # Load ROI data for masking label array, with or without upscaling
    if masking_label_array.shape != current_label_array.shape:
        logger.info("Upscaling of masking label is needed")
        lowres_region = convert_region_to_low_res(
            highres_region=region,
            highres_shape=current_label_array.shape,
            lowres_shape=masking_label_array.shape,
        )
        masking_label_region = masking_label_array[lowres_region].compute()
        masking_label_region = upscale_array(
            array=masking_label_region,
            target_shape=current_label_region.shape,
        )
    else:
        masking_label_region = masking_label_array[region].compute()

    # Check that all shapes match
    shapes = (
        masking_label_region.shape,
        current_label_region.shape,
        image_array.shape[1:],
    )
    if len(set(shapes)) > 1:
        raise ValueError(
            "Shape mismatch:\n"
            f"{current_label_region.shape=}\n"
            f"{masking_label_region.shape=}\n"
            f"{image_array.shape=}"
        )

    # Compute background mask
    background_3D = masking_label_region != label_value
    if (masking_label_region == label_value).sum() == 0:
        raise ValueError(
            f"Label {label_value} is not present in the extracted ROI"
        )

    # Set image background to zero
    n_channels = image_array.shape[0]
    for i in range(n_channels):
        image_array[i, background_3D] = 0

    return (image_array, background_3D, current_label_region)

masked_loading_wrapper(*, function, image_array, kwargs=None, use_masks, preprocessing_kwargs=None)

Wrap a function with some pre/post-processing functions

PARAMETER DESCRIPTION
function

The callable function to be wrapped.

TYPE: Callable

image_array

The image array to be preprocessed and then used as positional argument for function.

TYPE: ndarray

kwargs

Keyword arguments for function.

TYPE: Optional[dict] DEFAULT: None

use_masks

If False, the wrapper only calls function(*args, **kwargs).

TYPE: bool

preprocessing_kwargs

Keyword arguments for the preprocessing function (see call signature of _preprocess_input()).

TYPE: Optional[dict] DEFAULT: None

Source code in fractal_tasks_core/masked_loading.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def masked_loading_wrapper(
    *,
    function: Callable,
    image_array: np.ndarray,
    kwargs: Optional[dict] = None,
    use_masks: bool,
    preprocessing_kwargs: Optional[dict] = None,
):
    """
    Wrap a function with some pre/post-processing functions

    Args:
        function: The callable function to be wrapped.
        image_array: The image array to be preprocessed and then used as
            positional argument for `function`.
        kwargs: Keyword arguments for `function`.
        use_masks: If `False`, the wrapper only calls
            `function(*args, **kwargs)`.
        preprocessing_kwargs: Keyword arguments for the preprocessing function
            (see call signature of `_preprocess_input()`).
    """
    # Optional preprocessing
    if use_masks:
        preprocessing_kwargs = preprocessing_kwargs or {}
        (
            image_array,
            background_3D,
            current_label_region,
        ) = _preprocess_input(image_array, **preprocessing_kwargs)
    # Run function
    kwargs = kwargs or {}
    new_label_img = function(image_array, **kwargs)
    # Optional postprocessing
    if use_masks:
        new_label_img = _postprocess_output(
            modified_array=new_label_img,
            original_array=current_label_region,
            background=background_3D,
        )
    return new_label_img