Skip to content

careamist_v2

CAREamistV2 #

Main interface for training and predicting with CAREamics.

Attributes:

Name Type Description
workdir Path

Working directory in which to save training outputs.

config NGConfiguration[AlgorithmConfig]

CAREamics configuration.

model CAREamicsModule

The PyTorch Lightning module to be trained and used for prediction.

checkpoint_path Path | None

Path to a checkpoint file from which model and configuration may be loaded.

trainer Trainer

The PyTorch Lightning Trainer used for training and prediction.

callbacks list[Callback]

List of callbacks used during training.

prediction_writer PredictionWriterCallback

Callback used to write predictions to disk during prediction.

train_datamodule CareamicsDataModule | None

The datamodule used for training, set after calling train().

Parameters:

Name Type Description Default
config NGConfiguration[AlgorithmConfig] | Path

CAREamics configuration, or a path to a configuration file. See careamics.config.ng_factories for method to build configurations.

None
checkpoint_path Path

Path to a checkpoint file from which to load the model and configuration.

None
bmz_path Path

Path to a BioImage Model Zoo archive from which to load the model and configuration.

None
work_dir Path | str

Working directory in which to save training outputs. If None, the current working directory will be used.

None
callbacks list of PyTorch Lightning Callbacks

List of callbacks to use during training. If None, no additional callbacks will be used. Note that ModelCheckpoint and EarlyStopping callbacks are already defined in CAREamics and should only be modified through the training configuration (see NGConfiguration and TrainingConfig).

None
enable_progress_bar bool

Whether to show the progress bar during training.

True
Source code in src/careamics/careamist_v2.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
class CAREamistV2:
    """Main interface for training and predicting with CAREamics.

    Attributes
    ----------
    workdir : Path
        Working directory in which to save training outputs.
    config : NGConfiguration[AlgorithmConfig]
        CAREamics configuration.
    model : CAREamicsModule
        The PyTorch Lightning module to be trained and used for prediction.
    checkpoint_path : Path | None
        Path to a checkpoint file from which model and configuration may be loaded.
    trainer : Trainer
        The PyTorch Lightning Trainer used for training and prediction.
    callbacks : list[Callback]
        List of callbacks used during training.
    prediction_writer : PredictionWriterCallback
        Callback used to write predictions to disk during prediction.
    train_datamodule : CareamicsDataModule | None
        The datamodule used for training, set after calling `train()`.


    Parameters
    ----------
    config : NGConfiguration[AlgorithmConfig] | Path, default=None
        CAREamics configuration, or a path to a configuration file. See 
        `careamics.config.ng_factories` for method to build configurations.
    checkpoint_path : Path, default=None
        Path to a checkpoint file from which to load the model and configuration.
    bmz_path : Path, default=None
        Path to a BioImage Model Zoo archive from which to load the model and
        configuration.
    work_dir : Path | str, default=None
        Working directory in which to save training outputs. If None, the current
        working directory will be used.
    callbacks : list of PyTorch Lightning Callbacks, default=None
        List of callbacks to use during training. If None, no additional callbacks
        will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
        already defined in CAREamics and should only be modified through the
        training configuration (see NGConfiguration and TrainingConfig).
    enable_progress_bar : bool, default=True
        Whether to show the progress bar during training.
    """

    def __init__(
        self,
        config: NGConfiguration[AlgorithmConfig] | Path | None = None,
        *,
        checkpoint_path: Path | None = None,
        bmz_path: Path | None = None,
        work_dir: Path | str | None = None,
        callbacks: list[Callback] | None = None,
        enable_progress_bar: bool = True,
    ) -> None:
        """Constructor for CAREamistV2.

        Exactly one of `config`, `checkpoint_path`, or `bmz_path` must be provided.

        Parameters
        ----------
        config : NGConfiguration[AlgorithmConfig] | Path, default=None
            CAREamics configuration, or a path to a configuration file. See 
            `careamics.config.ng_factories` for method to build configurations. `config`
            is mutually exclusive with `checkpoint_path` and `bmz_path`.
        checkpoint_path : Path, default=None
            Path to a checkpoint file from which to load the model and configuration.
            `checkpoint_path` is mutually exclusive with `config` and `bmz_path`.
        bmz_path : Path, default=None
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration. `bmz_path` is mutually exclusive with `config` and
            `checkpoint_path`.
        work_dir : Path | str, default=None
            Working directory in which to save training outputs. If None, the current
            working directory will be used.
        callbacks : list of PyTorch Lightning Callbacks, default=None
            List of callbacks to use during training. If None, no additional callbacks
            will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
            already defined in CAREamics and should only be modified through the
            training configuration (see NGConfiguration and TrainingConfig).
        enable_progress_bar : bool, default=True
            Whether to show the progress bar during training.
        """
        self.checkpoint_path = checkpoint_path
        self.work_dir = self._resolve_work_dir(work_dir)
        self.config, self.model = self._load_model(config, checkpoint_path, bmz_path)

        self.config.training_config.lightning_trainer_config["enable_progress_bar"] = (
            enable_progress_bar
        )
        self.callbacks = self._define_callbacks(callbacks, self.config, self.work_dir)

        self.prediction_writer = PredictionWriterCallback(
            self.work_dir, enable_writing=False
        )

        experiment_loggers = self._create_loggers(
            self.config.training_config.logger,
            self.config.get_safe_experiment_name(),
            self.work_dir,
        )

        self.trainer = Trainer(
            callbacks=[self.prediction_writer, *self.callbacks],
            default_root_dir=self.work_dir,
            logger=experiment_loggers,
            **self.config.training_config.lightning_trainer_config or {},
        )

        self.train_datamodule: CareamicsDataModule | None = None

    def _load_model(
        self,
        config: NGConfiguration[AlgorithmConfig] | Path | None,
        checkpoint_path: Path | None,
        bmz_path: Path | None,
    ) -> tuple[NGConfiguration[AlgorithmConfig], CAREamicsModule]:
        """Load model.

        Parameters
        ----------
        config : NGConfiguration[AlgorithmConfig] | Path | None
            CAREamics configuration, or a path to a configuration file.
        checkpoint_path : Path | None
            Path to a checkpoint file from which to load the model and configuration.
        bmz_path : Path | None
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration.

        Returns
        -------
        NGConfiguration[AlgorithmConfig]
            The loaded configuration.
        CAREamicsModule
            The loaded model.

        Raises
        ------
        ValueError
            If not exactly one of `config`, `checkpoint_path`, or `bmz_path` is
            provided.
        """
        n_inputs = sum(
            [config is not None, checkpoint_path is not None, bmz_path is not None]
        )
        if n_inputs != 1:
            raise ValueError(
                "Exactly one of `config`, `checkpoint_path`, or `bmz_path` "
                "must be provided."
            )
        if config is not None:
            return self._from_config(config)
        elif checkpoint_path is not None:
            return self._from_checkpoint(checkpoint_path)
        else:
            assert bmz_path is not None
            return self._from_bmz(bmz_path)

    @staticmethod
    def _from_config(
        config: NGConfiguration[AlgorithmConfig] | Path,
    ) -> tuple[NGConfiguration[AlgorithmConfig], CAREamicsModule]:
        """Create model from configuration.

        Parameters
        ----------
        config : NGConfiguration[AlgorithmConfig] | Path
            CAREamics configuration, or a path to a configuration file.

        Returns
        -------
        NGConfiguration[AlgorithmConfig]
            The loaded configuration if a path was provided, otherwise the original
            configuration.
        CAREamicsModule
            The created model.
        """
        if isinstance(config, Path):
            config = load_configuration_ng(config)
        assert not isinstance(config, Path)

        model = create_module(config.algorithm_config)
        return config, model

    @staticmethod
    def _from_checkpoint(
        checkpoint_path: Path,
    ) -> tuple[NGConfiguration[AlgorithmConfig], CAREamicsModule]:
        """Load checkpoint and configuration from checkpoint file.

        Parameters
        ----------
        checkpoint_path : Path
            Path to a checkpoint file from which to load the model and configuration.

        Returns
        -------
        NGConfiguration[AlgorithmConfig]
            The loaded configuration.
        CAREamicsModule
            The loaded model.
        """
        config = load_config_from_checkpoint(checkpoint_path)
        module = load_module_from_checkpoint(checkpoint_path)
        return config, module

    @staticmethod
    def _from_bmz(
        bmz_path: Path,
    ) -> tuple[NGConfiguration[AlgorithmConfig], CAREamicsModule]:
        """Load checkpoint and configuration from a BioImage Model Zoo archive.

        Parameters
        ----------
        bmz_path : Path
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration.

        Returns
        -------
        NGConfiguration[AlgorithmConfig]
            The loaded configuration.
        CAREamicsModule
            The loaded model.

        Raises
        ------
        NotImplementedError
            Loading from BMZ is not implemented yet.
        """
        raise NotImplementedError("Loading from BMZ is not implemented yet.")

    @staticmethod
    def _resolve_work_dir(work_dir: str | Path | None) -> Path:
        """Resolve working directory.

        Parameters
        ----------
        work_dir : str | Path | None
            The working directory to resolve. If None, the current working directory
            will be used.

        Returns
        -------
        Path
            The resolved working directory.
        """
        if work_dir is None:
            work_dir = Path.cwd().resolve()
            logger.warning(
                f"No working directory provided. Using current working directory: "
                f"{work_dir}."
            )
        else:
            work_dir = Path(work_dir).resolve()
        return work_dir

    @staticmethod
    def _define_callbacks(
        callbacks: list[Callback] | None,
        config: NGConfiguration[AlgorithmConfig],
        work_dir: Path,
    ) -> list[Callback]:
        """Define callbacks for the training process.

        Parameters
        ----------
        callbacks : list[Callback] | None
            List of callbacks to use during training. If None, no additional callbacks
            will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
            already defined in CAREamics and instantiated in this method.
        config : NGConfiguration[AlgorithmConfig]
            The CAREamics configuration, used to instantiate the callbacks.
        work_dir : Path
            The working directory, used as a parameter to the checkpointing callback.

        Returns
        -------
        list[Callback]
            The list of callbacks to use during training.

        Raises
        ------
        ValueError
            If `ModelCheckpoint` or `EarlyStopping` callbacks are included in the
            provided `callbacks` list, as these are already defined in CAREamics and
            should only be modified through the training configuration (see
            NGConfiguration and TrainingConfig).
        """
        callbacks: list[Callback] = [] if callbacks is None else callbacks
        for c in callbacks:
            if isinstance(c, (ModelCheckpoint, EarlyStopping)):
                raise ValueError(
                    "`ModelCheckpoint` and `EarlyStopping` callbacks are already "
                    "defined in CAREamics and should only be modified through the "
                    "training configuration (see TrainingConfig)."
                )

            if isinstance(c, (CareamicsCheckpointInfo, ProgressBarCallback)):
                raise ValueError(
                    "`CareamicsCheckpointInfo` and `ProgressBar` callbacks are defined "
                    "internally and should not be passed as callbacks."
                )

        checkpoint_callback = ModelCheckpoint(
            dirpath=work_dir / "checkpoints" / config.get_safe_experiment_name(),
            filename=f"{config.get_safe_experiment_name()}_{{epoch:02d}}_step_{{step}}_{{val_loss:.4f}}",
            **config.training_config.checkpoint_callback.model_dump(),
        )
        checkpoint_callback.CHECKPOINT_NAME_LAST = f"{config.get_safe_experiment_name()}_last"
        internal_callbacks: list[Callback] = [
            checkpoint_callback,
            CareamicsCheckpointInfo(
                config.version, config.get_safe_experiment_name(), config.training_config
            ),
        ]

        enable_progress_bar = config.training_config.lightning_trainer_config.get(
            "enable_progress_bar", True
        )
        if enable_progress_bar:
            internal_callbacks.append(ProgressBarCallback())

        if config.training_config.early_stopping_callback is not None:
            internal_callbacks.append(
                EarlyStopping(
                    **config.training_config.early_stopping_callback.model_dump()
                )
            )

        return internal_callbacks + callbacks

    @staticmethod
    def _create_loggers(
        logger: str | None, experiment_name: str, work_dir: Path
    ) -> list[ExperimentLogger]:
        """Create loggers for the experiment.

        Parameters
        ----------
        logger : str | None
            Logger to use during training. If None, no logger will be used. Available
            loggers are defined in SupportedLogger.
        experiment_name : str
            Name of the experiment, used as a parameter to the loggers.
        work_dir : Path
            The working directory, used as a parameter to the loggers.
        """
        csv_logger = CSVLogger(name=experiment_name, save_dir=work_dir / "csv_logs")

        if logger is not None:
            logger = SupportedLogger(logger)

        match logger:
            case SupportedLogger.WANDB:
                return [
                    WandbLogger(name=experiment_name, save_dir=work_dir / "wandb_logs"),
                    csv_logger,
                ]
            case SupportedLogger.TENSORBOARD:
                return [
                    TensorBoardLogger(save_dir=work_dir / "tb_logs"),
                    csv_logger,
                ]
            case _:
                return [csv_logger]

    # Two overloads:
    # - 1st for supported data types & using ReadFuncLoading
    # - 2nd for ImageStackLoading
    # Why:
    #   ImageStackLoading supports any type as input, but we want to tell most users
    #   that they are only allowed Path, str, ndarray or a sequence of these.
    #   The first overload will be displaced first by most code editors, this is what
    #   most users will see.
    @overload
    def train(
        self,
        *,
        # BASIC PARAMS
        train_data: InputVar | None = None,
        train_data_target: InputVar | None = None,
        val_data: InputVar | None = None,
        val_data_target: InputVar | None = None,
        # ADVANCED PARAMS
        filtering_mask: InputVar | None = None,
        loading: ReadFuncLoading | None = None,
    ) -> None: ...

    @overload  # any data input is allowed for ImageStackLoading
    def train(
        self,
        *,
        # BASIC PARAMS
        train_data: Any | None = None,
        train_data_target: Any | None = None,
        val_data: Any | None = None,
        val_data_target: Any | None = None,
        # ADVANCED PARAMS
        filtering_mask: Any | None = None,
        loading: ImageStackLoading = ...,
    ) -> None: ...

    def train(
        self,
        *,
        # BASIC PARAMS
        train_data: Any | None = None,
        train_data_target: Any | None = None,
        val_data: Any | None = None,
        val_data_target: Any | None = None,
        # ADVANCED PARAMS
        filtering_mask: Any | None = None,
        loading: Loading = None,
    ) -> None:
        """Train the model on the provided data.

        The training data can be provided as arrays or paths.

        Parameters
        ----------
        train_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Training data, by default None.
        train_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Training target data, by default None.
        val_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Validation data. If not provided, `data_config.n_val_patches` patches will
            selected from the training data for validation.
        val_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Validation target data, by default None.
        filtering_mask : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Filtering mask for coordinate-based patch filtering, by default None.
        read_source_func : ReadFunc, optional
            Function to read the source data.
        read_kwargs : dict of {str: Any}, optional
            Additional keyword arguments to be passed to the read function.
        extension_filter : str, default=""
            Filter for the file extension.

        Raises
        ------
        ValueError
            If train_data is not provided.
        """
        if train_data is None:
            raise ValueError("Training data must be provided. Provide `train_data`.")

        if self.config.is_supervised() and train_data_target is None:
            raise ValueError(
                f"Training target data must be provided for supervised training (got "
                f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
                f"`train_data_target`."
            )

        if self.config.is_supervised() and val_data is not None and val_data_target is None:
            raise ValueError(
                f"Validation target data must be provided for supervised training (got "
                f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
                f"`val_data_target`."
            )

        datamodule = CareamicsDataModule(
            data_config=self.config.data_config,
            train_data=train_data,
            val_data=val_data,
            train_data_target=train_data_target,
            val_data_target=val_data_target,
            train_data_mask=filtering_mask,
            loading=loading,
        )

        self.train_datamodule = datamodule

        # set defaults (in case `stop_training` was called before)
        self.trainer.should_stop = False
        self.trainer.limit_val_batches = 1.0

        self.trainer.fit(
            self.model, datamodule=datamodule, ckpt_path=self.checkpoint_path
        )

    def _build_predict_datamodule(
        self,
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
    ) -> CareamicsDataModule:
        dataloader_params: dict[str, Any] | None = None
        if num_workers is not None:
            dataloader_params = {"num_workers": num_workers}

        pred_data_config = self.config.data_config.convert_mode(
            new_mode="predicting",
            new_patch_size=tile_size,
            overlap_size=tile_overlap,
            new_batch_size=batch_size,
            new_data_type=data_type,
            new_dataloader_params=dataloader_params,
            new_axes=axes,
            new_channels=channels,
            new_in_memory=in_memory,
        )
        return CareamicsDataModule(
            data_config=pred_data_config,
            pred_data=pred_data,
            pred_data_target=pred_data_target,
            loading=loading,
        )

    # see comment on train func for a description of why we have these two overloads
    @overload  # constrained input data type for supported data or ReadFuncLoading
    def predict(
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ReadFuncLoading | None = None,
    ) -> tuple[list[NDArray], list[str]]: ...

    @overload  # any data input is allowed for ImageStackLoading
    def predict(
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ImageStackLoading = ...,
    ) -> tuple[list[NDArray], list[str]]: ...

    def predict(
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
    ) -> tuple[list[NDArray], list[str]]:
        """
        Predict on data and return the predictions.

        Input can be a path to a data file, a list of paths, a numpy array, or a
        list of numpy arrays.

        If `data_type` and `axes` are not provided, the training configuration
        parameters will be used. If `tile_size` is not provided, prediction will
        be performed on the whole image.

        Note that if you are using a UNet model and tiling, the tile size must be
        divisible in every dimension by 2**d, where d is the depth of the model. This
        avoids artefacts arising from the broken shift invariance induced by the
        pooling layers of the UNet. Images smaller than the tile size in any spatial
        dimension will be automatically zero-padded.

        Parameters
        ----------
        pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
            Data to predict on. Can be a single item or a sequence of paths/arrays.
        batch_size : int, optional
            Batch size for prediction. If not provided, uses the training configuration
            batch size.
        tile_size : tuple of int, optional
            Size of the tiles to use for prediction. If not provided, prediction
            will be performed on the whole image.
        tile_overlap : tuple of int, default=(48, 48)
            Overlap between tiles, can be None.
        axes : str, optional
            Axes of the input data, by default None.
        data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
            Type of the input data.
        num_workers : int, optional
            Number of workers for the dataloader, by default None.
        channels : sequence of int or "all", optional
            Channels to use from the data. If None, uses the training configuration
            channels.
        in_memory : bool, optional
            Whether to load all data into memory. If None, uses the training
            configuration setting.
        read_source_func : ReadFunc, optional
            Function to read the source data.
        read_kwargs : dict of {str: Any}, optional
            Additional keyword arguments to be passed to the read function.
        extension_filter : str, default=""
            Filter for the file extension.

        Returns
        -------
        tuple of (list of NDArray, list of str)
            Predictions made by the model and their source identifiers.

        Raises
        ------
        ValueError
            If tile overlap is not specified when tile_size is provided.
        """
        datamodule = self._build_predict_datamodule(
            pred_data,
            batch_size=batch_size,
            tile_size=tile_size,
            tile_overlap=tile_overlap,
            axes=axes,
            data_type=data_type,
            num_workers=num_workers,
            channels=channels,
            in_memory=in_memory,
            loading=loading,
        )

        predictions: list[ImageRegionData] = self.trainer.predict(
            model=self.model, datamodule=datamodule
        )  # type: ignore[assignment]
        tiled = tile_size is not None
        predictions_output, sources = convert_prediction(
            predictions, tiled=tiled, restore_shape=True
        )

        return predictions_output, sources

    # see comment on train func for a description of why we have these two overloads
    @overload  # constrained input data type for supported data or ReadFuncLoading
    def predict_to_disk(
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        pred_data_target: InputVar | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ReadFuncLoading | None = None,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None: ...

    @overload  # any data input is allowed for ImageStackLoading
    def predict_to_disk(
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ImageStackLoading = ...,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None: ...

    def predict_to_disk(
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """
        Make predictions on the provided data and save outputs to files.

        Predictions are saved to `prediction_dir` (absolute paths are used as-is,
        relative paths are relative to `work_dir`). The directory structure matches
        the source directory.

        The file names of the predictions will match those of the source. If there is
        more than one sample within a file, the samples will be stacked along the sample
        dimension in the output file.

        If `data_type` and `axes` are not provided, the training configuration
        parameters will be used. If `tile_size` is not provided, prediction
        will be performed on whole images rather than in a tiled manner.

        Note that if you are using a UNet model and tiling, the tile size must be
        divisible in every dimension by 2**d, where d is the depth of the model. This
        avoids artefacts arising from the broken shift invariance induced by the
        pooling layers of the UNet. Images smaller than the tile size in any spatial
        dimension will be automatically zero-padded.

        Parameters
        ----------
        pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
            Data to predict on. Can be a single item or a sequence of paths/arrays.
        pred_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Prediction data target, by default None.
        prediction_dir : Path | str, default="predictions"
            The path to save the prediction results to. If `prediction_dir` is an
            absolute path, it will be used as-is. If it is a relative path, it will
            be relative to the pre-set `work_dir`. If the directory does not exist it
            will be created.
        batch_size : int, optional
            Batch size for prediction. If not provided, uses the training configuration
            batch size.
        tile_size : tuple of int, optional
            Size of the tiles to use for prediction. If not provided, uses whole image
            strategy.
        tile_overlap : tuple of int, default=(48, 48)
            Overlap between tiles.
        axes : str, optional
            Axes of the input data, by default None.
        data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
            Type of the input data.
        num_workers : int, optional
            Number of workers for the dataloader, by default None.
        channels : sequence of int or "all", optional
            Channels to use from the data. If None, uses the training configuration
            channels.
        in_memory : bool, optional
            Whether to load all data into memory. If None, uses the training
            configuration setting.
        read_source_func : ReadFunc, optional
            Function to read the source data.
        read_kwargs : dict of {str: Any}, optional
            Additional keyword arguments to be passed to the read function.
        extension_filter : str, default=""
            Filter for the file extension.
        write_type : {"tiff", "zarr", "custom"}, default="tiff"
            The data type to save as, includes custom.
        write_extension : str, optional
            If a known `write_type` is selected this argument is ignored. For a custom
            `write_type` an extension to save the data with must be passed.
        write_func : WriteFunc, optional
            If a known `write_type` is selected this argument is ignored. For a custom
            `write_type` a function to save the data must be passed. See notes below.
        write_func_kwargs : dict of {str: any}, optional
            Additional keyword arguments to be passed to the save function.

        Raises
        ------
        ValueError
            If `write_type` is custom and `write_extension` is None.
        ValueError
            If `write_type` is custom and `write_func` is None.
        """
        if write_func_kwargs is None:
            write_func_kwargs = {}

        if Path(prediction_dir).is_absolute():
            write_dir = Path(prediction_dir)
        else:
            write_dir = self.work_dir / prediction_dir
        self.prediction_writer.dirpath = write_dir

        if write_type == "custom":
            if write_extension is None:
                raise ValueError(
                    "A `write_extension` must be provided for custom write types."
                )
            if write_func is None:
                raise ValueError(
                    "A `write_func` must be provided for custom write types."
                )
        elif write_type == "zarr" and tile_size is None:
            raise ValueError(
                "Writing prediction to Zarr is only supported with tiling. Please "
                "provide a value for `tile_size`, and optionally `tile_overlap`."
            )
        else:
            write_func = get_write_func(write_type)
            write_extension = SupportedData.get_extension(write_type)

        tiled = tile_size is not None
        self.prediction_writer.set_writing_strategy(
            write_type=write_type,
            tiled=tiled,
            write_func=write_func,
            write_extension=write_extension,
            write_func_kwargs=write_func_kwargs,
        )

        self.prediction_writer.enable_writing(True)

        try:
            datamodule = self._build_predict_datamodule(
                pred_data,
                pred_data_target=pred_data_target,
                batch_size=batch_size,
                tile_size=tile_size,
                tile_overlap=tile_overlap,
                axes=axes,
                data_type=data_type,
                num_workers=num_workers,
                channels=channels,
                in_memory=in_memory,
                loading=loading,
            )

            self.trainer.predict(
                model=self.model, datamodule=datamodule, return_predictions=False
            )

        finally:
            self.prediction_writer.enable_writing(False)

    def export_to_bmz(
        self,
        path_to_archive: Path | str,
        friendly_model_name: str,
        input_array: NDArray,
        authors: list[dict],
        general_description: str,
        data_description: str,
        covers: list[Path | str] | None = None,
        channel_names: list[str] | None = None,
        model_version: str = "0.2.0",
    ) -> None:
        """Export the model to the BioImage Model Zoo format.

        This method packages the current weights into a zip file that can be uploaded
        to the BioImage Model Zoo. The archive consists of the model weights, the model
        specifications and various files (inputs, outputs, README, env.yaml etc.).

        `path_to_archive` should point to a file with a ".zip" extension.

        `friendly_model_name` is the name used for the model in the BMZ specs
        and website, it should consist of letters, numbers, dashes, underscores and
        parentheses only.

        Input array must be of the same dimensions as the axes recorded in the
        configuration of the `CAREamist`.

        Parameters
        ----------
        path_to_archive : pathlib.Path or str
            Path in which to save the model, including file name, which should end with
            ".zip".
        friendly_model_name : str
            Name of the model as used in the BMZ specs, it should consist of letters,
            numbers, dashes, underscores and parentheses only.
        input_array : NDArray
            Input array used to validate the model and as example.
        authors : list of dict
            List of authors of the model.
        general_description : str
            General description of the model used in the BMZ metadata.
        data_description : str
            Description of the data the model was trained on.
        covers : list of pathlib.Path or str, default=None
            Paths to the cover images.
        channel_names : list of str, default=None
            Channel names.
        model_version : str, default="0.1.0"
            Version of the model.
        """
        output_patch = self.predict(
            pred_data=input_array,
            data_type=SupportedData.ARRAY.value,
        )
        output = np.concatenate(output_patch, axis=0)
        input_array = reshape_array(input_array, self.config.data_config.axes)

        export_to_bmz(
            model=self.model,
            config=self.config,
            path_to_archive=path_to_archive,
            model_name=friendly_model_name,
            general_description=general_description,
            data_description=data_description,
            authors=authors,
            input_array=input_array,
            output_array=output,
            covers=covers,
            channel_names=channel_names,
            model_version=model_version,
        )

    def get_losses(self) -> dict[str, list]:
        """Return data that can be used to plot train and validation loss curves.

        Returns
        -------
        dict of str: list
            Dictionary containing losses for each epoch.
        """
        return read_csv_logger(self.config.get_safe_experiment_name(), self.work_dir / "csv_logs")

    def stop_training(self) -> None:
        """Stop the training loop."""
        self.trainer.should_stop = True
        self.trainer.limit_val_batches = 0  # skip validation

__init__(config=None, *, checkpoint_path=None, bmz_path=None, work_dir=None, callbacks=None, enable_progress_bar=True) #

Constructor for CAREamistV2.

Exactly one of config, checkpoint_path, or bmz_path must be provided.

Parameters:

Name Type Description Default
config NGConfiguration[AlgorithmConfig] | Path

CAREamics configuration, or a path to a configuration file. See careamics.config.ng_factories for method to build configurations. config is mutually exclusive with checkpoint_path and bmz_path.

None
checkpoint_path Path

Path to a checkpoint file from which to load the model and configuration. checkpoint_path is mutually exclusive with config and bmz_path.

None
bmz_path Path

Path to a BioImage Model Zoo archive from which to load the model and configuration. bmz_path is mutually exclusive with config and checkpoint_path.

None
work_dir Path | str

Working directory in which to save training outputs. If None, the current working directory will be used.

None
callbacks list of PyTorch Lightning Callbacks

List of callbacks to use during training. If None, no additional callbacks will be used. Note that ModelCheckpoint and EarlyStopping callbacks are already defined in CAREamics and should only be modified through the training configuration (see NGConfiguration and TrainingConfig).

None
enable_progress_bar bool

Whether to show the progress bar during training.

True
Source code in src/careamics/careamist_v2.py
def __init__(
    self,
    config: NGConfiguration[AlgorithmConfig] | Path | None = None,
    *,
    checkpoint_path: Path | None = None,
    bmz_path: Path | None = None,
    work_dir: Path | str | None = None,
    callbacks: list[Callback] | None = None,
    enable_progress_bar: bool = True,
) -> None:
    """Constructor for CAREamistV2.

    Exactly one of `config`, `checkpoint_path`, or `bmz_path` must be provided.

    Parameters
    ----------
    config : NGConfiguration[AlgorithmConfig] | Path, default=None
        CAREamics configuration, or a path to a configuration file. See 
        `careamics.config.ng_factories` for method to build configurations. `config`
        is mutually exclusive with `checkpoint_path` and `bmz_path`.
    checkpoint_path : Path, default=None
        Path to a checkpoint file from which to load the model and configuration.
        `checkpoint_path` is mutually exclusive with `config` and `bmz_path`.
    bmz_path : Path, default=None
        Path to a BioImage Model Zoo archive from which to load the model and
        configuration. `bmz_path` is mutually exclusive with `config` and
        `checkpoint_path`.
    work_dir : Path | str, default=None
        Working directory in which to save training outputs. If None, the current
        working directory will be used.
    callbacks : list of PyTorch Lightning Callbacks, default=None
        List of callbacks to use during training. If None, no additional callbacks
        will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
        already defined in CAREamics and should only be modified through the
        training configuration (see NGConfiguration and TrainingConfig).
    enable_progress_bar : bool, default=True
        Whether to show the progress bar during training.
    """
    self.checkpoint_path = checkpoint_path
    self.work_dir = self._resolve_work_dir(work_dir)
    self.config, self.model = self._load_model(config, checkpoint_path, bmz_path)

    self.config.training_config.lightning_trainer_config["enable_progress_bar"] = (
        enable_progress_bar
    )
    self.callbacks = self._define_callbacks(callbacks, self.config, self.work_dir)

    self.prediction_writer = PredictionWriterCallback(
        self.work_dir, enable_writing=False
    )

    experiment_loggers = self._create_loggers(
        self.config.training_config.logger,
        self.config.get_safe_experiment_name(),
        self.work_dir,
    )

    self.trainer = Trainer(
        callbacks=[self.prediction_writer, *self.callbacks],
        default_root_dir=self.work_dir,
        logger=experiment_loggers,
        **self.config.training_config.lightning_trainer_config or {},
    )

    self.train_datamodule: CareamicsDataModule | None = None

export_to_bmz(path_to_archive, friendly_model_name, input_array, authors, general_description, data_description, covers=None, channel_names=None, model_version='0.2.0') #

Export the model to the BioImage Model Zoo format.

This method packages the current weights into a zip file that can be uploaded to the BioImage Model Zoo. The archive consists of the model weights, the model specifications and various files (inputs, outputs, README, env.yaml etc.).

path_to_archive should point to a file with a ".zip" extension.

friendly_model_name is the name used for the model in the BMZ specs and website, it should consist of letters, numbers, dashes, underscores and parentheses only.

Input array must be of the same dimensions as the axes recorded in the configuration of the CAREamist.

Parameters:

Name Type Description Default
path_to_archive Path or str

Path in which to save the model, including file name, which should end with ".zip".

required
friendly_model_name str

Name of the model as used in the BMZ specs, it should consist of letters, numbers, dashes, underscores and parentheses only.

required
input_array NDArray

Input array used to validate the model and as example.

required
authors list of dict

List of authors of the model.

required
general_description str

General description of the model used in the BMZ metadata.

required
data_description str

Description of the data the model was trained on.

required
covers list of pathlib.Path or str

Paths to the cover images.

None
channel_names list of str

Channel names.

None
model_version str

Version of the model.

"0.1.0"
Source code in src/careamics/careamist_v2.py
def export_to_bmz(
    self,
    path_to_archive: Path | str,
    friendly_model_name: str,
    input_array: NDArray,
    authors: list[dict],
    general_description: str,
    data_description: str,
    covers: list[Path | str] | None = None,
    channel_names: list[str] | None = None,
    model_version: str = "0.2.0",
) -> None:
    """Export the model to the BioImage Model Zoo format.

    This method packages the current weights into a zip file that can be uploaded
    to the BioImage Model Zoo. The archive consists of the model weights, the model
    specifications and various files (inputs, outputs, README, env.yaml etc.).

    `path_to_archive` should point to a file with a ".zip" extension.

    `friendly_model_name` is the name used for the model in the BMZ specs
    and website, it should consist of letters, numbers, dashes, underscores and
    parentheses only.

    Input array must be of the same dimensions as the axes recorded in the
    configuration of the `CAREamist`.

    Parameters
    ----------
    path_to_archive : pathlib.Path or str
        Path in which to save the model, including file name, which should end with
        ".zip".
    friendly_model_name : str
        Name of the model as used in the BMZ specs, it should consist of letters,
        numbers, dashes, underscores and parentheses only.
    input_array : NDArray
        Input array used to validate the model and as example.
    authors : list of dict
        List of authors of the model.
    general_description : str
        General description of the model used in the BMZ metadata.
    data_description : str
        Description of the data the model was trained on.
    covers : list of pathlib.Path or str, default=None
        Paths to the cover images.
    channel_names : list of str, default=None
        Channel names.
    model_version : str, default="0.1.0"
        Version of the model.
    """
    output_patch = self.predict(
        pred_data=input_array,
        data_type=SupportedData.ARRAY.value,
    )
    output = np.concatenate(output_patch, axis=0)
    input_array = reshape_array(input_array, self.config.data_config.axes)

    export_to_bmz(
        model=self.model,
        config=self.config,
        path_to_archive=path_to_archive,
        model_name=friendly_model_name,
        general_description=general_description,
        data_description=data_description,
        authors=authors,
        input_array=input_array,
        output_array=output,
        covers=covers,
        channel_names=channel_names,
        model_version=model_version,
    )

get_losses() #

Return data that can be used to plot train and validation loss curves.

Returns:

Type Description
dict of str: list

Dictionary containing losses for each epoch.

Source code in src/careamics/careamist_v2.py
def get_losses(self) -> dict[str, list]:
    """Return data that can be used to plot train and validation loss curves.

    Returns
    -------
    dict of str: list
        Dictionary containing losses for each epoch.
    """
    return read_csv_logger(self.config.get_safe_experiment_name(), self.work_dir / "csv_logs")

predict(pred_data, *, batch_size=None, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, num_workers=None, channels=None, in_memory=None, loading=None) #

predict(pred_data: InputVar, *, batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ReadFuncLoading | None = None) -> tuple[list[NDArray], list[str]]
predict(pred_data: Any, *, batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ImageStackLoading = ...) -> tuple[list[NDArray], list[str]]

Predict on data and return the predictions.

Input can be a path to a data file, a list of paths, a numpy array, or a list of numpy arrays.

If data_type and axes are not provided, the training configuration parameters will be used. If tile_size is not provided, prediction will be performed on the whole image.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. Images smaller than the tile size in any spatial dimension will be automatically zero-padded.

Parameters:

Name Type Description Default
pred_data pathlib.Path, str, numpy.ndarray, or sequence of these

Data to predict on. Can be a single item or a sequence of paths/arrays.

required
batch_size int

Batch size for prediction. If not provided, uses the training configuration batch size.

None
tile_size tuple of int

Size of the tiles to use for prediction. If not provided, prediction will be performed on the whole image.

None
tile_overlap tuple of int

Overlap between tiles, can be None.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, czi, zarr, custom)

Type of the input data.

"array"
num_workers int

Number of workers for the dataloader, by default None.

None
channels sequence of int or "all"

Channels to use from the data. If None, uses the training configuration channels.

None
in_memory bool

Whether to load all data into memory. If None, uses the training configuration setting.

None
read_source_func ReadFunc

Function to read the source data.

required
read_kwargs dict of {str: Any}

Additional keyword arguments to be passed to the read function.

required
extension_filter str

Filter for the file extension.

""

Returns:

Type Description
tuple of (list of NDArray, list of str)

Predictions made by the model and their source identifiers.

Raises:

Type Description
ValueError

If tile overlap is not specified when tile_size is provided.

Source code in src/careamics/careamist_v2.py
def predict(
    self,
    # BASIC PARAMS
    pred_data: InputVar,
    *,
    batch_size: int | None = None,
    tile_size: tuple[int, ...] | None = None,
    tile_overlap: tuple[int, ...] | None = (48, 48),
    axes: str | None = None,
    data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
    # ADVANCED PARAMS
    num_workers: int | None = None,
    channels: Sequence[int] | Literal["all"] | None = None,
    in_memory: bool | None = None,
    loading: Loading = None,
) -> tuple[list[NDArray], list[str]]:
    """
    Predict on data and return the predictions.

    Input can be a path to a data file, a list of paths, a numpy array, or a
    list of numpy arrays.

    If `data_type` and `axes` are not provided, the training configuration
    parameters will be used. If `tile_size` is not provided, prediction will
    be performed on the whole image.

    Note that if you are using a UNet model and tiling, the tile size must be
    divisible in every dimension by 2**d, where d is the depth of the model. This
    avoids artefacts arising from the broken shift invariance induced by the
    pooling layers of the UNet. Images smaller than the tile size in any spatial
    dimension will be automatically zero-padded.

    Parameters
    ----------
    pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
        Data to predict on. Can be a single item or a sequence of paths/arrays.
    batch_size : int, optional
        Batch size for prediction. If not provided, uses the training configuration
        batch size.
    tile_size : tuple of int, optional
        Size of the tiles to use for prediction. If not provided, prediction
        will be performed on the whole image.
    tile_overlap : tuple of int, default=(48, 48)
        Overlap between tiles, can be None.
    axes : str, optional
        Axes of the input data, by default None.
    data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
        Type of the input data.
    num_workers : int, optional
        Number of workers for the dataloader, by default None.
    channels : sequence of int or "all", optional
        Channels to use from the data. If None, uses the training configuration
        channels.
    in_memory : bool, optional
        Whether to load all data into memory. If None, uses the training
        configuration setting.
    read_source_func : ReadFunc, optional
        Function to read the source data.
    read_kwargs : dict of {str: Any}, optional
        Additional keyword arguments to be passed to the read function.
    extension_filter : str, default=""
        Filter for the file extension.

    Returns
    -------
    tuple of (list of NDArray, list of str)
        Predictions made by the model and their source identifiers.

    Raises
    ------
    ValueError
        If tile overlap is not specified when tile_size is provided.
    """
    datamodule = self._build_predict_datamodule(
        pred_data,
        batch_size=batch_size,
        tile_size=tile_size,
        tile_overlap=tile_overlap,
        axes=axes,
        data_type=data_type,
        num_workers=num_workers,
        channels=channels,
        in_memory=in_memory,
        loading=loading,
    )

    predictions: list[ImageRegionData] = self.trainer.predict(
        model=self.model, datamodule=datamodule
    )  # type: ignore[assignment]
    tiled = tile_size is not None
    predictions_output, sources = convert_prediction(
        predictions, tiled=tiled, restore_shape=True
    )

    return predictions_output, sources

predict_to_disk(pred_data, *, pred_data_target=None, prediction_dir='predictions', batch_size=None, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, num_workers=None, channels=None, in_memory=None, loading=None, write_type='tiff', write_extension=None, write_func=None, write_func_kwargs=None) #

predict_to_disk(pred_data: InputVar, *, pred_data_target: InputVar | None = None, prediction_dir: Path | str = 'predictions', batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ReadFuncLoading | None = None, write_type: Literal['tiff', 'zarr', 'custom'] = 'tiff', write_extension: str | None = None, write_func: WriteFunc | None = None, write_func_kwargs: dict[str, Any] | None = None) -> None
predict_to_disk(pred_data: Any, *, pred_data_target: Any | None = None, prediction_dir: Path | str = 'predictions', batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ImageStackLoading = ..., write_type: Literal['tiff', 'zarr', 'custom'] = 'tiff', write_extension: str | None = None, write_func: WriteFunc | None = None, write_func_kwargs: dict[str, Any] | None = None) -> None

Make predictions on the provided data and save outputs to files.

Predictions are saved to prediction_dir (absolute paths are used as-is, relative paths are relative to work_dir). The directory structure matches the source directory.

The file names of the predictions will match those of the source. If there is more than one sample within a file, the samples will be stacked along the sample dimension in the output file.

If data_type and axes are not provided, the training configuration parameters will be used. If tile_size is not provided, prediction will be performed on whole images rather than in a tiled manner.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. Images smaller than the tile size in any spatial dimension will be automatically zero-padded.

Parameters:

Name Type Description Default
pred_data pathlib.Path, str, numpy.ndarray, or sequence of these

Data to predict on. Can be a single item or a sequence of paths/arrays.

required
pred_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Prediction data target, by default None.

None
prediction_dir Path | str

The path to save the prediction results to. If prediction_dir is an absolute path, it will be used as-is. If it is a relative path, it will be relative to the pre-set work_dir. If the directory does not exist it will be created.

"predictions"
batch_size int

Batch size for prediction. If not provided, uses the training configuration batch size.

None
tile_size tuple of int

Size of the tiles to use for prediction. If not provided, uses whole image strategy.

None
tile_overlap tuple of int

Overlap between tiles.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, czi, zarr, custom)

Type of the input data.

"array"
num_workers int

Number of workers for the dataloader, by default None.

None
channels sequence of int or "all"

Channels to use from the data. If None, uses the training configuration channels.

None
in_memory bool

Whether to load all data into memory. If None, uses the training configuration setting.

None
read_source_func ReadFunc

Function to read the source data.

required
read_kwargs dict of {str: Any}

Additional keyword arguments to be passed to the read function.

required
extension_filter str

Filter for the file extension.

""
write_type (tiff, zarr, custom)

The data type to save as, includes custom.

"tiff"
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_func_kwargs dict of {str: any}

Additional keyword arguments to be passed to the save function.

None

Raises:

Type Description
ValueError

If write_type is custom and write_extension is None.

ValueError

If write_type is custom and write_func is None.

Source code in src/careamics/careamist_v2.py
def predict_to_disk(
    self,
    # BASIC PARAMS
    pred_data: Any,
    *,
    pred_data_target: Any | None = None,
    prediction_dir: Path | str = "predictions",
    batch_size: int | None = None,
    tile_size: tuple[int, ...] | None = None,
    tile_overlap: tuple[int, ...] | None = (48, 48),
    axes: str | None = None,
    data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
    # ADVANCED PARAMS
    num_workers: int | None = None,
    channels: Sequence[int] | Literal["all"] | None = None,
    in_memory: bool | None = None,
    loading: Loading = None,
    # WRITE OPTIONS
    write_type: Literal["tiff", "zarr", "custom"] = "tiff",
    write_extension: str | None = None,
    write_func: WriteFunc | None = None,
    write_func_kwargs: dict[str, Any] | None = None,
) -> None:
    """
    Make predictions on the provided data and save outputs to files.

    Predictions are saved to `prediction_dir` (absolute paths are used as-is,
    relative paths are relative to `work_dir`). The directory structure matches
    the source directory.

    The file names of the predictions will match those of the source. If there is
    more than one sample within a file, the samples will be stacked along the sample
    dimension in the output file.

    If `data_type` and `axes` are not provided, the training configuration
    parameters will be used. If `tile_size` is not provided, prediction
    will be performed on whole images rather than in a tiled manner.

    Note that if you are using a UNet model and tiling, the tile size must be
    divisible in every dimension by 2**d, where d is the depth of the model. This
    avoids artefacts arising from the broken shift invariance induced by the
    pooling layers of the UNet. Images smaller than the tile size in any spatial
    dimension will be automatically zero-padded.

    Parameters
    ----------
    pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
        Data to predict on. Can be a single item or a sequence of paths/arrays.
    pred_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Prediction data target, by default None.
    prediction_dir : Path | str, default="predictions"
        The path to save the prediction results to. If `prediction_dir` is an
        absolute path, it will be used as-is. If it is a relative path, it will
        be relative to the pre-set `work_dir`. If the directory does not exist it
        will be created.
    batch_size : int, optional
        Batch size for prediction. If not provided, uses the training configuration
        batch size.
    tile_size : tuple of int, optional
        Size of the tiles to use for prediction. If not provided, uses whole image
        strategy.
    tile_overlap : tuple of int, default=(48, 48)
        Overlap between tiles.
    axes : str, optional
        Axes of the input data, by default None.
    data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
        Type of the input data.
    num_workers : int, optional
        Number of workers for the dataloader, by default None.
    channels : sequence of int or "all", optional
        Channels to use from the data. If None, uses the training configuration
        channels.
    in_memory : bool, optional
        Whether to load all data into memory. If None, uses the training
        configuration setting.
    read_source_func : ReadFunc, optional
        Function to read the source data.
    read_kwargs : dict of {str: Any}, optional
        Additional keyword arguments to be passed to the read function.
    extension_filter : str, default=""
        Filter for the file extension.
    write_type : {"tiff", "zarr", "custom"}, default="tiff"
        The data type to save as, includes custom.
    write_extension : str, optional
        If a known `write_type` is selected this argument is ignored. For a custom
        `write_type` an extension to save the data with must be passed.
    write_func : WriteFunc, optional
        If a known `write_type` is selected this argument is ignored. For a custom
        `write_type` a function to save the data must be passed. See notes below.
    write_func_kwargs : dict of {str: any}, optional
        Additional keyword arguments to be passed to the save function.

    Raises
    ------
    ValueError
        If `write_type` is custom and `write_extension` is None.
    ValueError
        If `write_type` is custom and `write_func` is None.
    """
    if write_func_kwargs is None:
        write_func_kwargs = {}

    if Path(prediction_dir).is_absolute():
        write_dir = Path(prediction_dir)
    else:
        write_dir = self.work_dir / prediction_dir
    self.prediction_writer.dirpath = write_dir

    if write_type == "custom":
        if write_extension is None:
            raise ValueError(
                "A `write_extension` must be provided for custom write types."
            )
        if write_func is None:
            raise ValueError(
                "A `write_func` must be provided for custom write types."
            )
    elif write_type == "zarr" and tile_size is None:
        raise ValueError(
            "Writing prediction to Zarr is only supported with tiling. Please "
            "provide a value for `tile_size`, and optionally `tile_overlap`."
        )
    else:
        write_func = get_write_func(write_type)
        write_extension = SupportedData.get_extension(write_type)

    tiled = tile_size is not None
    self.prediction_writer.set_writing_strategy(
        write_type=write_type,
        tiled=tiled,
        write_func=write_func,
        write_extension=write_extension,
        write_func_kwargs=write_func_kwargs,
    )

    self.prediction_writer.enable_writing(True)

    try:
        datamodule = self._build_predict_datamodule(
            pred_data,
            pred_data_target=pred_data_target,
            batch_size=batch_size,
            tile_size=tile_size,
            tile_overlap=tile_overlap,
            axes=axes,
            data_type=data_type,
            num_workers=num_workers,
            channels=channels,
            in_memory=in_memory,
            loading=loading,
        )

        self.trainer.predict(
            model=self.model, datamodule=datamodule, return_predictions=False
        )

    finally:
        self.prediction_writer.enable_writing(False)

stop_training() #

Stop the training loop.

Source code in src/careamics/careamist_v2.py
def stop_training(self) -> None:
    """Stop the training loop."""
    self.trainer.should_stop = True
    self.trainer.limit_val_batches = 0  # skip validation

train(*, train_data=None, train_data_target=None, val_data=None, val_data_target=None, filtering_mask=None, loading=None) #

train(*, train_data: InputVar | None = None, train_data_target: InputVar | None = None, val_data: InputVar | None = None, val_data_target: InputVar | None = None, filtering_mask: InputVar | None = None, loading: ReadFuncLoading | None = None) -> None
train(*, train_data: Any | None = None, train_data_target: Any | None = None, val_data: Any | None = None, val_data_target: Any | None = None, filtering_mask: Any | None = None, loading: ImageStackLoading = ...) -> None

Train the model on the provided data.

The training data can be provided as arrays or paths.

Parameters:

Name Type Description Default
train_data pathlib.Path, str, numpy.ndarray, or sequence of these

Training data, by default None.

None
train_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Training target data, by default None.

None
val_data pathlib.Path, str, numpy.ndarray, or sequence of these

Validation data. If not provided, data_config.n_val_patches patches will selected from the training data for validation.

None
val_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Validation target data, by default None.

None
filtering_mask pathlib.Path, str, numpy.ndarray, or sequence of these

Filtering mask for coordinate-based patch filtering, by default None.

None
read_source_func ReadFunc

Function to read the source data.

required
read_kwargs dict of {str: Any}

Additional keyword arguments to be passed to the read function.

required
extension_filter str

Filter for the file extension.

""

Raises:

Type Description
ValueError

If train_data is not provided.

Source code in src/careamics/careamist_v2.py
def train(
    self,
    *,
    # BASIC PARAMS
    train_data: Any | None = None,
    train_data_target: Any | None = None,
    val_data: Any | None = None,
    val_data_target: Any | None = None,
    # ADVANCED PARAMS
    filtering_mask: Any | None = None,
    loading: Loading = None,
) -> None:
    """Train the model on the provided data.

    The training data can be provided as arrays or paths.

    Parameters
    ----------
    train_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Training data, by default None.
    train_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Training target data, by default None.
    val_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Validation data. If not provided, `data_config.n_val_patches` patches will
        selected from the training data for validation.
    val_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Validation target data, by default None.
    filtering_mask : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Filtering mask for coordinate-based patch filtering, by default None.
    read_source_func : ReadFunc, optional
        Function to read the source data.
    read_kwargs : dict of {str: Any}, optional
        Additional keyword arguments to be passed to the read function.
    extension_filter : str, default=""
        Filter for the file extension.

    Raises
    ------
    ValueError
        If train_data is not provided.
    """
    if train_data is None:
        raise ValueError("Training data must be provided. Provide `train_data`.")

    if self.config.is_supervised() and train_data_target is None:
        raise ValueError(
            f"Training target data must be provided for supervised training (got "
            f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
            f"`train_data_target`."
        )

    if self.config.is_supervised() and val_data is not None and val_data_target is None:
        raise ValueError(
            f"Validation target data must be provided for supervised training (got "
            f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
            f"`val_data_target`."
        )

    datamodule = CareamicsDataModule(
        data_config=self.config.data_config,
        train_data=train_data,
        val_data=val_data,
        train_data_target=train_data_target,
        val_data_target=val_data_target,
        train_data_mask=filtering_mask,
        loading=loading,
    )

    self.train_datamodule = datamodule

    # set defaults (in case `stop_training` was called before)
    self.trainer.should_stop = False
    self.trainer.limit_val_batches = 1.0

    self.trainer.fit(
        self.model, datamodule=datamodule, ckpt_path=self.checkpoint_path
    )