From 8c384c68fe0ea7d73655a6e58bdec46a78f2d868 Mon Sep 17 00:00:00 2001
From: Laurent Guerard <laurent.guerard@unibas.ch>
Date: Wed, 11 Dec 2024 16:51:57 +0100
Subject: [PATCH] WIP

Script is working but segmentation results aren't perfect, maybe keeping
the pre processing step is needed.
To continue
---
 1_identify_fibers.py | 261 ++++++++++++++++++++++++++++++++++---------
 1 file changed, 207 insertions(+), 54 deletions(-)

diff --git a/1_identify_fibers.py b/1_identify_fibers.py
index f33076f..ee5bc97 100755
--- a/1_identify_fibers.py
+++ b/1_identify_fibers.py
@@ -5,8 +5,7 @@
 # TODO: are the imports RoiManager and ResultsTable needed when using the services?
 from ij import IJ, WindowManager as wm
 from ij.plugin import Duplicator, RoiEnlarger, RoiScaler
-from trainableSegmentation import WekaSegmentation
-from de.biovoxxel.toolbox import Extended_Particle_Analyzer
+
 from ij.measure import ResultsTable
 
 # Bio-formats imports
@@ -16,29 +15,47 @@ from loci.plugins.in import ImporterOptions
 # python imports
 import time
 import os
+import sys
+
+# TrackMate imports
+from fiji.plugin.trackmate import Logger, Model, SelectionModel, Settings, TrackMate
+from fiji.plugin.trackmate.action import LabelImgExporter
+from fiji.plugin.trackmate.cellpose import CellposeDetectorFactory
+from fiji.plugin.trackmate.cellpose.CellposeSettings import PretrainedModel
+from fiji.plugin.trackmate.features import FeatureFilter
+from fiji.plugin.trackmate.providers import (
+    SpotAnalyzerProvider,
+    SpotMorphologyAnalyzerProvider,
+)
+from fiji.plugin.trackmate.tracking.jaqaman import SparseLAPTrackerFactory
+
+from java.lang import Double
+
+from ch.epfl.biop.ij2command import Labels2Rois
+
 
 #@ String (visibility=MESSAGE, value="<html><b> Welcome to Myosoft - identify fibers! </b></html>") msg1
-#@ File (label="Select directory with classifiers", style="directory") classifiers_dir
 #@ File (label="Select directory for output", style="directory") output_dir
-#@ File (label="Select image file", description="select your image")  path_to_image
+#@ File (label="Select image file", description="select your image") path_to_image
+#@ File(label="Cellpose environment folder", style="directory", description="Folder with the cellpose env") cellpose_dir
 #@ Boolean (label="close image after processing", description="tick this box when using batch mode", value=False) close_raw
 #@ String (visibility=MESSAGE, value="<html><b> Morphometric Gates </b></html>") msg2
 #@ Integer (label="Min Area [um²]", value=10) minAr
 #@ Integer (label="Max Area [um²]", value=6000) maxAr
-#@ Float (label="Min Circularity", value=0.5) minCir
-#@ Float (label="Max Circularity", value=1) maxCir
-#@ Float (label="Min solidity", value=0.0) minSol
-#@ Float (label="Max solidity", value=1) maxSol
+#@ Double (label="Min Circularity", value=0.5) minCir
+#@ Double (label="Max Circularity", value=1) maxCir
+#@ Double (label="Min solidity", value=0.0) minSol
+#@ Double (label="Max solidity", value=1) maxSol
 #@ Integer (label="Min perimeter [um]", value=5) minPer
 #@ Integer (label="Max perimeter [um]", value=300) maxPer
 #@ Integer (label="Min min ferret [um]", value=0.1) minMinFer
 #@ Integer (label="Max min ferret [um]", value=100) maxMinFer
 #@ Integer (label="Min ferret AR", value=0) minFAR
 #@ Integer (label="Max ferret AR", value=8) maxFAR
-#@ Float (label="Min roundess", value=0.2) minRnd
-#@ Float (label="Max roundess", value=1) maxRnd
+#@ Double (label="Min roundess", value=0.2) minRnd
+#@ Double (label="Max roundess", value=1) maxRnd
 #@ String (visibility=MESSAGE, value="<html><b> Expand ROIS to match fibers </b></html>") msg3
-#@ Float (label="ROI expansion [microns]", value=1) enlarge
+#@ Double (label="ROI expansion [microns]", value=1) enlarge
 #@ String (visibility=MESSAGE, value="<html><b> channel positions in the hyperstack </b></html>") msg5
 #@ Integer (label="Membrane staining channel number", style="slider", min=1, max=5, value=1) membrane_channel
 #@ Integer (label="Fiber staining (MHC) channel number (0=skip)", style="slider", min=0, max=5, value=3) fiber_channel
@@ -163,45 +180,190 @@ def get_threshold_from_method(imp, channel, method):
     return lower_thr, upper_thr
 
 
-def apply_weka_model(model_path, imp, tiles_per_dim):
-    """apply a pretrained WEKA model to an ImagePlus
+def run_tm(
+    implus,
+    channel_seg,
+    cellpose_env,
+    seg_model,
+    diam_seg,
+    channel_sec=0,
+    quality_thresh=[0,0],
+    intensity_thresh=[0,0],
+    circularity_thresh=[0,0],
+    perimeter_thresh=[0,0],
+    feret_thresh=[0,0],
+    area_thresh=[0,0],
+    crop_roi=None,
+    use_gpu=True,
+):
+    """
+    Function to run TrackMate on open data
 
     Parameters
     ----------
-    model_path : string
-        path to the model file
-    imp : ImagePlus
-        ImagePlus to apply the model to
-    tiles_per_dim : integer
-        tiles the imp to save RAM
+    implus : ImagePlus
+        ImagePlus on which to run the function
+    channel_seg : int
+        Channel of interest
+    cellpose_env : str
+        Path to the cellpose environment
+    seg_model : PretrainedModel
+        Model to use for the segmentation
+    diam_seg : float
+        Diameter to use for segmentation
+    channel_sec : int, optional
+        Secondary channel to use for segmentation, by default 0
+    quality_thresh : float, optional
+        Threshold for quality filtering, by default 0
+    intensity_thresh : float, optional
+        Threshold for intensity filtering, by default 0
+    circularity_thresh : float, optional
+        Threshold for circularity filtering, by default 0
+    perimeter_thresh : float, optional
+        Threshold for perimeter filtering, by default 0
+    feret_thresh : float, optional
+        Threshold for Feret filtering, by default 0
+    area_thresh : float, optional
+        Threshold for area filtering, by default 0
+    crop_roi : ROI, optional
+        ROI to crop on the image, by default None
+    use_gpu : bool, optional
+        Boolean to use GPU or not, by default True
 
     Returns
     -------
     ImagePlus
-        the result of the WEKA segmentation. One channel per class.
+        Label image with the segmented objects
     """
-    segmentator = WekaSegmentation()
-    segmentator.loadClassifier( model_path )
-    result = segmentator.applyClassifier( imp, [tiles_per_dim, tiles_per_dim], 0, True ) #ImagePlus imp, int[x,y,z] tilesPerDim, int numThreads (0=all), boolean probabilityMaps
-
-    return result
 
-
-def process_weka_result(imp):
-    """apply myosoft pre-processing steps for the imp after WEKA classification to prepare it
-    for ROI detection with the extended particle analyzer
+    # Get image dimensions and calibration
+    dims = implus.getDimensions()
+    cal = implus.getCalibration()
+
+    # If the image has more than one slice, adjust the dimensions
+    if implus.getNSlices() > 1:
+        implus.setDimensions(dims[2], dims[4], dims[3])
+
+    # Set ROI if provided
+    if crop_roi is not None:
+        implus.setRoi(crop_roi)
+
+    # Initialize TrackMate model
+    model = Model()
+    model.setLogger(Logger.IJTOOLBAR_LOGGER)
+
+    # Prepare settings for TrackMate
+    settings = Settings(implus)
+    settings.detectorFactory = CellposeDetectorFactory()
+
+    # Configure detector settings
+    settings.detectorSettings["TARGET_CHANNEL"] = channel_seg
+    settings.detectorSettings["OPTIONAL_CHANNEL_2"] = channel_sec
+    settings.detectorSettings["CELLPOSE_PYTHON_FILEPATH"] = os.path.join(
+        cellpose_env, "python.exe"
+    )
+    settings.detectorSettings["CELLPOSE_MODEL_FILEPATH"] = os.path.join(
+        os.environ["USERPROFILE"], ".cellpose", "models"
+    )
+    settings.detectorSettings["CELLPOSE_MODEL"] = seg_model
+    settings.detectorSettings["CELL_DIAMETER"] = diam_seg
+    settings.detectorSettings["USE_GPU"] = use_gpu
+    settings.detectorSettings["SIMPLIFY_CONTOURS"] = True
+
+    settings.initialSpotFilterValue = -1.0
+
+    # Add spot analyzers
+    spotAnalyzerProvider = SpotAnalyzerProvider(1)
+    spotMorphologyProvider = SpotMorphologyAnalyzerProvider(1)
+
+    for key in spotAnalyzerProvider.getKeys():
+        settings.addSpotAnalyzerFactory(spotAnalyzerProvider.getFactory(key))
+
+    for key in spotMorphologyProvider.getKeys():
+        settings.addSpotAnalyzerFactory(spotMorphologyProvider.getFactory(key))
+
+    # Apply spot filters based on thresholds
+    if any(quality_thresh):
+        settings = set_trackmate_filter(settings, "QUALITY", quality_thresh)
+    if any(intensity_thresh):
+        settings = set_trackmate_filter(settings, "MEAN_INTENSITY_CH" + str(channel_seg), intensity_thresh)
+    if any(circularity_thresh):
+        settings = set_trackmate_filter(settings, "CIRCULARITY", circularity_thresh)
+    if any(area_thresh):
+        settings = set_trackmate_filter(settings, "AREA", area_thresh)
+    if any(perimeter_thresh):
+        settings = set_trackmate_filter(settings, "PERIMETER", perimeter_thresh)
+    if any(feret_thresh):
+        settings = set_trackmate_filter(settings, "FERET", feret_thresh)
+
+
+    print(settings)
+
+    # Configure tracker
+    settings.trackerFactory = SparseLAPTrackerFactory()
+    settings.trackerSettings = settings.trackerFactory.getDefaultSettings()
+    # settings.addTrackAnalyzer(TrackDurationAnalyzer())
+    settings.trackerSettings["LINKING_MAX_DISTANCE"] = 3.0
+    settings.trackerSettings["GAP_CLOSING_MAX_DISTANCE"] = 3.0
+    settings.trackerSettings["MAX_FRAME_GAP"] = 2
+
+
+    # Initialize TrackMate with model and settings
+    trackmate = TrackMate(model, settings)
+    trackmate.computeSpotFeatures(True)
+    trackmate.computeTrackFeatures(False)
+
+    # Check input validity
+    if not trackmate.checkInput():
+        sys.exit(str(trackmate.getErrorMessage()))
+        return
+
+    # Process the data
+    if not trackmate.process():
+        if "[SparseLAPTracker] The spot collection is empty." in str(
+            trackmate.getErrorMessage()
+        ):
+            return IJ.createImage(
+                "Untitled",
+                "8-bit black",
+                implus.getWidth(),
+                implus.getHeight(),
+                implus.getNFrames(),
+            )
+        else:
+            sys.exit(str(trackmate.getErrorMessage()))
+            return
+
+    # Export the label image
+    # sm = SelectionModel(model)
+    exportSpotsAsDots = False
+    exportTracksOnly = False
+    label_imp = LabelImgExporter.createLabelImagePlus(
+        trackmate, exportSpotsAsDots, exportTracksOnly, False
+    )
+    label_imp.setDimensions(1, dims[3], dims[4])
+    label_imp.setCalibration(cal)
+    implus.setDimensions(dims[2], dims[3], dims[4])
+    return label_imp
+
+def set_trackmate_filter(settings, filter_name, filter_value):
+    """Sets a TrackMate spot filter with specified filter name and values.
 
     Parameters
     ----------
-    imp : ImagePlus
-        a single channel (= desired class) of the WEKA classification result imp
+    settings : Settings
+        TrackMate settings object to which the filter will be added.
+    filter_name : str
+        The name of the filter to be applied.
+    filter_value : list
+        A list containing two values for the filter. The first value is
+        applied as an above-threshold filter, and the second as a below-threshold filter.
     """
-    IJ.run(imp, "8-bit", "")
-    IJ.run(imp, "Median...", "radius=3")
-    IJ.run(imp, "Gaussian Blur...", "sigma=2")
-    IJ.run(imp, "Auto Threshold", "method=MaxEntropy")
-    IJ.run(imp, "Invert", "")
-
+    filter = FeatureFilter(filter_name, filter_value[0], True)
+    settings.addSpotFilter(filter)
+    filter = FeatureFilter(filter_name, filter_value[1], False)
+    settings.addSpotFilter(filter)
+    return settings
 
 def delete_channel(imp, channel_number):
     """delete a channel from target imp
@@ -514,10 +676,6 @@ print("output_dir: ", str(output_dir))
 if not os.path.exists( str(output_dir) ):
     os.makedirs( str(output_dir) )
 
-classifiers_dir = fix_ij_dirs(classifiers_dir)
-primary_model = classifiers_dir + "/" + "primary.model"
-secondary_model = classifiers_dir + "/" + "secondary_central_nuclei.model"
-
 # update the log for the user
 IJ.log( "Now working on " + str(raw_image_title) )
 if raw_image_calibration.scaled() == False:
@@ -536,19 +694,14 @@ IJ.log( "MHC positive fiber channel = " + str(fiber_channel) )
 IJ.log( "sub-tiling = " + str(tiling_factor) )
 IJ.log( " -- settings used -- ")
 
-# image (pre)processing and segmentation (-> ROIs)
-membrane = Duplicator().run(raw, membrane_channel, membrane_channel, 1, 1, 1, 1) # imp, firstC, lastC, firstZ, lastZ, firstT, lastT
-preprocess_membrane_channel(membrane)
-weka_result1 = apply_weka_model(primary_model, membrane, tiling_factor )
-delete_channel(weka_result1, 1)
-weka_result2 = apply_weka_model(secondary_model, weka_result1, tiling_factor )
-delete_channel(weka_result2, 1)
-weka_result2.setCalibration(raw_image_calibration)
-process_weka_result(weka_result2)
-IJ.saveAs(weka_result2, "Tiff", output_dir + "/" + raw_image_title + "_all_fibers_binary")
-eda_parameters = [minAr, maxAr, minPer, maxPer, minCir, maxCir, minRnd, maxRnd, minSol, maxSol, minFAR, maxFAR, minMinFer, maxMinFer]
-raw.show() # EPA will not work if no image is shown
-run_extended_particle_analyzer(weka_result2, eda_parameters)
+# image (pre)processing and segmentation (-> ROIs)# imp, firstC, lastC, firstZ, lastZ, firstT, lastT
+imp_result = run_tm(raw, membrane_channel, cellpose_dir.getPath(), PretrainedModel.CYTO2, 30.0, area_thresh=[minAr, maxAr], circularity_thresh=[minCir, maxCir],
+        perimeter_thresh=[minPer, maxPer],
+        # feret_thresh=[minMinFer, maxMinFer],
+    )
+IJ.saveAs(imp_result, "Tiff", output_dir + "/" + raw_image_title + "_all_fibers_binary")
+
+sys.exit()
 
 # modify rois
 rm.hide()
-- 
GitLab