././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/0000755000175000017500000000000014730277752012303 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708524430.0 nabu-2024.2.1/LICENSE0000644000175000017500000000205214565401616013301 0ustar00pierrepierreMIT License Copyright (c) 2020-2024 ESRF Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/PKG-INFO0000644000175000017500000001066714730277752013412 0ustar00pierrepierreMetadata-Version: 2.1 Name: nabu Version: 2024.2.1 Summary: Nabu - Tomography software Author-email: Pierre Paleo , Henri Payno , Alessandro Mirone , Jérôme Lesaint Maintainer-email: Pierre Paleo License: MIT License Copyright (c) 2020-2024 ESRF Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Project-URL: Homepage, https://gitlab.esrf.fr/tomotools/nabu Project-URL: Documentation, http://www.silx.org/pub/nabu/doc Project-URL: Repository, https://gitlab.esrf.fr/tomotools/nabu/-/releases Project-URL: Changelog, https://gitlab.esrf.fr/tomotools/nabu/-/blob/master/CHANGELOG.md Keywords: tomography,reconstruction,X-ray imaging,synchrotron radiation,High Performance Computing,Parallel geometry,Conebeam geometry,Helical geometry,Ring artefact correction,Geometric calibration Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Environment :: Console Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: Unix Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: POSIX Classifier: Topic :: Scientific/Engineering :: Physics Classifier: Topic :: Scientific/Engineering :: Medical Science Apps. Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: full Provides-Extra: full_nocuda Provides-Extra: doc License-File: LICENSE # Nabu ESRF tomography processing software. ## Installation To install the development version: ```bash pip install [--user] git+https://gitlab.esrf.fr/tomotools/nabu.git ``` To install the stable version: ```bash pip install [--user] nabu ``` ## Usage Nabu can be used in several ways: - As a Python library, by features like `Backprojector`, `FlatField`, etc - As a standalone application with the command line interface - From Tomwer ([https://gitlab.esrf.fr/tomotools/tomwer/](https://gitlab.esrf.fr/tomotools/tomwer/)) To get quickly started, launch: ```bash nabu-config ``` Edit the generated configuration file (`nabu.conf`). Then: ```bash nabu nabu.conf --slice 500-600 ``` will reconstruct the slices 500 to 600, with processing steps depending on `nabu.conf` contents. ## Documentation The documentation can be found on the silx.org page ([https://www.silx.org/pub/nabu/doc](http://www.silx.org/pub/nabu/doc)). The latest documentation built by continuous integration can be found here: [https://tomotools.gitlab-pages.esrf.fr/nabu/](https://tomotools.gitlab-pages.esrf.fr/nabu/) ## Running the tests Once nabu is installed, running ```bash nabu-test ``` will execute all the tests. You can also specify specific module(s) to test, for example: ```bash nabu-test preproc misc ``` You can also provide more `pytest` options, for example increase verbosity with `-v`, exit at the first fail with `-x`, etc. Use `nabu-test --help` for displaying the complete options list. ## Nabu - what's in a name ? Nabu was the Mesopotamian god of literacy, rational arts, scribes and wisdom. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/README.md0000644000175000017500000000311614315516747013562 0ustar00pierrepierre# Nabu ESRF tomography processing software. ## Installation To install the development version: ```bash pip install [--user] git+https://gitlab.esrf.fr/tomotools/nabu.git ``` To install the stable version: ```bash pip install [--user] nabu ``` ## Usage Nabu can be used in several ways: - As a Python library, by features like `Backprojector`, `FlatField`, etc - As a standalone application with the command line interface - From Tomwer ([https://gitlab.esrf.fr/tomotools/tomwer/](https://gitlab.esrf.fr/tomotools/tomwer/)) To get quickly started, launch: ```bash nabu-config ``` Edit the generated configuration file (`nabu.conf`). Then: ```bash nabu nabu.conf --slice 500-600 ``` will reconstruct the slices 500 to 600, with processing steps depending on `nabu.conf` contents. ## Documentation The documentation can be found on the silx.org page ([https://www.silx.org/pub/nabu/doc](http://www.silx.org/pub/nabu/doc)). The latest documentation built by continuous integration can be found here: [https://tomotools.gitlab-pages.esrf.fr/nabu/](https://tomotools.gitlab-pages.esrf.fr/nabu/) ## Running the tests Once nabu is installed, running ```bash nabu-test ``` will execute all the tests. You can also specify specific module(s) to test, for example: ```bash nabu-test preproc misc ``` You can also provide more `pytest` options, for example increase verbosity with `-v`, exit at the first fail with `-x`, etc. Use `nabu-test --help` for displaying the complete options list. ## Nabu - what's in a name ? Nabu was the Mesopotamian god of literacy, rational arts, scribes and wisdom. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.4967566 nabu-2024.2.1/doc/0000755000175000017500000000000014730277751013047 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708524430.0 nabu-2024.2.1/doc/conf.py0000644000175000017500000000731314565401616014345 0ustar00pierrepierre# Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- project = "Nabu" copyright = "2019-2024, ESRF" author = "Pierre Paleo" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "myst_parser", "sphinx.ext.autosectionlabel", "sphinx.ext.napoleon", "sphinx.ext.autodoc", "sphinx.ext.mathjax", "sphinx.ext.viewcode", "nbsphinx", # 'sphinx.ext.autosummary', # 'sphinx.ext.doctest', # 'sphinx.ext.inheritance_diagram', ] # myst_commonmark_only = True # for myst suppress_warnings = [ "myst.header", # non-consecutive headers levels "autosectionlabel.*", # duplicate section names ] myst_heading_anchors = 3 myst_enable_extensions = [ # "amsmath", # # "colon_fence", # # "deflist", "dollarmath", # "html_admonition", # "html_image", # "linkify", # # "replacements", # # "smartquotes", # # "substitution" ] # # autosummary_generate = True autodoc_member_order = "bysource" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # # from corlab_theme import get_theme_dir # html_theme = 'corlab_theme' from cloud_sptheme import get_theme_dir html_theme = "cloud" html_theme_path = [get_theme_dir()] # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) html_css_files = [ "theme_overrides.css", ] html_theme_options = { # 'navigation_depth': -1, "max_width": "75%", "minimal_width": "720px", } # For mathjax mathjax_path = "javascript/MathJax-3.0.5/es5/tex-mml-chtml.js" """ # For recommonmark from recommonmark.transform import AutoStructify github_doc_root = 'https://github.com/rtfd/recommonmark/tree/master/doc/' def setup(app): app.add_config_value('recommonmark_config', { 'url_resolver': lambda url: github_doc_root + url, 'auto_toc_tree_section': 'Contents', 'enable_math': True, 'enable_inline_math': True, }, True) app.add_transform(AutoStructify) """ # Document __init__ autoclass_content = "both" from nabu import __version__ version = __version__ release = version master_doc = "index" # # nbsphinx # nbsphinx_allow_errors = True ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/doc/create_conf_doc.py0000755000175000017500000000216714315516747016527 0ustar00pierrepierre#!/usr/bin/env python import os from nabu.pipeline.fullfield.nabu_config import nabu_config def header(file_): content = "# Nabu configuration parameters\nThis file lists all the current configuration parameters available in the [configuration file](nabu_config_file.md)." print(content, file=file_) def generate(file_): def write(content): print(content, file=file_) header(file_) for section, values in nabu_config.items(): if section == "about": continue write("### %s\n" % section) for key, val in values.items(): if val["type"] == "unsupported": continue help_content = val["help"] if "---" in help_content: help_content = help_content.replace("--", "") write(help_content + "\n") write( "```ini\n%s = %s\n```" % (key, val["default"]) ) if __name__ == "__main__": fname = os.path.join( os.path.dirname(os.path.realpath(__file__)), "nabu_config_items.md" ) with open(fname, "w") as f: generate(f) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/doc/get_mathjax.py0000755000175000017500000000176414315516747015727 0ustar00pierrepierre#!/usr/bin/env python import os import tarfile import requests MATHJAX_PATH = "_static/javascript/MathJax-3.0.5" MATHJAX_URL = "http://www.silx.org/pub/nabu/static/MathJax-3.0.5.tar.lzma" def download_file(file_url, target_file_path): print("Downloading %s" % file_url) rep = requests.get(file_url) with open(target_file_path, "wb") as f: f.write(rep.content) def uncompress_file(compressed_file_path, target_directory): print("Uncompressing %s into %s" % (compressed_file_path, target_directory)) with tarfile.open(compressed_file_path) as f: f.extractall(path=target_directory) def main(): doc_path = os.path.dirname(os.path.realpath(__file__)) mathjax_path = os.path.join(doc_path, MATHJAX_PATH) if os.path.isdir(mathjax_path): return mathjax_comp_file = mathjax_path + ".tar.lzma" download_file(MATHJAX_URL, mathjax_comp_file) uncompress_file(mathjax_comp_file, os.path.dirname(mathjax_path)) if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.4967566 nabu-2024.2.1/nabu/0000755000175000017500000000000014730277751013227 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734438198.0 nabu-2024.2.1/nabu/__init__.py0000644000175000017500000000041614730266466015342 0ustar00pierrepierre__version__ = "2024.2.1" __nabu_modules__ = [ "app", "cuda", "estimation", "io", "misc", "opencl", "pipeline", "processing", "preproc", "reconstruction", "resources", "thirdparty", "stitching", ] version = __version__ ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5007565 nabu-2024.2.1/nabu/app/0000755000175000017500000000000014730277752014010 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/app/__init__.py0000644000175000017500000000000014315516747016106 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/app/bootstrap.py0000644000175000017500000000630214654107202016363 0ustar00pierrepierrefrom os import path, environ from glob import glob from ..utils import get_folder_path from ..pipeline.config import generate_nabu_configfile, parse_nabu_config_file from ..pipeline.fullfield.nabu_config import nabu_config as default_fullfield_config from ..pipeline.helical.nabu_config import nabu_config as helical_fullfield_config from .utils import parse_params_values from .cli_configs import BootstrapConfig def bootstrap(): args = parse_params_values(BootstrapConfig, parser_description="Initialize a nabu configuration file") do_bootstrap = bool(args["bootstrap"]) no_comments = bool(args["nocomments"]) overwrite = bool(args["overwrite"]) if do_bootstrap: print( "The --bootstrap option is now the default behavior of the nabu-config command. This option is therefore not needed anymore." ) if path.isfile(args["output"]) and not (overwrite): rep = input("File %s already exists. Overwrite ? [y/N]" % args["output"]) if rep.lower() != "y": print("Stopping") exit(0) opts_level = args["level"] prefilled_values = {} template_name = args["template"] if template_name != "": prefilled_values = get_config_template(template_name, if_not_found="print") if prefilled_values is None: exit(0) opts_level = "advanced" if args["dataset"] != "": prefilled_values["dataset"] = {} user_dataset = args["dataset"] if not path.isabs(user_dataset): user_dataset = path.abspath(user_dataset) print("Warning: using absolute dataset path %s" % user_dataset) if not path.exists(user_dataset): print("Error: cannot find the file or directory %s" % user_dataset) exit(1) prefilled_values["dataset"]["location"] = user_dataset if args["helical"]: my_config = helical_fullfield_config else: my_config = default_fullfield_config generate_nabu_configfile( args["output"], my_config, comments=not (no_comments), options_level=opts_level, prefilled_values=prefilled_values, ) return 0 def get_config_template(template_name, if_not_found="raise"): def handle_not_found(msg): if if_not_found == "raise": raise FileNotFoundError(msg) elif if_not_found == "print": print(msg) templates_path = get_folder_path(path.join("resources", "templates")) custom_templates_path = environ.get("NABU_TEMPLATES_PATH", None) templates = glob(path.join(templates_path, "*.conf")) if custom_templates_path is not None: templates_custom = glob(path.join(custom_templates_path, "*.conf")) templates_custom += glob(path.join(custom_templates_path, "*.cfg")) templates = templates_custom + templates available_templates_names = [path.splitext(path.basename(fname))[0] for fname in templates] if template_name not in available_templates_names: handle_not_found("Unable to find template '%s'. Available are: %s" % (template_name, available_templates_names)) return fname = templates[available_templates_names.index(template_name)] return parse_nabu_config_file(fname) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708524430.0 nabu-2024.2.1/nabu/app/bootstrap_stitching.py0000644000175000017500000000414414565401616020450 0ustar00pierrepierrefrom .cli_configs import BootstrapStitchingConfig from ..pipeline.config import generate_nabu_configfile from ..stitching.config import ( get_default_stitching_config, SECTIONS_COMMENTS as _SECTIONS_COMMENTS, INPUT_DATASETS_FIELD as _INPUT_DATASETS_FIELD, INPUTS_SECTION as _INPUTS_SECTION, ) from .utils import parse_params_values from tomoscan.factory import Factory from tomoscan.esrf.volume.utils import guess_volumes def guess_tomo_objects(my_str: str) -> tuple: """ try to find some tomo object from a string. The string can be either related to a volume or a scan and can be an identifier or a filfe/folder path :param str my_str: string related to the tomo object :return: a tuple of tomo objects either instance of VolumeBase or TomoScanBase :rtype: tuple """ try: # create_tomo_object_from_identifier will raise an exception is the string does not match an identifier return (Factory.create_tomo_object_from_identifier(my_str),) except Exception: pass try: volumes = guess_volumes(my_str) except Exception: pass else: if len(volumes) > 0: return volumes try: return Factory.create_scan_objects(my_str) except Exception: return tuple() def bootstrap_stitching(): args = parse_params_values( BootstrapStitchingConfig, parser_description="Initialize a 'nabu-stitching' configuration file", ) prefilled_values = {} datasets_as_str = args.get("datasets", None) datasets = [] for dataset in datasets_as_str: datasets.extend(guess_tomo_objects(dataset)) if len(datasets) > 0: prefilled_values = { _INPUTS_SECTION: {_INPUT_DATASETS_FIELD: [dataset.get_identifier().to_str() for dataset in datasets]} } generate_nabu_configfile( fname=args["output"], default_config=get_default_stitching_config(args["stitching_type"]), comments=True, sections_comments=_SECTIONS_COMMENTS, options_level=args["level"], prefilled_values=prefilled_values, ) return 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/app/cast_volume.py0000644000175000017500000002600014654107202016664 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import os import sys import logging from argparse import RawTextHelpFormatter import numpy from silx.io.url import DataUrl from tomoscan.esrf.volume.utils import guess_volumes from tomoscan.factory import Factory from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from nabu.io.cast_volume import ( RESCALE_MAX_PERCENTILE, RESCALE_MIN_PERCENTILE, cast_volume, get_default_output_volume, ) from nabu.pipeline.params import files_formats from nabu.utils import convert_str_to_tuple from nabu.io.cast_volume import _min_max_from_histo _logger = logging.getLogger(__name__) def main(argv=None): if argv is None: argv = sys.argv _volume_url_helps = "\n".join( [ f"- {(volume.__name__).ljust(15)}: {volume.example_defined_from_str_identifier()}" for volume in ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) ] ) volume_help = f"""To define a volume you can either provide: \n * an url (recommanded way) - see details lower \n * a path. For hdf5 and multitiff we expect a file path. For edf, tif and jp2k we expect a folder path. In this case we will try to deduce the Volume from it. \n url must be defined like: \n{_volume_url_helps} """ parser = argparse.ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) parser.add_argument( "input_volume", help=f"input volume. {volume_help}", ) parser.add_argument( "--output-data-type", help="output data type. Valid value are numpy default types name like (uint8, uint16, int8, int16, int32, float32, float64)", default="uint16", ) parser.add_argument( "--output_volume", help=f"output volume. Must be provided if 'output_type' isn't. Must looks like: \n{volume_help}", default=None, ) parser.add_argument( "--output_type", help=f"output type. Must be provided if 'output_volume' isn't. Valid values are {tuple(files_formats.keys())}", default=None, ) parser.add_argument( "--data_min", help=f"value to clamp to volume cast new min. Any lower value will also be clamp to this value.", default=None, ) parser.add_argument( "--data_max", help=f"value to clamp to volume cast new max. Any higher value will also be clamp to this value.", default=None, ) parser.add_argument( "--rescale_min_percentile", help=f"used to determine data_min if not provided. Expected as percentage. Default is {RESCALE_MIN_PERCENTILE}%%", default=RESCALE_MIN_PERCENTILE, ) parser.add_argument( "--rescale_max_percentile", help=f"used to determine data_max if not provided. Expected as percentage. Default is {RESCALE_MAX_PERCENTILE}%%", default=RESCALE_MAX_PERCENTILE, ) parser.add_argument( "--overwrite", dest="overwrite", action="store_true", default=False, help="Overwrite file or dataset if exists", ) parser.add_argument( "--compression-ratios", dest="compression_ratios", default=None, help="Define compression ratios for jp2k. Expected as a list like [20, 10, 1] for [quality layer 1, quality layer 2, quality layer 3]... Pass parameter to glymur. See https://glymur.readthedocs.io/en/latest/how_do_i.html#write-images-with-different-compression-ratios-for-different-layers for more details", ) parser.add_argument( "--histogram-url", dest="histogram_url", default=None, help="Provide url to the histogram - like: '/{path}/my_file.hdf5?path/to/my/data' with my_file.hdf5 is the file containing the histogram. Located under 'path'. And 'path/to/my/data' is the location of the HDF5 dataset", ) options = parser.parse_args(argv[1:]) # handle input volume if os.path.exists(options.input_volume): volumes = guess_volumes(options.input_volume) def is_not_histogram(vol_identifier): return not (hasattr(vol_identifier, "data_path") and vol_identifier.data_path.endswith("histogram")) volumes = tuple(filter(is_not_histogram, volumes)) if len(volumes) == 0: _logger.error(f"no valid volume found in {options.input_volume}") exit(1) elif len(volumes) > 1: _logger.error( f"found several volume from {options.input_volume}. Please provide one full url from {[volume.get_identifier() for volume in volumes]}" ) else: input_volume = volumes[0] else: try: input_volume = Factory.create_tomo_object_from_identifier(options.input_volume) except Exception as e: raise ValueError(f"Fail to build input volume from url {options.input_volume}") from e # handle output format output_format = files_formats.get(options.output_type, None) # handle output volume if options.output_volume is not None: # if an url is provided if ":" in options.output_volume: try: output_volume = Factory.create_tomo_object_from_identifier(options.output_volume) except Exception as e: raise ValueError(f"Fail to build output volume from {options.output_volume}") from e if output_format is not None: if not ( isinstance(output_volume, EDFVolume) and output_format == "edf" or isinstance(output_format, HDF5Volume) and output_format == "hdf5" or isinstance(output_format, JP2KVolume) and output_format == "jp2" or isinstance(output_format, (TIFFVolume, MultiTIFFVolume)) and output_format == "tiff" ): raise ValueError( "Requested 'output_type' and output volume url are incoherent. 'output_type' is optional when url provided" ) else: path_extension = os.path.splitext(options.output_volume)[-1] if path_extension == "": # if a folder ha sbeen provided we try to create a volume from this path and the output format if output_format == "tiff": output_volume = TIFFVolume( folder=options.output_volume, ) elif output_format == "edf": output_volume = EDFVolume( folder=options.output_volume, ) elif output_format == "jp2": output_volume = JP2KVolume( folder=options.output_volume, ) else: raise ValueError( f"Unable to deduce an output volume from {options.output_volume} and output format {output_format}. Please provide an output_volume as an url" ) else: # if a file path_has been provided if path_extension.lower() in ("tif", "tiff") and output_format in ( None, "tiff", ): output_volume = MultiTIFFVolume( file_path=options.output_volume, ) elif path_extension.lower() in ( "h5", "nx", "nexus", "hdf", "hdf5", ) and output_format in (None, "hdf5"): output_volume = HDF5Volume( file_path=options.output_volume, data_path="volume", ) else: raise ValueError( f"Unable to deduce an output volume from {options.output_volume} and output format {output_format}. Please provide an output_volume as an url" ) elif options.output_type is None: raise ValueError("'output_type' or 'output_volume' is expected") else: output_volume = get_default_output_volume( input_volume=input_volume, output_type=output_format # pylint: disable=E0606 ) try: output_data_type = numpy.dtype(getattr(numpy, options.output_data_type)) except Exception as e: raise ValueError(f"Unable to get output data type from {options.output_data_type}") from e # get data_min and data_max data_min = options.data_min if data_min is not None: data_min = float(data_min) data_max = options.data_max if data_max is not None: data_max = float(data_max) # get rescale_min_percentile and rescale_min_percentile rescale_min_percentile = options.rescale_min_percentile def clean_percentiles_str(percentile): # remove ' char percentile = percentile.rstrip("'").lstrip("'") # remove " char percentile = percentile.rstrip('"').lstrip('"') # remove % char return percentile.rstrip("%") if isinstance(rescale_min_percentile, str): rescale_min_percentile = float(clean_percentiles_str(rescale_min_percentile)) rescale_max_percentile = options.rescale_max_percentile if isinstance(rescale_min_percentile, str): rescale_max_percentile = float(clean_percentiles_str(rescale_max_percentile)) assert rescale_min_percentile is not None, "rescale_min_percentile should be an int" assert rescale_max_percentile is not None, "rescale_max_percentile should be an int" # handle histogram and data_min, data_max if options.histogram_url is not None: if data_min is not None or data_max is not None: raise ValueError("Both histogram url and data min/max are provided. Don't know which one to take") else: if not options.histogram_url.startswith("silx:"): options.histogram_url = "silx:" + options.histogram_url histogram_url = DataUrl(path=options.histogram_url) data_min, data_max = _min_max_from_histo( url=histogram_url, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) # update output volume from options output_volume.overwrite = options.overwrite if options.compression_ratios is not None: output_volume.cratios = list([int(value) for value in convert_str_to_tuple(options.compression_ratios)]) # do volume casting cast_volume( input_volume=input_volume, output_volume=output_volume, output_data_type=output_data_type, data_min=data_min, data_max=data_max, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) exit(0) if __name__ == "__main__": main(sys.argv) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/app/cli_configs.py0000644000175000017500000005340414726604214016637 0ustar00pierrepierre# # Default configuration for CLI tools # # Default configuration for "bootstrap" command from nabu.stitching.definitions import StitchingType from nabu.pipeline.config_validators import str2bool from tomoscan.framereducer.method import ReduceMethod BootstrapConfig = { "bootstrap": { "help": "DEPRECATED, this is the default behavior. Bootstrap a configuration file from scratch.", "action": "store_const", "const": 1, }, "convert": { "help": "UNSUPPORTED. This option has no effect and will disappear. Convert a PyHST configuration file to a nabu configuration file.", "default": "", }, "output": { "help": "Output filename", "default": "nabu.conf", }, "nocomments": { "help": "Remove the comments in the configuration file (default: False)", "action": "store_const", "const": 1, }, "level": { "help": "Level of options to embed in the configuration file. Can be 'required', 'optional', 'advanced'.", "default": "optional", }, "dataset": { "help": "Pre-fill the configuration file with the dataset path.", "default": "", }, "template": { "help": "Use a template configuration file. Available are: id19_pag, id16_holo, id16_ctf, id16a_fluo, bm05_pag. You can also define your own templates via the NABU_TEMPLATES_PATH environment variable.", "default": "", }, "helical": {"help": "Prepare configuration file for helical", "default": 0, "required": False, "type": int}, "overwrite": { "help": "Whether to overwrite the output file if exists", "action": "store_const", "const": 1, }, } # Default configuration for "zsplit" command ZSplitConfig = { "input_file": { "help": "Input HDF5-Nexus file", "mandatory": True, }, "output_directory": { "help": "Output directory to write split files.", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "entry": { "help": "HDF5 entry to take in the input file. By default, the first entry is taken.", "default": "", }, "n_stages": { "help": "Number of expected stages (i.e different 'Z' values). By default it is inferred from the dataset.", "default": -1, "type": int, }, "use_virtual_dataset": { "help": "Whether to use virtual datasets for output file. Not using a virtual dataset duplicates data and thus results in big files ! However virtual datasets currently have performance issues. Default is False", "default": 0, "type": int, }, } # Default configuration for "histogram" command HistogramConfig = { "h5_file": { "help": "HDF5 file(s). It can be one or several paths to HDF5 files. You can specify entry for each file with /path/to/file.h5?entry0000", "mandatory": True, "nargs": "+", }, "output_file": { "help": "Output file (HDF5)", "mandatory": True, }, "bins": { "help": "Number of bins for histogram if they have to be computed. Default is one million.", "default": 1000000, "type": int, }, "chunk_size_slices": { "help": "If histogram are computed, specify the maximum subvolume size (in number of slices) for computing histogram.", "default": 100, "type": int, }, "chunk_size_GB": { "help": "If histogram are computed, specify the maximum subvolume size (in GibaBytes) for computing histogram.", "default": -1, "type": float, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, } # Default configuration for "reconstruct" command ReconstructConfig = { "input_file": { "help": "Nabu input file", "default": "", "mandatory": True, }, "logfile": { "help": "Log file. Default is dataset_prefix_nabu.log", "default": "", }, "log_file": { "help": "Same as logfile. Deprecated, use --logfile instead.", "default": "", }, "slice": { "help": "Slice(s) indice(s) to reconstruct, in the format z1-z2. Default (empty) is the whole volume. This overwrites the configuration file start_z and end_z. You can also use --slice first, --slice last, --slice middle, and --slice all", "default": "", }, "gpu_mem_fraction": { "help": "Which fraction of GPU memory to use. Default is 0.9.", "default": 0.9, "type": float, }, "cpu_mem_fraction": { "help": "Which fraction of memory to use. Default is 0.9.", "default": 0.9, "type": float, }, "max_chunk_size": { "help": "Maximum chunk size to use.", "default": -1, "type": int, }, "phase_margin": { "help": "Specify an explicit phase margin to use when performing phase retrieval.", "default": -1, "type": int, }, "force_use_grouped_pipeline": { "help": "Force nabu to use the 'grouped' reconstruction pipeline - slower but should work for all big datasets.", "default": 0, "type": int, }, } MultiCorConfig = ReconstructConfig.copy() MultiCorConfig.update( { "cor": { "help": "Positions of the center of rotation. It must be a list of comma-separated scalars, or in the form start:stop:step, where start, stop and step can all be floating-point values.", "default": "", "mandatory": True, }, "slice": { "help": "Slice(s) indice(s) to reconstruct, in the format z1-z2. Default (empty) is the whole volume. This overwrites the configuration file start_z and end_z. You can also use --slice first, --slice last, --slice middle, and --slice all", "default": "", "mandatory": True, }, } ) GenerateInfoConfig = { "hist_file": { "help": "HDF5 file containing the histogram, either the reconstruction file or a dedicated histogram file.", "default": "", }, "hist_entry": { "help": "Histogram HDF5 entry. Defaults to the first available entry.", "default": "", }, "output": { "help": "Output file name", "default": "", "mandatory": True, }, "bliss_file": { "help": "HDF5 master file produced by BLISS", "default": "", }, "bliss_entry": { "help": "Entry in the HDF5 master file produced by BLISS. By default, take the first entry.", "default": "", }, "info_file": { "help": "Path to the .info file, in the case of a EDF dataset", "default": "", }, "edf_proj": { "help": "Path to a projection, in the case of a EDF dataset", "default": "", }, } RotateRadiosConfig = { "dataset": { "help": "Path to the dataset. Only HDF5 format is supported for now.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry. By default, the first entry is taken.", "default": "", }, "angle": { "help": "Rotation angle in degrees", "default": 0.0, "mandatory": True, "type": float, }, "center": { "help": "Rotation center, in the form (x, y) where x (resp. y) is the horizontal (resp. vertical) dimension, i.e along the columns (resp. lines). Default is (Nx/2 - 0.5, Ny/2 - 0.5).", "default": "", }, "output": { "help": "Path to the output file. Only HDF5 output is supported. In the case of HDF5 input, the output file will have the same structure.", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "batchsize": { "help": "Size of the batch of images to process. Default is 100", "default": 100, "type": int, }, "use_cuda": { "help": "Whether to use Cuda if available", "default": "1", }, "use_multiprocessing": { "help": "Whether to use multiprocessing if available", "default": "1", }, } DFFConfig = { "dataset": { "help": "Path to the dataset.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry (for HDF5 datasets). By default, the first entry is taken.", "default": "", }, "flatfield": { "help": "Whether to perform flat-field normalization. Default is True.", "default": "1", "type": int, }, "sigma": { "default": 0.0, "help": "Enable high-pass filtering on double flatfield with this value of 'sigma'", "type": float, }, "output": { "help": "Path to the output file (HDF5).", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "chunk_size": { "help": "Maximum number of lines to read in each projection in a single pass. Default is 100", "default": 100, "type": int, }, } CompareVolumesConfig = { "volume1": { "help": "Path to the first volume.", "default": "", "mandatory": True, }, "volume2": { "help": "Path to the first volume.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry. By default, the first entry is taken.", "default": "", }, "hdf5_path": { "help": "Full HDF5 path to the data. Default is /reconstruction/results/data", "default": "", }, "chunk_size": { "help": "Maximum number of images to read in each step. Default is 100.", "default": 100, "type": int, }, "stop_at": { "help": "Stop the comparison immediately when the difference exceeds this threshold. Default is to compare the full volumes.", "default": "1e-4", }, "statistics": { "help": "Compute statistics on the compared (sub-)volumes. Mind that in this case the command output will not be empty!", "default": 0, "type": int, }, } # Default configuration for "stitching" command StitchingConfig = { "input-file": { "help": "Nabu configuraiton file for stitching (can be obtain from nabu-stitching-boostrap command)", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "--only-create-master-file": { "help": "Will create the master file with all sub files (volumes or scans). It expects the processing to be finished. It can happen if all slurm job have been submitted but you've been kicked out of the cluster of if you need to relaunch manually some failling job slurm for any reason", "default": False, "action": "store_true", }, } # Default configuration for "stitching-bootstrap" command BootstrapStitchingConfig = { "stitching-type": { "help": f"User can provide stitching type to filter some parameters. Must be in {StitchingType.values()}.", "default": None, }, "level": { "help": "Level of options to embed in the configuration file. Can be 'required', 'optional', 'advanced'.", "default": "optional", }, "output": { "help": "output file to store the configuration", "default": "stitching.conf", }, "datasets": { "help": "datasets to be stitched together", "default": tuple(), "nargs": "*", }, } ShrinkConfig = { "input_file": { "help": "Path to the NX file", "default": "", "mandatory": True, }, "output_file": { "help": "Path to the output NX file", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry in the file. Default is to take the first entry.", "default": "", }, "binning": { "help": "Binning factor, in the form (bin_z, bin_x). Each image (projection, dark, flat) will be binned by this factor", "default": "", }, "subsampling": {"help": "Subsampling factor for projections (and metadata)", "default": ""}, "threads": { "help": "Number of threads to use for binning. Default is 1.", "default": 1, "type": int, }, } CompositeCorConfig = { "--filename_template": { "required": True, "help": """The filename template. It can optionally contain a segment equal to "X"*ndigits which will be replaced by the stage number if several stages are requested by the user""", }, "--entry_name": { "required": False, "help": "Optional. The entry_name. It defaults to entry0000", "default": "entry0000", }, "--num_of_stages": { "type": int, "required": False, "help": "Optional. How many stages. Example: from 0 to 43 -> --num_of_stages 44. It is optional. ", }, "--oversampling": { "type": int, "default": 4, "required": False, "help": "Oversampling in the research of the axis position. Defaults to 4 ", }, "--n_subsampling_y": { "type": int, "default": 10, "required": False, "help": "How many lines we are going to take from each radio. Defaults to 10.", }, "--theta_interval": { "type": float, "default": 5, "required": False, "help": "Angular step for composing the image. Default to 5", }, "--first_stage": {"type": int, "default": None, "required": False, "help": "Optional. The first stage. "}, "--output_file": { "type": str, "required": False, "help": "Optional. Where the list of cors will be written. Default is the filename postixed with cors.txt. If the output filename is postfixed with .json the output will be in json format", }, "--cor_options": { "type": str, "help": """the cor_options string used by Nabu. Example --cor_options "side='near'; near_pos = 300.0; near_width = 20.0" """, "required": True, }, } CreateDistortionMapHorizontallyMatchedFromPolyConfig = { "--nz": {"type": int, "help": "vertical dimension of the detector", "required": True}, "--nx": {"type": int, "help": "horizontal dimension of the detector", "required": True}, "--center_z": {"type": float, "help": "vertical position of the optical center", "required": True}, "--center_x": {"type": float, "help": "horizontal position of the optical center", "required": True}, "--c4": {"type": float, "help": "order 4 coefficient", "required": True}, "--c2": {"type": float, "help": "order 2 coefficient", "required": True}, "--target_file": {"type": str, "help": "The map output filename", "required": True}, "--axis_pos": { "type": float, "default": None, "help": "Optional argument. If given it will be corrected for use with the produced map. The value is printed, or given as return argument if the utility is used from a script", "required": False, }, "--loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, } DiagToRotConfig = { "--diag_file": dict( required=True, help="The reconstruction file obtained by nabu-helical using the diag_zpro_run option", type=str ), "--entry_name": dict( required=False, help="entry_name. Defauls is entry0000", default="entry0000", ), "--near": dict( required=False, help="This is a relative offset respect to the center of the radios. The cor will be searched around the provided value. If not given the optinal parameter original_scan must be the original nexus file; and the estimated core will be taken there. The netry_name parameter also must be provided in this case", default=None, type=float, ), "--original_scan": dict( required=False, help="The original nexus file. Required only if near parameter is not given", default=None, type=str, ), "--entry_name": dict( required=False, help="The original nexus file entry name. Required only if near parameter is not given", default=None, type=str, ), "--near_width": dict( required=False, help="For the horizontal correlation, searching the cor. The radius around the near value", default=20, type=int, ), "--low_pass": dict( required=False, help="Data are filtered horizontally. details smaller than the provided value are filtered out. Default is 1( gaussian sigma)", default=1.0, type=float, ), "--high_pass": dict( required=False, help="Data are filtered horizontally. Bumps larger than the provided value are filtered out. Default is 10( gaussian sigma)", default=10, type=int, ), "--linear_interpolation": dict( required=False, help="If True(default) the cor will vary linearly with z_transl", default=True, type=str2bool ), "--use_l1_norm": dict( required=False, default=True, help="If false then a L2 norm will be used for the error metric, considering the overlaps, if true L1 norm will be considered", type=str2bool, ), "--cor_file": dict(required=True, help="The file where the information to correct the cor are written", type=str), } DiagToPixConfig = { "--diag_file": dict( required=True, help="The reconstruction file obtained by nabu-helical using the diag_zpro_run option", type=str ), "--entry_name": dict(required=False, help="entry_name. Defauls is entry0000", default="entry0000", type=str), "--search_radius_v": dict( required=False, help="For the vertical correlation, The maximal error in pixels of one turn respect to a contiguous one. Default is 20 ", default=20, type=int, ), "--nexus_target": dict( required=False, help="If given, the mentioned file will be edited with the proper pixel size, the proper COR, and corrected x_translations", default=None, type=str, ), "--nexus_source": dict( required=False, help="Optionaly given, used only if nexus_target has been give. The nexus file will be edited and written on nexus_target. Otherwise nexus_target is considered to be the source", default=None, type=str, ), } CorrectRotConfig = { "--cor_file": dict(required=True, help="The file produce by diag_to_rot", type=str), "--entry_name": dict(required=False, help="entry_name. Defauls is entry0000", default="entry0000", type=str), "--nexus_target": dict( required=True, help="The given file will be edited with the proper pixel size, the proper COR, and corrected x_translations", default=None, type=str, ), "--nexus_source": dict( required=True, help="The nexus file will be edited and written on nexus_target", default=None, type=str ), } ReduceDarkFlatConfig = { "dataset": {"help": "Dataset (NXtomo or EDF folder) to be treated", "mandatory": True}, "entry": { "dest": "entry", "help": "an entry can be specify in case of an NXtomo", "default": None, "required": False, }, "dark-method": { "help": f"Define the method to be used for computing darks. Valid methods are {ReduceMethod.values()}", "default": ReduceMethod.MEAN, "required": False, }, "flat-method": { "help": f"Define the method to be used for computing flats. Valid methods are {ReduceMethod.values()}", "default": ReduceMethod.MEDIAN, "required": False, }, "overwrite": { "dest": "overwrite", "action": "store_true", "default": False, "help": "Overwrite dark/flats if exists", "required": False, }, "debug": { "dest": "debug", "action": "store_true", "default": False, "help": "Set logging system in debug mode", "required": False, }, "output-reduced-flats-file": { "aliases": ("orfl",), "default": None, "help": "Where to save reduced flats. If not provided will be dump near the .nx file at {scan_prefix}_flats.hdf5", "required": False, }, "output-reduced-flats-data-path": { "aliases": ("output-reduced-flats-dp", "orfdp"), "default": None, "help": "Path in the output reduced flats file to save the dataset. If not provided will be saved at {entry}/flats/", "required": False, }, "output-reduced-darks-file": { "aliases": ("ordf",), "default": None, "help": "Where to save reduced dark. If not provided will be dump near the .nx file at {scan_prefix}_darks.hdf5", "required": False, }, "output-reduced-darks-data-path": { "aliases": ("output-reduced-darks-dp", "orddp"), "default": None, "help": "Path in the output reduced darks file to save the dataset. If not provided will be saved at {entry}/darks/", "required": False, }, } ShowReconstructionTimingsConfig = { "logfile": { "help": "Path to the log file.", "default": "", "mandatory": True, }, "cutoff": { "help": "Cut-off parameter. Timings below this value will be discarded. For a upper-bound cutoff, provide a value in the form 'low, high'", "default": None, }, "type": { "help": "How to display the result. Default is a pie chart. Possible values are: pie, bars, violin", "default": "pie", "type": str, }, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/app/compare_volumes.py0000644000175000017500000000641514422670612017556 0ustar00pierrepierrefrom math import ceil from posixpath import join import numpy as np from tomoscan.io import HDF5File from .cli_configs import CompareVolumesConfig from ..utils import clip_circle from .utils import parse_params_values from ..io.utils import get_first_hdf5_entry, hdf5_entry_exists from ..io.reader import get_hdf5_dataset_shape def idx_1d_to_3d(idx, shape): nz, ny, nx = shape x = idx % nx idx2 = (idx - x) // nx y = idx2 % ny z = (idx2 - y) // ny return (z, y, x) def compare_volumes(fname1, fname2, h5_path, chunk_size, do_stats, stop_at_thresh, clip_outer_circle=False): result = None f1 = HDF5File(fname1, "r") f2 = HDF5File(fname2, "r") try: # Check that data is in the provided hdf5 path for fname in [fname1, fname2]: if not hdf5_entry_exists(fname, h5_path): result = "File %s do not has data in %s" % (fname, h5_path) return # Check shapes shp1 = get_hdf5_dataset_shape(fname1, h5_path) shp2 = get_hdf5_dataset_shape(fname2, h5_path) if shp1 != shp2: result = "Volumes do not have the same shape: %s vs %s" % (shp1, shp2) return # Compare volumes n_steps = ceil(shp1[0] / chunk_size) for i in range(n_steps): start = i * chunk_size end = min((i + 1) * chunk_size, shp1[0]) data1 = f1[h5_path][start:end, :, :] data2 = f2[h5_path][start:end, :, :] abs_diff = np.abs(data1 - data2) if clip_outer_circle: for j in range(abs_diff.shape[0]): abs_diff[j] = clip_circle(abs_diff[j], radius=0.9 * min(abs_diff.shape[1:])) coord_argmax = idx_1d_to_3d(np.argmax(abs_diff), data1.shape) if do_stats: mean = np.mean(abs_diff) std = np.std(abs_diff) maxabs = np.max(abs_diff) print("Chunk %d: mean = %e std = %e max = %e" % (i, mean, std, maxabs)) if stop_at_thresh is not None and abs_diff[coord_argmax] > stop_at_thresh: coord_argmax_absolute = (start + coord_argmax[0],) + coord_argmax[1:] result = "abs_diff[%s] = %e" % (coord_argmax_absolute, abs_diff[coord_argmax]) return except Exception as exc: result = "Error: %s" % (str(exc)) raise finally: f1.close() f2.close() return result def compare_volumes_cli(): args = parse_params_values( CompareVolumesConfig, parser_description="A command-line utility for comparing two volumes." ) fname1 = args["volume1"] fname2 = args["volume2"] h5_path = args["hdf5_path"] if h5_path == "": entry = args["entry"].strip() or None if entry is None: entry = get_first_hdf5_entry(fname1) h5_path = join(entry, "reconstruction/results/data") do_stats = bool(args["statistics"]) chunk_size = args["chunk_size"] stop_at_thresh = args["stop_at"] or None if stop_at_thresh is not None: stop_at_thresh = float(stop_at_thresh) res = compare_volumes(fname1, fname2, h5_path, chunk_size, do_stats, stop_at_thresh) if res is not None: print(res) return 0 if __name__ == "__main__": compare_volumes_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/app/composite_cor.py0000644000175000017500000001202214712705065017215 0ustar00pierrepierreimport logging import os import sys import numpy as np import re from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer from nabu.pipeline.estimators import CompositeCOREstimator, estimate_cor from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.utils import extract_parameters from nxtomo.application.nxtomo import NXtomo from .. import version from .cli_configs import CompositeCorConfig from .utils import parse_params_values from ..utils import DictToObj import json class NumpyArrayEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) def main(user_args=None): "Application to extract with the composite cor finder the center of rotation for a scan or a series of scans" if user_args is None: user_args = sys.argv[1:] args_dict = parse_params_values( CompositeCorConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=user_args, ) composite_cor_entry_point(args_dict) # here we have been called by the cli. The return value 0 means OK return 0 def composite_cor_entry_point(args_dict): args = DictToObj(args_dict) if len(os.path.dirname(args.filename_template)) == 0: # To make sure that other utility routines can succesfully deal with path within the current directory args.filename_template = os.path.join(".", args.filename_template) args.filename_template = os.path.abspath(args.filename_template) if args.first_stage is not None: if args.num_of_stages is None: args.num_of_stages = 1 # if the first_stage parameter has been given then # we are using numbers to form the names of the files. # The filename must containe a XX..X substring which will be replaced pattern = re.compile("[X]+") ps = pattern.findall(args.filename_template) if len(ps) == 0: message = f""" You have specified the "first_stage" parameter, with an integer. Therefore the "filename_template" parameter is expected to containe a XX..X subsection but none was found in the passed parameter which is {args.filename_template} """ raise ValueError(message) ls = list(map(len, ps)) idx = np.argmax(ls) args.filename_template = args.filename_template.replace(ps[idx], "{i_stage:" + "0" + str(ls[idx]) + "d}") if args.num_of_stages is None: # this way it works also in the simple case where # only the filename is provided together with the cor options num_of_stages = 1 first_stage = 0 else: num_of_stages = args.num_of_stages first_stage = args.first_stage cor_list = [] for iz in range(first_stage, first_stage + num_of_stages): if args.num_of_stages is not None: nexus_name = args.filename_template.format(i_stage=iz) else: nexus_name = args.filename_template dataset_info = HDF5DatasetAnalyzer(nexus_name, extra_options={"h5_entry": args.entry_name}) update_dataset_info_flats_darks(dataset_info, flatfield_mode=1) ### JL start ### # Extract CoR parameters from configuration file try: cor_options = extract_parameters(args.cor_options, sep=";") except Exception as exc: msg = "Could not extract parameters from cor_options: %s" % (str(exc)) raise ValueError(msg) ### JL end ### # JL start ### #: next bit will be done in estimate # if "near_pos" not in args.cor_options: # scan = NXtomo() # scan.load(file_path=nexus_name, data_path=args.entry_name) # estimated_near = scan.instrument.detector.x_rotation_axis_pixel_position # # cor_options = args.cor_options + f" ; near_pos = {estimated_near} " # # else: # cor_options = args.cor_options ### JL end ### cor_finder = CompositeCOREstimator( dataset_info, oversampling=args.oversampling, theta_interval=args.theta_interval, n_subsampling_y=args.n_subsampling_y, take_log=True, spike_threshold=0.04, cor_options=cor_options, norm_order=1, ) cor_position = cor_finder.find_cor() cor_list.append(cor_position) cor_list = np.array(cor_list).T if args.output_file is not None: output_name = args.output_file else: output_name = os.path.splitext(args.filename_template)[0] + "_cors.txt" if output_name.endswith(".json"): with open(output_name, "w") as fp: json.dump( dict(rotation_axis_position=cor_list[0], rotation_axis_position_list=cor_list), fp, indent=4, cls=NumpyArrayEncoder, ) else: np.savetxt( output_name, cor_list, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/app/correct_rot.py0000644000175000017500000000412214550227307016675 0ustar00pierrepierrefrom .. import version from os import environ import argparse import shutil import os import sys import re import h5py import numpy as np from ..resources.logger import LoggerOrPrint from .utils import parse_params_values from .cli_configs import CorrectRotConfig from silx.io.dictdump import h5todict from nxtomo.application.nxtomo import NXtomo import h5py from nabu.utils import DictToObj def main(user_args=None): """Applies the correction found by diag_to_rot to a nexus file""" if user_args is None: user_args = sys.argv[1:] args = DictToObj( parse_params_values( CorrectRotConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=user_args, ) ) # now we read the results of the diag_to_rot utility, they are in the cor_file parameter # of the cli cor_data = DictToObj(h5todict(args.cor_file, "/")) my_cor = cor_data.cor[0] # we will take my_cor as cor at the first angular position # and then we correct the x_translation at the other angles # We now load the nexus that we wish to correct nx_tomo = NXtomo().load(args.nexus_source, args.entry_name) # The cor_file that we use for correction # is providing us with the z_m that gives for each # cor position found in the cor array the corresponding value of # the translation along z (in meters) z_translation = nx_tomo.sample.z_translation.value z_translation = z_translation - z_translation[0] # now we interpolate to find the correction # for each position of the encoders cors = np.interp(z_translation, cor_data.z_m, cor_data.cor) # this is the correction x_correction = nx_tomo.instrument.detector.x_pixel_size.value * (cors - my_cor) # we are in meters here # and we apply it to the nexus that we have loaded nx_tomo.sample.x_translation = nx_tomo.sample.x_translation.value + x_correction # finally we write it to the corrected nexus file nx_tomo.save(file_path=args.nexus_target, data_path=args.entry_name, overwrite=True) return 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/app/create_distortion_map_from_poly.py0000644000175000017500000001406114654107202023013 0ustar00pierrepierreimport sys import numpy as np import h5py from .. import version from ..utils import DictToObj from ..resources.logger import Logger from .cli_configs import CreateDistortionMapHorizontallyMatchedFromPolyConfig from .utils import parse_params_values def create_distortion_maps_entry_point(user_args=None): """This application builds two arrays. Let us call them map_x and map_z. Both are 2D arrays with shape given by (nz, nx). These maps are meant to be used to generate a corrected detector image, using them to obtain the pixel (i,j) of the corrected image by interpolating the raw data at position ( map_z(i,j), map_x(i,j) ). This map is determined by a user given polynomial P(rs) in the radial variable rs = sqrt( (z-center_z)**2 + (x-center_x)**2 ) / (nx/2) where center_z and center_x give the center around which the deformation is centered. The perfect position (zp,xp) , that would be observed on a perfect detector, of a photon observed at pixel (z,x) of the distorted detector is: (zp, xp) = (center_z, center_x) + P(rs) * ( z - center_z , x - center_x ) The polynomial is given by P(rs) = rs *(1 + c2 * rs**2 + c4 * rs**4) The map is rescaled and reshifted so that a perfect match is realised at the borders of a horizontal line passing by the center. This ensures coerence with the procedure of pixel size calibration which is performed moving a needle horizontally and reading the motor positions at the extreme positions. The maps are written in the target file, creating it as hdf5 file, in the datasets "/coords_source_x" "/coords_source_z" The URLs of these two maps can be used for the detector correction of type "map_xz" in the nabu configuration file as in this example [dataset] ... detector_distortion_correction = map_xz detector_distortion_correction_options = map_x="silx:./map_coordinates.h5?path=/coords_source_x" ; map_z="silx:./map_coordinates.h5?path=/coords_source_z" """ if user_args is None: user_args = sys.argv[1:] args_dict = parse_params_values( CreateDistortionMapHorizontallyMatchedFromPolyConfig, parser_description=create_maps_x_and_z.__doc__, program_version="nabu " + version, user_args=user_args, ) logger = Logger("create_distortion_maps", level=user_args["loglevel"], logfile="create_distortion_maps.log") coords_source_x, coords_source_z, new_axis_pos = create_maps_x_and_z(args_dict) # pylint: disable=E0606 with h5py.File(args_dict["target_file"], "w") as f: f["coords_source_x"] = coords_source_x f["coords_source_z"] = coords_source_z if new_axis_pos is not None: logger.info("New axis position at %e it was previously %e " % (new_axis_pos, args_dict["axis_pos"])) return 0 def create_maps_x_and_z(args_dict): """This method is meant for those applications which wants to use the functionalities of the poly2map entry point through a standar python API. The argument arg_dict must contain the keys that you can find in cli_configs.py: CreateDistortionMapHorizontallyMatchedFromPolyConfig Look at this files for variables and their meaning and defaults Parameters:: args_dict : dict a dictionary containing keys : center_x, center_z, nz, nx, c2, c4, axis_pos return: max_x, map_z, new_rot_pos """ args = DictToObj(args_dict) nz, nx = args.nz, args.nx center_x, center_z = (args.center_x, args.center_z) c4 = args.c4 c2 = args.c2 polynomial = np.poly1d([c4, 0, c2, 0, 1, 0.0]) # change of variable cofv = np.poly1d([1.0 / (nx / 2), 0]) polynomial = nx / 2 * polynomial(cofv) left_border = 0 - center_x right_border = nx - 1 - center_x def get_rescaling_shift(left_border, right_border, polynomial): dl = polynomial(left_border) dr = polynomial(right_border) rescaling = (dr - dl) / (right_border - left_border) shift = -left_border * rescaling + dl return rescaling, shift final_grid_rescaling, final_grid_shift = get_rescaling_shift(left_border, right_border, polynomial) coords_z, coords_x = np.indices([nz, nx]) coords_z = ((coords_z - center_z) * final_grid_rescaling).astype("d") coords_x = ((coords_x - center_x) * final_grid_rescaling + final_grid_shift).astype("d") distances_goal = np.sqrt(coords_z * coords_z + coords_x * coords_x) distances_unknown = distances_goal pp_deriv = polynomial.deriv() # iteratively finding the positions to interpolated at by newton method for i in range(10): errors = polynomial(distances_unknown) - distances_goal derivative = pp_deriv(distances_unknown) distances_unknown = distances_unknown - errors / derivative distances_data_sources = distances_unknown # avoid 0/0 distances_data_sources[distances_goal < 1] = 1 distances_goal[distances_goal < 1] = 1 coords_source_z = coords_z * distances_data_sources / distances_goal + center_z coords_source_x = coords_x * distances_data_sources / distances_goal + center_x new_axis_pos = None if args.axis_pos is not None: # here we search on the central line a point new_axis_pos # such that # coords_source_x[ i_central_y, nuova_axis_pos ] == axis_pos # n_y, n_x = coords_source_x.shape x_coordinates_central_line = coords_source_x[n_y // 2] if np.any(np.less_equal(np.diff(x_coordinates_central_line), 0)): message = """ Error in the coordinates map, the X coordinates are not monotonous on the central line """ raise ValueError(message) i1 = np.searchsorted(x_coordinates_central_line, args.axis_pos) i1 = np.clip(i1, 1, n_x - 1) i0 = i1 - 1 val0 = x_coordinates_central_line[i0] val1 = x_coordinates_central_line[i1] fract = (args.axis_pos - val0) / (val1 - val0) new_axis_pos = i0 * (1 - fract) + i1 * fract return coords_source_x, coords_source_z, new_axis_pos ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/app/diag_to_pix.py0000644000175000017500000003532614654107202016644 0ustar00pierrepierreimport os from multiprocessing import Pool import sys import numpy as np from scipy.ndimage import gaussian_filter import h5py from silx.io.dictdump import h5todict from nxtomo.application.nxtomo import NXtomo from .. import version from ..utils import DictToObj, get_available_threads from .utils import parse_params_values from .cli_configs import DiagToPixConfig from ..pipeline.estimators import oversample """ The operations here below rely on diag objects which are found in the result of a nab-helical run with the diag_zpro_run set to a number > 0 They are found in the configuration section of the nabu output, in several sequential dataset, with hdf5 dataset keys which are 0,1,2.... corresponding to all the z-windows ( chunck) for which we have collected contributions at different angles which are nothing else but pieces of ready to use preprocessed radio , and this for different angles. In other words redundant contributions conccurring at the same prepocessed radiography but comming from different moment of the helical scan are kept separate Forming pairs from contributions for same angle they should coincide where they both have signal ( part of them can be dark if the detector is out of view above or below) For each key there is a sequence of radio, the corresponding sequence of weights map, the corresponding z translation, and angles The number passed to diag_zpro_run object, is >1, and is interpreted by the diagnostic collection run as the number of wished collecte angles between 0 and 180. Lets call it num_0_180 The collection is not done here. Here we exploit the result of a previous collection to deduce, looking at the correlations, which correction we must bring to the pixel size An example of collection is this : |_____ diagnostics | | |__ 0 | |_ radios (4*num_0_180, chunky, chunkx) | | | |_ weights (4*num_0_180, chunky, chunkx) | | | |_ angles ( 4*num_0_180,) | |_ searched_rad ( 2*num_0_180,) these are all the searched angles between 0 and 360 in radians | |_ zmm_trans ( 4*num_0_180,) the z translation in mm | |_ zpix_transl ( 4*num_0_180,) the z translation in pix | |_ pixes_size_mm scalar """ def transform_images(diag, ovs): """Filter the radios, and oversample them along the vertical line. The method in general is similar to the composite cor finding. Several overlapping positions are used to match redundant contributions at different rotation stages ( theta and theta+360). But beforehand it is beneficial to remove low spatial frequencies. And we do oversampling on the fly. """ assert len(ovs) == 2, "oversampling must be specified for both vertical and horizontal dimension" diag.radios[:] = diag.radios / diag.weights diag.radios = [oversample((ima - gaussian_filter(ima, 20, mode="nearest")), ovs) for ima in diag.radios] diag.weights = [oversample(ima, ovs) for ima in diag.weights] def detailed_merit(diag, shift): # res will become the merit summed over all the pairs theta, theta+180 res = 0.0 # need to account for the weight also. So this will become the used weight for the pairs theta, theta+180 res_w = 0.0 ## The following two variables are very important information to be collected. ## On the the z translation over a 360 turn ## the other is the pixel size in mm. ## At the end of the script, the residual shift for perfect correlation ## will used to correct zpix_mm, doing a pro-rata with respect to ## the z observed translation over one turn observed_oneturn_total_shift_zpix_list = [] zpix_mm = None n_angles_2pi = len(diag.radios) // 2 # In accordance with the collection layout for diagnostics (diag_zpro_run parameter passed to nabu-helical) # there are n_angles_pi in [0,180[, and then again the same number of possibly valid radios # (check for nan in z translation) in [180,360[, [360,540[, 540,720[ # So we have len(diag.radios) // 2 in the range [0,360[ # because we have len(diag.radios) in [0,720[ detailed_merit_list = [] # one for each theta theta+360 pair detailed_weight_list = [] # one for each theta theta+360 pair for i in range(n_angles_2pi): # if we have something for both items of the pair, proceed if (not np.isnan(diag.zpix_transl[i])) and (not np.isnan(diag.zpix_transl[i + n_angles_2pi])): # because we have theta and theta + 360 zpix_mm = diag.pixel_size_mm add, add_w = merit( diag.radios[i], diag.radios[i + n_angles_2pi], diag.weights[i], diag.weights[i + n_angles_2pi], shift ) detailed_merit_list.append(add) detailed_weight_list.append(add_w) observed_oneturn_total_shift_zpix_list.append(diag.zpix_transl[i + n_angles_2pi] - diag.zpix_transl[i]) return detailed_merit_list, detailed_weight_list, observed_oneturn_total_shift_zpix_list, zpix_mm def merit(ima_a, ima_b, w_a, w_b, s): """A definition of the merit which accounts also for the data weight. calculates the merit for a given shift s. Comparison between a and b Considering signal ima and weight w """ if s == 0: # return - abs( (ima_a - ima_b) * w_a * w_b ).astype("d").mean(), (w_a * w_b).astype("d").mean() return (ima_a * ima_b * w_a * w_b).astype("d").sum(), (w_a * w_b).astype("d").sum() elif s > 0: # Keep the comment lines in case one wish to test L1 # pi = abs(ima_b[s:] - ima_a[:-s]) # pw = w_b[s:] * w_a[:-s] # return - ( pi * pw ).astype("d").mean(), (pw).astype("d").mean() pi = ima_b[s:] * ima_a[:-s] pw = w_b[s:] * w_a[:-s] return (pi * pw).astype("d").sum(), pw.astype("d").sum() else: # Keep the comment lines in case one wish to test L1 # pi = abs(ima_a[-s:] - ima_b[:s]) # pw = w_a[-s:] * w_b[:s] # return - ( pi * pw ).astype("d").mean(), pw.astype("d").mean() pi = ima_a[-s:] * ima_b[:s] pw = w_a[-s:] * w_b[:s] return (pi * pw).astype("d").sum(), pw.astype("d").sum() def build_total_merit_list(diag, oversample_factor, args): # calculats the merit at all the tested extra adjustment shifts. transform_images(diag, [oversample_factor, 1]) h_ima = diag.radios[0].shape[0] # search_radius_v = min(oversample_factor * args.search_radius_v, h_ima - 1) search_radius_v = oversample_factor * args.search_radius_v shift_s = [] for_all_shifts_detailed_merit_lists = [] for_all_shifts_detailed_weight_lists = [] observed_oneturn_total_shift_zpix_list, zpix_mm = None, None for shift in range(-search_radius_v, search_radius_v + 1): ( detailed_merit_list, detailed_weight_list, found_observed_oneturn_total_shift_zpix_list, found_zpix_mm, ) = detailed_merit(diag, shift) if found_zpix_mm is not None: # the following two lines do not depend on the shift. # The shift is what we do prior to a comparison f images # while the two items below are a properties of the scan # in particular they depend on z_translation and angles from bliss zpix_mm = found_zpix_mm observed_oneturn_total_shift_zpix_list = found_observed_oneturn_total_shift_zpix_list # The merit and weight are the result of comparison, they depend on the shift for_all_shifts_detailed_merit_lists.append(detailed_merit_list) for_all_shifts_detailed_weight_lists.append(detailed_weight_list) shift_s.append( shift / oversample_factor ) # shift_s is stored in original pixel units. Images were oversampled else: # here there is nothing to append, not correspondance was found pass # now transposition: we want for each pair theta, theta+360 a list which contains meritvalues for each adjustment shift # For each pair there is a list which runs over the shifts # Same thing for the weights for_all_pairs_detailed_merit_lists = zip(*for_all_shifts_detailed_merit_lists) for_all_pairs_detailed_weight_lists = zip(*for_all_shifts_detailed_weight_lists) return ( for_all_pairs_detailed_merit_lists, for_all_pairs_detailed_weight_lists, observed_oneturn_total_shift_zpix_list, zpix_mm, ) def main(user_args=None): """Analyse the diagnostics and correct the pixel size""" if user_args is None: user_args = sys.argv[1:] args = DictToObj( parse_params_values( DiagToPixConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=user_args, ) ) oversample_factor = 4 if args.nexus_source is None: args.nexus_source = args.nexus_target ## Read all the available diagnostics. ## Every key correspond to a chunk of the helical pipeline diag_url = os.path.join("/", args.entry_name, "reconstruction/configuration/diagnostics") diag_keys = [] with h5py.File(args.diag_file, "r") as f: diag_keys = list(f[diag_url].keys()) diag_keys = [diag_keys[i] for i in np.argsort(list(map(int, diag_keys)))] # The diag_keys are 0,1,2 ... corresponding to all the z-windows ( chunck) for which we have collected contributions at different angles # which are nothing else but pieces of ready to use preprocessed radio , and this for different angles. # Pairs should coincide where they both have signal ( part of them can be dark if the detector is out of view above or below) # For each key there is a sequence of radio, the corresponding sequence of weights map, the corresponding z translation, and angles zpix_mm = None observed_oneturn_total_shift_zpix = None argument_list = [ (DictToObj(h5todict(args.diag_file, os.path.join(diag_url, my_key))), oversample_factor, args) for my_key in diag_keys ] ncpus = get_available_threads() with Pool(processes=ncpus) as pool: all_res_plus_infos = pool.starmap(build_total_merit_list, argument_list) observed_oneturn_total_shift_zpix, zpix_mm = None, None # needs to flatten the result of pool.map for_all_pairs_detailed_merit_lists = [] for_all_pairs_detailed_weight_lists = [] observed_oneturn_total_shift_zpix_list = [] zpix_mm = None for ( tmp_merit_lists, tmp_weight_lists, tmp_observed_oneturn_total_shift_zpix_list, tmp_zpix_mm, ) in all_res_plus_infos: if tmp_zpix_mm is not None: # then each item of the composed list will be for a given pairs theta, theta+360 # and each such item is a list where each item is for a given probed shift for_all_pairs_detailed_merit_lists.extend(tmp_merit_lists) for_all_pairs_detailed_weight_lists.extend(tmp_weight_lists) observed_oneturn_total_shift_zpix_list.extend(tmp_observed_oneturn_total_shift_zpix_list) zpix_mm = tmp_zpix_mm if zpix_mm is None: message = "No overlapping was found" raise RuntimeError(message) if len(for_all_pairs_detailed_merit_lists) == 0: message = "No diag was found" raise RuntimeError(message) # Now an important search step: # We find for which pair of theta theta+360 the observed translation has the bigger absolute value. # Then the search for the optimum is performed for the readjustment shift in the # range (-search_radius_v, search_radius_v + 1) # considered as readjustmnet for the foud ideal pair which has exactly a translation equal to this maximal absolute observed translation # For all the others the readjustment is multiplied by the pro-rata factor # given by their smaller z-translation over the maximal one max_absolute_shift = abs(np.array(observed_oneturn_total_shift_zpix_list)).max() # gong to search for the best pixel size max_merit = None best_shift = None search_radius_v = oversample_factor * args.search_radius_v probed_shift_list = list(range(-search_radius_v, search_radius_v + 1)) for shift in range(-search_radius_v, search_radius_v + 1): total_sum = 0 total_weight = 0 for merit_list, weight_list, one_turn_shift in zip( for_all_pairs_detailed_merit_lists, for_all_pairs_detailed_weight_lists, observed_oneturn_total_shift_zpix_list, ): # sanity check assert len(merit_list) == len(probed_shift_list), " this should not happen" assert len(weight_list) == len(probed_shift_list), " this should not happen" # pro_rata shift my_shift = shift * (one_turn_shift / max_absolute_shift) # doing interpolation with search sorted i1 = np.searchsorted(probed_shift_list, my_shift) if i1 > 0 and i1 < len(probed_shift_list): i0 = i1 - 1 fract = (-my_shift + probed_shift_list[i1]) / (probed_shift_list[i1] - probed_shift_list[i0]) total_sum += fract * merit_list[i0] + (1 - fract) * merit_list[i1] total_weight += fract * weight_list[i0] + (1 - fract) * weight_list[i1] if total_weight == 0: # this avoid 0/0 = nan total_weight = 1 m = total_sum / total_weight if (max_merit is None) or ((not np.isnan(m)) and m > max_merit): max_merit = m best_shift = shift / oversample_factor print(" Best shift at ", best_shift) print( f" Over one turn the reference shift was {max_absolute_shift} pixels. But a residual shift of {best_shift} remains " ) # the formula below is already purged from the ovrsamplig factor. We did this when we recorded best_shift and the z shift # is registered when lloking at the z_translation and one does not fiddle aroud with the oversamplig at that moment zpix_mm = zpix_mm * (max_absolute_shift) / (max_absolute_shift - best_shift) print(f"Corrected zpix_mm = {zpix_mm}") if args.nexus_target is not None: nx_tomo = NXtomo().load(args.nexus_source, args.entry_name) nx_tomo.instrument.detector.x_pixel_size = zpix_mm * 1.0e-3 # pixel size must be provided in SI (meters) nx_tomo.instrument.detector.y_pixel_size = zpix_mm * 1.0e-3 # pixel size must be provided in SI (meters) nx_tomo.save(file_path=args.nexus_target, data_path=args.entry_name, overwrite=True) return 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/app/diag_to_rot.py0000644000175000017500000004111414712705065016646 0ustar00pierrepierreimport os import sys from multiprocessing import Pool import numpy as np from scipy.ndimage import gaussian_filter import h5py from silx.io.dictdump import h5todict from nxtomo.application.nxtomo import NXtomo from .. import version from ..utils import get_available_threads, DictToObj from ..pipeline.estimators import oversample from .utils import parse_params_values from .cli_configs import DiagToRotConfig """ The operations here below rely on diag objects which are found in the result of a nab-helical run with the diag_zpro_run set to a number > 0 They are found in the configuration section of the nabu output, in several sequential dataset, with hdf5 dataset keys which are 0,1,2.... corresponding to all the z-windows ( chunck) for which we have collected contributions at different angles which are nothing else but pieces of ready to use preprocessed radio , and this for different angles. In other words redundant contributions conccurring at the same prepocessed radiography but comming from different moment of the helical scan are kept separate. By forming pairs theta, theta +180 These Pairs should coincide on an overlapping region after flipping one, where they both have signal ( part of them can be dark if the detector is out of view above or below) For each key there is a sequence of radio, the corresponding sequence of weights map, the corresponding z translation, and angles The number passed to diag_zpro_run object, is >1, and is interpreted by the diagnostic collection run as the number of wished collected angles between 0 and 180. Lets call it num_0_180 The collection is not done here. Here we exploit the result of a previous collection to deduce, looking at the correlations, the cor An example of collection is this : |_____ diagnostics | | |__ 0 | |_ radios (4*num_0_180, chunky, chunkx) | | | |_ weights (4*num_0_180, chunky, chunkx) | | | |_ angles ( 4*num_0_180,) | |_ searched_rad ( 2*num_0_180,) these are all the searched angles between 0 and 360 in radians | |_ zmm_trans ( 4*num_0_180,) the z translation in mm | |_ zpix_transl ( 4*num_0_180,) the z translation in pix | |_ pixes_size_mm scalar Here we follow the evolution of the rotation angle along the scan. The final result can be left in its detailed form, giving the found cor at every analysed scan position, or the result of the interpolation, giving the cor at the two extremal position of z_translation. """ def transform_images(diag, args): """ Filter and transform the radios and the weights. Filter the radios, and oversample them along the horizontal line. The method in general is similar to the composite cor finding. Several overlapping positions are used to match redundant contributions at different rotation stages ( theta and theta+180). But beforehand it is beneficial to remove low spatial frequencies. And we do oversampling on the fly. Parameters: diag: object used member of diag are radios and weights args: object its member are the application parameters. Here we use only: low_pass, high_pass, ovs ( oversampling factor for the horisontal dimension ) """ diag.radios[:] = (diag.radios / diag.weights).astype("f") new_radios = [] for ima in diag.radios: ima = gaussian_filter(ima, [0, args.low_pass], mode="nearest") ima = ima - gaussian_filter(ima, [0, args.high_pass], mode="nearest") new_radios.append(ima) diag.radios = [oversample(ima, [1, args.ovs]).astype("f") for ima in new_radios] diag.weights = [oversample(ima, [1, args.ovs]).astype("f") for ima in diag.weights] def total_merit_list(arg_tuple): """ builds three lists : all_merits, all_energies, all_z_transl For every pair (theta, theta+180 ) add an item to the list which contains: for "all_merits" a list of merit, one for every overlap in the overlap_list argument, for "all_energies", same logic, but calculating the implied energy, implied in the calculation of the merit, for "all_z_transl" we add the averaged z_transl for the considered pair Parameters: diag: object used member of diag are radios, weights and zpix_transl args: object containing the application parameters. Its used members are ovs, high_pass, low_pass """ (diag, overlap_list, args) = arg_tuple orig_sy, ovsd_sx = diag.radios[0].shape all_merits = [] all_energies = [] all_z_transls = [] # the following two lines are in accordance with the nabu collection layout for diagos # there are n_angles_pi in [0,180[, and then again the same number of possibly valid radios # (check for nan in z translation) in [180,360[, [360,540[, 540,720[ n_angles_pi = len(diag.radios) // 4 n_angles_2pi = len(diag.radios) // 2 # check for (theta, theta+180 )pairs whose first radio of the pair in in [0,180[ or [360,540[ for i in list(range(n_angles_pi)) + list(range(n_angles_2pi, n_angles_2pi + n_angles_pi)): merits = [] energies = [] z_transl = [] if (not np.isnan(diag.zpix_transl[i])) and (not np.isnan(diag.zpix_transl[i + n_angles_pi])): radio1 = diag.radios[i] radio2 = diag.radios[i + n_angles_pi] weight1 = diag.weights[i] weight2 = diag.weights[i + n_angles_pi] for overlap in overlap_list: if overlap <= ovsd_sx: my_overlap = overlap my_radio1 = radio1 my_radio2 = radio2 my_weight1 = weight1 my_weight2 = weight2 else: my_overlap = ovsd_sx - (overlap - ovsd_sx) my_radio1 = np.fliplr(radio1) my_radio2 = np.fliplr(radio2) my_weight1 = np.fliplr(weight1) my_weight2 = np.fliplr(weight2) radio_common_left = np.fliplr(my_radio1[:, ovsd_sx - my_overlap :])[ :, : -(args.ovs * args.high_pass * 2) ] radio_common_right = my_radio2[:, ovsd_sx - my_overlap : -(args.ovs * args.high_pass * 2)] diff_common = radio_common_right - radio_common_left weight_common_left = np.fliplr(my_weight1[:, ovsd_sx - my_overlap :])[ :, : -(args.ovs * args.high_pass * 2) ] weight_common_right = my_weight2[:, ovsd_sx - my_overlap : -(args.ovs * args.high_pass * 2)] weight_common = weight_common_right * weight_common_left if args.use_l1_norm: merits.append(abs(diff_common * weight_common).astype("d").sum()) energies.append(abs(weight_common).astype("d").sum()) else: merits.append((diff_common * diff_common * weight_common).astype("d").sum()) energies.append( ( (radio_common_left * radio_common_left + radio_common_right * radio_common_right) * weight_common ) .astype("d") .sum() ) else: merits = [0] * (len(overlap_list)) energies = [0] * (len(overlap_list)) z_transl = 0.5 * (diag.zpix_transl[i] + diag.zpix_transl[i + n_angles_pi]) all_z_transls.append(z_transl) all_merits.append(merits) all_energies.append(energies) return all_merits, all_energies, all_z_transls def find_best_interpolating_line(args): (all_z_transl, index_overlap_list_a, index_overlap_list_b, all_energies, all_res) = args z_a = np.nanmin(all_z_transl) z_b = np.nanmax(all_z_transl) best_error = np.nan best_off_pair = None for index_ovlp_a in index_overlap_list_a: for index_ovlp_b in index_overlap_list_b: index_ovlps = np.interp(all_z_transl, [z_a, z_b], [index_ovlp_a, index_ovlp_b]) indexes = (np.arange(all_energies.shape[0]))[~np.isnan(index_ovlps)].astype("i") index_ovlps = index_ovlps[~np.isnan(index_ovlps)] index_ovlps = np.round_(index_ovlps).astype("i") diff_enes = all_res[(indexes, index_ovlps)] orig_enes = all_energies[(indexes, index_ovlps)] error = (diff_enes / (orig_enes + 1.0e-30)).astype("d").sum() if not (error > best_error): best_error = error best_error_pair = index_ovlp_a, index_ovlp_b return best_error, best_error_pair # pylint: disable=E0606 def main(user_args=None): """Find the cor as a function f z translation and write an hdf5 which contains interpolable tables. This file can be used subsequently with the correct-rot utility. """ if user_args is None: user_args = sys.argv[1:] args = DictToObj( parse_params_values( DiagToRotConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=user_args, ) ) if args.near is None: if args.original_scan is None: raise ValueError( "the parameter near was not provided but the original_scan parameter was not provided either" ) if args.entry_name is None: raise ValueError( "the parameter near was not provided but the entry_name parameter for the original scan was not provided either" ) scan = NXtomo() scan.load(file_path=args.original_scan, data_path=args.entry_name) args.near = scan.instrument.detector.x_rotation_axis_pixel_position else: pass args.ovs = 4 diag_url = os.path.join("/", args.entry_name, "reconstruction/configuration/diagnostics") diag_keys = [] with h5py.File(args.diag_file, "r") as f: diag_keys = list(f[diag_url].keys()) diag_keys = [diag_keys[i] for i in np.argsort(list(map(int, diag_keys)))] all_merits = [] all_energies = [] all_z_transls = [] arguments_for_multiprocessing = [] for i_key, my_key in enumerate(diag_keys): diag = DictToObj(h5todict(args.diag_file, os.path.join(diag_url, my_key))) args.original_shape = diag.radios[0].shape args.zpix_mm = diag.pixel_size_mm transform_images(diag, args) if i_key == 0: orig_sy, ovsd_sx = diag.radios[0].shape # already transformed here, ovsd_sx is expanded args.ovsd_sx = ovsd_sx overlap_min = max(4, ovsd_sx - 2 * args.ovs * (args.near + args.near_width)) overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * args.ovs * (args.near - args.near_width)) overlap_list = list(range(int(overlap_min), int(overlap_max) + 1)) if overlap_min > overlap_max: message = f""" There is no safe search range in find_cor once the margins corresponding to the high_pass filter are discarded. May be the near value (which is the offset respect to the center of the image) is too big, or too negative, in short too close to the borders. """ raise ValueError(message) arguments_for_multiprocessing.append((diag, overlap_list, args)) # pylint: disable=E0606 ncpus = get_available_threads() with Pool(processes=ncpus) as pool: result_list = pool.map(total_merit_list, arguments_for_multiprocessing) for merits, energies, z_transls in result_list: all_z_transls.extend(z_transls) all_merits.extend(merits) all_energies.append(energies) n_pairings_with_data = 0 for en, me in zip(all_merits, all_energies): if np.any(me): n_pairings_with_data += 1 if args.linear_interpolation: if n_pairings_with_data < 2: message = f""" The diagnostics collection has probably been run over a too thin section of the scan or you scan does not allow to form pairs of theta, theta+360. I only found {n_pairings_with_data} pairings and this is not enough to do correlation + interpolation between sections """ raise RuntimeError(message) else: if n_pairings_with_data < 1: message = f""" The diagnostics collection has probably been run over a too thin section of the scan or you scan does not allow to form pairs of theta, theta+360. I only found {n_pairings_with_data} pairings """ raise RuntimeError(message) # all_merits, all_energies, all_z_transls = zip( result_list ) # merits, energies, z_transls = total_merit_list(diag, overlap_list, args) # all_z_transls.extend(z_transls) # all_merits.extend(merits) # all_energies.append(energies) all_merits = np.array(all_merits) all_energies = np.array(all_energies) all_merits.shape = -1, len(overlap_list) all_energies.shape = -1, len(overlap_list) if args.linear_interpolation: do_linear_interpolation(args, overlap_list, all_merits, all_energies, all_z_transls) else: do_height_by_height(args, overlap_list, all_merits, all_energies, all_z_transls) return 0 def do_height_by_height(args, overlap_list, all_diff, all_energies, all_z_transl): # now we find the best cor for each chunk, or nan if no overlap is found z_a = np.min(all_z_transl) z_b = np.max(all_z_transl) grouped_diff = {} grouped_energy = {} for diff, energy, z in zip(all_diff, all_energies, all_z_transl): found = z for key in grouped_diff.keys(): if abs(key - z) < 2.0: # these are in pixel units found = key break grouped_diff[found] = grouped_diff.get(found, np.zeros([len(overlap_list)], "f")) + diff grouped_energy[found] = grouped_energy.get(found, np.zeros([len(overlap_list)], "f")) + energy z_list = list(grouped_energy.keys()) z_list.sort() cor_list = [] for z in z_list: diff = grouped_diff[z] energy = grouped_energy[z] best_error = np.nan best_off = None if not np.isnan(z): for i_ovlp in range(len(overlap_list)): error = diff[i_ovlp] / (energy[i_ovlp] + 1.0e-30) if not (error > best_error): best_error = error best_off = i_ovlp if best_off is not None: offset = (args.ovsd_sx - overlap_list[best_off]) / args.ovs / 2 sy, sx = args.original_shape cor_abs = (sx - 1) / 2 + offset cor_list.append(cor_abs) else: cor_list.append(np.nan) else: # no overlap was available for that z cor_list.append(np.nan) with h5py.File(args.cor_file, "w") as f: my_mask = ~np.isnan(np.array(cor_list)) f["cor"] = np.array(cor_list)[my_mask] f["z_pix"] = np.array(z_list)[my_mask] f["z_m"] = (np.array(z_list)[my_mask]) * args.zpix_mm * 1.0e-3 def do_linear_interpolation(args, overlap_list, all_res, all_energies, all_z_transl): # now we consider all the linear regressions of the offset with z_transl ncpus = get_available_threads() index_overlap_list = np.arange(len(overlap_list)).astype("i") arguments_list = [ (all_z_transl, piece, index_overlap_list, all_energies, all_res) for piece in np.array_split(index_overlap_list, ncpus) ] with Pool(processes=ncpus) as pool: result_list = pool.map(find_best_interpolating_line, arguments_list) error_list = [tok[0] for tok in result_list] best_pos = np.argmin(error_list) best_error, best_error_pair = result_list[best_pos] # find the interpolated line i_ovlp_a, i_ovlp_b = best_error_pair offset_a = (args.ovsd_sx - overlap_list[i_ovlp_a]) / args.ovs / 2 offset_b = (args.ovsd_sx - overlap_list[i_ovlp_b]) / args.ovs / 2 sy, sx = args.original_shape cor_abs_a = (sx - 1) / 2 + offset_a cor_abs_b = (sx - 1) / 2 + offset_b z_a = np.nanmin(all_z_transl) z_b = np.nanmax(all_z_transl) with h5py.File(args.cor_file, "w") as f: f["cor"] = np.array([cor_abs_a, cor_abs_b]) f["z_pix"] = np.array([z_a, z_b]) f["z_m"] = np.array([z_a, z_b]) * args.zpix_mm * 1.0e-3 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/app/double_flatfield.py0000644000175000017500000001245414656662150017651 0ustar00pierrepierreimport numpy as np from ..preproc.double_flatfield import DoubleFlatField from ..preproc.flatfield import FlatField from ..io.writer import NXProcessWriter from ..resources.dataset_analyzer import analyze_dataset from ..resources.nxflatfield import update_dataset_info_flats_darks from ..resources.logger import Logger, LoggerOrPrint from .cli_configs import DFFConfig from .utils import parse_params_values class DoubleFlatFieldChunks: def __init__( self, dataset_path, output_file, chunk_size=100, sigma=None, do_flatfield=True, h5_entry=None, logger=None ): self.logger = LoggerOrPrint(logger) self.dataset_info = analyze_dataset(dataset_path, extra_options={"hdf5_entry": h5_entry}, logger=logger) self.chunk_size = min(chunk_size, self.dataset_info.radio_dims[-1]) self.do_flatfield = bool(do_flatfield) if self.do_flatfield: update_dataset_info_flats_darks(self.dataset_info, flatfield_mode=True) self.output_file = output_file self.sigma = sigma if sigma is not None and abs(sigma) > 1e-5 else None def _get_config(self): conf = { "dataset": self.dataset_info.location, "entry": self.dataset_info.hdf5_entry or None, "dff_sigma": self.sigma, "do_flatfield": self.do_flatfield, } return conf def _read_projections(self, chunk_size, start_idx=0): reader_kwargs = {"sub_region": (slice(None), slice(start_idx, start_idx + chunk_size), slice(None))} if self.dataset_info.kind == "edf": reader_kwargs = {"n_reading_threads": 4} self.reader = self.dataset_info.get_reader(**reader_kwargs) self.projections = self.reader.load_data() def _init_flatfield(self, start_z=None, end_z=None): if not self.do_flatfield: return chunk_size = end_z - start_z if start_z is not None else self.chunk_size self.flatfield = FlatField( (self.dataset_info.n_angles, chunk_size, self.dataset_info.radio_dims[0]), flats={k: arr[start_z:end_z, :] for k, arr in self.dataset_info.flats.items()}, darks={k: arr[start_z:end_z, :] for k, arr in self.dataset_info.darks.items()}, radios_indices=sorted(self.dataset_info.projections.keys()), ) def _apply_flatfield(self, start_z=None, end_z=None): if self.do_flatfield: self._init_flatfield(start_z=start_z, end_z=end_z) self.flatfield.normalize_radios(self.projections) def _init_dff(self): self.double_flatfield = DoubleFlatField( self.projections.shape, input_is_mlog=False, output_is_mlog=False, average_is_on_log=self.sigma is not None, sigma_filter=self.sigma, ) def compute_double_flatfield(self): """ Compute the double flatfield for the current dataset. """ n_z = self.dataset_info.radio_dims[-1] chunk_size = self.chunk_size n_steps = n_z // chunk_size extra_step = bool(n_z % chunk_size) res = np.zeros(self.dataset_info.radio_dims[::-1]) for i in range(n_steps): self.logger.debug("Computing DFF batch %d/%d" % (i + 1, n_steps + int(extra_step))) subregion = (None, None, i * chunk_size, (i + 1) * chunk_size) self._read_projections(chunk_size, start_idx=i * chunk_size) self._apply_flatfield(start_z=i * chunk_size, end_z=(i + 1) * chunk_size) self._init_dff() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[subregion[-2] : subregion[-1]] = dff[:] # Need to initialize objects with a different shape if extra_step: curr_idx = (i + 1) * self.chunk_size self.logger.debug("Computing DFF batch %d/%d" % (i + 2, n_steps + int(extra_step))) self._read_projections(n_z - curr_idx, start_idx=curr_idx) self._apply_flatfield(start_z=(i + 1) * chunk_size, end_z=n_z) self._init_dff() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[curr_idx:] = dff[:] return res def write_double_flatfield(self, arr): """ Write the double flatfield image to a file """ writer = NXProcessWriter( self.output_file, entry=self.dataset_info.hdf5_entry or "entry", filemode="a", overwrite=True, ) writer.write(arr, "double_flatfield", config=self._get_config()) self.logger.info("Wrote %s" % writer.fname) def dff_cli(): args = parse_params_values( DFFConfig, parser_description="A command-line utility for computing the double flatfield of a dataset." ) logger = Logger("nabu_double_flatfield", level=args["loglevel"], logfile="nabu_double_flatfield.log") output_file = args["output"] dff = DoubleFlatFieldChunks( args["dataset"], output_file, chunk_size=args["chunk_size"], sigma=args["sigma"], do_flatfield=bool(args["flatfield"]), h5_entry=args["entry"] or None, logger=logger, ) dff_image = dff.compute_double_flatfield() dff.write_double_flatfield(dff_image) return 0 if __name__ == "__main__": dff_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/app/generate_header.py0000644000175000017500000002132714550227307017460 0ustar00pierrepierreimport os import numpy as np from tomoscan.io import HDF5File from fabio.edfimage import EdfImage from ..io.utils import get_first_hdf5_entry from .utils import parse_params_values from .cli_configs import GenerateInfoConfig edf_header_hdf5_path = { # EDF Header "count_time": "/technique/scan/exposure_time", # ms in HDF5 "date": "/start_time", "energy": "/technique/scan/energy", "flip": "/technique/detector/flipping", # [x, y] in HDF5 "motors": "/instrument/positioners", "optic_used": "/technique/optic/magnification", } info_hdf5_path = { "Energy": "/technique/scan/energy", "Distance": "/technique/scan/sample_detector_distance", # mm in HDF5 and EDF # ~ "Prefix": TODO "Directory": "/technique/saving/path", "ScanRange": "/technique/scan/scan_range", "TOMO_N": "/technique/scan/tomo_n", "REF_ON": "/technique/scan/ref_on", "REF_N": "/technique/reference/ref_n", "DARK_N": "/technique/dark/dark_n", "Dim_1": "/technique/detector/size", # array "Dim_2": "/technique/detector/size", # array "Count_time": "/technique/scan/exposure_time", # ms in HDF5, s in EDF ? "Shutter_time": "/technique/scan/shutter_time", # ms in HDF5, s in EDF ? "Optic_used": "/technique/optic/magnification", "PixelSize": "/technique/detector/pixel_size", # array (microns) "Date": "/start_time", "Scan_Type": "/technique/scan/scan_type", "SrCurrent": "/instrument/machine/current", "Comment": "/technique/scan/comment", # empty ? } def decode_bytes(content): if isinstance(content, bytes): return content.decode() else: return content def simulate_edf_header(fname, entry, return_dict=False): edf_header = { "ByteOrder": "LowByteFirst", } with HDF5File(fname, "r") as fid: for edf_name, h5_path in edf_header_hdf5_path.items(): if edf_name == "motors": continue h5_path = entry + h5_path edf_header[edf_name] = decode_bytes(fid[h5_path][()]) h5_motors = decode_bytes(fid[entry + edf_header_hdf5_path["motors"]]) edf_header["motor_mne"] = list(h5_motors.keys()) edf_header["motor_pos"] = [v[()] for v in h5_motors.values()] # remove "scan_numbers" from motors try: idx = edf_header["motor_mne"].index("scan_numbers") edf_header["motor_mne"].pop(idx) edf_header["motor_pos"].pop(idx) except ValueError: pass # remove invalid values from motor_pos invalid_values = ["*DIS*"] try: for invalid_val in invalid_values: idx = edf_header["motor_pos"].index(invalid_val) edf_header["motor_mne"].pop(idx) edf_header["motor_pos"].pop(idx) except ValueError: pass # Format edf_header["flip"] = "" % (edf_header["flip"][0], edf_header["flip"][1]) edf_header["motor_mne"] = " ".join(edf_header["motor_mne"]) edf_header["motor_pos"] = " ".join(list(map(str, edf_header["motor_pos"]))) edf_header["count_time"] = edf_header["count_time"] * 1e-3 # HDF5: ms -> EDF: s if return_dict: return edf_header res = "" for k, v in edf_header.items(): res = res + "%s = %s ;\n" % (k, v) return res def simulate_info_file(fname, entry, return_dict=False): info_file_content = {} with HDF5File(fname, "r") as fid: for info_name, h5_path in info_hdf5_path.items(): h5_path = entry + h5_path info_file_content[info_name] = decode_bytes(fid[h5_path][()]) info_file_content["Dim_1"] = info_file_content["Dim_1"][0] info_file_content["Dim_2"] = info_file_content["Dim_2"][1] info_file_content["PixelSize"] = info_file_content["PixelSize"][0] info_file_content["Prefix"] = os.path.basename(fname) info_file_content["Col_end"] = info_file_content["Dim_1"] - 1 info_file_content["Col_beg"] = 0 info_file_content["Row_end"] = info_file_content["Dim_2"] - 1 info_file_content["Row_beg"] = 0 for what in ["Count_time", "Shutter_time"]: info_file_content[what] = info_file_content[what] * 1e-3 if return_dict: return info_file_content # Format res = "" for k, v in info_file_content.items(): k_s = str("%s=" % k) sep = " " * (24 - len(k_s)) res = res + k_s + sep + str("%s\n" % v) return res def get_hst_saturations(hist, bins, numels): aMin = bins[0] aMax = bins[-1] hist_sum = numels * 1.0 hist_cum = np.cumsum(hist) hist_cum_rev = np.cumsum(hist[::-1]) i_s1 = np.where(hist_cum > 0.00001 * hist_sum)[0][0] sat1 = aMin + i_s1 * (aMax - aMin) / (hist.size - 1) i_S1 = np.where(hist_cum > 0.002 * hist_sum)[0][0] Sat1 = aMin + i_S1 * (aMax - aMin) / (hist.size - 1) i_s2 = np.argwhere(hist_cum_rev > 0.00001 * hist_sum)[0][0] sat2 = aMin + (hist.size - 1 - i_s2) * (aMax - aMin) / (hist.size - 1) i_S2 = np.argwhere(hist_cum_rev > 0.002 * hist_sum)[0][0] Sat2 = aMin + (hist.size - 1 - i_S2) * (aMax - aMin) / (hist.size - 1) return sat1, sat2, Sat1, Sat2 def simulate_hst_vol_header(fname, entry=None, return_dict=False): with HDF5File(fname, "r") as fid: try: histogram_path = entry + "/histogram/results/data" histogram_config_path = entry + "/histogram/configuration" Nz, Ny, Nx = fid[histogram_config_path]["volume_shape"][()] vol_header_content = { "NUM_X": Nx, "NUM_Y": Ny, "NUM_Z": Nz, "BYTEORDER": "LOWBYTEFIRST", } hist = fid[histogram_path][()] except KeyError as err: print("Could not load histogram from %s: %s" % (fname, err)) return {} if return_dict else "" bins = hist[1] hist = hist[0] vmin = bins[0] vmax = bins[-1] + (bins[-1] - bins[-2]) s1, s2, S1, S2 = get_hst_saturations(hist, bins, Nx * Ny * Nz) vmin = bins[0] vmax = bins[-1] + (bins[-1] - bins[-2]) vol_header_content.update( { "ValMin": vmin, "ValMax": vmax, "s1": s1, "s2": s2, "S1": S1, "S2": S2, } ) if return_dict: return vol_header_content res = "" for k, v in vol_header_content.items(): res = res + "%s = %s\n" % (k, v) return res def format_as_info(d): res = "" for k, v in d.items(): k_s = str("%s=" % k) sep = " " * (24 - len(k_s)) res = res + k_s + sep + str("%s\n" % v) return res def generate_merged_info_file_content( hist_fname=None, hist_entry=None, bliss_fname=None, bliss_entry=None, first_edf_proj=None, info_file=None ): # EDF Header if first_edf_proj is None: edf_header = simulate_edf_header(bliss_fname, bliss_entry, return_dict=True) else: edf = EdfImage() edf.open(first_edf_proj) edf_header = edf.getHeader() # .info File if info_file is None: info_file_content = simulate_info_file(bliss_fname, bliss_entry, return_dict=True) info_file_content = format_as_info(info_file_content) else: with open(info_file, "r") as f: info_file_content = f.read() # .vol File vol_file_content = simulate_hst_vol_header(hist_fname, entry=hist_entry, return_dict=True) # res = format_as_info(edf_header) res += info_file_content res += format_as_info(vol_file_content) return res def str_or_none(s): if len(s.strip()) == 0: return None return s def generate_merged_info_file(): args = parse_params_values(GenerateInfoConfig, parser_description="Generate a .info file") hist_fname = str_or_none(args["hist_file"]) hist_entry = str_or_none(args["hist_entry"]) bliss_fname = str_or_none(args["bliss_file"]) bliss_entry = str_or_none(args["bliss_entry"]) first_edf_proj = str_or_none(args["edf_proj"]) info_file = str_or_none(args["info_file"]) if hist_fname is not None and hist_entry is None: hist_entry = get_first_hdf5_entry(hist_fname) if bliss_fname is not None and bliss_entry is None: bliss_entry = get_first_hdf5_entry(bliss_fname) if not ((bliss_fname is None) ^ (info_file is None)): print("Error: please provide either --bliss_file or --info_file") exit(1) if info_file is not None and first_edf_proj is None: print("Error: please provide also --edf_proj when using the EDF format") exit(1) content = generate_merged_info_file_content( hist_fname=hist_fname, hist_entry=hist_entry, bliss_fname=bliss_fname, bliss_entry=bliss_entry, first_edf_proj=first_edf_proj, info_file=info_file, ) with open(args["output"], "w") as f: f.write(content) return 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1706619687.0 nabu-2024.2.1/nabu/app/histogram.py0000644000175000017500000001737214556171447016371 0ustar00pierrepierrefrom os import path import posixpath from silx.io.url import DataUrl from silx.io.dictdump import h5todict from ..utils import check_supported from ..io.utils import get_first_hdf5_entry, get_h5_value from ..io.writer import NXProcessWriter from ..processing.histogram import PartialHistogram, VolumeHistogram, hist_as_2Darray from ..processing.histogram_cuda import CudaVolumeHistogram from ..resources.logger import Logger, LoggerOrPrint from .utils import parse_params_values from .cli_configs import HistogramConfig class VolumesHistogram: """ A class for extracting or computing histograms of one or several volumes. """ available_backends = { "numpy": VolumeHistogram, "cuda": CudaVolumeHistogram, } def __init__( self, fnames, output_file, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None, backend="cuda" ): """ Initialize a VolumesHistogram object. Parameters ----------- fnames: list of str List of paths to HDF5 files. To specify an entry for each file name, use the "?" separator: /path/to/file.h5?entry0001 output_file: str Path to the output file write_histogram_if_computed: bool, optional Whether to write histograms that are computed to a file. Some volumes might be missing their histogram. In this case, the histogram is computed, and the result is written to a dedicated file in the same directory as 'output_file'. Default is True. """ self._get_files_and_entries(fnames) self.chunk_size_slices = chunk_size_slices self.chunk_size_GB = chunk_size_GB self.nbins = nbins self.logger = LoggerOrPrint(logger) self.output_file = output_file self._get_histogrammer_backend(backend) def _get_files_and_entries(self, fnames): res_fnames = [] res_entries = [] for fname in fnames: if "?" not in fname: entry = None else: fname, entry = fname.split("?") if entry == "": entry = None res_fnames.append(fname) res_entries.append(entry) self.fnames = res_fnames self.entries = res_entries def _get_histogrammer_backend(self, backend): check_supported(backend, self.available_backends.keys(), "histogram backend") self.VolumeHistogramClass = self.available_backends[backend] def _get_config_onevolume(self, fname, entry, data_shape): return { "chunk_size_slices": self.chunk_size_slices, "chunk_size_GB": self.chunk_size_GB, "bins": self.nbins, "filename": fname, "entry": entry, "volume_shape": data_shape, } def _get_config(self): conf = self._get_config_onevolume("", "", None) conf.pop("filename") conf.pop("entry") conf["filenames"] = self.fnames conf["entries"] = [entry if entry is not None else "None" for entry in self.entries] return conf def _write_histogram_onevolume(self, fname, entry, histogram, data_shape): output_file = ( path.join(path.dirname(self.output_file), path.splitext(path.basename(fname))[0]) + "_histogram" + path.splitext(fname)[1] ) self.logger.info("Writing histogram of %s into %s" % (fname, output_file)) writer = NXProcessWriter(output_file, entry, filemode="w", overwrite=True) writer.write( hist_as_2Darray(histogram), "histogram", config=self._get_config_onevolume(fname, entry, data_shape) ) def get_histogram_single_volume(self, fname, entry, write_histogram_if_computed=True, return_config=False): entry = entry or get_first_hdf5_entry(fname) hist_path = posixpath.join(entry, "histogram", "results", "data") hist_cfg_path = posixpath.join(entry, "histogram", "configuration") rec_path = posixpath.join(entry, "reconstruction", "results", "data") rec_url = DataUrl(file_path=fname, data_path=rec_path) hist = get_h5_value(fname, hist_path) config = None if hist is None: self.logger.info("No histogram found in %s, computing it" % fname) vol_histogrammer = self.VolumeHistogramClass( rec_url, chunk_size_slices=self.chunk_size_slices, chunk_size_GB=self.chunk_size_GB, nbins=self.nbins, logger=self.logger, ) hist = vol_histogrammer.compute_volume_histogram() if write_histogram_if_computed: self._write_histogram_onevolume(fname, entry, hist, vol_histogrammer.data_shape) else: if return_config: raise ValueError( "return_config must be set to True to get configuration for non-existing histograms" ) hist = hist_as_2Darray(hist) config = h5todict(path.splitext(fname)[0] + "_histogram" + path.splitext(fname)[1], path=hist_cfg_path) if return_config: return hist, config else: return hist def get_histogram(self, return_config=False): histograms = [] configs = [] for fname, entry in zip(self.fnames, self.entries): self.logger.info("Getting histogram for %s" % fname) hist, conf = self.get_histogram_single_volume(fname, entry, return_config=True) histograms.append(hist) configs.append(conf) self.logger.info("Merging histograms") histogrammer = PartialHistogram(method="fixed_bins_number", num_bins=self.nbins) hist = histogrammer.merge_histograms(histograms, dont_truncate_bins=True) if return_config: return hist, configs else: return hist def merge_histograms_configurations(self, configs): if configs is None or len(configs) == 0: return res_config = {"volume_shape": list(configs[0]["volume_shape"])} res_config["volume_shape"][0] = 0 for conf in configs: nz, ny, nx = conf["volume_shape"] res_config["volume_shape"][0] += nz res_config["volume_shape"] = tuple(res_config["volume_shape"]) return res_config def write_histogram(self, hist, config=None): self.logger.info("Writing final histogram to %s" % (self.output_file)) config = config or {} base_config = self._get_config() base_config.pop("volume_shape") config.update(base_config) writer = NXProcessWriter(self.output_file, "entry0000", filemode="w", overwrite=True) writer.write(hist_as_2Darray(hist), "histogram", config=config) def histogram_cli(): args = parse_params_values(HistogramConfig, parser_description="Extract/compute histogram of volume(s).") logger = Logger("nabu_histogram", level=args["loglevel"], logfile="nabu_histogram.log") output = args["output_file"].split("?")[0] if path.exists(output): logger.fatal("Output file %s already exists, not overwriting it" % output) exit(1) chunk_size_gb = float(args["chunk_size_GB"]) if chunk_size_gb <= 0: chunk_size_gb = None histogramer = VolumesHistogram( args["h5_file"], output, chunk_size_slices=int(args["chunk_size_slices"]), chunk_size_GB=chunk_size_gb, nbins=int(args["bins"]), logger=logger, ) hist, configs = histogramer.get_histogram(return_config=True) config = histogramer.merge_histograms_configurations(configs) histogramer.write_histogram(hist, config=config) return 0 if __name__ == "__main__": histogram_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1711446150.0 nabu-2024.2.1/nabu/app/multicor.py0000644000175000017500000000772714600514206016215 0ustar00pierrepierrefrom os import remove import numpy as np from .. import version from .reconstruct import get_reconstructor from .cli_configs import MultiCorConfig from .utils import parse_params_values from ..utils import view_as_images_stack def get_user_cors(cors): """ From a user-provided str describing the centers of rotation, build a list. """ cors = cors.strip("[()]") cors = cors.split(",") cors = [c.strip() for c in cors] cors_list = [] for c in cors: if ":" in c: if c.count(":") != 2: raise ValueError("Malformed range format for '%s': expected format start:stop:step" % c) start, stop, step = c.split(":") c_list = np.arange(float(start), float(stop), float(step)).tolist() else: c_list = [float(c)] cors_list.extend(c_list) return cors_list def main(): args = parse_params_values( MultiCorConfig, parser_description=f"Perform a tomographic reconstruction of a single slice using multiple centers of rotation", program_version="nabu " + version, ) reconstructor = get_reconstructor( args, # Put a dummy CoR to avoid crash in both full-FoV and extended-FoV. # It will be overwritten later by the user-defined CoRs overwrite_options={"reconstruction/rotation_axis_position": 10.0}, ) if reconstructor.delta_z > 1: raise ValueError("Only slice reconstruction can be used (have delta_z = %d)" % reconstructor.delta_z) reconstructor.reconstruct() # warm-up, spawn pipeline pipeline = reconstructor.pipeline file_prefix = pipeline.processing_options["save"]["file_prefix"] ##### # Remove the first reconstructed file (not used here) last_file = list(pipeline.writer.writer.browse_data_files())[-1] try: remove(last_file) except: pass ###### cors = get_user_cors(args["cor"]) all_recs = [] rec_instance = pipeline.reconstruction for cor in cors: # Re-configure with new CoR pipeline.processing_options["reconstruction"]["rotation_axis_position"] = cor pipeline.processing_options["save"]["file_prefix"] = file_prefix + "_%.03f" % cor pipeline._init_writer(create_subfolder=False, single_output_file_initialized=False) # Get sinogram into contiguous array # TODO Can't do memcpy2D ?! It used to work in cuda 11. # For now: transfer to host... not optimal sino = pipeline._d_radios[:, pipeline._d_radios.shape[1] // 2, :].get() # pylint: disable=E1136 if pipeline.process_config.do_halftomo: # re-initialize FBP object, because in half-tomography the output slice size is a function of CoR options = pipeline.processing_options["reconstruction"] rec_instance = pipeline.FBPClass( sino.shape, angles=options["angles"], rot_center=cor, filter_name=options["fbp_filter_type"] or "none", halftomo=options["enable_halftomo"], # slice_roi=self.process_config.rec_roi, padding_mode=options["padding_type"], extra_options={ "scale_factor": 1.0 / options["voxel_size_cm"][0], "axis_correction": options["axis_correction"], "centered_axis": options["centered_axis"], "clip_outer_circle": options["clip_outer_circle"], "filter_cutoff": options["fbp_filter_cutoff"], }, ) else: pipeline.reconstruction.reset_rot_center(cor) # Run reconstruction rec = rec_instance.fbp(sino) # if return_all_recs: # all_recs.append(rec) rec_3D = view_as_images_stack(rec) # writer wants 3D data # Write pipeline.writer.write_data(rec_3D) reconstructor.logger.info("Wrote %s" % pipeline.writer.fname) return 0 if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/app/nx_z_splitter.py0000644000175000017500000001156014422670612017257 0ustar00pierrepierreimport warnings from shutil import copy as copy_file from os import path from h5py import VirtualSource, VirtualLayout from tomoscan.io import HDF5File from ..resources.logger import Logger, LoggerOrPrint from ..io.utils import get_first_hdf5_entry from .cli_configs import ZSplitConfig from .utils import parse_params_values warnings.warn( "This command-line utility is intended as a temporary solution. Please do not rely too much on it.", Warning ) def _get_z_translations(fname, entry): z_path = path.join(entry, "sample", "z_translation") with HDF5File(fname, "r") as fid: z_transl = fid[z_path][:] return z_transl class NXZSplitter: def __init__(self, fname, output_dir, n_stages=None, entry=None, logger=None, use_virtual_dataset=False): self.fname = fname self._ext = path.splitext(fname)[-1] self.output_dir = output_dir self.n_stages = n_stages if entry is None: entry = get_first_hdf5_entry(fname) self.entry = entry self.logger = LoggerOrPrint(logger) self.use_virtual_dataset = use_virtual_dataset def _patch_nx_file(self, fname, mask): orig_fname = self.fname detector_path = path.join(self.entry, "instrument", "detector") sample_path = path.join(self.entry, "sample") with HDF5File(fname, "a") as fid: def patch_nx_entry(name): newval = fid[name][mask] del fid[name] fid[name] = newval detector_entries = [ path.join(detector_path, what) for what in ["count_time", "image_key", "image_key_control"] ] sample_entries = [ path.join(sample_path, what) for what in ["rotation_angle", "x_translation", "y_translation", "z_translation"] ] for what in detector_entries + sample_entries: self.logger.debug("Patching %s" % what) patch_nx_entry(what) # Patch "data" using a virtual dataset self.logger.debug("Patching data") data_path = path.join(detector_path, "data") if self.use_virtual_dataset: data_shape = fid[data_path].shape data_dtype = fid[data_path].dtype new_data_shape = (int(mask.sum()),) + data_shape[1:] vlayout = VirtualLayout(shape=new_data_shape, dtype=data_dtype) vsource = VirtualSource(orig_fname, name=data_path, shape=data_shape, dtype=data_dtype) vlayout[:] = vsource[mask, :, :] del fid[data_path] fid[detector_path].create_virtual_dataset("data", vlayout) if not (self.use_virtual_dataset): data_path = path.join(self.entry, "instrument", "detector", "data") with HDF5File(orig_fname, "r") as fid: data_arr = fid[data_path][mask, :, :] # Actually load data. Heavy ! with HDF5File(fname, "a") as fid: del fid[data_path] fid[data_path] = data_arr def z_split(self): """ Split a HDF5-NX file according to different z_translation. """ z_transl = _get_z_translations(self.fname, self.entry) different_z = set(z_transl) n_z = len(different_z) self.logger.info("Detected %d different z values: %s" % (n_z, str(different_z))) if n_z <= 1: raise ValueError("Detected only %d z-value. Stopping." % n_z) if self.n_stages is not None and self.n_stages != n_z: raise ValueError("Expected %d different stages, but I detected %d" % (self.n_stages, n_z)) masks = [(z_transl == z) for z in different_z] for i_z, mask in enumerate(masks): fname_curr_z = path.join( self.output_dir, path.splitext(path.basename(self.fname))[0] + str("_%06d" % i_z) + self._ext ) self.logger.info("Creating %s" % fname_curr_z) copy_file(self.fname, fname_curr_z) self._patch_nx_file(fname_curr_z, mask) def zsplit(): # Parse arguments args = parse_params_values( ZSplitConfig, parser_description="Split a HDF5-Nexus file according to z translation (z-series)" ) # Sanitize arguments fname = args["input_file"] output_dir = args["output_directory"] loglevel = args["loglevel"].upper() entry = args["entry"] if len(entry) == 0: entry = None n_stages = args["n_stages"] if n_stages < 0: n_stages = None use_virtual_dataset = bool(args["use_virtual_dataset"]) # Instantiate and execute logger = Logger("NX_z-splitter", level=loglevel, logfile="nxzsplit.log") nx_splitter = NXZSplitter( fname, output_dir, n_stages=n_stages, entry=entry, logger=logger, use_virtual_dataset=use_virtual_dataset ) nx_splitter.z_split() return 0 if __name__ == "__main__": zsplit() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/app/parse_reconstruction_log.py0000644000175000017500000001117414654107202021465 0ustar00pierrepierreimport numpy as np from os import path from datetime import datetime from ..utils import check_supported, convert_str_to_tuple from .utils import parse_params_values from .cli_configs import ShowReconstructionTimingsConfig try: import matplotlib.pyplot as plt __have_matplotlib__ = True except ImportError: __have_matplotlib__ = False steps_to_measure = [ "Reading data", "Applying flat-field", "Applying double flat-field", "Applying CCD corrections", "Rotating projections", "Performing phase retrieval", "Performing unsharp mask", "Taking logarithm", "Applying radios movements", "Normalizing sinograms", "Building sinograms", # deprecated "Removing rings on sinograms", "Reconstruction", "Computing histogram", "Saving data", ] def extract_timings_from_volume_reconstruction_lines(lines, separator=" - "): def extract_timestamp(line): timestamp = line.split(separator)[0] return datetime.strptime(timestamp, "%d-%m-%Y %H:%M:%S") def extract_current_step(line): return line.split(separator)[-1] current_step = extract_current_step(lines[0]) t1 = extract_timestamp(lines[0]) res = {} for line in lines[1:]: line = line.strip() if len(line.split(separator)) == 1: continue timestamp = line.strip().split(separator)[0] t2 = datetime.strptime(timestamp, "%d-%m-%Y %H:%M:%S") res.setdefault(current_step, []) res[current_step].append((t2 - t1).seconds) t1 = t2 current_step = extract_current_step(line) return res def parse_logfile(fname, separator=" - "): """ Returns ------- timings: list of dict List of dictionaries: one dict per reconstruction in the log file. For each dict, the key is the pipeline step name, and the value is the list of timings for the different chunks. """ with open(fname, "r") as f: lines = f.readlines() start_text = "Going to reconstruct slices" end_text = "Merging reconstructions to" start_line = None rec_log_bounds = [] for i, line in enumerate(lines): if start_text in line: start_line = i if end_text in line: if start_line is None: raise ValueError("Could not find reconstruction start string indicator") rec_log_bounds.append((start_line, i)) rec_file_basename = path.basename(line.split(end_text)[-1]) results = [] for bounds in rec_log_bounds: start, end = bounds timings = {} res = extract_timings_from_volume_reconstruction_lines(lines[start:end], separator=separator) for step in steps_to_measure: if step in res: timings[step] = res[step] results.append(timings) return results def display_timings_pie(timings, reduce_function=None, cutoffs=None): reduce_function = reduce_function or np.median cutoffs = cutoffs or (0, np.inf) def _format_pie_text(pct, allvals): # https://matplotlib.org/stable/gallery/pie_and_polar_charts/pie_and_donut_labels.html absolute = int(np.round(pct / 100.0 * np.sum(allvals))) return f"{pct:.1f}%\n({absolute:d} s)" for run in timings: fig = plt.figure() pie_labels = [] pie_sizes = [] for step_name, step_timings in run.items(): t = reduce_function(step_timings) if t > cutoffs[0] and t < cutoffs[1]: # pie_labels.append(step_name) pie_labels.append(step_name + "\n(%d s)" % t) pie_sizes.append(t) ax = fig.subplots() # ax.pie(pie_sizes, labels=pie_labels, autopct=lambda pct: _format_pie_text(pct, pie_sizes)) # autopct='%1.1f%%') ax.pie(pie_sizes, labels=pie_labels, autopct="%1.1f%%") fig.show() input("Press any key to continue") def parse_reclog_cli(): args = parse_params_values( ShowReconstructionTimingsConfig, parser_description="Display reconstruction performances from a log file" ) if not (__have_matplotlib__): print("Need matplotlib to use this utility") exit(1) display_functions = { "pie": display_timings_pie, } logfile = args["logfile"] cutoff = args["cutoff"] display_type = args["type"] check_supported(display_type, display_functions.keys(), "Graphics display type") if cutoff is not None: cutoff = list(map(float, convert_str_to_tuple(cutoff))) timings = parse_logfile(logfile) display_functions[display_type](timings, cutoffs=cutoff) return 0 if __name__ == "__main__": parse_reclog_cli() exit(0) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1706619687.0 nabu-2024.2.1/nabu/app/prepare_weights_double.py0000644000175000017500000001256714556171447021117 0ustar00pierrepierreimport h5py import numpy as np from scipy.special import erf # pylint: disable=all import sys import os from scipy.ndimage import gaussian_filter from nxtomo.nxobject.nxdetector import ImageKey from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer from ..io.reader import load_images_from_dataurl_dict def main(argv=None): """auxiliary program that can be called to create default input detector profiles, for nabu helical, concerning the weights of the pixels and the "double flat" renormalisation denominator. The result is an hdf5 file that can be used as a "processes_file" in the nabu configuration and is used by nabu-helical. In particulars cases the user may have fancy masks and correction map and will provide its own processes file, and will not need this. This code, and in particular the auxiliary function below (that by the way tomwer can use) provide a default construction of such maps. The double-flat is set to one and the weight is build on the basis of the flat fields from the dataset with an apodisation on the borders which allows to eliminate discontinuities in the contributions from the borders, above and below for the z-translations, and on the left or roght border for half-tomo. The usage is :: nabu-helical-prepare-weights-double nexus_file_name entry_name Then the resulting file can be used as processes file in the configuration file of nabu-helical """ if argv is None: argv = sys.argv[1:] if len(argv) not in [2, 3, 4, 5, 6]: message = f""" Usage: nabu-helical-prepare-weights-double nexus_file_name entry_name [target_file name [transition_width_vertical [rotation_axis_position [transition_width_vertical]]]] """ print(message) sys.exit(1) file_name = argv[0] if len(os.path.dirname(file_name)) == 0: # To make sure that other utility routines can succesfully deal with path within the current directory file_name = os.path.join(".", file_name) # still tere was some problem with relative path and how they are dealt with in nxtomomill # Better to use absolute path file_name = os.path.abspath(file_name) entry_name = argv[1] process_file_name = "double.h5" dataset_info = HDF5DatasetAnalyzer(file_name, extra_options={"h5_entry": entry_name}) update_dataset_info_flats_darks(dataset_info, flatfield_mode=1) beam_profile = 0 my_flats = load_images_from_dataurl_dict(dataset_info.flats) for key, flat in my_flats.items(): beam_profile += flat beam_profile = beam_profile / len(list(dataset_info.flats.keys())) transition_width_vertical = 50.0 # the following two line determines the horisontal transition # by default a transition on the right ( corresponds to an axis close to the right border) rotation_axis_position = beam_profile.shape[1] - 200 transition_width_horizontal = 100.0 if len(argv) in [3, 4, 5, 6]: process_file_name = argv[2] if len(argv) in [4, 5, 6]: transition_width_vertical = float(argv[3]) if len(argv) in [5, 6]: rotation_axis_position = (beam_profile.shape[1] - 1) / 2 + float(argv[4]) if len(argv) in [6]: transition_width_horizontal = 2 * (float(argv[5])) create_heli_maps( beam_profile, process_file_name, entry_name, transition_width_vertical, rotation_axis_position, transition_width_horizontal, ) # here we have been called by the cli. The return value 0 means OK return 0 def create_heli_maps( profile, process_file_name, entry_name, transition_width_vertical, rotation_axis_position, transition_width_horizontal, ): profile = profile / profile.max() profile = profile.astype("f") profile = gaussian_filter(profile, 10) def compute_border_v(L, m, w): x = np.arange(L) d = (x - L + m).astype("f") res_r = (1 - erf(d / w)) / 2 d = (x - m).astype("f") res_l = (1 + erf(d / w)) / 2 return res_r * res_l def compute_border_h(L, r, tw): if r > (L - 1) / 2: if tw > (L - r): tw = max(1.0, L - r) m = tw / 2 w = tw / 5 x = np.arange(L) d = (x - L + m).astype("f") res_r = (1 - erf(d / w)) / 2 return res_r else: if tw > r: tw = max(1.0, r) m = tw / 2 w = tw / 5 x = np.arange(L) d = (x - m).astype("f") res_l = (1 + erf(d / w)) / 2 return res_l with h5py.File(process_file_name, mode="a") as fd: path_weights = entry_name + "/weights_field/results/data" path_double = entry_name + "/double_flatfield/results/data" if path_weights in fd: del fd[path_weights] if path_double in fd: del fd[path_double] border = compute_border_h(profile.shape[1], rotation_axis_position, transition_width_horizontal) border_v = compute_border_v( profile.shape[0], round(transition_width_vertical / 2), transition_width_vertical / 5 ) fd[path_weights] = (profile * border) * border_v[:, None] fd[path_double] = np.ones_like(profile) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/app/reconstruct.py0000644000175000017500000001324014726604214016725 0ustar00pierrepierrefrom tomoscan.io import HDF5File from .. import version from ..utils import list_match_queries from ..pipeline.config import parse_nabu_config_file from ..pipeline.config_validators import convert_to_int from .cli_configs import ReconstructConfig from .utils import parse_params_values def update_reconstruction_start_end(conf_dict, user_indices): if len(user_indices) == 0: return rec_cfg = conf_dict["reconstruction"] err = None val_int, conv_err = convert_to_int(user_indices) if conv_err is None: start_z = user_indices end_z = user_indices else: if user_indices in ["first", "middle", "last"]: start_z = user_indices end_z = user_indices elif user_indices == "all": start_z = 0 end_z = -1 elif "-" in user_indices: try: start_z, end_z = user_indices.split("-") start_z = int(start_z) end_z = int(end_z) except Exception as exc: err = "Could not interpret slice indices '%s': %s" % (user_indices, str(exc)) else: err = "Could not interpret slice indices: %s" % user_indices if err is not None: print(err) exit(1) rec_cfg["start_z"] = start_z rec_cfg["end_z"] = end_z def get_log_file(arg_logfile, legacy_arg_logfile, forbidden=None): default_arg_val = "" # Compat. log_file --> logfile if legacy_arg_logfile != default_arg_val: logfile = legacy_arg_logfile else: logfile = arg_logfile # if forbidden is None: forbidden = [] for forbidden_val in forbidden: if logfile == forbidden_val: print("Error: --logfile argument cannot have the value %s" % forbidden_val) exit(1) if logfile == "": logfile = True return logfile def get_reconstructor(args, overwrite_options=None): # Imports are done here, otherwise "nabu --version" takes forever from ..pipeline.fullfield.processconfig import ProcessConfig from ..pipeline.fullfield.reconstruction import FullFieldReconstructor # logfile = get_log_file(args["logfile"], args["log_file"], forbidden=[args["input_file"]]) conf_dict = parse_nabu_config_file(args["input_file"]) update_reconstruction_start_end(conf_dict, args["slice"].strip()) if overwrite_options is not None: for option_key, option_val in overwrite_options.items(): opt_section, opt_name = option_key.split("/") conf_dict[opt_section][opt_name] = option_val proc = ProcessConfig(conf_dict=conf_dict, create_logger=logfile) logger = proc.logger logger.info("Going to reconstruct slices (%d, %d)" % (proc.rec_region["start_z"], proc.rec_region["end_z"])) # Get extra options extra_options = { "gpu_mem_fraction": args["gpu_mem_fraction"], "cpu_mem_fraction": args["cpu_mem_fraction"], "chunk_size": args["max_chunk_size"] if args["max_chunk_size"] > 0 else None, "margin": args["phase_margin"], "force_grouped_mode": bool(args["force_use_grouped_pipeline"]), } reconstructor = FullFieldReconstructor(proc, logger=logger, extra_options=extra_options) return reconstructor def list_hdf5_entries(fname): with HDF5File(fname, "r") as f: entries = list(f.keys()) return entries def main(): args = parse_params_values( ReconstructConfig, parser_description=f"Perform a tomographic reconstruction.", program_version="nabu " + version, ) # Get extra options extra_options = { "gpu_mem_fraction": args["gpu_mem_fraction"], "cpu_mem_fraction": args["cpu_mem_fraction"], "chunk_size": args["max_chunk_size"] if args["max_chunk_size"] > 0 else None, "margin": args["phase_margin"], "force_grouped_mode": bool(args["force_use_grouped_pipeline"]), } # logfile = get_log_file(args["logfile"], args["log_file"], forbidden=[args["input_file"]]) conf_dict = parse_nabu_config_file(args["input_file"]) update_reconstruction_start_end(conf_dict, args["slice"].strip()) # Imports are done here, otherwise "nabu --version" takes forever from ..pipeline.fullfield.processconfig import ProcessConfig from ..pipeline.fullfield.reconstruction import FullFieldReconstructor # hdf5_entries = conf_dict["dataset"].get("hdf5_entry", "").strip(",") # spit by coma and remove empty spaces hdf5_entries = [e.strip() for e in hdf5_entries.split(",")] # clear '/' at beginning of the entry (so both entry like 'entry0000' and '/entry0000' are handled) hdf5_entries = [e.lstrip("/") for e in hdf5_entries] if hdf5_entries != [""]: file_hdf5_entries = list_hdf5_entries(conf_dict["dataset"]["location"]) hdf5_entries = list_match_queries(file_hdf5_entries, hdf5_entries) if hdf5_entries == []: raise ValueError("No entry found matching pattern '%s'" % conf_dict["dataset"]["hdf5_entry"]) for hdf5_entry in hdf5_entries: if len(hdf5_entries) > 1: print("-" * 80) print("Processing entry: %s" % hdf5_entry) print("-" * 80) conf_dict["dataset"]["hdf5_entry"] = hdf5_entry proc = ProcessConfig(conf_dict=conf_dict, create_logger=logfile) # logger is in append mode logger = proc.logger logger.info("Going to reconstruct slices (%d, %d)" % (proc.rec_region["start_z"], proc.rec_region["end_z"])) R = FullFieldReconstructor(proc, logger=logger, extra_options=extra_options) proc = R.process_config R.reconstruct() R.finalize_files_saving() return 0 if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/app/reconstruct_helical.py0000644000175000017500000001054214656662150020415 0ustar00pierrepierrefrom .. import version from ..resources.utils import is_hdf5_extension from ..pipeline.config import parse_nabu_config_file from ..pipeline.config_validators import convert_to_int from .cli_configs import ReconstructConfig from .utils import parse_params_values from .reconstruct import update_reconstruction_start_end, get_log_file def main_helical(): ReconstructConfig["dry_run"] = { "help": "Stops after printing some information on the reconstruction layout.", "default": 0, "type": int, } ReconstructConfig["diag_zpro_run"] = { "help": "run the pipeline without reconstructing but collecting the contributing radios slices for angles theta+n*360. The given argument is the number of thet in the interval [0 ,180[. The same number is taken if available in [180,360[. And the whole is repated is available in [0,360[ for a total of 4*diag_zpro_run possible exctracted contributions", "default": 0, "type": int, } args = parse_params_values( ReconstructConfig, parser_description=f"Perform a helical tomographic reconstruction", program_version="nabu " + version, ) # Imports are done here, otherwise "nabu --version" takes forever from ..pipeline.helical.processconfig import ProcessConfig from ..pipeline.helical.helical_reconstruction import HelicalReconstructorRegridded # # A crash with scikit-cuda happens only on PPC64 platform if and nvidia-persistenced is running. # On such machines, a warm-up has to be done. import platform if platform.machine() == "ppc64le": try: from silx.math.fft.cufft import CUFFT except: # can't catch narrower - cublasNotInitialized requires cublas ! CUFFT = None # logfile = get_log_file(args["logfile"], args["log_file"], forbidden=[args["input_file"]]) conf_dict = parse_nabu_config_file(args["input_file"]) update_reconstruction_start_end(conf_dict, args["slice"].strip()) proc = ProcessConfig(conf_dict=conf_dict, create_logger=logfile) logger = proc.logger if "tilt_correction" in proc.processing_steps: message = """ The rotate_projections step is activated. The Helical pipelines are not yet suited for projection rotation it will soon be implemented. For the moment you should deactivate the rotation options in nabu.conf """ raise ValueError(message) # Determine which reconstructor to use reconstructor_cls = None phase_method = None if "phase" in proc.processing_steps: phase_method = proc.processing_options["phase"]["method"] # fix the reconstruction roi if not given if "reconstruction" in proc.processing_steps: rec_config = proc.processing_options["reconstruction"] rot_center = rec_config["rotation_axis_position"] Nx, Ny = proc.dataset_info.radio_dims if proc.nabu_config["reconstruction"]["auto_size"]: if 2 * rot_center > Nx: w = int(round(2 * rot_center)) else: w = int(round(2 * Nx - 2 * rot_center)) rec_config["start_x"] = int(round(rot_center - w / 2)) rec_config["end_x"] = int(round(rot_center + w / 2)) rec_config["start_y"] = rec_config["start_x"] rec_config["end_y"] = rec_config["end_x"] reconstructor_cls = HelicalReconstructorRegridded logger.debug("Using pipeline: %s" % reconstructor_cls.__name__) # Get extra options extra_options = { "gpu_mem_fraction": args["gpu_mem_fraction"], "cpu_mem_fraction": args["cpu_mem_fraction"], } extra_options.update( { ##### ??? "use_phase_margin": args["use_phase_margin"], "max_chunk_size": args["max_chunk_size"] if args["max_chunk_size"] > 0 else None, "phase_margin": args["phase_margin"], "dry_run": args["dry_run"], "diag_zpro_run": args["diag_zpro_run"], } ) R = reconstructor_cls(proc, logger=logger, extra_options=extra_options) R.reconstruct() if not R.dry_run: R.merge_data_dumps() if is_hdf5_extension(proc.nabu_config["output"]["file_format"]): R.merge_hdf5_reconstructions() R.merge_histograms() # here we have been called by the cli. The return value 0 means OK return 0 if __name__ == "__main__": main_helical() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708524430.0 nabu-2024.2.1/nabu/app/reduce_dark_flat.py0000644000175000017500000001467214565401616017644 0ustar00pierrepierreimport sys import logging import argparse from typing import Optional from nabu.app.cli_configs import ReduceDarkFlatConfig from .utils import parse_params_values from .. import version from tomoscan.framereducer.method import ReduceMethod from tomoscan.scanbase import TomoScanBase from tomoscan.esrf.scan.edfscan import EDFTomoScan from tomoscan.factory import Factory from silx.io.url import DataUrl def _create_data_urls(output_file: Optional[str], output_data_path: Optional[str], name: str): """ util function to compute reduced Data and metadata url(s) This only handle the case of hdf5 outputs """ assert name in ("flats", "darks"), f"name is '{name}'" def get_data_paths(data_path: Optional[str]) -> tuple: """return (data_path, metadta_path)""" if data_path is None: return "{entry}/" + name + "/{index}", "{entry}/" + name elif not data_path.endswith("/{index}"): # we are not expecting useds to provide the index but only upstream part return data_path + "/{index}", data_path else: raise RuntimeError( "unhandled use case (/index provided) and don;t know where to set the data and the metadata" ) data_path, metadata_path = get_data_paths(output_data_path) output_file = output_file or "{scan_prefix}_" + f"{name}.hdf5" data_urls = [ DataUrl( file_path=output_file, data_path=data_path, scheme="silx", ), ] metadata_urls = [ DataUrl( file_path=output_file, data_path=metadata_path, scheme="silx", ), ] return data_urls, metadata_urls def reduce_dark_flat( scan: TomoScanBase, dark_method: ReduceMethod, flat_method: ReduceMethod, overwrite: bool = False, output_reduced_darks_file: Optional[str] = None, output_reduced_darks_data_path: Optional[str] = None, output_reduced_flats_file: Optional[str] = None, output_reduced_flats_data_path: Optional[str] = None, ) -> int: """ calculation of the darks / flats calling tomoscan utils function """ dark_method = ReduceMethod.from_value(dark_method) if dark_method is not None else None flat_method = ReduceMethod.from_value(flat_method) if flat_method is not None else None # 1. define url where to save the file ## 1.1 for darks if dark_method is None: reduced_darks_data_urls = () reduced_darks_metadata_urls = () elif output_reduced_darks_file is None and output_reduced_darks_data_path is None: # if no settings provided then take the default path (the idea is also to be more robust to future modifications) reduced_darks_data_urls = scan.REDUCED_DARKS_DATAURLS reduced_darks_metadata_urls = scan.REDUCED_DARKS_METADATAURLS elif isinstance(scan, EDFTomoScan): # simplification of the equation raise ValueError("reduce-dark-flat can only compute create dark-flats at default location for edf") else: reduced_darks_data_urls, reduced_darks_metadata_urls = _create_data_urls( output_file=output_reduced_darks_file, output_data_path=output_reduced_darks_data_path, name="darks", ) ## 1.2 for flats if flat_method is None: reduced_flats_data_urls = () reduced_flats_metadata_urls = () elif output_reduced_flats_file is None and output_reduced_flats_data_path is None: # if no settings provided then take the default path (the idea is also to be more robust to future modifications) reduced_flats_data_urls = scan.REDUCED_FLATS_DATAURLS reduced_flats_metadata_urls = scan.REDUCED_FLATS_METADATAURLS elif isinstance(scan, EDFTomoScan): # simplification of the equation raise ValueError("reduce-dark-flat can only compute create dark-flats at default location for edf") else: reduced_flats_data_urls, reduced_flats_metadata_urls = _create_data_urls( output_file=output_reduced_flats_file, output_data_path=output_reduced_flats_data_path, name="flats", ) # 2. compute and save darks / flats success = True ## 2.1 handle dark if dark_method is not None: try: reduced_darks, darks_metadata = scan.compute_reduced_darks( reduced_method=dark_method, overwrite=overwrite, return_info=True, ) except Exception as e: print(f"failed to create reduced darks. Error is {e}") success = False else: scan.save_reduced_darks( darks=reduced_darks, darks_infos=darks_metadata, output_urls=reduced_darks_data_urls, metadata_output_urls=reduced_darks_metadata_urls, overwrite=overwrite, ) ## 2.2 handle flats if flat_method is not None: try: reduced_flats, flats_metadata = scan.compute_reduced_flats( reduced_method=flat_method, overwrite=overwrite, return_info=True, ) except Exception as e: print(f"failed to create reduced flats. Error is {e}") success = False else: scan.save_reduced_flats( flats=reduced_flats, flats_infos=flats_metadata, output_urls=reduced_flats_data_urls, metadata_output_urls=reduced_flats_metadata_urls, overwrite=overwrite, ) return success def main(argv=None): """ Compute reduce dark(s) and flat(s) of a dataset """ if argv is None: argv = sys.argv[1:] args = parse_params_values( ReduceDarkFlatConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=argv, ) scan = Factory.create_scan_object(args["dataset"], entry=args["entry"]) exit( reduce_dark_flat( scan=scan, dark_method=args["dark_method"], flat_method=args["flat_method"], overwrite=args["overwrite"], output_reduced_darks_file=args["output_reduced_darks_file"], output_reduced_darks_data_path=args["output_reduced_darks_data_path"], output_reduced_flats_file=args["output_reduced_flats_file"], output_reduced_flats_data_path=args["output_reduced_flats_data_path"], ) ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/app/rotate.py0000644000175000017500000001445514550227307015660 0ustar00pierrepierreimport posixpath from os import path from math import ceil from shutil import copy from multiprocessing import cpu_count from multiprocessing.pool import ThreadPool import numpy as np from tomoscan.io import HDF5File from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from ..io.utils import get_first_hdf5_entry from ..processing.rotation import Rotation from ..resources.logger import Logger, LoggerOrPrint from ..pipeline.config_validators import optional_tuple_of_floats_validator, boolean_validator from ..processing.rotation_cuda import CudaRotation, __has_pycuda__ from .utils import parse_params_values from .cli_configs import RotateRadiosConfig class HDF5ImagesStackRotation: def __init__( self, input_file, output_file, angle, center=None, entry=None, logger=None, batch_size=100, use_cuda=True, use_multiprocessing=True, ): self.logger = LoggerOrPrint(logger) self.use_cuda = use_cuda & __has_pycuda__ self.batch_size = batch_size self.use_multiprocessing = use_multiprocessing self._browse_dataset(input_file, entry) self._get_rotation(angle, center) self._init_output_dataset(output_file) def _browse_dataset(self, input_file, entry): self.input_file = input_file if entry is None or entry == "": entry = get_first_hdf5_entry(input_file) self.entry = entry self.dataset_info = NXtomoScan(input_file, entry=entry) def _get_rotation(self, angle, center): if self.use_cuda: self.logger.info("Using Cuda rotation") rot_cls = CudaRotation else: self.logger.info("Using skimage rotation") rot_cls = Rotation if self.use_multiprocessing: self.thread_pool = ThreadPool(processes=cpu_count() - 2) self.logger.info("Using multiprocessing with %d cores" % self.thread_pool._processes) self.rotation = rot_cls((self.dataset_info.dim_2, self.dataset_info.dim_1), angle, center=center, mode="edge") def _init_output_dataset(self, output_file): self.output_file = output_file copy(self.input_file, output_file) first_proj_url = self.dataset_info.projections[list(self.dataset_info.projections.keys())[0]] self.data_path = first_proj_url.data_path() dirname, basename = posixpath.split(self.data_path) self._data_path_dirname = dirname self._data_path_basename = basename def _rotate_stack_cuda(self, images, output): # pylint: disable=E1136 self.rotation.cuda_processing.allocate_array("tmp_images_stack", images.shape) self.rotation.cuda_processing.allocate_array("tmp_images_stack_rot", images.shape) d_in = self.rotation.cuda_processing.get_array("tmp_images_stack") d_out = self.rotation.cuda_processing.get_array("tmp_images_stack_rot") n_imgs = images.shape[0] d_in[:n_imgs].set(images) for j in range(n_imgs): self.rotation.rotate(d_in[j], output=d_out[j]) d_out[:n_imgs].get(ary=output[:n_imgs]) def _rotate_stack(self, images, output): if self.use_cuda: self._rotate_stack_cuda(images, output) elif self.use_multiprocessing: out_tmp = self.thread_pool.map(self.rotation.rotate, images) print(out_tmp[0]) output[:] = np.array(out_tmp, dtype="f") # list -> np array... consumes twice as much memory else: for j in range(images.shape[0]): output[j] = self.rotation.rotate(images[j]) def rotate_images(self, suffix="_rot"): data_path = self.data_path fid = HDF5File(self.input_file, "r") fid_out = HDF5File(self.output_file, "a") try: data_ptr = fid[data_path] n_images = data_ptr.shape[0] data_out_ptr = fid_out[data_path] # Delete virtual dataset in output file, create "data_rot" dataset del fid_out[data_path] fid_out[self._data_path_dirname].create_dataset( self._data_path_basename + suffix, shape=data_ptr.shape, dtype=data_ptr.dtype ) data_out_ptr = fid_out[data_path + suffix] # read by group of images to hide latency group_size = self.batch_size images_rot = np.zeros((group_size, data_ptr.shape[1], data_ptr.shape[2]), dtype="f") n_groups = ceil(n_images / group_size) for i in range(n_groups): self.logger.info("Processing radios group %d/%d" % (i + 1, n_groups)) i_min = i * group_size i_max = min((i + 1) * group_size, n_images) images = data_ptr[i_min:i_max, :, :].astype("f") self._rotate_stack(images, images_rot) data_out_ptr[i_min:i_max, :, :] = images_rot[: i_max - i_min, :, :].astype(data_ptr.dtype) finally: fid_out[self._data_path_dirname].move(posixpath.basename(data_path) + suffix, self._data_path_basename) fid_out[data_path].attrs["interpretation"] = "image" fid.close() fid_out.close() def rotate_cli(): args = parse_params_values( RotateRadiosConfig, parser_description="A command-line utility for performing a rotation on all the radios of a dataset.", ) logger = Logger("nabu_rotate", level=args["loglevel"], logfile="nabu_rotate.log") dataset_path = args["dataset"] h5_entry = args["entry"] output_file = args["output"] center = optional_tuple_of_floats_validator("", "", args["center"]) # pylint: disable=E1121 use_cuda = boolean_validator("", "", args["use_cuda"]) # pylint: disable=E1121 use_multiprocessing = boolean_validator("", "", args["use_multiprocessing"]) # pylint: disable=E1121 if path.exists(output_file): logger.fatal("Output file %s already exists, not overwriting it" % output_file) exit(1) h5rot = HDF5ImagesStackRotation( dataset_path, output_file, args["angle"], center=center, entry=h5_entry, logger=logger, batch_size=args["batchsize"], use_cuda=use_cuda, use_multiprocessing=use_multiprocessing, ) h5rot.rotate_images() return 0 if __name__ == "__main__": rotate_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/app/shrink_dataset.py0000644000175000017500000000666314422670612017366 0ustar00pierrepierreimport os import posixpath from multiprocessing.pool import ThreadPool import numpy as np from silx.io.dictdump import dicttonx, nxtodict from ..misc.binning import binning as image_binning from ..io.utils import get_first_hdf5_entry from ..pipeline.config_validators import optional_tuple_of_floats_validator, optional_positive_integer_validator from .cli_configs import ShrinkConfig from .utils import parse_params_values def access_nested_dict(dict_, path, default=None): items = [s for s in path.split(posixpath.sep) if len(s) > 0] if len(items) == 1: return dict_.get(items[0], default) if items[0] not in dict_: return default return access_nested_dict(dict_[items[0]], posixpath.sep.join(items[1:])) def set_nested_dict_value(dict_, path, val): dirname, basename = posixpath.split(path) sub_dict = access_nested_dict(dict_, dirname) sub_dict[basename] = val def shrink_dataset(input_file, output_file, binning=None, subsampling=None, entry=None, n_threads=1): entry = entry or get_first_hdf5_entry(input_file) data_dict = nxtodict(input_file, path=entry, dereference_links=False) to_subsample = [ "control/data", "instrument/detector/count_time", "instrument/detector/data", "instrument/detector/image_key", "instrument/detector/image_key_control", "sample/rotation_angle", "sample/x_translation", "sample/y_translation", "sample/z_translation", ] detector_data = access_nested_dict(data_dict, "instrument/detector/data") if detector_data is None: raise ValueError("No data found in %s entry %s" % (input_file, entry)) if binning is not None: def _apply_binning(img_res_tuple): img, res = img_res_tuple res[:] = image_binning(img, binning) data_binned = np.zeros( (detector_data.shape[0], detector_data.shape[1] // binning[0], detector_data.shape[2] // binning[1]), detector_data.dtype, ) with ThreadPool(n_threads) as tp: tp.map(_apply_binning, zip(detector_data, data_binned)) detector_data = data_binned set_nested_dict_value(data_dict, "instrument/detector/data", data_binned) if subsampling is not None: for item_path in to_subsample: item_val = access_nested_dict(data_dict, item_path) if item_val is not None: set_nested_dict_value(data_dict, item_path, item_val[::subsampling]) dicttonx(data_dict, output_file, h5path=entry) def shrink_cli(): args = parse_params_values(ShrinkConfig, parser_description="Shrink a NX dataset") if not (os.path.isfile(args["input_file"])): print("No such file: %s" % args["input_file"]) exit(1) if os.path.isfile(args["output_file"]): print("Output file %s already exists, not overwriting it" % args["output_file"]) exit(1) binning = optional_tuple_of_floats_validator("", "binning", args["binning"]) # pylint: disable=E1121 if binning is not None: binning = tuple(map(int, binning)) subsampling = optional_positive_integer_validator("", "subsampling", args["subsampling"]) # pylint: disable=E1121 shrink_dataset( args["input_file"], args["output_file"], binning=binning, subsampling=subsampling, entry=args["entry"], n_threads=args["threads"], ) return 0 if __name__ == "__main__": shrink_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/app/stitching.py0000644000175000017500000001014214713343202016335 0ustar00pierrepierreimport logging from pprint import pformat from tqdm import tqdm from nabu.stitching.slurm_utils import split_stitching_configuration_to_slurm_job from .cli_configs import StitchingConfig from ..pipeline.config import parse_nabu_config_file from nabu.stitching.single_axis_stitching import stitching from nabu.stitching.utils.post_processing import StitchingPostProcAggregation from nabu.stitching.config import dict_to_config_obj from .utils import parse_params_values try: from sluurp.executor import submit except ImportError: has_sluurp = False else: has_sluurp = True _logger = logging.getLogger(__name__) def main(): args = parse_params_values( StitchingConfig, parser_description="Run stitching from a configuration file. Configuration can be obtain from `stitching-config`", ) logging.basicConfig(level=args["loglevel"].upper()) conf_dict = parse_nabu_config_file(args["input-file"], allow_no_value=True) stitching_config = dict_to_config_obj(conf_dict) assert stitching_config.axis is not None, "axis must be defined to know how to stitch" _logger.info(" when loaded axis is %s", stitching_config.axis) stitching_config.settle_inputs() if args["only_create_master_file"]: # option to ease creation of the master in the following cases: # * user has submitted all the job but has been kicked out of the cluster # * only a few slurm job for some random version (cluster update...) and user want to retrigger only those job and process the aggregation only. On those cases no need to redo it all. tomo_objs = [] for _, sub_config in split_stitching_configuration_to_slurm_job(stitching_config, yield_configuration=True): tomo_objs.append(sub_config.get_output_object().get_identifier().to_str()) post_processing = StitchingPostProcAggregation( existing_objs_ids=tomo_objs, stitching_config=stitching_config, ) post_processing.process() elif stitching_config.slurm_config.partition in (None, ""): # case 1: run locally _logger.info("run stitching locally with: %s", pformat(stitching_config.to_dict())) main_progress = tqdm(total=100, desc="stitching", leave=True) stitching(stitching_config, progress=main_progress) else: if not has_sluurp: raise ImportError( "sluurp not install. Please install it to distribute stitching on slurm (pip install slurm)" ) main_progress = tqdm(total=100, position=0, desc="stitching") # case 2: run on slurm # note: to speed up we could do shift research on pre processing and run it only once (if manual of course). Here it will be run for all part _logger.info(f"will distribute stitching") futures = {} # 2.1 launch jobs slurm_job_progress_bars: dict = {} for i_job, (job, sub_config) in enumerate( split_stitching_configuration_to_slurm_job(stitching_config, yield_configuration=True) ): _logger.info(f"submit job nb {i_job}: handles {sub_config.slices}") output_volume = sub_config.get_output_object().get_identifier().to_str() futures[output_volume] = submit(job, timeout=999999) # note on total=100: we only consider percentage in this case (providing advancement from slurm jobs) slurm_job_progress_bars[job] = tqdm( total=100, position=i_job + 1, desc=f" part {str(i_job).ljust(3)}", delay=0.5, # avoid to mess with terminal and (near) future logs bar_format="{l_bar}{bar}", # avoid using 'r_bar' as 'total' is set to 100 (percentage) leave=False, ) main_progress.n = 50 # 2.2 wait for future to be done and concatenate the result post_processing = StitchingPostProcAggregation( futures=futures, stitching_config=stitching_config, progress_bars=slurm_job_progress_bars, ) post_processing.process() exit(0) if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5007565 nabu-2024.2.1/nabu/app/tests/0000755000175000017500000000000014730277752015152 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1700659359.0 nabu-2024.2.1/nabu/app/tests/__init__.py0000644000175000017500000000000014527400237017237 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730363900.0 nabu-2024.2.1/nabu/app/tests/test_reduce_dark_flat.py0000644000175000017500000000524514710640774022043 0ustar00pierrepierreimport os import pytest from nabu.app.reduce_dark_flat import reduce_dark_flat ##### try: from tomoscan.tests.utils import NXtomoMockContext except ImportError: from tomoscan.test.utils import NXtomoMockContext @pytest.fixture(scope="function") def hdf5_scan(tmp_path): """simple fixture to create a scan and provide it to another function""" test_dir = tmp_path / "my_hdf5_scan" with NXtomoMockContext( scan_path=str(test_dir), n_proj=10, n_ini_proj=10, ) as scan: yield scan ###### @pytest.mark.parametrize("dark_method", (None, "first", "mean")) @pytest.mark.parametrize("flat_method", (None, "last", "median")) def test_reduce_dark_flat_hdf5(tmp_path, hdf5_scan, dark_method, flat_method): # noqa F811 """simply test output - processing is tested at tomoscan side""" # test with default url default_darks_path = os.path.join(hdf5_scan.path, hdf5_scan.get_dataset_basename() + "_darks.hdf5") default_flats_path = os.path.join(hdf5_scan.path, hdf5_scan.get_dataset_basename() + "_flats.hdf5") assert not os.path.exists(default_darks_path) assert not os.path.exists(default_flats_path) reduce_dark_flat( scan=hdf5_scan, dark_method=dark_method, flat_method=flat_method, ) if dark_method is not None: assert os.path.exists(default_darks_path) else: assert not os.path.exists(default_darks_path) if flat_method is not None: assert os.path.exists(default_flats_path) else: assert not os.path.exists(default_flats_path) # make sure if already exists and no overwrite fails if dark_method is not None or flat_method is not None: with pytest.raises(KeyError): reduce_dark_flat( scan=hdf5_scan, dark_method=dark_method, flat_method=flat_method, overwrite=False, ) # test with url provided by the user tuned_darks_path = os.path.join(tmp_path, "new_folder", "darks.hdf5") tuned_flats_path = os.path.join(tmp_path, "new_folder", "flats.hdf5") assert not os.path.exists(tuned_darks_path) assert not os.path.exists(tuned_flats_path) reduce_dark_flat( scan=hdf5_scan, dark_method=dark_method, flat_method=flat_method, output_reduced_darks_file=tuned_darks_path, output_reduced_flats_file=tuned_flats_path, ) if dark_method is not None: assert os.path.exists(tuned_darks_path) else: assert not os.path.exists(tuned_darks_path) if flat_method is not None: assert os.path.exists(tuned_flats_path) else: assert not os.path.exists(tuned_flats_path) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/app/utils.py0000644000175000017500000000223214550227307015510 0ustar00pierrepierrefrom argparse import ArgumentParser def parse_params_values(Params, parser_description=None, program_version=None, user_args=None): parser = ArgumentParser(description=parser_description) for param_name, vals in Params.items(): if param_name[0] != "-": # It would be better to use "required" and not to pop it. # required is an accepted keyword for argparse optional = not (vals.pop("mandatory", False)) if optional: param_name = "--" + param_name aliases = vals.pop("aliases", tuple()) if optional: aliases = tuple(["--" + alias for alias in aliases]) else: aliases = () parser.add_argument(param_name, *aliases, **vals) if program_version is not None: parser.add_argument("--version", "-V", action="version", version=program_version) args = parser.parse_args(args=user_args) args_dict = args.__dict__ return args_dict def parse_sections(sections): sections = sections.lower() if sections == "all": return None sections = sections.replace(" ", "").split(",") return sections ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/app/validator.py0000644000175000017500000000645014550227307016343 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import sys import os import h5py import tomoscan.validator from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from tomoscan.esrf.scan.edfscan import EDFTomoScan def get_scans(path, entries: str): path = os.path.abspath(path) res = [] if EDFTomoScan.is_tomoscan_dir(path): res.append(EDFTomoScan(scan=path)) elif NXtomoScan.is_tomoscan_dir(path): if entries == "__all__": entries = NXtomoScan.get_valid_entries(path) for entry in entries: res.append(NXtomoScan(path, entry)) else: raise TypeError(f"{path} does not looks like a folder containing .EDF or a valid nexus file ") return res def main(): argv = sys.argv parser = argparse.ArgumentParser(description="Check if provided scan(s) seems valid to be reconstructed.") parser.add_argument("path", help="Data to validate (h5 file, edf folder)") parser.add_argument("entries", help="Entries to be validated (in the case of a h5 file)", nargs="*") parser.add_argument( "--ignore-dark", help="Do not check for dark", default=True, action="store_false", dest="check_dark", ) parser.add_argument( "--ignore-flat", help="Do not check for flat", default=True, action="store_false", dest="check_flat", ) parser.add_argument( "--no-phase-retrieval", help="Check scan energy, distance and pixel size", dest="check_phase_retrieval", default=True, action="store_false", ) parser.add_argument( "--check-nan", help="Check frames if contains any nan.", dest="check_nan", default=False, action="store_true", ) parser.add_argument( "--skip-links-check", "--no-link-check", help="Check frames dataset if have some broken links.", dest="check_vds", default=True, action="store_false", ) parser.add_argument( "--all-entries", help="Check all entries of the files (for HDF5 only for now)", default=False, action="store_true", ) parser.add_argument( "--extend", help="By default it only display items with issues. Extend will display them all", dest="only_issues", default=True, action="store_false", ) options = parser.parse_args(argv[1:]) if options.all_entries is True: entries = "__all__" else: if len(options.entries) == 0 and h5py.is_hdf5(options.path): entries = "__all__" else: entries = options.entries scans = get_scans(path=options.path, entries=entries) if len(scans) == 0: raise ValueError(f"No scan found from file:{options.path}, entries:{options.entries}") for scan in scans: validator = tomoscan.validator.ReconstructionValidator( scan=scan, check_phase_retrieval=options.check_phase_retrieval, check_values=options.check_nan, check_vds=options.check_vds, check_dark=options.check_dark, check_flat=options.check_flat, ) sys.stdout.write(validator.checkup(only_issues=options.only_issues)) return 0 if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5007565 nabu-2024.2.1/nabu/cuda/0000755000175000017500000000000014730277752014144 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/__init__.py0000644000175000017500000000000014315516747016242 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/convolution.py0000644000175000017500000000036114550227307017064 0ustar00pierrepierrefrom ..processing.convolution_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.cuda.convolution has been moved to nabu.processing.convolution_cuda", do_print=True, func_name="convolution_cuda", ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/fft.py0000644000175000017500000000030214550227307015257 0ustar00pierrepierrefrom ..processing.fft_cuda import * from ..utils import deprecation_warning deprecation_warning("nabu.cuda.fft has been moved to nabu.processing.fft_cuda", do_print=True, func_name="fft_cuda") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/cuda/kernel.py0000644000175000017500000000730014654107202015761 0ustar00pierrepierreimport pycuda.gpuarray as garray from pycuda.compiler import SourceModule from ..processing.kernel_base import KernelBase from ..utils import catch_warnings # TODO use warnings.catch_warnings once python < 3.11 is dropped class CudaKernel(KernelBase): """ Helper class that wraps CUDA kernel through pycuda SourceModule. Parameters ----------- kernel_name: str Name of the CUDA kernel. filename: str, optional Path to the file name containing kernels definitions src: str, optional Source code of kernels definitions signature: str, optional Signature of kernel function. If provided, pycuda will not guess the types of kernel arguments, making the calls slightly faster. For example, a function acting on two pointers, an integer and a float32 has the signature "PPif". texrefs: list, optional List of texture references, if any automation_params: dict, optional Automation parameters, see below sourcemodule_kwargs: optional Extra arguments to provide to pycuda.compiler.SourceModule(), """ def __init__( self, kernel_name, filename=None, src=None, signature=None, texrefs=None, automation_params=None, silent_compilation_warnings=False, **sourcemodule_kwargs, ): super().__init__( kernel_name, filename=filename, src=src, automation_params=automation_params, silent_compilation_warnings=silent_compilation_warnings, ) self.compile_kernel_source(kernel_name, sourcemodule_kwargs) self.prepare(signature, texrefs) def compile_kernel_source(self, kernel_name, sourcemodule_kwargs): self.sourcemodule_kwargs = sourcemodule_kwargs self.kernel_name = kernel_name with catch_warnings(action=("ignore" if self.silent_compilation_warnings else None)): # pylint: disable=E1123 self.module = SourceModule(self.src, **self.sourcemodule_kwargs) self.func = self.module.get_function(kernel_name) def prepare(self, kernel_signature, texrefs): self.prepared = False self.kernel_signature = kernel_signature self.texrefs = texrefs or [] if kernel_signature is not None: self.func.prepare(self.kernel_signature, texrefs=self.texrefs) self.prepared = True def follow_device_arr(self, args): args = list(args) # Replace GPUArray with GPUArray.gpudata for i, arg in enumerate(args): if isinstance(arg, garray.GPUArray): args[i] = arg.gpudata return tuple(args) def get_last_kernel_time(self): """ Return the execution time (in seconds) of the last called kernel. The last called kernel should have been called with time_kernel=True. """ if self.last_kernel_time is not None: return self.last_kernel_time() else: return None def call(self, *args, **kwargs): grid, block, args, kwargs = self._prepare_call(*args, **kwargs) if self.prepared: func_call = self.func.prepared_call if "time_kernel" in kwargs: func_call = self.func.prepared_timed_call kwargs.pop("time_kernel") if "block" in kwargs: kwargs.pop("block") if "grid" in kwargs: kwargs.pop("grid") t = func_call(grid, block, *args, **kwargs) else: kwargs["block"] = block kwargs["grid"] = grid t = self.func(*args, **kwargs) self.last_kernel_time = t # list ? __call__ = call ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/medfilt.py0000644000175000017500000000033014550227307016125 0ustar00pierrepierrefrom ..processing.medfilt_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.cuda.medfilt has been moved to nabu.processing.medfilt_cuda", do_print=True, func_name="medfilt_cuda" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/padding.py0000644000175000017500000000033014550227307016107 0ustar00pierrepierrefrom ..processing.padding_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.cuda.padding has been moved to nabu.processing.padding_cuda", do_print=True, func_name="padding_cuda" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/cuda/processing.py0000644000175000017500000000513714654107202016663 0ustar00pierrepierrefrom ..utils import MissingComponentError from ..processing.processing_base import ProcessingBase from .utils import get_cuda_context, __has_pycuda__ if __has_pycuda__: import pycuda.driver as cuda import pycuda.gpuarray as garray from ..cuda.kernel import CudaKernel dev_attrs = cuda.device_attribute GPUArray = garray.GPUArray from pycuda.tools import dtype_to_ctype else: GPUArray = MissingComponentError("pycuda") dtype_to_ctype = MissingComponentError("pycuda") # NB: we must detach from a context before creating another context class CudaProcessing(ProcessingBase): array_class = GPUArray if __has_pycuda__ else None dtype_to_ctype = dtype_to_ctype def __init__(self, device_id=None, ctx=None, stream=None, cleanup_at_exit=True): """ Initialie a CudaProcessing instance. CudaProcessing is a base class for all CUDA-based processings. This class provides utilities for context/device management, and arrays allocation. Parameters ---------- device_id: int, optional ID of the cuda device to use (those of the `nvidia-smi` command). Ignored if ctx is not None. ctx: pycuda.driver.Context, optional Existing context to use. If provided, do not create a new context. stream: pycudacuda.driver.Stream, optional Cuda stream. If not provided, will use the default stream cleanup_at_exit: bool, optional Whether to clean-up the context at exit. Ignored if ctx is not None. """ super().__init__() if ctx is None: self.ctx = get_cuda_context(device_id=device_id, cleanup_at_exit=cleanup_at_exit) else: self.ctx = ctx self.stream = stream self.device = self.ctx.get_device() self.device_name = self.device.name() self.device_id = self.device.get_attribute(dev_attrs.MULTI_GPU_BOARD_GROUP_ID) # pylint: disable=E0606 def push_context(self): self.ctx.push() return self.ctx def pop_context(self): self.ctx.pop() def _allocate_array_mem(self, shape, dtype): return garray.zeros(shape, dtype) def kernel( self, kernel_name, filename=None, src=None, signature=None, texrefs=None, automation_params=None, **build_kwargs ): return CudaKernel( # pylint: disable=E0606 kernel_name, filename=filename, src=src, signature=signature, texrefs=texrefs, automation_params=automation_params, **build_kwargs, ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/cuda/src/0000755000175000017500000000000014730277752014733 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/ElementOp.cu0000644000175000017500000001602414550227307017146 0ustar00pierrepierre#include typedef pycuda::complex complex; // Generic operations #define OP_ADD 0 #define OP_SUB 1 #define OP_MUL 2 #define OP_DIV 3 // #ifndef GENERIC_OP #define GENERIC_OP OP_ADD #endif // arr2D *= arr1D (line by line, i.e along fast dim) __global__ void inplace_complex_mul_2Dby1D(complex* arr2D, complex* arr1D, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; // This does not seem to work // Use cuCmulf of cuComplex.h ? //~ arr2D[y*width + x] *= arr1D[x]; size_t i = y*width + x; complex a = arr2D[i]; complex b = arr1D[x]; arr2D[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr2D[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } __global__ void inplace_generic_op_2Dby2D(float* arr2D, float* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; uint i = y*width + x; #if GENERIC_OP == OP_ADD arr2D[i] += arr2D_other[i]; #elif GENERIC_OP == OP_SUB arr2D[i] -= arr2D_other[i]; #elif GENERIC_OP == OP_MUL arr2D[i] *= arr2D_other[i]; #elif GENERIC_OP == OP_DIV arr2D[i] /= arr2D_other[i]; #endif } // launched with (Nx, Ny, Nz) threads // does array3D[x, y, z] = op(array3D[x, y, z], array1D[x]) (in the "numpy broadcasting" sense) __global__ void inplace_generic_op_3Dby1D( float * array3D, float* array1D, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; uint z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return; size_t idx = ((z * Ny) + y)*Nx + x; #if GENERIC_OP == OP_ADD array3D[idx] += array1D[x]; #elif GENERIC_OP == OP_SUB array3D[idx] -= array1D[x]; #elif GENERIC_OP == OP_MUL array3D[idx] *= array1D[x]; #elif GENERIC_OP == OP_DIV array3D[idx] /= array1D[x]; #endif } // arr3D *= arr1D (along fast dim) __global__ void inplace_complex_mul_3Dby1D(complex* arr3D, complex* arr1D, int width, int height, int depth) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; int z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= width) || (y >= height) || (z >= depth)) return; // This does not seem to work // Use cuCmulf of cuComplex.h ? //~ arr3D[(z*height + y)*width + x] *= arr1D[x]; size_t i = (z*height + y)*width + x; complex a = arr3D[i]; complex b = arr1D[x]; arr3D[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr3D[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } // arr2D *= arr2D __global__ void inplace_complex_mul_2Dby2D(complex* arr2D_out, complex* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; size_t i = y*width + x; complex a = arr2D_out[i]; complex b = arr2D_other[i]; arr2D_out[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr2D_out[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } // arr2D *= arr2D __global__ void inplace_complexreal_mul_2Dby2D(complex* arr2D_out, float* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; int i = y*width + x; complex a = arr2D_out[i]; float b = arr2D_other[i]; arr2D_out[i]._M_re *= b; arr2D_out[i]._M_im *= b; } /* Kernel used for CTF phase retrieval img_f = img_f * filter_num img_f[0, 0] -= mean_scale_factor * filter_num[0,0] img_f = img_f * filter_denom where mean_scale_factor = Nx*Ny */ __global__ void CTF_kernel( complex* image, float* filter_num, float* filter_denom, float mean_scale_factor, int Nx, int Ny ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx) || (y >= Ny)) return; uint idx = y*Nx + x; image[idx] *= filter_num[idx]; if (idx == 0) image[idx] -= mean_scale_factor; image[idx] *= filter_denom[idx]; } #ifndef DO_CLIP_MIN #define DO_CLIP_MIN 0 #endif #ifndef DO_CLIP_MAX #define DO_CLIP_MAX 0 #endif // arr = -log(arr) __global__ void nlog(float* array, int Nx, int Ny, int Nz, float clip_min, float clip_max) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t y = blockDim.y * blockIdx.y + threadIdx.y; size_t z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return; size_t pos = (z*Ny + y)*Nx + x; float val = array[pos]; #if DO_CLIP_MIN val = fmaxf(val, clip_min); #endif #if DO_CLIP_MAX val = fminf(val, clip_max); #endif array[pos] = -logf(val); } // Reverse elements of a 2D array along "x", i.e: // arr = arr[:, ::-1] // launched with grid (Nx/2, Ny) __global__ void reverse2D_x(float* array, int Nx, int Ny) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx/2) || (y >= Ny)) return; uint pos = y*Nx + x; uint pos2 = y*Nx + (Nx - 1 - x); float tmp = array[pos]; array[pos] = array[pos2]; array[pos2] = tmp; } /** Generic mul-add kernel with possibly-complicated indexing. dst[DST_IDX] = fac_dst*dst[DST_IDX] + fac_other*other[OTHER_IDX] where DST_IDX = dst_start_row:dst_end_row, dst_start_col:dst_end_col OTHER_IDX = other_start_row:other_end_row, other_start_col:other_end_col Usage: mul_add(dst, other, dst_nx, other_nx, a, b, (x1, x2), (y1, y2), (x3, x4), (y3, y4)) */ __global__ void mul_add( float* dst, float* other, int dst_width, int other_width, float fac_dst, float fac_other, int2 dst_x_range, int2 dst_y_range, int2 other_x_range, int2 other_y_range ) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t y = blockDim.y * blockIdx.y + threadIdx.y; int x_start_dst = dst_x_range.x; int x_stop_dst = dst_x_range.y; int y_start_dst = dst_y_range.x; int y_stop_dst = dst_y_range.y; int x_start_other = other_x_range.x; int x_stop_other = other_x_range.y; int y_start_other = other_y_range.x; int y_stop_other = other_y_range.y; int operation_width = x_stop_dst - x_start_dst; // assumed == x_stop_other - x_start_other int operation_height = y_stop_dst - y_start_dst; // assumed == y_stop_other - y_start_other if ((x >= operation_width) || (y >= operation_height)) return; size_t idx_in_dst = (y + y_start_dst)*dst_width + (x + x_start_dst); size_t idx_in_other = (y + y_start_other)*other_width + (x + x_start_other); dst[idx_in_dst] = fac_dst * dst[idx_in_dst] + fac_other * other[idx_in_other]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724839730.0 nabu-2024.2.1/nabu/cuda/src/backproj.cu0000644000175000017500000001416114663573462017064 0ustar00pierrepierre#ifndef SHARED_SIZE #define SHARED_SIZE 256 #endif #ifdef USE_TEXTURES texture tex_projections; #endif #ifdef CLIP_OUTER_CIRCLE inline __device__ int is_in_circle(int x, int y, float center_x, float center_y, int radius2) { return (((x - center_x)*(x - center_x) + (y - center_y)*(y - center_y)) <= radius2); } #endif /* Linear interpolation on a 2D array, horizontally. This will return arr[y][x] where y is an int (exact access) and x is a float (linear interp horizontally) */ static inline __device__ float linear_interpolation(float* arr, int Nx, float x, int y) { // check commented to gain a bit of speed - the check was done before function call // if (x < 0 || x >= Nx) return 0.0f; // texture address mode CLAMP_TO_EDGE int xm = (int) floorf(x); int xp = (int) ceilf(x); if ((xm == xp) || (xp >= Nx)) return arr[y*Nx+xm]; else return (arr[y*Nx+xm] * (xp - x)) + (arr[y*Nx+xp] * (x - xm)); } /** Implementation details ----------------------- This implementation uses two pre-computed arrays in global memory: cos(theta) -> d_cos -sin(theta) -> d_msin As the backprojection is voxel-driven, each thread will, at some point, need cos(theta) and -sin(theta) for *all* theta. Thus, we need to pre-fetch d_cos and d_msin in the fastest cached memory. Here we use the shared memory (faster than constant memory and texture). Each thread group will pre-fetch values from d_cos and d_msin to shared memory Initially, we fetched as much values as possible, ending up in a block of 1024 threads (32, 32). However, it turns out that performances are best with (16, 16) blocks. **/ // Backproject one sinogram // One thread handles up to 4 pixels in the output slice __global__ void backproj( float* d_slice, #ifndef USE_TEXTURES float* d_sino, #endif int num_projs, int num_bins, float axis_position, int n_x, int n_y, float offset_x, float offset_y, float* d_cos, float* d_msin, #ifdef DO_AXIS_CORRECTION float* d_axis_corr, #endif float scale_factor ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; uint Gx = blockDim.x * gridDim.x; uint Gy = blockDim.y * gridDim.y; // (xr, yr) (xrp, yr) // (xr, yrp) (xrp, yrp) float xr = (x + offset_x) - axis_position, yr = (y + offset_y) - axis_position; float xrp = xr + Gx, yrp = yr + Gy; /*volatile*/ __shared__ float s_cos[SHARED_SIZE]; /*volatile*/ __shared__ float s_msin[SHARED_SIZE]; #ifdef DO_AXIS_CORRECTION /*volatile*/ __shared__ float s_axis[SHARED_SIZE]; float axcorr; #endif int next_fetch = 0; int tid = threadIdx.y * blockDim.x + threadIdx.x; float costheta, msintheta; float h1, h2, h3, h4; float sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, sum4 = 0.0f; for (int proj = 0; proj < num_projs; proj++) { if (proj == next_fetch) { // Fetch SHARED_SIZE values to shared memory __syncthreads(); if (next_fetch + tid < num_projs) { s_cos[tid] = d_cos[next_fetch + tid]; s_msin[tid] = d_msin[next_fetch + tid]; #ifdef DO_AXIS_CORRECTION s_axis[tid] = d_axis_corr[next_fetch + tid]; #endif } next_fetch += SHARED_SIZE; __syncthreads(); } costheta = s_cos[proj - (next_fetch - SHARED_SIZE)]; msintheta = s_msin[proj - (next_fetch - SHARED_SIZE)]; #ifdef DO_AXIS_CORRECTION axcorr = s_axis[proj - (next_fetch - SHARED_SIZE)]; #endif float c1 = fmaf(costheta, xr, axis_position); // cos(theta)*xr + axis_pos float c2 = fmaf(costheta, xrp, axis_position); // cos(theta)*(xr + Gx) + axis_pos float s1 = fmaf(msintheta, yr, 0.0f); // -sin(theta)*yr float s2 = fmaf(msintheta, yrp, 0.0f); // -sin(theta)*(yr + Gy) h1 = c1 + s1; h2 = c2 + s1; h3 = c1 + s2; h4 = c2 + s2; #ifdef DO_AXIS_CORRECTION h1 += axcorr; h2 += axcorr; h3 += axcorr; h4 += axcorr; #endif #ifdef USE_TEXTURES if (h1 >= 0 && h1 < num_bins) sum1 += tex2D(tex_projections, h1 + 0.5f, proj + 0.5f); if (h2 >= 0 && h2 < num_bins) sum2 += tex2D(tex_projections, h2 + 0.5f, proj + 0.5f); if (h3 >= 0 && h3 < num_bins) sum3 += tex2D(tex_projections, h3 + 0.5f, proj + 0.5f); if (h4 >= 0 && h4 < num_bins) sum4 += tex2D(tex_projections, h4 + 0.5f, proj + 0.5f); #else if (h1 >= 0 && h1 < num_bins) sum1 += linear_interpolation(d_sino, num_bins, h1, proj); if (h2 >= 0 && h2 < num_bins) sum2 += linear_interpolation(d_sino, num_bins, h2, proj); if (h3 >= 0 && h3 < num_bins) sum3 += linear_interpolation(d_sino, num_bins, h3, proj); if (h4 >= 0 && h4 < num_bins) sum4 += linear_interpolation(d_sino, num_bins, h4, proj); #endif } int write_topleft = 1, write_topright = 1, write_botleft = 1, write_botright = 1; #ifdef CLIP_OUTER_CIRCLE float center_x = (n_x - 1)/2.0f, center_y = (n_y - 1)/2.0f; int radius2 = min(n_x/2, n_y/2); radius2 *= radius2; write_topleft = is_in_circle(x, y, center_x, center_y, radius2); write_topright = is_in_circle(x + Gx, y, center_x, center_y, radius2); write_botleft = is_in_circle(x, y + Gy, center_x, center_y, radius2); write_botright = is_in_circle(x + Gy, y + Gy, center_x, center_y, radius2); #endif // useful only if n_x < blocksize_x or n_y < blocksize_y if (x >= n_x) return; if (y >= n_y) return; // Pixels in top-left quadrant if (write_topleft) d_slice[y*(n_x) + x] = sum1 * scale_factor; // Pixels in top-right quadrant if ((Gx + x < n_x) && (write_topright)) { d_slice[y*(n_x) + Gx + x] = sum2 * scale_factor; } if (Gy + y < n_y) { // Pixels in bottom-left quadrant if (write_botleft) d_slice[(y+Gy)*(n_x) + x] = sum3 * scale_factor; // Pixels in bottom-right quadrant if ((Gx + x < n_x) && (write_botright)) d_slice[(y+Gy)*(n_x) + Gx + x] = sum4 * scale_factor; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/backproj_polar.cu0000644000175000017500000000347314315516747020262 0ustar00pierrepierre#ifndef SHARED_SIZE #define SHARED_SIZE 256 #endif texture tex_projections; __global__ void backproj_polar( float* d_slice, int num_projs, int num_bins, float axis_position, int n_x, int n_y, int offset_x, int offset_y, float* d_cos, float* d_msin, float scale_factor ) { int i_r = offset_x + blockDim.x * blockIdx.x + threadIdx.x; int i_theta = offset_y + blockDim.y * blockIdx.y + threadIdx.y; float r = i_r - axis_position; float x = r * d_cos[i_theta]; float y = - r * d_msin[i_theta]; /*volatile*/ __shared__ float s_cos[SHARED_SIZE]; /*volatile*/ __shared__ float s_msin[SHARED_SIZE]; int next_fetch = 0; int tid = threadIdx.y * blockDim.x + threadIdx.x; float costheta, msintheta; float h1; float sum1 = 0.0f; for (int proj = 0; proj < num_projs; proj++) { if (proj == next_fetch) { // Fetch SHARED_SIZE values to shared memory __syncthreads(); if (next_fetch + tid < num_projs) { s_cos[tid] = d_cos[next_fetch + tid]; s_msin[tid] = d_msin[next_fetch + tid]; } next_fetch += SHARED_SIZE; __syncthreads(); } costheta = s_cos[proj - (next_fetch - SHARED_SIZE)]; msintheta = s_msin[proj - (next_fetch - SHARED_SIZE)]; float c1 = fmaf(costheta, x, axis_position); // cos(theta)*xr + axis_pos float s1 = fmaf(msintheta, y, 0.0f); // -sin(theta)*yr h1 = c1 + s1; if (h1 >= 0 && h1 < num_bins) sum1 += tex2D(tex_projections, h1 + 0.5f, proj + 0.5f); } // useful only if n_x < blocksize_x or n_y < blocksize_y if (i_r >= n_x) return; if (i_theta >= n_y) return; d_slice[i_theta*(n_x) + i_r] = sum1 * scale_factor; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/boundary.h0000644000175000017500000000552214315516747016732 0ustar00pierrepierre#ifndef BOUNDARY_H #define BOUNDARY_H // Get the center index of the filter, // and the "half-Left" and "half-Right" lengths. // In the case of an even-sized filter, the center is shifted to the left. #define GET_CENTER_HL(hlen){\ if (hlen & 1) {\ c = hlen/2;\ hL = c;\ hR = c;\ }\ else {\ c = hlen/2 - 1;\ hL = c;\ hR = c+1;\ }\ }\ // Boundary handling modes #define CONV_MODE_REFLECT 0 // cba|abcd|dcb #define CONV_MODE_NEAREST 1 // aaa|abcd|ddd #define CONV_MODE_WRAP 2 // bcd|abcd|abc #define CONV_MODE_CONSTANT 3 // 000|abcd|000 #ifndef USED_CONV_MODE #define USED_CONV_MODE CONV_MODE_NEAREST #endif #define CONV_PERIODIC_IDX_X int idx_x = gidx - c + jx; if (idx_x < 0) idx_x += Nx; if (idx_x >= Nx) idx_x -= Nx; #define CONV_PERIODIC_IDX_Y int idx_y = gidy - c + jy; if (idx_y < 0) idx_y += Ny; if (idx_y >= Ny) idx_y -= Ny; #define CONV_PERIODIC_IDX_Z int idx_z = gidz - c + jz; if (idx_z < 0) idx_z += Nz; if (idx_z >= Nz) idx_z -= Nz; // clamp not in cuda __device__ int clamp(int x, int min_, int max_) { return min(max(x, min_), max_); } #define CONV_NEAREST_IDX_X int idx_x = clamp((int) (gidx - c + jx), 0, Nx-1); #define CONV_NEAREST_IDX_Y int idx_y = clamp((int) (gidy - c + jy), 0, Ny-1); #define CONV_NEAREST_IDX_Z int idx_z = clamp((int) (gidz - c + jz), 0, Nz-1); #define CONV_REFLECT_IDX_X int idx_x = gidx - c + jx; if (idx_x < 0) idx_x = -idx_x-1; if (idx_x >= Nx) idx_x = Nx-(idx_x-(Nx-1)); #define CONV_REFLECT_IDX_Y int idx_y = gidy - c + jy; if (idx_y < 0) idx_y = -idx_y-1; if (idx_y >= Ny) idx_y = Ny-(idx_y-(Ny-1)); #define CONV_REFLECT_IDX_Z int idx_z = gidz - c + jz; if (idx_z < 0) idx_z = -idx_z-1; if (idx_z >= Nz) idx_z = Nz-(idx_z-(Nz-1)); #if USED_CONV_MODE == CONV_MODE_REFLECT #define CONV_IDX_X CONV_REFLECT_IDX_X #define CONV_IDX_Y CONV_REFLECT_IDX_Y #define CONV_IDX_Z CONV_REFLECT_IDX_Z #elif USED_CONV_MODE == CONV_MODE_NEAREST #define CONV_IDX_X CONV_NEAREST_IDX_X #define CONV_IDX_Y CONV_NEAREST_IDX_Y #define CONV_IDX_Z CONV_NEAREST_IDX_Z #elif USED_CONV_MODE == CONV_MODE_WRAP #define CONV_IDX_X CONV_PERIODIC_IDX_X #define CONV_IDX_Y CONV_PERIODIC_IDX_Y #define CONV_IDX_Z CONV_PERIODIC_IDX_Z #elif USED_CONV_MODE == CONV_MODE_CONSTANT #error "constant not implemented yet" #else #error "Unknown convolution mode" #endif // Image access patterns #define READ_IMAGE_1D_X input[(gidz*Ny + gidy)*Nx + idx_x] #define READ_IMAGE_1D_Y input[(gidz*Ny + idx_y)*Nx + gidx] #define READ_IMAGE_1D_Z input[(idx_z*Ny + gidy)*Nx + gidx] #define READ_IMAGE_2D_XY input[(gidz*Ny + idx_y)*Nx + idx_x] #define READ_IMAGE_2D_XZ input[(idx_z*Ny + gidy)*Nx + idx_x] #define READ_IMAGE_2D_YZ input[(idx_z*Ny + idx_y)*Nx + gidx] #define READ_IMAGE_3D_XYZ input[(idx_z*Ny + idx_y)*Nx + idx_x] #endif ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556963.0 nabu-2024.2.1/nabu/cuda/src/cone.cu0000644000175000017500000000577714656662143016227 0ustar00pierrepierre/* ----------------------------------------------------------------------- Copyright: 2010-2022, imec Vision Lab, University of Antwerp 2014-2022, CWI, Amsterdam Contact: astra@astra-toolbox.com Website: http://www.astra-toolbox.com/ This file is part of the ASTRA Toolbox. The ASTRA Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. The ASTRA Toolbox is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with the ASTRA Toolbox. If not, see . ----------------------------------------------------------------------- */ static const unsigned int g_anglesPerWeightBlock = 16; static const unsigned int g_detBlockU = 32; static const unsigned int g_detBlockV = 32; __global__ void devFDK_preweight(void* D_projData, unsigned int projPitch, unsigned int startAngle, unsigned int endAngle, float fSrcOrigin, float fDetOrigin, float fZShift, float fDetUSize, float fDetVSize, unsigned int iProjAngles, unsigned int iProjU, unsigned int iProjV) { float* projData = (float*)D_projData; int angle = startAngle + blockIdx.y * g_anglesPerWeightBlock + threadIdx.y; if (angle >= endAngle) return; const int detectorU = (blockIdx.x%((iProjU+g_detBlockU-1)/g_detBlockU)) * g_detBlockU + threadIdx.x; const int startDetectorV = (blockIdx.x/((iProjU+g_detBlockU-1)/g_detBlockU)) * g_detBlockV; int endDetectorV = startDetectorV + g_detBlockV; if (endDetectorV > iProjV) endDetectorV = iProjV; // We need the length of the central ray and the length of the ray(s) to // our detector pixel(s). const float fCentralRayLength = fSrcOrigin + fDetOrigin; const float fU = (detectorU - 0.5f*iProjU + 0.5f) * fDetUSize; const float fT = fCentralRayLength * fCentralRayLength + fU * fU; float fV = (startDetectorV - 0.5f*iProjV + 0.5f) * fDetVSize + fZShift; // Contributions to the weighting factors: // fCentralRayLength / fRayLength : the main FDK preweighting factor // fSrcOrigin / (fDetUSize * fCentralRayLength) // : to adjust the filter to the det width // pi / (2 * iProjAngles) : scaling of the integral over angles const float fW2 = fCentralRayLength / (fDetUSize * fSrcOrigin); const float fW = fCentralRayLength * fW2 * (M_PI / 2.0f) / (float)iProjAngles; for (int detectorV = startDetectorV; detectorV < endDetectorV; ++detectorV) { const float fRayLength = sqrtf(fT + fV * fV); const float fWeight = fW / fRayLength; projData[(detectorV*iProjAngles+angle)*projPitch+detectorU] *= fWeight; fV += fDetVSize; } }././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/convolution.cu0000644000175000017500000001627114315516747017651 0ustar00pierrepierre/* * Convolution (without textures) * Adapted from OpenCL code of the the silx project * */ #include "boundary.h" typedef unsigned int uint; /******************************************************************************/ /**************************** 1D Convolution **********************************/ /******************************************************************************/ // Convolution with 1D kernel along axis "X" (fast dimension) // Works for batched 1D on 2D and batched 2D on 3D, along axis "X". __global__ void convol_1D_X( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_1D_X * filter[L-1 - jx]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 1D kernel along axis "Y" // Works for batched 1D on 2D and batched 2D on 3D, along axis "Y". __global__ void convol_1D_Y( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" sum += READ_IMAGE_1D_Y * filter[L-1 - jy]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 1D kernel along axis "Z" // Works for batched 1D on 2D and batched 2D on 3D, along axis "Z". __global__ void convol_1D_Z( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" sum += READ_IMAGE_1D_Z * filter[L-1 - jz]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } /******************************************************************************/ /**************************** 2D Convolution **********************************/ /******************************************************************************/ // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_XY( float * input, float * output, float * filter, int Lx, // filter number of columns, int Ly, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_2D_XY * filter[(Ly-1-jy)*Lx + (Lx-1 - jx)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_XZ( float * input, float * output, float * filter, int Lx, // filter number of columns, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_2D_XZ * filter[(Lz-1-jz)*Lx + (Lx-1 - jx)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_YZ( float * input, float * output, float * filter, int Ly, // filter number of columns, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Ly); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" sum += READ_IMAGE_2D_YZ * filter[(Lz-1-jz)*Ly + (Ly-1 - jy)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } /******************************************************************************/ /**************************** 3D Convolution **********************************/ /******************************************************************************/ // Convolution with 3D kernel __global__ void convol_3D_XYZ( float * input, float * output, float * filter, int Lx, // filter number of columns, int Ly, // filter number of rows, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_3D_XYZ * filter[((Lz-1-jz)*Ly + (Ly-1-jy))*Lx + (Lx-1 - jx)]; } } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/dfi_fftshift.cu0000644000175000017500000000425414550227307017717 0ustar00pierrepierre#include #define BLOCK_SIZE 16 __global__ void dfi_cuda_swap_quadrants_complex(cufftComplex *input, cufftComplex *output, int dim_x) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_y = gridDim.y * blockDim.y; //a half of real length output[idy * dim_x + idx] = input[(dim_y + idy) * dim_x + idx + 1]; output[(dim_y + idy) * dim_x + idx] = input[idy * dim_x + idx + 1]; } __global__ void dfi_cuda_swap_quadrants_real(cufftReal *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_x = gridDim.x * blockDim.x; int dim_x2 = dim_x/2, dim_y2 = dim_x2; long sw_idx1, sw_idx2; sw_idx1 = idy * dim_x + idx; cufftReal temp = output[sw_idx1]; if (idx < dim_x2) { sw_idx2 = (dim_y2 + idy) * dim_x + (dim_x2 + idx); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } else { sw_idx2 = (dim_y2 + idy) * dim_x + (idx - dim_x2); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } } __global__ void swap_full_quadrants_complex(cufftComplex *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_x = gridDim.x * blockDim.x; int dim_x2 = dim_x/2, dim_y2 = dim_x2; long sw_idx1, sw_idx2; sw_idx1 = idy * dim_x + idx; cufftComplex temp = output[sw_idx1]; if (idx < dim_x2) { sw_idx2 = (dim_y2 + idy) * dim_x + (dim_x2 + idx); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } else { sw_idx2 = (dim_y2 + idy) * dim_x + (idx - dim_x2); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } } __global__ void dfi_cuda_crop_roi(cufftReal *input, int x, int y, int roi_x, int roi_y, int raster_size, float scale, cufftReal *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; if (idx < roi_x && idy < roi_y) { output[idy * roi_x + idx] = input[(idy + y) * raster_size + (idx + x)] * scale; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/cuda/src/flatfield.cu0000644000175000017500000000401214402565210017174 0ustar00pierrepierre#ifndef N_FLATS #error "Please provide the N_FLATS variable" #endif #ifndef N_DARKS #error "Please provide the N_FLATS variable" #endif /** * In-place flat-field normalization with linear interpolation. * This kernel assumes that all the radios are loaded into memory * (although not necessarily the full radios images) * and in radios[x, y z], z in the radio index * * radios: 3D array * flats: 3D array * darks: 3D array * Nx: number of pixel horizontally in the radios * Nx: number of pixel vertically in the radios * Nx: number of radios * flats_indices: indices of flats to fetch for each radio * flats_weights: weights of flats for each radio * darks_indices: indices of darks, in sorted order **/ __global__ void flatfield_normalization( float* radios, float* flats, float* darks, int Nx, int Ny, int Nz, int* flats_indices, float* flats_weights ) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t y = blockDim.y * blockIdx.y + threadIdx.y; size_t z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return; size_t pos = (z*Ny+y)*Nx + x; float dark_val = 0.0f, flat_val = 1.0f; #if N_FLATS == 1 flat_val = flats[y*Nx + x]; #else int prev_idx = flats_indices[z*2 + 0]; int next_idx = flats_indices[z*2 + 1]; float w1 = flats_weights[z*2 + 0]; float w2 = flats_weights[z*2 + 1]; if (next_idx == -1) { flat_val = flats[(prev_idx*Ny+y)*Nx + x]; } else { flat_val = w1 * flats[(prev_idx*Ny+y)*Nx + x] + w2 * flats[(next_idx*Ny+y)*Nx + x]; } #endif #if (N_DARKS == 1) dark_val = darks[y*Nx + x]; #else // TODO interpolate between darks // Same as above... #error "N_DARKS > 1 is not supported yet" #endif float val = (radios[pos] - dark_val) / (flat_val - dark_val); #ifdef NAN_VALUE if (flat_val == dark_val) val = NAN_VALUE; #endif radios[pos] = val; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/fourier_wavelets.cu0000644000175000017500000000104314550227307020636 0ustar00pierrepierre/** Damping kernel used in the Fourier-Wavelets sinogram destriping method. */ __global__ void kern_fourierwavelets(float2* sinoF, int Nx, int Ny, float wsigma) { int gidx = threadIdx.x + blockIdx.x*blockDim.x; int gidy = threadIdx.y + blockIdx.y*blockDim.y; int Nfft = Ny/2+1; if (gidx >= Nx || gidy >= Nfft) return; float m = gidy/wsigma; float factor = 1.0f - expf(-(m * m)/2); int tid = gidy*Nx + gidx; // do not forget the scale factor (here Ny) sinoF[tid].x *= factor; sinoF[tid].y *= factor; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/halftomo.cu0000644000175000017500000000552614550227307017074 0ustar00pierrepierre/* Perform a "half tomography" sinogram conversion. A 360 degrees sinogram is converted to a 180 degrees sinogram with a field of view extended (at most) twice". * Parameters: * sinogram: the 360 degrees sinogram, shape (n_angles, n_x) * output: the 160 degrees sinogram, shape (n_angles/2, rotation_axis_position * 2) * weights: an array of weight, size n_x - rotation_axis_position */ __global__ void halftomo_kernel( float* sinogram, float* output, float* weights, int n_angles, int n_x, int rotation_axis_position ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; int n_a2 = (n_angles + 1) / 2; int d = n_x - rotation_axis_position; int n_x2 = 2 * rotation_axis_position; int r = rotation_axis_position; if ((x >= n_x2) || (y >= n_a2)) return; // output[:, :r - d] = sino[:n_a2, :r - d] if (x < r - d) { output[y * n_x2 + x] = sinogram[y * n_x + x]; } // output[:, r-d:r+d] = (1 - weights) * sino[:n_a2, r-d:] else if (x < r+d) { float w = weights[x - (r - d)]; output[y * n_x2 + x] = (1.0f - w) * sinogram[y*n_x + x] \ + w * sinogram[(n_a2 + y)*n_x + (n_x2 - 1 - x)]; } // output[:, nx:] = sino[n_a2:, ::-1][:, 2 * d :] = sino[n_a2:, -2*d-1:-n_x-1:-1] else { output[y * n_x2 + x] = sinogram[(n_a2 + y)*n_x + (n_x2 - 1 - x)]; } } /* Multiply in-place a 360 degrees sinogram with weights. This kernel is used to prepare a sinogram to be backprojected using half-tomography geometry. One of the sides (left or right) is multiplied with weights. For example, if "r" is the center of rotation near the right side: sinogram[:, -overlap_width:] *= weights where overlap_width = 2*(n_x - 1 - r) This can still be improved when the geometry has horizontal translations. In this case, we should have "start_x" and "end_x" as arrays of size n_angles, i.e one varying (start_x, end_x) per angle. Parameters ----------- * sinogram: array of size (n_angles, n_x): 360 degrees sinogram * weights: array of size (n_angles,): weights to apply on one side of the sinogram * n_angles: int: number of angles * n_x: int: horizontal size (number of pixels) of the sinogram * start_x: int: start x-position for applying the weights * end_x: int: end x-position for applying the weights (included!) */ __global__ void halftomo_prepare_sinogram( float* sinogram, float* weights, int n_angles, int n_x, int start_x, int end_x ) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t i_angle = blockDim.y * blockIdx.y + threadIdx.y; if (x < start_x || x > end_x || i_angle >= n_angles) return; sinogram[i_angle * n_x + x] *= weights[x - start_x]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/cuda/src/helical_padding.cu0000644000175000017500000000756314422670612020354 0ustar00pierrepierre // see nabu/pipeline/helical/filtering.py for details __device__ float adjustment_by_integration( int my_rot, int two_rots , float *data, int y, int y_mirror, int Nx_padded, int integration_radius, float my_rot_float, float two_rots_float) { float sigma = integration_radius/3.0f; float sum_a=0.0; float sum_w_a = 0.0; float sum_w_b = 0.0; float sum_b=0.0; for(int my_ix = my_rot - integration_radius; my_ix <= my_rot + integration_radius ; my_ix++) { float d = (my_ix - my_rot_float) ; float w_a = exp( - ( d*d )/sigma/sigma/2.0f) ; sum_a += data[ y*Nx_padded + my_ix ] * w_a; sum_w_a+= w_a; int x_mirror = two_rots - my_ix ; d = (x_mirror - (two_rots_float -my_rot_float)) ; float w_b = exp( - ( d*d )/sigma/sigma/2.0f); sum_b += data[ y_mirror*Nx_padded + x_mirror ] * w_b; sum_w_b += w_b; } float adjustment = (sum_b/sum_w_b - sum_a/sum_w_b) ; return adjustment ; } __global__ void padding( float* data, int* mirror_indexes, #if defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) || defined(MIRROR_EDGES_VARIABLE_ROT_POS) float *rot_axis_pos, #else float rot_axis_pos, #endif int Nx, int Ny, int Nx_padded, int pad_left_len, int pad_right_len #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) ,float pad_left_val, float pad_right_val #endif ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny) || x < Nx) return; int idx = y*Nx_padded + x; int y_mirror = mirror_indexes[y]; int x_mirror =0 ; #if defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) || defined(MIRROR_EDGES_VARIABLE_ROT_POS) float two_rots = rot_axis_pos[y] + rot_axis_pos[y_mirror]; float my_rot = rot_axis_pos[y]; #else float two_rots = 2*rot_axis_pos ; float my_rot = rot_axis_pos; #endif int two_rots_int = __float2int_rn(two_rots) ; int my_rot_int = __float2int_rn(my_rot); if( two_rots_int > Nx) { int integration_radius = min( 30, Nx-1 - max(my_rot_int, two_rots_int - my_rot_int ) ) ; x_mirror = two_rots_int - x ; if (x_mirror < 0 ) { #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) if( x < Nx_padded - pad_left_len) { data[idx] = pad_left_val; } else { data[idx] = pad_right_val; } #else if( x < Nx_padded - pad_left_len) { float adjustment = adjustment_by_integration( my_rot_int, two_rots_int , data, y, y_mirror, Nx_padded, integration_radius, my_rot, two_rots); data[idx] = data[y_mirror*Nx_padded + 0] - adjustment; } else { data[idx] = data[y*Nx_padded + 0]; } #endif } else { float adjustment = adjustment_by_integration( my_rot_int, two_rots_int , data, y, y_mirror, Nx_padded, integration_radius, my_rot, two_rots); data[idx] = data[y_mirror*Nx_padded + x_mirror]-adjustment; } } else { int integration_radius = min( 30, min(my_rot_int, two_rots_int - my_rot_int) -1) ; x_mirror = two_rots_int - (x - Nx_padded) ; if (x_mirror > Nx-1 ) { #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) if( x < Nx_padded - pad_left_len) { data[idx] = pad_left_val ; } else { data[idx] = pad_right_val; } #else if( x < Nx_padded - pad_left_len) { data[idx] = data[y*Nx_padded + Nx - 1 ]; } else { float adjustment = adjustment_by_integration( my_rot_int, two_rots_int , data, y, y_mirror, Nx_padded, integration_radius, my_rot, two_rots); data[idx] = data[y_mirror*Nx_padded + Nx-1]-adjustment; } #endif } else { float adjustment = adjustment_by_integration( my_rot_int, two_rots_int, data, y, y_mirror, Nx_padded, integration_radius, my_rot, two_rots); data[idx] = data[y_mirror*Nx_padded + x_mirror] - adjustment; } } return; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/cuda/src/hierarchical_backproj.cu0000644000175000017500000002120114715655362021551 0ustar00pierrepierre/* """ Algorithm by Jonas Graetz. Submitted for publication. Please cite : reference to be added... """ */ __device__ float3 operator-(const float3& a, const float3& b) { return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); } __device__ float dot(const float3& a, const float3& b) { return a.x * b.x + a.y * b.y + a.z * b.z; } __device__ float dot4(const float4& a, const float4& b) { return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; } inline __device__ int is_in_circle(float x, float y, float center_x, float center_y, int radius2) { return (((x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)) <= radius2); } __global__ void clip_outer_circle(float* slice, int ny, int nx) { const int tiy = threadIdx.y; const int bidy = blockIdx.y; int iy = (bidy * blockDim.y + tiy); const int tix = threadIdx.x; const int bidx = blockIdx.x; int ix = (bidx * blockDim.x + tix); float center_x = (nx - 1) / 2.0f, center_y = (ny - 1) / 2.0f; int radius2 = min(nx / 2, ny / 2); radius2 *= radius2; if (ix < nx && iy < ny) { if (!is_in_circle(ix, iy, center_x, center_y, radius2)) { slice[iy * nx + ix] = 0.0f; } } } __device__ float bilinear(float* data, int width, int height, float x, float y) { int ix0 = (int)floorf(x); int iy0 = (int)floorf(y); float fx = x - ix0; float fy = y - iy0; int ix1 = ix0 + 1; int iy1 = iy0 + 1; ix0 = min(width - 1, max(0, ix0)); ix1 = min(width - 1, max(0, ix1)); iy0 = min(height - 1, max(0, iy0)); iy1 = min(height - 1, max(0, iy1)); float v00 = data[iy0 * width + ix0]; float v01 = data[iy0 * width + ix1]; float v10 = data[iy1 * width + ix0]; float v11 = data[iy1 * width + ix1]; return (v00 * (1 - fx) + v01 * fx) * (1 - fy) + (v10 * (1 - fx) + v11 * fx) * fy; } __global__ void backprojector(float* bpsetups, float* gridTransforms, int reductionFactor, int grid_width, int grid_height, int ngrids, float* grids, int sino_width, int sino_nangles, float scale_factor, float* sinogram, int projectionOffset) { const int tix = threadIdx.x; const int bidx = blockIdx.x; int ix = (bidx * blockDim.x + tix); const int tiy = threadIdx.y; const int bidy = blockIdx.y; int iy = (bidy * blockDim.y + tiy); const int tiz = threadIdx.z; const int bidz = blockIdx.z; int iz = (bidz * blockDim.z + tiz); const int grid_px = ix; const int grid_py = iy; const int grid_i = iz; size_t grid_pos = (grid_i * ((size_t)grid_height) + grid_px) * ((size_t)grid_width) + grid_py; // if( grid_pos==0) grids[grid_pos] = grid_height; // if( grid_pos==1) grids[grid_pos] = grid_width; // if( grid_pos==2) grids[grid_pos] = ngrids; // if( grid_pos==3) printf(" CU %d sino_nangles %d sino_width %d\n", // grid_height* grid_width* ngrids , sino_nangles, sino_width ) ; if ((grid_px < grid_height) && (grid_py < grid_width) && (grid_i < ngrids)) { const float3 grid_t1 = make_float3(gridTransforms[grid_i * 6 + 0], gridTransforms[grid_i * 6 + 1], gridTransforms[grid_i * 6 + 2]); const float3 grid_t2 = make_float3(gridTransforms[grid_i * 6 + 3 + 0], gridTransforms[grid_i * 6 + 3 + 1], gridTransforms[grid_i * 6 + 3 + 2]); const float4 final_p = make_float4(0.f, grid_t1.x * grid_px + grid_t1.y * grid_py + grid_t1.z, grid_t2.x * grid_px + grid_t2.y * grid_py + grid_t2.z, 1.f); float val = 0.f; int setup_i = 0; for (int k = 0; k < reductionFactor; k++) { setup_i = grid_i * reductionFactor + k + projectionOffset; if (setup_i < sino_nangles) // although the sinogram itself could be // read beyond extent, this is not true // for the setups-array! { int bi = setup_i * 4 * 3; const float4 ph = make_float4(bpsetups[bi + 0], bpsetups[bi + 1], bpsetups[bi + 2], bpsetups[bi + 3]); bi += 8; const float4 pw = make_float4(bpsetups[bi + 0], bpsetups[bi + 1], bpsetups[bi + 2], bpsetups[bi + 3]); const float n = 1.f / dot4(final_p, pw); const float h = dot4(final_p, ph) * n; int ih0 = (int)floorf(h); int ih1 = ih0 + 1; float fh = h - ih0; size_t sino_pos = setup_i * ((size_t)sino_width) + ih0; if (ih0 >= 0 && ih0 < sino_width) { // if(sino_pos>= sino_width*sino_nangles) printf(" problema // 1\n"); val += sinogram[sino_pos] * (1 - fh); } if (ih1 >= 0 && ih1 < sino_width) { // if(sino_pos+1>= sino_width*sino_nangles) printf(" problema // 2 h ih0, ih1, sino_width , sino_nangles, setup_i %e %e %e // %d %d %d %d %d\n", ray.x, ray.y, ray.z, ih0, ih1, // sino_width , sino_nangles, setup_i); val += sinogram[sino_pos + 1] * fh; } } } size_t grid_pos = (grid_i * ((size_t)grid_height) + grid_px) * ((size_t)grid_width) + grid_py; grids[grid_pos] = scale_factor * val; } } __global__ void aggregator(int do_sum, const float* newGridTransforms, const float* prevGridInverseTransforms, const int reductionFactor, int new_grid_width, int new_grid_height, int new_ngrids, float* newGrids, int prev_grid_width, int prev_grid_height, int prev_ngrids, float* prevGrids) { const int tix = threadIdx.x; const int bidx = blockIdx.x; int ix = (bidx * blockDim.x + tix); const int tiy = threadIdx.y; const int bidy = blockIdx.y; int iy = (bidy * blockDim.y + tiy); const int tiz = threadIdx.z; const int bidz = blockIdx.z; int iz = (bidz * blockDim.z + tiz); const int new_grid_px = ix; const int new_grid_py = iy; const int new_grid_i = iz; if ((new_grid_px < new_grid_height) && (new_grid_py < new_grid_width) && (new_grid_i < new_ngrids)) { const float3 new_grid_t1 = make_float3(newGridTransforms[new_grid_i * 6 + 0], newGridTransforms[new_grid_i * 6 + 1], newGridTransforms[new_grid_i * 6 + 2]); const float3 new_grid_t2 = make_float3(newGridTransforms[new_grid_i * 6 + 3 + 0], newGridTransforms[new_grid_i * 6 + 3 + 1], newGridTransforms[new_grid_i * 6 + 3 + 2]); const float3 final_p = make_float3( new_grid_t1.x * new_grid_px + new_grid_t1.y * new_grid_py + new_grid_t1.z, new_grid_t2.x * new_grid_px + new_grid_t2.y * new_grid_py + new_grid_t2.z, 1.f); if (isnan(new_grid_t1.x)) { return; // inband-signaling for unused grids that shall be skipped } float val = 0.f; int prev_grid_i; float3 prev_grid_ti1, prev_grid_ti2; float3 prev_p_tex; for (int k = 0; k < reductionFactor; k++) { prev_grid_i = new_grid_i * reductionFactor + k; if (prev_grid_i < prev_ngrids) { prev_grid_ti1 = make_float3(prevGridInverseTransforms[prev_grid_i * 6 + 0], prevGridInverseTransforms[prev_grid_i * 6 + 1], prevGridInverseTransforms[prev_grid_i * 6 + 2]); prev_grid_ti2 = make_float3(prevGridInverseTransforms[prev_grid_i * 6 + 3 + 0], prevGridInverseTransforms[prev_grid_i * 6 + 3 + 1], prevGridInverseTransforms[prev_grid_i * 6 + 3 + 2]); if (isnan(prev_grid_ti1.x)) { break; } prev_p_tex = make_float3(dot(prev_grid_ti2, final_p), dot(prev_grid_ti1, final_p), (float)prev_grid_i); val += bilinear(prevGrids + prev_grid_i * ((size_t)prev_grid_height) * ((size_t)prev_grid_width), prev_grid_width, prev_grid_height, prev_p_tex.x, prev_p_tex.y); } } size_t new_grid_pos = (new_grid_i * ((size_t)new_grid_height) + new_grid_px) * ((size_t)new_grid_width) + new_grid_py; if (do_sum == 1) { newGrids[new_grid_pos] += val; } else { newGrids[new_grid_pos] = val; } } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/histogram.cu0000644000175000017500000000145014315516747017260 0ustar00pierrepierretypedef unsigned int uint; __global__ void histogram( float * array, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz, // input/output depth float arr_min, // array minimum value float arr_max, // array maximum value uint* hist, // histogram int nbins // histogram size (number of bins) ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; float val = array[(gidz*Ny + gidy)*Nx + gidx]; float bin_pos = nbins * ((val - arr_min) / (arr_max - arr_min)); uint bin_left = min((uint) bin_pos, nbins-1); atomicAdd(hist + bin_left, 1); } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/interpolation.cu0000644000175000017500000000150714315516747020155 0ustar00pierrepierretypedef unsigned int uint; // linear interpolation along "axis 0", where values outside of bounds are extrapolated __global__ void linear_interp_vertical(float* arr2D, float* out, int Nx, int Ny, float* x, float* x_new) { uint c = blockDim.x * blockIdx.x + threadIdx.x; uint i = blockDim.y * blockIdx.y + threadIdx.y; if ((c >= Nx) || (i >= Ny)) return; int extrapolate_side = x_new[0] > x[0] ? 1 : 0; // 0: left, 1: right float dx, dy; if (i == 0) extrapolate_side = 1; else if (i == Ny - 1) extrapolate_side = 0; int extrapolate_side_compl = 1 - extrapolate_side; dx = x[i+extrapolate_side] - x[i - extrapolate_side_compl]; dy = arr2D[(i + extrapolate_side) * Nx + c] - arr2D[(i - extrapolate_side_compl) * Nx + c]; out[i * Nx + c] = (dy / dx) * (x_new[i] - x[i]) + arr2D[i * Nx + c]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/medfilt.cu0000644000175000017500000000462114315516747016712 0ustar00pierrepierre#include "boundary.h" typedef unsigned int uint; #ifndef MEDFILT_X #define MEDFILT_X 3 #endif #ifndef MEDFILT_Y #define MEDFILT_Y 3 #endif #ifndef DO_THRESHOLD #define DO_THRESHOLD 0 #endif // General-purpose 2D (or batched 2D) median filter with a square footprint. // Boundary handling is customized via the USED_CONV_MODE macro (see boundary.h) // Most of the time is spent computing the median, so this kernel can be sped up by // - creating dedicated kernels for 3x3, 5x5 (see http://ndevilla.free.fr/median/median/src/optmed.c) // - Using a quickselect algorithm instead of sorting (see http://ndevilla.free.fr/median/median/src/quickselect.c) __global__ void medfilt2d( float * input, float * output, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz, // input/output depth float threshold // threshold for thresholded median filter ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(MEDFILT_X); // Get elements in a 3x3 neighborhood float elements[MEDFILT_X*MEDFILT_Y] = {0}; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" elements[jy*MEDFILT_Y+jx] = READ_IMAGE_2D_XY; } } // Sort the elements with insertion sort // TODO quickselect ? int i = 1, j; while (i < MEDFILT_X*MEDFILT_Y) { j = i; while (j > 0 && elements[j-1] > elements[j]) { float tmp = elements[j]; elements[j] = elements[j-1]; elements[j-1] = tmp; j--; } i++; } float median = elements[MEDFILT_X*MEDFILT_Y/2]; #if DO_THRESHOLD == 1 float out_val = 0.0f; uint idx = (gidz*Ny + gidy)*Nx + gidx; if (input[idx] >= median + threshold) out_val = median; else out_val = input[idx]; output[idx] = out_val; #elif DO_THRESHOLD == 2 float out_val = 0.0f; uint idx = (gidz*Ny + gidy)*Nx + gidx; if (fabsf(input[idx] - median) > threshold) out_val = median; else out_val = input[idx]; output[idx] = out_val; #else output[(gidz*Ny + gidy)*Nx + gidx] = median; #endif } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/normalization.cu0000644000175000017500000000375414550227307020152 0ustar00pierrepierretypedef unsigned int uint; /** * Chebyshev background removal. * This kernel does a degree 2 polynomial estimation of each line of an array, * and then subtracts the estimation from each line. * This process is done in-place. */ __global__ void normalize_chebyshev( float * array, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= 1) || (gidy >= Ny) || (gidz >= Nz)) return; float ff0=0.0f, ff1=0.0f, ff2=0.0f; float sum0=0.0f, sum1=0.0f, sum2=0.0f; float f0, f1, f2, x; for (int j=0; j < Nx; j++) { uint pos = (gidz*Ny + gidy)*Nx + j; float arr_val = array[pos]; x = 2.0f*(j + 0.5f - Nx/2.0f)/Nx; f0 = 1.0f; f1 = x; f2 = (3.0f*x*x-1.0f); ff0 = ff0 + f0 * arr_val; ff1 = ff1 + f1 * arr_val; ff2 = ff2 + f2 * arr_val; sum0 += f0 * f0; sum1 += f1 * f1; sum2 += f2 * f2; } for (int j=0; j< Nx; j++) { uint pos = (gidz*Ny + gidy)*Nx + j; x = 2.0f*(j+0.5f-Nx/2.0f)/Nx; f0 = 1.0f; f1 = x; f2 = (3.0f*x*x-1.0f); array[pos] -= ff0*f0/sum0 + ff1*f1/sum1 + ff2*f2/sum2; } } // launched with (Nx, 1, Nz) threads __global__ void vertical_mean( float * array, float* output, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; uint z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= 1) || (z >= Nz)) return; float m = 0.0f; for (uint i = 0; i < Ny; i++) { float s = array[(z * Ny + i) * Nx + x]; m += (s - m)/(i+1); } output[z * Nx + x] = (float) m; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/padding.cu0000644000175000017500000001050414550227307016661 0ustar00pierrepierre#include typedef pycuda::complex complex; typedef unsigned int uint; /** This function padds in-place a 2D array with constant values. It is designed to leave the data in the "FFT layout", i.e the data is *not* in the center of the extended/padded data. In one dimension: <--------------- N0 ----------------> | original data | padded values | <----- N -------- ><---- Pl+Pr -----> N0: width of data Pl, Pr: left/right padding lengths ASSUMPTIONS: - data is already extended before padding (its size is Nx_padded * Ny_padded) - the original data lies in the top-left quadrant. **/ __global__ void padding_constant( float* data, int Nx, int Ny, int Nx_padded, int Ny_padded, int pad_left_len, int pad_right_len, int pad_top_len, int pad_bottom_len, float pad_left_val, float pad_right_val, float pad_top_val, float pad_bottom_val ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; int idx = y*Nx_padded + x; // data[s0:s0+Pd, :s1] = pad_bottom_val if ((Ny <= y) && (y < Ny+pad_bottom_len) && (x < Nx)) data[idx] = pad_bottom_val; // data[s0+Pd:s0+Pd+Pu, :s1] = pad_top_val else if ((Ny + pad_bottom_len <= y) && (y < Ny+pad_bottom_len+pad_top_len) && (x < Nx)) data[idx] = pad_top_val; // data[:, s1:s1+Pr] = pad_right_val else if ((Nx <= x) && (x < Nx+pad_right_len)) data[idx] = pad_right_val; // data[:, s1+Pr:s1+Pr+Pl] = pad_left_val else if ((Nx+pad_right_len <= x) && (x < Nx+pad_right_len+pad_left_len)) data[idx] = pad_left_val; // top-left quadrant else return; } __global__ void padding_edge( float* data, int Nx, int Ny, int Nx_padded, int Ny_padded, int pad_left_len, int pad_right_len, int pad_top_len, int pad_bottom_len ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; int idx = y*Nx_padded + x; // // This kernel can be optimized: // - Optimize the logic to use less comparisons // - Store the values data[0], data[s0-1, 0], data[0, s1-1], data[s0-1, s1-1] // into shared memory to read only once from global mem. // // data[s0:s0+Pd, :s1] = data[s0, :s1] if ((Ny <= y) && (y < Ny+pad_bottom_len) && (x < Nx)) data[idx] = data[(Ny-1)*Nx_padded+x]; // data[s0+Pd:s0+Pd+Pu, :s1] = data[0, :s1] else if ((Ny + pad_bottom_len <= y) && (y < Ny+pad_bottom_len+pad_top_len) && (x < Nx)) data[idx] = data[x]; // data[:s0, s1:s1+Pr] = data[:s0, s1] else if ((y < Ny) && (Nx <= x) && (x < Nx+pad_right_len)) data[idx] = data[y*Nx_padded + Nx-1]; // data[:s0, s1+Pr:s1+Pr+Pl] = data[:s0, 0] else if ((y < Ny) && (Nx+pad_right_len <= x) && (x < Nx+pad_right_len+pad_left_len)) data[idx] = data[y*Nx_padded]; // data[s0:s0+Pb, s1:s1+Pr] = data[s0-1, s1-1] else if ((Ny <= y && y < Ny + pad_bottom_len) && (Nx <= x && x < Nx + pad_right_len)) data[idx] = data[(Ny-1)*Nx_padded + Nx-1]; // data[s0:s0+Pb, s1+Pr:s1+Pr+Pl] = data[s0-1, 0] else if ((Ny <= y && y < Ny + pad_bottom_len) && (Nx+pad_right_len <= x && x < Nx + pad_right_len+pad_left_len)) data[idx] = data[(Ny-1)*Nx_padded]; // data[s0+Pb:s0+Pb+Pu, s1:s1+Pr] = data[0, s1-1] else if ((Ny+pad_bottom_len <= y && y < Ny + pad_bottom_len+pad_top_len) && (Nx <= x && x < Nx + pad_right_len)) data[idx] = data[Nx-1]; // data[s0+Pb:s0+Pb+Pu, s1+Pr:s1+Pr+Pl] = data[0, 0] else if ((Ny+pad_bottom_len <= y && y < Ny + pad_bottom_len+pad_top_len) && (Nx+pad_right_len <= x && x < Nx + pad_right_len+pad_left_len)) data[idx] = data[0]; // top-left quadrant else return; } __global__ void coordinate_transform( float* array_in, float* array_out, int* cols_inds, int* rows_inds, int Nx, int Nx_padded, int Ny_padded ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; uint idx = y*Nx_padded + x; int x2 = cols_inds[x]; int y2 = rows_inds[y]; array_out[idx] = array_in[y2*Nx + x2]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/cuda/src/proj.cu0000644000175000017500000000724614422670612016235 0ustar00pierrepierretypedef unsigned int uint; #define M_PI_F 3.141592653589793 texture texSlice; __global__ void joseph_projector( float *d_Sino, int dimslice, int num_bins, float* angles_per_project, float axis_position, float* d_axis_corrections, int* d_beginPos, int* d_strideJoseph, int* d_strideLine, int num_projections, int dimrecx, int dimrecy, float offset_x, int josephnoclip, int normalize ) { uint tidx = threadIdx.x; uint bidx = blockIdx.x; uint tidy = threadIdx.y; uint bidy = blockIdx.y; float angle; float cos_angle, sin_angle ; __shared__ float corrections[16]; __shared__ int beginPos[16*2]; __shared__ int strideJoseph[16*2]; __shared__ int strideLine[16*2]; // thread will use corrections[tidy] // All are read by first warp int offset, OFFSET; switch(tidy) { case 0: corrections[tidx] = d_axis_corrections[bidy*16+tidx]; break; case 1: case 2: offset = 16 * (tidy - 1); OFFSET = dimrecy * (tidy - 1); beginPos[offset + tidx] = d_beginPos[OFFSET+ bidy*16 + tidx]; break; case 3: case 4: offset = 16 * (tidy - 3); OFFSET = dimrecy*(tidy - 3); strideJoseph[offset + tidx] = d_strideJoseph[OFFSET + bidy*16 + tidx]; break; case 5: case 6: offset = 16*(tidy-5); OFFSET = dimrecy*(tidy-5); strideLine[offset + tidx] = d_strideLine[OFFSET + bidy*16 + tidx]; break; } __syncthreads(); angle = angles_per_project[bidy*16+tidy]; cos_angle = cos(angle); sin_angle = sin(angle); if (fabs(cos_angle) > 0.70710678f) { if(cos_angle > 0) { cos_angle = cos(angle); sin_angle = sin(angle); } else { cos_angle = -cos(angle); sin_angle = -sin(angle); } } else { if (sin_angle > 0) { cos_angle = sin(angle); sin_angle = -cos(angle); } else { cos_angle = -sin(angle); sin_angle = cos(angle); } } float res=0.0f; float axis_corr = axis_position + corrections[tidy]; float axis = axis_position; float xpix = (bidx*16 + tidx) - offset_x; float posx = axis * (1.0f - sin_angle/cos_angle) + (xpix - axis_corr)/cos_angle; float shiftJ = sin_angle/cos_angle; float x1 = fminf(-sin_angle/cos_angle, 0.f); float x2 = fmaxf(-sin_angle/cos_angle, 0.f); float Area; Area = 1.0f/cos_angle; int stlA, stlB, stlAJ, stlBJ; stlA = strideLine[16 + tidy]; stlB = strideLine[tidy]; stlAJ = strideJoseph[16 + tidy]; stlBJ = strideJoseph[tidy]; int beginA = beginPos[16 + tidy]; int beginB = beginPos[tidy]; float add; int l; if(josephnoclip) { for(int j=0; j= 0.0f) * (x1 < (dimslice + 2)) * (x2 >= 0.0f) * (x2 < (dimslice + 2)); add = tex2D(texSlice, x1,x2); res += add * l; posx += shiftJ; } } if((bidy*16 + tidy) < num_projections && (bidx*16 + tidx) < num_bins) { res *= Area; if (normalize) res *= M_PI_F * 0.5f / num_projections; d_Sino[dimrecx*(bidy*16 + tidy) + (bidx*16 + tidx)] = res; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/src/rotation.cu0000644000175000017500000000122114315516747017116 0ustar00pierrepierretypedef unsigned int uint; texture tex_image; __global__ void rotate( float* output, int Nx, int Ny, float cos_angle, float sin_angle, float rotc_x, float rotc_y ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; if (gidx >= Nx || gidy >= Ny) return; float x = (gidx - rotc_x)*cos_angle - (gidy - rotc_y)*sin_angle; float y = (gidx - rotc_x)*sin_angle + (gidy - rotc_y)*cos_angle; x += rotc_x; y += rotc_y; float out_val = tex2D(tex_image, x + 0.5f, y + 0.5f); output[gidy * Nx + gidx] = out_val; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/cuda/src/transpose.cu0000644000175000017500000000073414550227307017275 0ustar00pierrepierre#ifndef SRC_DTYPE #define SRC_DTYPE float #endif #ifndef DST_DTYPE #define DST_DTYPE float #endif #include __global__ void transpose(SRC_DTYPE* src, DST_DTYPE* dst, int src_width, int src_height) { // coordinates for "dst" uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= src_height) || (y >= src_width)) return; dst[y*src_height + x] = (DST_DTYPE) src[x*src_width + y]; }././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/cuda/tests/0000755000175000017500000000000014730277752015306 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/cuda/tests/__init__.py0000644000175000017500000000000114315516747017405 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/cuda/utils.py0000644000175000017500000002265614712705065015662 0ustar00pierrepierreimport atexit from math import ceil import numpy as np from ..resources.gpu import GPUDescription try: import pycuda import pycuda.driver as cuda from pycuda import gpuarray as garray from pycuda.tools import clear_context_caches from pycuda.compiler import get_nvcc_version as pycuda_get_nvcc_version __has_pycuda__ = True __pycuda_error_msg__ = None if pycuda.VERSION[0] < 2020: print("Error: need pycuda >= 2020.1") __has_pycuda__ = False except ImportError as err: __has_pycuda__ = False __pycuda_error_msg__ = str(err) try: import cupy __has_cupy__ = True except ImportError: __has_cupy__ = False def get_cuda_context(device_id=None, cleanup_at_exit=True): """ Create or get a CUDA context. """ current_ctx = cuda.Context.get_current() # If a context already exists, use this one # TODO what if the device used is different from device_id ? if current_ctx is not None: return current_ctx # Otherwise create a new context cuda.init() if device_id is None: device_id = 0 # Use the Context obtained by retaining the device's primary context, # which is the one used by the CUDA runtime API (ex. scikit-cuda). # Unlike Context.make_context(), the newly-created context is not made current. context = cuda.Device(device_id).retain_primary_context() context.push() # Register a clean-up function at exit def _finish_up(context): if context is not None: context.pop() context = None clear_context_caches() if cleanup_at_exit: atexit.register(_finish_up, context) return context def count_cuda_devices(): if cuda.Context.get_current() is None: cuda.init() return cuda.Device.count() def get_gpu_memory(device_id): """ Return the total memory (in GigaBytes) of a device. """ cuda.init() return cuda.Device(device_id).total_memory() / 1e9 def is_gpu_usable(): """ Test whether at least one Nvidia GPU is available. """ try: n_gpus = count_cuda_devices() except Exception as exc: # Fragile if exc.__str__() != "cuInit failed: no CUDA-capable device is detected": raise n_gpus = 0 res = n_gpus > 0 return res def detect_cuda_gpus(): """ Detect the available Nvidia CUDA GPUs on the current host. Returns -------- gpus: dict Dictionary where the key is the GPU ID, and the value is a `pycuda.driver.Device` object. error_msg: str In the case where there is an error, the message is returned in this item. Otherwise, it is a None object. """ gpus = {} error_msg = None if not (__has_pycuda__): return {}, __pycuda_error_msg__ try: cuda.init() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg try: n_gpus = cuda.Device.count() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg for i in range(n_gpus): gpus[i] = cuda.Device(i) return gpus, None def collect_cuda_gpus(): """ Return a dictionary of GPU ids and brief description of each CUDA-compatible GPU with a few fields. """ gpus, error_msg = detect_cuda_gpus() if error_msg is not None: return None cuda_gpus = {} for gpu_id, gpu in gpus.items(): cuda_gpus[gpu_id] = GPUDescription(gpu).get_dict() return cuda_gpus def get_nvcc_version(nvcc_cmd="nvcc"): try: ver = "".join(pycuda_get_nvcc_version(nvcc_cmd)).split("release")[1].strip().split(" ")[0].strip(",") except: ver = None return ver def check_textures_availability(): """ Check whether Cuda textures can be used. The only limitation is pycuda which does not support texture objects. Textures references were deprecated, and removed from Cuda 12. """ nvcc_ver = get_nvcc_version() if nvcc_ver is None: return False # unknown - can't parse NVCC version for some reason nvcc_major = int(nvcc_ver.split(".")[0]) return nvcc_major < 12 """ pycuda/driver.py np.complex64: SIGNED_INT32, num_channels = 2 np.float64: SIGNED_INT32, num_channels = 2 np.complex128: array_format.SIGNED_INT32, num_channels = 4 double precision: pycuda-helpers.hpp: typedef float fp_tex_float; // --> float32 typedef int2 fp_tex_double; // --> float64 typedef uint2 fp_tex_cfloat; // --> complex64 typedef int4 fp_tex_cdouble; // --> complex128 """ def cuarray_format_to_dtype(cuarr_fmt): # reverse of cuda.dtype_to_array_format fmt = cuda.array_format mapping = { fmt.UNSIGNED_INT8: np.uint8, fmt.UNSIGNED_INT16: np.uint16, fmt.UNSIGNED_INT32: np.uint32, fmt.SIGNED_INT8: np.int8, fmt.SIGNED_INT16: np.int16, fmt.SIGNED_INT32: np.int32, fmt.FLOAT: np.float32, } if cuarr_fmt not in mapping: raise TypeError("Unknown format %s" % cuarr_fmt) return mapping[cuarr_fmt] def cuarray_shape_dtype(cuarray): desc = cuarray.get_descriptor_3d() shape = (desc.height, desc.width) if desc.depth > 0: shape = (desc.depth,) + shape dtype = cuarray_format_to_dtype(desc.format) return shape, dtype def get_shape_dtype(arr): if isinstance(arr, garray.GPUArray) or isinstance(arr, np.ndarray): return arr.shape, arr.dtype elif isinstance(arr, cuda.Array): return cuarray_shape_dtype(arr) else: raise ValueError("Unknown array type %s" % str(type(arr))) def copy_array(dst, src, check=False, src_dtype=None, dst_x_in_bytes=0, dst_y=0): """ Copy a source array to a destination array. Source and destination can be either numpy.ndarray, pycuda.Driver.Array, or pycuda.gpuarray.GPUArray. Parameters ----------- dst: pycuda.driver.Array or pycuda.gpuarray.GPUArray or numpy.ndarray Destination array. Its content will be overwritten by copy. src: pycuda.driver.Array or pycuda.gpuarray.GPUArray or numpy.ndarray Source array. check: bool, optional Whether to check src and dst shape and data type. """ shape_src, dtype_src = get_shape_dtype(src) shape_dst, dtype_dst = get_shape_dtype(dst) dtype_src = src_dtype or dtype_src if check: if shape_src != shape_dst: raise ValueError("shape_src != shape_dst : have %s and %s" % (str(shape_src), str(shape_dst))) if dtype_src != dtype_dst: raise ValueError("dtype_src != dtype_dst : have %s and %s" % (str(dtype_src), str(dtype_dst))) if len(shape_src) == 2: copy = cuda.Memcpy2D() h, w = shape_src elif len(shape_src) == 3: copy = cuda.Memcpy3D() d, h, w = shape_src copy.depth = d else: raise ValueError("Expected arrays with 2 or 3 dimensions") if isinstance(src, cuda.Array): copy.set_src_array(src) elif isinstance(src, garray.GPUArray): copy.set_src_device(src.gpudata) else: # numpy copy.set_src_host(src) if isinstance(dst, cuda.Array): copy.set_dst_array(dst) # Support offset (x, y) in target (for copying to texture) copy.dst_x_in_bytes = dst_x_in_bytes copy.dst_y = dst_y elif isinstance(dst, garray.GPUArray): copy.set_dst_device(dst.gpudata) else: # numpy copy.set_dst_host(dst) copy.width_in_bytes = copy.dst_pitch = w * np.dtype(dtype_src).itemsize copy.dst_height = copy.height = h # ?? if len(shape_src) == 2: copy(True) else: copy() ### def copy_big_gpuarray(dst, src, itemsize=4, checks=False): """ Copy a big `pycuda.gpuarray.GPUArray` into another. Transactions of more than 2**32 -1 octets fail, so are doing several partial copies of smaller arrays. """ d2h = isinstance(dst, np.ndarray) if checks: assert dst.dtype == src.dtype assert dst.shape == src.shape limit = 2**32 - 1 if np.prod(dst.shape) * itemsize < limit: if d2h: src.get(ary=dst) else: dst[:] = src[:] return def get_shape2(shape): shape2 = list(shape) while np.prod(shape2) * 4 > limit: shape2[0] //= 2 return tuple(shape2) shape2 = get_shape2(dst.shape) nz0 = dst.shape[0] nz = shape2[0] n_transfers = ceil(nz0 / nz) for i in range(n_transfers): zmax = min((i + 1) * nz, nz0) if d2h: src[i * nz : zmax].get(ary=dst[i * nz : zmax]) else: dst[i * nz : zmax] = src[i * nz : zmax] def replace_array_memory(arr, new_shape): """ Replace the underlying buffer data of a `pycuda.gpuarray.GPUArray`. This function is dangerous ! It should merely be used to clear memory, the array should not be used afterwise. """ arr.gpudata.free() arr.gpudata = arr.allocator(int(np.prod(new_shape) * arr.dtype.itemsize)) arr.shape = new_shape # TODO re-compute strides return arr def pycuda_to_cupy(arr_pycuda): arr_cupy_mem = cupy.cuda.UnownedMemory(arr_pycuda.ptr, arr_pycuda.size, arr_pycuda) arr_cupy_memptr = cupy.cuda.MemoryPointer(arr_cupy_mem, offset=0) return cupy.ndarray(arr_pycuda.shape, dtype=arr_pycuda.dtype, memptr=arr_cupy_memptr) # pylint: disable=E1123 def cupy_to_pycuda(arr_cupy): return garray.empty(arr_cupy.shape, arr_cupy.dtype, gpudata=arr_cupy.data.ptr) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/estimation/0000755000175000017500000000000014730277752015404 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/estimation/__init__.py0000644000175000017500000000057314402565210017503 0ustar00pierrepierrefrom .alignment import AlignmentBase from .cor import ( CenterOfRotation, CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow, CenterOfRotationAdaptiveSearch, ) from .cor_sino import SinoCor from .distortion import estimate_flat_distortion from .focus import CameraFocus from .tilt import CameraTilt from .translation import DetectorTranslationAlongBeam ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1733732263.0 nabu-2024.2.1/nabu/estimation/alignment.py0000644000175000017500000005130114725523647017735 0ustar00pierrepierre# import math import logging import numpy as np from tqdm import tqdm from numpy.polynomial.polynomial import Polynomial from silx.math.medianfilter import medfilt2d import scipy.fft # pylint: disable=E0611 from ..utils import previouspow2 from ..misc import fourier_filters from ..resources.logger import LoggerOrPrint try: import matplotlib.pyplot as plt __have_matplotlib__ = True except ImportError: logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled") plt = None __have_matplotlib__ = False def progress_bar(x, verbose=True): if verbose: return tqdm(x) else: return x local_fftn = scipy.fft.rfftn local_ifftn = scipy.fft.irfftn class AlignmentBase: default_extra_options = {"blocking_plots": False} _default_cor_options = {} def __init__( self, vert_fft_width=False, horz_fft_width=False, verbose=False, logger=None, data_type=np.float32, extra_options=None, ): """ Alignment basic functions. Parameters ---------- vert_fft_width: boolean, optional If True, restrict the vertical size to a power of 2: >>> new_v_dim = 2 ** math.floor(math.log2(v_dim)) horz_fft_width: boolean, optional If True, restrict the horizontal size to a power of 2: >>> new_h_dim = 2 ** math.floor(math.log2(h_dim)) verbose: boolean, optional When True it will produce verbose output, including plots. data_type: `numpy.float32` Computation data type. """ self._init_parameters(vert_fft_width, horz_fft_width, verbose, logger, data_type, extra_options=extra_options) self._plot_windows = {} def _init_parameters(self, vert_fft_width, horz_fft_width, verbose, logger, data_type, extra_options=None): self.logger = LoggerOrPrint(logger) self.truncate_vert_pow2 = vert_fft_width self.truncate_horz_pow2 = horz_fft_width if verbose and not __have_matplotlib__: self.logger.warning("Matplotlib not available. Plotting disabled, despite being activated by user") verbose = False self.verbose = verbose self.data_type = data_type self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) @staticmethod def _check_img_stack_size(img_stack: np.ndarray, img_pos: np.ndarray): shape_stack = np.squeeze(img_stack).shape shape_pos = np.squeeze(img_pos).shape if not len(shape_stack) == 3: raise ValueError( "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack))) ) if not len(shape_pos) == 1: raise ValueError( "Positions need to be a 1-dimensional array. Shape of the positions variable: %s" % (" ".join(("%d" % x for x in shape_pos))) ) if not shape_stack[0] == shape_pos[0]: raise ValueError( "The same number of images and positions is required." + " Shape of stack: %s, shape of positions variable: %s" % ( " ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)), ) ) @staticmethod def _check_img_pair_sizes(img_1: np.ndarray, img_2: np.ndarray): shape_1 = np.squeeze(img_1).shape shape_2 = np.squeeze(img_2).shape if not len(shape_1) == 2: raise ValueError( "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1))) ) if not len(shape_2) == 2: raise ValueError( "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2))) ) if not np.all(shape_1 == shape_2): raise ValueError( "Images need to be of the same shape. Shape of image #1: %s, image #2: %s" % ( " ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)), ) ) @staticmethod def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None): """Computes the sub-pixel max position of the given function sampling. Parameters ---------- f_vals: numpy.ndarray Function values of the sampled points fy: numpy.ndarray, optional Vertical coordinates of the sampled points fx: numpy.ndarray, optional Horizontal coordinates of the sampled points Raises ------ ValueError In case position and values do not have the same size, or in case the fitted maximum is outside the fitting region. Returns ------- tuple(float, float) Estimated (vertical, horizontal) function max, according to the coordinates in fy and fx. """ if not (len(f_vals.shape) == 2): raise ValueError( "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in f_vals.shape))) ) if fy is None: fy_half_size = (f_vals.shape[0] - 1) / 2 fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0]) elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])): raise ValueError( "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]" % (fy.size, " ".join(("%d" % s for s in f_vals.shape))) ) if fx is None: fx_half_size = (f_vals.shape[1] - 1) / 2 fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1]) elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])): raise ValueError( "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]" % (fx.size, " ".join(("%d" % s for s in f_vals.shape))) ) fy, fx = np.meshgrid(fy, fx, indexing="ij") fy = fy.flatten() fx = fx.flatten() coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy**2, fx**2]) coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0] # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is: # x_v = -b / 2a. For a 2D parabola, the vertex position is: # (y, x)_v = - b / A, where: A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]] b = coeffs[1:3] vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0] vertex_min_yx = [np.min(fy), np.min(fx)] vertex_max_yx = [np.max(fy), np.max(fx)] if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx): raise ValueError( "Fitted (y: {}, x: {}) positions are outside the input margins y: [{}, {}], and x: [{}, {}]".format( vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1], ) ) return vertex_yx @staticmethod def refine_max_position_1d(f_vals, fx=None, return_vertex_val=False, return_all_coeffs=False): """Computes the sub-pixel max position of the given function sampling. Parameters ---------- f_vals: numpy.ndarray Function values of the sampled points fx: numpy.ndarray, optional Coordinates of the sampled points return_vertex_val: boolean, option Enables returning the vertex values. Defaults to False. Raises ------ ValueError In case position and values do not have the same size, or in case the fitted maximum is outside the fitting region. Returns ------- float Estimated function max, according to the coordinates in fx. """ if not len(f_vals.shape) in (1, 2): raise ValueError( "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in f_vals.shape))) ) num_vals = f_vals.shape[0] if fx is None: fx_half_size = (num_vals - 1) / 2 fx = np.linspace(-fx_half_size, fx_half_size, num_vals) else: fx = np.squeeze(fx) if not (len(fx.shape) == 1 and np.all(fx.size == num_vals)): raise ValueError( "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d" % (fx.size, num_vals) ) if len(f_vals.shape) == 1: # using Polynomial.fit, because supposed to be more numerically # stable than previous solutions (according to numpy). poly = Polynomial.fit(fx, f_vals, deg=2) coeffs = poly.convert().coef else: coords = np.array([np.ones(num_vals), fx, fx**2]) coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0] # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is: # x_v = -b / 2a. vertex_x = -coeffs[1, :] / (2 * coeffs[2, :]) if not return_all_coeffs: vertex_x = vertex_x[0] vertex_min_x = np.min(fx) vertex_max_x = np.max(fx) lower_bound_ok = vertex_min_x < vertex_x upper_bound_ok = vertex_x < vertex_max_x if not np.all(lower_bound_ok * upper_bound_ok): if len(f_vals.shape) == 1: message = "Fitted position {} is outide the input margins [{}, {}]".format( vertex_x, vertex_min_x, vertex_max_x ) else: message = "Fitted positions outside the input margins [{}, {}]: {} below and {} above".format( vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok), ) raise ValueError(message) if return_vertex_val: vertex_val = coeffs[0, :] + vertex_x * coeffs[1, :] / 2 return vertex_x, vertex_val else: return vertex_x @staticmethod def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None): """ Extracts a region around the maximum value. Parameters ---------- cc: numpy.ndarray Correlation image. peak_radius: int, optional The l_inf radius of the area to extract around the peak. The default is 1. cc_vs: numpy.ndarray, optional The vertical coordinates of `cc`. The default is None. cc_hs: numpy.ndarray, optional The horizontal coordinates of `cc`. The default is None. Returns ------- f_vals: numpy.ndarray The extracted function values. fv: numpy.ndarray The vertical coordinates of the extracted values. fh: numpy.ndarray The horizontal coordinates of the extracted values. """ img_shape = np.array(cc.shape) # get pixel having the maximum value of the correlation array pix_max_corr = np.argmax(cc) pv, ph = np.unravel_index(pix_max_corr, img_shape) # select a n x n neighborhood for the sub-pixel fitting (with wrapping) pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2] ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1] # extract the (v, h) pixel coordinates fv = None if cc_vs is None else cc_vs[pv] fh = None if cc_hs is None else cc_hs[ph] # extract the correlation values pv, ph = np.meshgrid(pv, ph, indexing="ij") f_vals = cc[pv, ph] return (f_vals, fv, fh) @staticmethod def extract_peak_regions_1d(cc, axis=-1, peak_radius=1, cc_coords=None): """ Extracts a region around the maximum value. Parameters ---------- cc: numpy.ndarray Correlation image. axis: int, optional Find the max values along the specified direction. The default is -1. peak_radius: int, optional The l_inf radius of the area to extract around the peak. The default is 1. cc_coords: numpy.ndarray, optional The coordinates of `cc` along the selected axis. The default is None. Returns ------- f_vals: numpy.ndarray The extracted function values. fc_ax: numpy.ndarray The coordinates of the extracted values, along the selected axis. """ if len(cc.shape) == 1: cc = cc[None, ...] img_shape = np.array(cc.shape) if not (len(img_shape) == 2): raise ValueError( "The input image should be either a 1 or 2-dimensional array. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in cc.shape))) ) other_axis = (axis + 1) % 2 # get pixel having the maximum value of the correlation array pix_max = np.argmax(cc, axis=axis) # select a n neighborhood for the many 1D sub-pixel fittings (with wrapping) p_ax_range = np.arange(-peak_radius, +peak_radius + 1) p_ax = (pix_max[None, :] + p_ax_range[:, None]) % img_shape[axis] p_ln = np.tile(np.arange(0, img_shape[other_axis])[None, :], [2 * peak_radius + 1, 1]) # extract the pixel coordinates along the axis fc_ax = None if cc_coords is None else cc_coords[p_ax.flatten()].reshape(p_ax.shape) # extract the correlation values if other_axis == 0: f_vals = cc[p_ln, p_ax] else: f_vals = cc[p_ax, p_ln] return (f_vals, fc_ax) def _determine_roi(self, img_shape, roi_yxhw): if roi_yxhw is None: # vertical and horizontal window sizes are reduced to a power of 2 # to accelerate fft if requested. Default is not. roi_yxhw = previouspow2(img_shape) if not self.truncate_vert_pow2: roi_yxhw[0] = img_shape[0] if not self.truncate_horz_pow2: roi_yxhw[1] = img_shape[1] roi_yxhw = np.array(roi_yxhw, dtype=np.intp) if len(roi_yxhw) == 2: # Convert centered 2-element roi into 4-element roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw)) return roi_yxhw def _prepare_image( self, img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None, ): """ Prepare and returns a cropped and filtered image, or array of filtered images if the input is an array of images. Parameters ---------- img: numpy.ndarray image or stack of images invalid_val: float value to be used in replacement of nan and inf values median_filt_shape: int or sequence of int the width or the widths of the median window low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` Returns ------- numpy.array_like The computed filter """ img = np.squeeze(img) # Removes singleton dimensions, but does a shallow copy img = np.ascontiguousarray(img, dtype=self.data_type) if roi_yxhw is not None: img = img[ ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3], ] img = img.copy() img[np.isnan(img)] = invalid_val img[np.isinf(img)] = invalid_val if high_pass is not None or low_pass is not None: img_filter = fourier_filters.get_bandpass_filter( img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) # fft2 and iff2 use axes=(-2, -1) by default img = local_ifftn(local_fftn(img, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real if median_filt_shape is not None: img_shape = img.shape # expanding filter shape with ones, to cover the stack of images # but disabling inter-image filtering median_filt_shape = np.concatenate( ( np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.intp), median_filt_shape, ) ) img = medfilt2d(img, kernel_size=median_filt_shape) return img def _transform_to_fft( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None ): do_circular_conv = padding_mode is None or padding_mode == "wrap" img_shape = img_2.shape if not do_circular_conv: pad_size = np.ceil(np.array(img_shape) / 2).astype(np.intp) pad_array = [(0,)] * len(img_shape) for a in axes: pad_array[a] = (pad_size[a],) img_1 = np.pad(img_1, pad_array, mode=padding_mode) img_2 = np.pad(img_2, pad_array, mode=padding_mode) else: pad_size = None img_shape = img_2.shape # compute fft's of the 2 images img_fft_1 = local_fftn(img_1, axes=axes) img_fft_2 = local_fftn(img_2, axes=axes) if low_pass is not None or high_pass is not None: filt = fourier_filters.get_bandpass_filter( img_shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) else: filt = None return img_fft_1, img_fft_2, filt, pad_size def _compute_correlation_fft( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None ): img_fft_1, img_fft_2, filt, pad_size = self._transform_to_fft( img_1, img_2, padding_mode=padding_mode, axes=axes, low_pass=low_pass, high_pass=high_pass ) img_prod = img_fft_1 * np.conjugate(img_fft_2) if filt is not None: img_prod *= filt # inverse fft of the product to get cross_correlation of the 2 images cc = np.real(local_ifftn(img_prod, axes=axes)) if pad_size is not None: cc_shape = cc.shape cc = np.fft.fftshift(cc, axes=axes) slicing = [slice(None)] * len(cc_shape) for a in axes: slicing[a] = slice(pad_size[a], cc_shape[a] - pad_size[a]) cc = cc[tuple(slicing)] cc = np.fft.ifftshift(cc, axes=axes) return cc def _add_plot_window(self, fig, ax=None): self._plot_windows[fig.number] = {"figure": fig, "axes": ax} def close_plot_window(self, n, errors="raise"): """ Close a plot window. Applicable only if the class was instantiated with verbose=True. Parameters ---------- n: int Figure number to close errors: str, optional What to do with errors. It can be either "raise", "log" or "ignore". """ if not self.verbose: return if n not in self._plot_windows: msg = "Cannot close plot window number %d: no such window" % n if errors == "raise": raise ValueError(msg) elif errors == "log": self.logger.error(msg) fig_ax = self._plot_windows.pop(n) plt.close(fig_ax["figure"].number) # would also work with the object itself def close_last_plot_windows(self, n=1): """ Close the last "n" plot windows. Applicable only if the class was instanciated with verbose=True. Parameters ----------- n: int, optional Integer indicating how many plot windows should be closed. """ figs_nums = sorted(self._plot_windows.keys(), reverse=True) n = min(n, len(figs_nums)) for i in range(n): self.close_plot_window(figs_nums[i], errors="ignore") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/estimation/cor.py0000644000175000017500000014235714726604214016545 0ustar00pierrepierreimport math import numpy as np from ..utils import deprecated_class, deprecation_warning, is_scalar from ..misc import fourier_filters from .alignment import AlignmentBase, plt, progress_bar, local_fftn, local_ifftn # three possible values for the validity check, which can optionally be returned by the find_shifts methods cor_result_validity = { "unknown": "unknown", "sound": "sound", "correct": "sound", "questionable": "questionable", } class CenterOfRotation(AlignmentBase): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, side=None, shift_axis: int = -1, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, return_relative_to_middle=None, ): """Find the Center of Rotation (CoR), given two images. This method finds the half-shift between two opposite images, by means of correlation computed in Fourier space. The output of this function, allows to compute motor movements for aligning the sample rotation axis. Given the following values: - L1: distance from source to motor - L2: distance from source to detector - ps: physical pixel size - v: output of this function displacement of motor = (L1 / L2 * ps) * v Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). shift_axis: int Axis along which we want the shift to be computed. Default is -1 (horizontal). roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` return_validity: a boolean, defaults to false if set to True adds a second return value which may have three string values. These values are "unknown", "sound", "questionable". It will be "uknown" if the validation method is not implemented and it will be "sound" or "questionable" if it is implemented. Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- float Estimated center of rotation position from the center of the RoI in pixels. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... CoR_calc = CenterOfRotation() ... cor_position = CoR_calc.find_shift(radio1, radio2) Or for noisy images: >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3)) """ # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotation.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass) img_shape = img_2.shape cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2]) cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs) fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh) estimated_cor = fitted_shifts_vh[shift_axis] / 2.0 if is_scalar(side): near_pos = side - (img_1.shape[-1] - 1) / 2 if ( np.abs(near_pos - estimated_cor) / near_pos > 0.2 ): # For comparison, near_pos is RELATIVE to the middle of image (as estimated_cor is). validity_check_result = cor_result_validity["questionable"] else: validity_check_result = cor_result_validity["sound"] else: validity_check_result = cor_result_validity["unknown"] if not (return_relative_to_middle): estimated_cor += (img_1.shape[-1] - 1) / 2 if return_validity: return estimated_cor, validity_check_result else: return estimated_cor class CenterOfRotationSlidingWindow(CenterOfRotation): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, side="center", window_width=None, roi_yxhw=None, median_filt_shape=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, return_relative_to_middle=None, ): """Semi-automatically find the Center of Rotation (CoR), given two images or sinograms. Suitable for half-aquisition scan. This method finds the half-shift between two opposite images, by minimizing difference over a moving window. Parameters and usage is the same as CenterOfRotation, except for the following two parameters. Parameters ---------- side: string or float, optional Expected region of the CoR. Allowed values: 'left', 'center' or 'right'. Default is 'center' window_width: int, optional Width of window that will slide on the other image / part of the sinogram. Default is None. """ # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationSlidingWindow.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- validity_check_result = cor_result_validity["unknown"] if side is None: raise ValueError("Side should be one of 'left', 'right', 'center' or a scalar. 'None' was given instead") self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image( img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_2 = self._prepare_image( img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape img_width = img_shape[-1] if isinstance(side, str): if window_width is None: if side == "center": window_width = round(img_width / 4.0 * 3.0) else: window_width = round(img_width / 10) window_shift = window_width // 2 window_width = window_shift * 2 + 1 if side == "right": win_2_start = 0 elif side == "left": win_2_start = img_width - window_width else: win_2_start = img_width // 2 - window_shift else: abs_pos = int(side + img_width // 2) window_fraction = 0.1 # Hard-coded ? window_width = round(window_fraction * img_width) window_shift = window_width // 2 window_width = window_shift * 2 + 1 win_2_start = np.clip(abs_pos - window_shift, 0, img_width // 2 - 1) win_2_start = img_width // 2 - 1 - win_2_start win_1_start_seed = 0 # number of pixels where the window will "slide". n = img_width - window_width win_2_end = win_2_start + window_width diffs_mean = np.zeros((n,), dtype=img_1.dtype) diffs_std = np.zeros((n,), dtype=img_1.dtype) for ii in progress_bar(range(n), verbose=self.verbose): win_1_start = win_1_start_seed + ii win_1_end = win_1_start + window_width img_diff = img_1[:, win_1_start:win_1_end] - img_2[:, win_2_start:win_2_end] diffs_abs = np.abs(img_diff) diffs_mean[ii] = diffs_abs.mean() diffs_std[ii] = diffs_abs.std() diffs_mean = diffs_mean.min() - diffs_mean win_ind_max = np.argmax(diffs_mean) diffs_std = diffs_std.min() - diffs_std if not win_ind_max == np.argmax(diffs_std): self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] else: validity_check_result = cor_result_validity["sound"] (f_vals, f_pos) = self.extract_peak_regions_1d(diffs_mean, peak_radius=peak_fit_radius) win_pos_max, win_val_max = self.refine_max_position_1d(f_vals, return_vertex_val=True) # Derive the COR if is_scalar(side): cor_h = -(win_2_start - (win_1_start_seed + win_ind_max + win_pos_max)) / 2.0 cor_pos = -(win_2_start - (win_1_start_seed + np.arange(n))) / 2.0 else: cor_h = -(win_2_start - (win_ind_max + win_pos_max)) / 2.0 cor_pos = -(win_2_start - np.arange(n)) / 2.0 if (side == "right" and win_ind_max == 0) or (side == "left" and win_ind_max == n): self.logger.warning("Sliding window width %d might be too large!" % window_width) if self.verbose: print("Lowest difference window: index=%d, range=[0, %d]" % (win_ind_max, n)) print("CoR tested for='%s', found at voxel=%g (from center)" % (side, cor_h)) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(cor_pos, diffs_mean, label="Mean difference") ax.stem(cor_h, win_val_max, linefmt="C1-", markerfmt="C1o", label="Best mean difference") ax.stem(cor_pos, -diffs_std, linefmt="C2-", markerfmt="C2o", label="Std-dev difference") ax.set_title("Window dispersions") plt.legend() plt.show(block=False) if not (return_relative_to_middle): cor_h += (img_width - 1) / 2.0 if return_validity: return cor_h, validity_check_result else: return cor_h class CenterOfRotationGrowingWindow(CenterOfRotation): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, side="all", min_window_width=11, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, return_relative_to_middle=None, ): """Automatically find the Center of Rotation (CoR), given two images or sinograms. Suitable for half-aquisition scan. This method finds the half-shift between two opposite images, by minimizing difference over a moving window. Usage and parameters are the same as CenterOfRotationSlidingWindow, except for the following parameter. Parameters ---------- min_window_width: int, optional Minimum window width that covers the common region of the two images / sinograms. Default is 11. """ # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationGrowingWindow.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- validity_check_result = cor_result_validity["unknown"] self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image( img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_2 = self._prepare_image( img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape def window_bounds(mid_point, window_max_width=img_shape[-1]): return ( np.fmax(np.ceil(mid_point - window_max_width / 2), 0).astype(np.intp), np.fmin(np.ceil(mid_point + window_max_width / 2), img_shape[-1]).astype(np.intp), ) img_lower_half_size = np.floor(img_shape[-1] / 2).astype(np.intp) img_upper_half_size = np.ceil(img_shape[-1] / 2).astype(np.intp) if is_scalar(side): self.logger.error( "Passing a first CoR guess is not supported for CenterOfRotationGrowingWindow. Using side='all'." ) side = "all" if side.lower() == "right": win_1_mid_start = img_lower_half_size win_1_mid_end = np.floor(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width win_2_mid_start = -img_upper_half_size + min_window_width win_2_mid_end = img_upper_half_size elif side.lower() == "left": win_1_mid_start = -img_lower_half_size + min_window_width win_1_mid_end = img_lower_half_size win_2_mid_start = img_upper_half_size win_2_mid_end = np.ceil(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width elif side.lower() == "center": win_1_mid_start = 0 win_1_mid_end = img_shape[-1] win_2_mid_start = 0 win_2_mid_end = img_shape[-1] elif side.lower() == "all": win_1_mid_start = -img_lower_half_size + min_window_width win_1_mid_end = np.floor(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width win_2_mid_start = -img_upper_half_size + min_window_width win_2_mid_end = np.ceil(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width else: raise ValueError( "Side should be one of 'left', 'right', or 'center' or 'all'. '%s' was given instead" % side.lower() ) n1 = win_1_mid_end - win_1_mid_start n2 = win_2_mid_end - win_2_mid_start if not n1 == n2: raise ValueError( "Internal error: the number of window steps for the two images should be the same." + "Found the following configuration instead => Side: %s, #1: %d, #2: %d" % (side, n1, n2) ) diffs_mean = np.zeros((n1,), dtype=img_1.dtype) diffs_std = np.zeros((n1,), dtype=img_1.dtype) for ii in progress_bar(range(n1), verbose=self.verbose): win_1 = window_bounds(win_1_mid_start + ii) win_2 = window_bounds(win_2_mid_end - ii) img_diff = img_1[:, win_1[0] : win_1[1]] - img_2[:, win_2[0] : win_2[1]] diffs_abs = np.abs(img_diff) diffs_mean[ii] = diffs_abs.mean() diffs_std[ii] = diffs_abs.std() diffs_mean = diffs_mean.min() - diffs_mean win_ind_max = np.argmax(diffs_mean) diffs_std = diffs_std.min() - diffs_std if not win_ind_max == np.argmax(diffs_std): self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] else: validity_check_result = cor_result_validity["sound"] (f_vals, f_pos) = self.extract_peak_regions_1d(diffs_mean, peak_radius=peak_fit_radius) win_pos_max, win_val_max = self.refine_max_position_1d(f_vals, return_vertex_val=True) cor_h = (win_1_mid_start + (win_ind_max + win_pos_max) - img_upper_half_size) / 2.0 if (side.lower() == "right" and win_ind_max == 0) or (side.lower() == "left" and win_ind_max == n1): self.logger.warning("Minimum growing window width %d might be too large!" % min_window_width) if self.verbose: cor_pos = (win_1_mid_start + np.arange(n1) - img_upper_half_size) / 2.0 self.logger.info("Lowest difference window: index=%d, range=[0, %d]" % (win_ind_max, n1)) self.logger.info("CoR tested for='%s', found at voxel=%g (from center)" % (side, cor_h)) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(cor_pos, diffs_mean, label="Mean difference") ax.stem(cor_h, win_val_max, linefmt="C1-", markerfmt="C1o", label="Best mean difference") ax.stem(cor_pos, -diffs_std, linefmt="C2-", markerfmt="C2o", label="Std-dev difference") ax.set_title("Window dispersions") plt.show(block=False) if not (return_relative_to_middle): cor_h += (img_shape[-1] - 1) / 2.0 if return_validity: return cor_h, validity_check_result else: return cor_h class CenterOfRotationAdaptiveSearch(CenterOfRotation): """This adaptive method works by applying a gaussian which highlights, by apodisation, a region which can possibly contain the good center of rotation. The whole image is spanned during several applications of the apodisation. At each application the apodisation function, which is a gaussian, is moved to a new guess position. The lenght of the step, by which the gaussian is moved, and its sigma are obtained by multiplying the shortest distance from the left or right border with a self.step_fraction and self.sigma_fraction factors which ensure global overlapping. for each step a region around the CoR of each image is selected, and the regions of the two images are compared to calculate a cost function. The value of the cost function, at its minimum is used to select the best step at which the CoR is taken as final result. The option filtered_cost= True (default) triggers the filtering (according to low_pass and high_pass) of the two images which are used for he cost function. ( Note: the low_pass and high_pass options are used, if given, also without the filtered_cost option, by being passed to the base class CenterOfRotation ) """ sigma_fraction = 1.0 / 4.0 step_fraction = 1.0 / 6.0 def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, roi_yxhw=None, median_filt_shape=None, padding_mode=None, high_pass=None, low_pass=None, margins=None, filtered_cost=True, return_validity=False, return_relative_to_middle=None, ): """Find the Center of Rotation (CoR), given two images. This method finds the half-shift between two opposite images, by means of correlation computed in Fourier space. A global search is done on on the detector span (minus a margin) without assuming centered scan conditions. Usage and parameters are the same as CenterOfRotation, except for the following parameters. Parameters ---------- margins: None or a couple of floats or ints if margins is None or in the form of (margin1,margin2) the search is done between margin1 and dim_x-1-margin2. If left to None then by default (margin1,margin2) = ( 10, 10 ). filtered_cost: boolean. True by default. It triggers the use of filtered images in the calculation of the cost function. """ # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationAdaptiveSearch.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- validity_check_result = cor_result_validity["unknown"] self._check_img_pair_sizes(img_1, img_2) used_type = img_1.dtype roi_yxhw = self._determine_roi(img_1.shape, roi_yxhw) if filtered_cost and (low_pass is not None or high_pass is not None): img_filter = fourier_filters.get_bandpass_filter( img_1.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) # fft2 and iff2 use axes=(-2, -1) by default img_filtered_1 = local_ifftn(local_fftn(img_1, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real img_filtered_2 = local_ifftn(local_fftn(img_2, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real else: img_filtered_1 = img_1 img_filtered_2 = img_2 img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_filtered_1 = self._prepare_image(img_filtered_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_filtered_2 = self._prepare_image(img_filtered_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) dim_radio = img_1.shape[1] if margins is None: lim_1, lim_2 = 10, dim_radio - 1 - 10 else: lim_1, lim_2 = margins lim_2 = dim_radio - 1 - lim_2 if lim_1 < 1: lim_1 = 1 if lim_2 > dim_radio - 2: lim_2 = dim_radio - 2 if lim_2 <= lim_1: message = ( "Image shape or cropped selection too small for global search." + " After removal of the margins the search limits collide." + " The cropped size is %d\n" % (dim_radio) ) raise ValueError(message) found_centers = [] x_cor = lim_1 while x_cor < lim_2: tmp_sigma = ( min( (img_1.shape[1] - x_cor), (x_cor), ) * self.sigma_fraction ) tmp_x = (np.arange(img_1.shape[1]) - x_cor) / tmp_sigma apodis = np.exp(-tmp_x * tmp_x / 2.0) x_cor_rel = x_cor - (img_1.shape[1] // 2) img_1_apodised = img_1 * apodis try: cor_position = CenterOfRotation.find_shift( self, img_1_apodised.astype(used_type), img_2.astype(used_type), low_pass=low_pass, high_pass=high_pass, roi_yxhw=roi_yxhw, return_relative_to_middle=True, ) except ValueError as err: if "positions are outside the input margins" in str(err): x_cor = min(x_cor + x_cor * self.step_fraction, x_cor + (dim_radio - x_cor) * self.step_fraction) continue except: message = "Unexpected error from base class CenterOfRotation.find_shift in CenterOfRotationAdaptiveSearch.find_shift : {err}".format( err=err ) self.logger.error(message) raise p_1 = cor_position * 2 if cor_position < 0: p_2 = img_2.shape[1] + cor_position * 2 else: p_2 = -img_2.shape[1] + cor_position * 2 if abs(x_cor_rel - p_1 / 2) < abs(x_cor_rel - p_2 / 2): cor_position = p_1 / 2 else: cor_position = p_2 / 2 cor_in_img = img_1.shape[1] // 2 + cor_position tmp_sigma = ( min( (img_1.shape[1] - cor_in_img), (cor_in_img), ) * self.sigma_fraction ) M1 = int(round(cor_position + img_1.shape[1] // 2)) - int(round(tmp_sigma)) M2 = int(round(cor_position + img_1.shape[1] // 2)) + int(round(tmp_sigma)) piece_1 = img_filtered_1[:, M1:M2] piece_2 = img_filtered_2[:, img_1.shape[1] - M2 : img_1.shape[1] - M1] if piece_1.size and piece_2.size: piece_1 = piece_1 - piece_1.mean() piece_2 = piece_2 - piece_2.mean() energy = np.array(piece_1 * piece_1 + piece_2 * piece_2, "d").sum() diff_energy = np.array((piece_1 - piece_2) * (piece_1 - piece_2), "d").sum() cost = diff_energy / energy if not np.isnan(cost): if tmp_sigma * 2 > abs(x_cor_rel - cor_position): found_centers.append([cost, abs(x_cor_rel - cor_position), cor_position, energy]) x_cor = min(x_cor + x_cor * self.step_fraction, x_cor + (dim_radio - x_cor) * self.step_fraction) if len(found_centers) == 0: message = "Unable to find any valid CoR candidate in {my_class}.find_shift ".format( my_class=self.__class__.__name__ ) raise ValueError(message) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Now build the neigborhood of the minimum as a list of five elements: # the minimum in the middle of the two before, and the two after filtered_found_centers = [] for i in range(len(found_centers)): if i > 0: if abs(found_centers[i][2] - found_centers[i - 1][2]) < 0.5: filtered_found_centers.append(found_centers[i]) continue if i + 1 < len(found_centers): if abs(found_centers[i][2] - found_centers[i + 1][2]) < 0.5: filtered_found_centers.append(found_centers[i]) continue if len(filtered_found_centers): found_centers = filtered_found_centers min_choice = min(found_centers) index_min_choice = found_centers.index(min_choice) min_neighborood = [ found_centers[i][2] if (i >= 0 and i < len(found_centers)) else math.nan for i in range(index_min_choice - 2, index_min_choice + 2 + 1) ] score_right = 0 for i_pos in [3, 4]: if abs(min_neighborood[i_pos] - min_neighborood[2]) < 0.5: score_right += 1 else: break score_left = 0 for i_pos in [1, 0]: if abs(min_neighborood[i_pos] - min_neighborood[2]) < 0.5: score_left += 1 else: break if score_left + score_right >= 2: validity_check_result = cor_result_validity["sound"] else: self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # An informative message in case one wish to look at how it has gone informative_message = " ".join( ["CenterOfRotationAdaptiveSearch found this neighborood of the optimal position:"] + [str(t) if not math.isnan(t) else "N.A." for t in min_neighborood] ) self.logger.debug(informative_message) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # The return value is the optimum which had been placed in the middle of the neighborood cor_position = min_neighborood[2] if not (return_relative_to_middle): cor_position += (img_1.shape[-1] - 1) / 2.0 if return_validity: return cor_position, validity_check_result else: return cor_position __call__ = find_shift class CenterOfRotationOctaveAccurate(CenterOfRotation): """This is a Python implementation of Octave/fastomo3/accurate COR estimator. The Octave 'accurate' function is renamed `local_correlation`. The Nabu standard `find_shift` has the same API as the other COR estimators (sliding, growing...) """ def _cut(self, im, nrows, ncols, new_center_row=None, new_center_col=None): """Cuts a sub-matrix out of a larger matrix. Cuts in the center of the original matrix, except if new center is specified NO CHECKING of validity indices sub-matrix! Parameters ---------- im : array. Original matrix nrows : int Number of rows in the output matrix. ncols : int Number of columns in the output matrix. new_center_row : int Index of center row around which to cut (default: None, i.e. center) new_center_col : int Index of center column around which to cut (default: None, i.e. center) Returns ------- nrows x ncols array. Examples -------- im_roi = cut(im, 1024, 1024) -> cut center 1024x1024 pixels im_roi = cut(im, 1024, 1024, 600.5, 700.5) -> cut 1024x1024 pixels around pixels (600-601, 700-701) Author: P. Cloetens 2023-11-06 J. Lesaint * See octave-archive for the original Octave code. * 2023-11-06: Python implementation. Comparison seems OK. """ [n, m] = im.shape if new_center_row is None: new_center_row = (n + 1) / 2 if new_center_col is None: new_center_col = (m + 1) / 2 rb = int(np.round(0.5 + new_center_row - nrows / 2)) rb = int(np.round(new_center_row - nrows / 2)) re = int(nrows + rb) cb = int(np.round(0.5 + new_center_col - ncols / 2)) cb = int(np.round(new_center_col - ncols / 2)) ce = int(ncols + cb) return im[rb:re, cb:ce] def _checkifpart(self, rapp, rapp_hist): res = 0 for k in range(rapp_hist.shape[0]): if np.allclose(rapp, rapp_hist[k, :]): res = 1 return res return res def _interpolate(self, input, shift, mode="mean", interpolation_method="linear"): """Applies to the input a translation by a vector `shift`. Based on `scipy.ndimage.affine_transform` function. JL: This Octave function was initially used in the refine clause of the local_correlation (Octave find_shift). Since find_shift is always called with refine=False in Octave, refine is not implemented (see local_interpolation()) and this function becomes useless. Parameters ---------- input : array Array to which the translation is applied. shift : tuple, list or array of length 2. mode : str Type of padding applied to the unapplicable areas of the output image. Default `mean` is a constant padding with the mean of the input array. `mode` must belong to 'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap' See `scipy.ndimage.affine_transform` for details. interpolation_method : str or int. The interpolation is based on spline interpolation. Either 0, 1, 2, 3, 4 or 5: order of the spline interpolation functions. Or one among 'linear','cubic','pchip','nearest','spline' (Octave legacy). 'nearest' is equivalent to 0 'linear' is equivalent to 1 'cubic','pchip','spline' are equivalent to 3. """ admissible_modes = ( "reflect", "grid-mirror", "constant", "grid-constant", "nearest", "mirror", "grid-wrap", "wrap", ) admissible_interpolation_methods = ("linear", "cubic", "pchip", "nearest", "spline") from scipy.ndimage import affine_transform [s0, s1] = shift matrix = np.zeros([2, 3], dtype=float) matrix[0, 0] = 1.0 matrix[1, 1] = 1.0 matrix[:, 2] = [-s0, -s1] # JL: due to transf. convention diff in Octave and scipy (push fwd vs pull back) if interpolation_method == "nearest": order = 0 elif interpolation_method == "linear": order = 1 elif interpolation_method in ("pchip", "cubic", "spline"): order = 3 elif interpolation_method in (0, 1, 2, 3, 4, 5): order = interpolation_method else: raise ValueError( f"Interpolation method is {interpolation_method} and should either an integer between 0 (inc.) and 5 (inc.) or in {admissible_interpolation_methods}." ) if mode == "mean": mode = "constant" cval = input.mean() return affine_transform(input, matrix, mode=mode, order=order, cval=cval) elif mode not in admissible_modes: raise ValueError(f"Pad method is {mode} and should be in {admissible_modes}.") return affine_transform(input, matrix, mode=mode, order=order) def _local_correlation( self, z1, z2, maxsize=[5, 5], cor_estimate=[0, 0], refine=None, pmcc=False, normalize=True, ): """Returns the 2D shift in pixels between two images. It looks for a local optimum around the initial shift cor_estimate and within a window 'maxsize'. It uses variance of the difference of the normalized images or PMCC It adapts the shift estimate in case optimum is at the edge of the window If 'maxsize' is set to 0, it will only use approximate shift (+ refine possibly) Set 'cor_estimate' to allow for the use of any initial shift estimation. When not successful (stuck in loop or edge reached), returns [nan nan] Positive values corresponds to moving z2 to higher values of the index to compensate drift: interpolate(f)(z2, row, column) Parameters ---------- z1,z2 : 2D arrays. The two (sub)images to be compared. maxsize : 2-list. Default [5,5] Size of the search window. cor_estimate: Initial guess of the center of rotation. refine: Boolean or None (default is None) Wether the initial guess should be refined of not. pmcc: Boolean (default is False) Use Pearson correlation coefficient i.o. variance. normalize: Boolean (default is True) Set mean of each image to 1 if True. Returns ------- c = [row,column] (or [NaN,NaN] if unsuccessful.) 2007-01-05 P. Cloetens cloetens@esrf.eu * Initial revision 2023-11-10 J. Lesaint jerome.lesaint@esrf.fr * Python conversion. """ if type(maxsize) in (float, int): maxsize = [int(maxsize), int(maxsize)] elif type(maxsize) in (tuple, list): maxsize = [int(maxsize[0]), int(maxsize[1])] elif maxsize in ([], None, ""): maxsize = [5, 5] if refine is None: refine = np.allclose(maxsize, 0.0) if normalize: z1 /= np.mean(z1) z2 /= np.mean(z2) ##################################### # JL : seems useless since func is always called with a first approximate. ## determination of approximative shift (manually or Fourier correlation) # if isinstance(cor_estimate,str): # if cor_estimate in ('fft','auto','fourier'): # padding_mode = None # cor_estimate = self._compute_correlation_fft( # z1, # z2, # padding_mode, # high_pass=self.high_pass, # low_pass=self.low_pass # ) # elif cor_estimate in ('manual','man','m'): # cor_estimate = None # # No ImageJ plugin here : # # rapp = ij_align(z1,z2) #################################### # check if refinement with realspace correlation is required # otherwise keep result as it is if np.allclose(maxsize, 0): shiftfound = 1 if refine: c = np.round(np.array(cor_estimate, dtype=int)) else: c = np.array(cor_estimate, dtype=int) else: shiftfound = 0 cor_estimate = np.round(np.array(cor_estimate, dtype=int)) rapp_hist = [] if np.sum(np.abs(cor_estimate) + 1 >= z1.shape): self.logger.debug(f"Approximate shift of [{cor_estimate[0]},{cor_estimate[1]}] is too large, setting [0 0]") cor_estimate = np.array([0, 0]) maxsize = np.minimum(maxsize, np.floor((np.array(z1.shape) - 1) / 2)).astype(int) maxsize = np.minimum(maxsize, np.array(z1.shape) - np.abs(cor_estimate) - 1).astype(int) while not shiftfound: # Set z1 region # Rationale: the (shift[0]+maxsize[0]:,shift[1]+maxsize[1]:) block of z1 should match # the (maxsize[0]:,maxisze[1]:)-upper-left corner of z2. # We first extract this z1 block. # Then, take moving z2-block according to maxsize. # Of course, care must be taken with borders, hence the various max,min calls. # Extract the reference block shape_ar = np.array(z1.shape) cor_ar = np.array(cor_estimate) maxsize_ar = np.array(maxsize) z1beg = np.maximum(cor_ar + maxsize_ar, np.zeros(2, dtype=int)) z1end = shape_ar + np.minimum(cor_ar - maxsize_ar, np.zeros(2, dtype=int)) z1p = z1[z1beg[0] : z1end[0], z1beg[1] : z1end[1]].flatten() # Build local correlations array. window_shape = (2 * int(maxsize[0]) + 1, 2 * int(maxsize[1]) + 1) cc = np.zeros(window_shape) # Prepare second block indices z2beg = (cor_ar + maxsize_ar > 0) * cc.shape + (cor_ar + maxsize_ar <= 0) * (shape_ar - z1end + z1beg) - 1 z2end = z2beg + z1end - z1beg if pmcc: std_z1p = z1p.std() if normalize == 2: z1p /= z1p.mean() for k in range(cc.shape[0]): for l in range(cc.shape[1]): if pmcc: z2p = z2[z2beg[0] - k : z2end[0] - k, z2beg[1] - l : z2end[1] - l].flatten() std_z2p = z2p.std() cc[k, l] = -np.cov(z1p, z2p, rowvar=True)[1, 0] / (std_z1p * std_z2p) else: if normalize == 2: z2p = z2[z2beg[0] - k : z2end[0] - k, z2beg[1] - l : z2end[1] - l].flatten() z2p /= z2p.mean() z2p -= z1p else: z2p = z2[z2beg[0] - k : z2end[0] - k, z2beg[1] - l : z2end[1] - l].flatten() z2p -= z1p cc[k, l] = ((z2p - z2p.mean()) ** 2).sum() # cc(k,l) = std(z1p./z2(z2beg(1)-k:z2end(1)-k,z2beg(2)-l:z2end(2)-l)(:)); c = np.unravel_index(np.argmin(cc, axis=None), shape=cc.shape) if not np.sum((c == 0) + (c == np.array(cc.shape) - 1)): # check that we are not at the edge of the region that was sampled x = np.array([-1, 0, 1]) tmp = self.refine_max_position_2d(cc[c[0] - 1 : c[0] + 2, c[1] - 1 : c[1] + 2], x, x) c += tmp shiftfound = True c += z1beg - z2beg rapp_hist = [] if not shiftfound: cor_estimate = c # Check that new shift estimate was not already done (avoid eternal loop) if self._checkifpart(cor_estimate, rapp_hist): self.logger.debug("Stuck in loop?") refine = True shiftfound = True c = np.array([np.nan, np.nan]) else: rapp_hist.append(cor_estimate) self.logger.debug(f"Changing shift estimate: {cor_estimate}") maxsize = np.minimum(maxsize, np.array(z1.shape) - np.abs(cor_estimate) - 1).astype(int) if (maxsize == 0).sum(): self.logger.debug("Edge of image reached") refine = False shiftfound = True c = np.array([np.nan, np.nan]) elif len(rapp_hist) > 0: self.logger.debug("\n") #################################### # refine result; useful when shifts are not integer values # JL: I don't understand why this refine step should be useful. # In Octave, from fastomo.m, refine is always set to False. # So this could be ignored. # I keep it for future use if it proves useful. # if refine: # if debug: # print('Refining solution ...') # z2n = self.interpolate(z2,c) # indices = np.ceil(np.abs(c)).astype(int) # z1p = np.roll(z1,((c>0) * (-1) * indices),[0,1]) # z1p = z1p[1:-indices[0]-1,1:-indices[1]-1].flatten() # z2n = np.roll(z2n,((c>0) * (-1) * indices),[0,1]) # z2n = z2n[:-indices[0],:-indices[1]] # ccrefine = np.zeros([3,3]) # [n2,m2] = z2n.shape # for k in range(3): # for l in range(3): # z2p = z1p - z2n[2-k:n2-k,2-l:m2-l].flatten() # ccrefine[k,l] = ((z2p - z2p.mean())**2).sum() # x = np.array([-1,0,1]) # crefine = self.refine_max_position_2d(ccrefine, x, x) # #crefine = min2par(ccrefine) # # Check if the refinement is effectively confined to subpixel # if (np.abs(crefine) >= 1).sum(): # self.logger.info("Problems refining result\n") # else: # c += crefine return c def find_shift( self, img_1, img_2, side="center", roi_yxhw=None, median_filt_shape=None, padding_mode=None, low_pass=0.01, high_pass=None, maxsize=[5, 5], refine=None, pmcc=False, normalize=True, limz=0.5, return_relative_to_middle=None, ): # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationOctaveAccurate.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- self._check_img_pair_sizes(img_1, img_2) img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) cc = self._compute_correlation_fft( img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass, ) # We use fftshift to deal more easily with negative shifts. # This has a cost of subtracting half the image shape afterward. shift = np.unravel_index(np.argmax(np.fft.fftshift(cc)), img_shape) shift -= np.array(img_shape) // 2 # The real "accurate" starts here (i.e. the octave findshift() func). if np.abs(shift[0]) > 10 * limz: # This is suspiscious. We don't trust results of correlate. self.logger.warning("Pre-correlation yields {shift[0]} pixels vertical motion") self.logger.warning("We do not consider it.") shift = (0, 0) # Limit the size of region for comparison to cutsize in both directions. # Hard-coded? cutsize = img_shape[1] // 2 oldshift = np.round(shift).astype(int) if (img_shape[0] > cutsize) or (img_shape[1] > cutsize): im0 = self._cut(img_1, min(img_shape[0], cutsize), min(img_shape[1], cutsize)) im1 = self._cut( np.roll(img_2, oldshift, axis=(0, 1)), min(img_shape[0], cutsize), min(img_shape[1], cutsize) ) shift = oldshift + self._local_correlation( im0, im1, maxsize=maxsize, refine=refine, pmcc=pmcc, normalize=normalize, ) else: shift = self._local_correlation( img_1, img_2, maxsize=maxsize, cor_estimate=oldshift, refine=refine, pmcc=pmcc, normalize=normalize, ) if ((shift - oldshift) ** 2).sum() > 4: self.logger.warning(f"Pre-correlation ({oldshift}) and accurate correlation ({shift}) are not consistent.") self.logger.warning("Please check!!!") offset = shift[1] / 2 if np.abs(shift[0]) > limz: self.logger.debug("Verify alignment or sample motion.") self.logger.debug(f"Verical motion: {shift[0]} pixels.") self.logger.debug(f"Offset?: {offset} pixels.") else: self.logger.debug(f"Offset?: {offset} pixels.") if not (return_relative_to_middle): offset += (img_shape[1] - 1) / 2 return offset # COMPAT. from .cor_sino import CenterOfRotationFourierAngles as CenterOfRotationFourierAngles0 CenterOfRotationFourierAngles = deprecated_class( "CenterOfRotationFourierAngles was moved to nabu.estimation.cor_sino", do_print=True )(CenterOfRotationFourierAngles0) # ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/estimation/cor_sino.py0000644000175000017500000004525414726604214017573 0ustar00pierrepierreimport numpy as np from scipy.signal import convolve2d from scipy.fft import rfft from ..utils import deprecation_warning, is_scalar from ..resources.logger import LoggerOrPrint try: from algotom.prep.calculation import find_center_vo, find_center_360 __have_algotom__ = True except ImportError: __have_algotom__ = False class SinoCor: """ This class has 2 methods: - overlap. Find a rough estimate of COR - accurate. Try to refine COR to 1/10 pixel """ def __init__(self, img_1, img_2, logger=None): """ """ self.logger = LoggerOrPrint(logger) self.sx = img_1.shape[1] # algorithm cannot accept odd number of projs. This is handled in the SinoCORFinder class. nproj2 = img_1.shape[0] # extract upper and lower part of sinogram, flipping H the upper part self.data1 = img_1 self.data2 = img_2 self.rcor_abs = round(self.sx / 2.0) self.cor_acc = round(self.sx / 2.0) # parameters for overlap sino - rough estimation # default sliding ROI is 20% of the width of the detector # the maximum size of ROI in the "right" case is 2*(self.sx - COR) # ex: 2048 pixels, COR= 2000, window_width should not exceed 96! self.window_width = round(self.sx / 5) @staticmethod def schift(mat, val): ker = np.zeros((3, 3)) s = 1.0 if val < 0: s = -1.0 val = s * val ker[1, 1] = 1 - val if s > 0: ker[1, 2] = val else: ker[1, 0] = val mat = convolve2d(mat, ker, mode="same") return mat def overlap(self, side="right", window_width=None): """ Compute COR by minimizing difference of circulating ROI - side: preliminary knowledge if the COR is on right or left - window_width: width of ROI that will slide on the other part of the sinogram by default, 20% of the width of the detector. """ if window_width is None: window_width = self.window_width if not (window_width & 1): window_width -= 1 # number of pixels where the window will "slide". n = self.sx - int(window_width) nr = range(n) dmax = 1000000000.0 imax = 0 # Should we do both right and left and take the minimum "diff" of the 2 ? # windows self.data2 moves over self.data1, measure the width of the histogram and retains the smaller one. if side == "right": for i in nr: imout = self.data1[:, n - i : n - i + window_width] - self.data2[:, 0:window_width] diff = imout.max() - imout.min() if diff < dmax: dmax = diff imax = i self.cor_abs = self.sx - (imax + window_width + 1.0) / 2.0 self.cor_rel = self.sx / 2 - (imax + window_width + 1.0) / 2.0 elif side == "left": for i in nr: imout = self.data1[:, i : i + window_width] - self.data2[:, self.sx - window_width : self.sx] diff = imout.max() - imout.min() if diff < dmax: dmax = diff imax = i self.cor_abs = (imax + window_width - 1.0) / 2 self.cor_rel = self.cor_abs - self.sx / 2.0 - 1 else: raise ValueError(f"Invalid side given ({side}). should be 'left' or 'right'") if imax < 1: self.logger.warning("sliding width %d seems too large!" % window_width) self.rcor_abs = round(self.cor_abs) return self.rcor_abs def accurate(self, neighborhood=7, shift_value=0.1): """ refine the calculation around COR integer pre-calculated value The search will be executed in the defined neighborhood Parameters ----------- neighborhood: int Parameter for accurate calculation in the vicinity of the rough estimate. It must be an odd number. 0.1 pixels float shifts will be performed over this number of pixel """ # define the H-size (odd) of the window one can use to find the best overlap moving finely over ng pixels if not (neighborhood & 1): neighborhood += 1 ng2 = int(neighborhood / 2) # pleft and pright are the number of pixels available on the left and the right of the cor position # to slide a window pleft = self.rcor_abs - ng2 pright = self.sx - self.rcor_abs - ng2 - 1 # the maximum window to slide is restricted by the smaller side if pleft > pright: p_sign = 1 xwin = 2 * (self.sx - self.rcor_abs - ng2) - 1 else: p_sign = -1 xwin = 2 * (self.rcor_abs - ng2) + 1 # Note that xwin is odd xc1 = self.rcor_abs - int(xwin / 2) xc2 = self.sx - self.rcor_abs - int(xwin / 2) - 1 im1 = self.data1[:, xc1 : xc1 + xwin] im2 = self.data2[:, xc2 : xc2 + xwin] pixs = p_sign * (np.arange(neighborhood) - ng2) diff0 = 1000000000.0 isfr = shift_value * np.arange(10) self.cor_acc = self.rcor_abs for pix in pixs: x0 = xc1 + pix for isf in isfr: if isf != 0: ims = self.schift(self.data1[:, x0 : x0 + xwin].copy(), -p_sign * isf) else: ims = self.data1[:, x0 : x0 + xwin] imout = ims - self.data2[:, xc2 : xc2 + xwin] diff = imout.max() - imout.min() if diff < diff0: self.cor_acc = self.rcor_abs + (pix + p_sign * isf) / 2.0 diff0 = diff return self.cor_acc # Aliases estimate_cor_coarse = overlap estimate_cor_fine = accurate class SinoCorInterface: """ A class that mimics the interface of CenterOfRotation, while calling SinoCor """ def __init__(self, logger=None, **kwargs): self._logger = logger def find_shift( self, img_1, img_2, side="right", window_width=None, neighborhood=7, shift_value=0.1, return_relative_to_middle=None, **kwargs, ): # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationCoarseToFine.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- cor_finder = SinoCor(img_1, img_2, logger=self._logger) cor_finder.estimate_cor_coarse(side=side, window_width=window_width) cor = cor_finder.estimate_cor_fine(neighborhood=neighborhood, shift_value=shift_value) # offset will be added later - keep compatibility with result from AlignmentBase.find_shift() if return_relative_to_middle: return cor - (img_1.shape[1] - 1) / 2 else: return cor class CenterOfRotationFourierAngles: """This CoR estimation algo is proposed by V. Valls (BCU). It is based on the Fourier transform of the columns on the sinogram. It requires an initial guesss of the CoR wich is retrieved from dataset_info.dataset_scanner.x_rotation_axis_pixel_position. It is assumed in mm and pixel size in um. Options are (for the moment) hard-coded in the SinoCORFinder.cor_finder.extra_options dict. """ def __init__(self, *args, **kwargs): pass def _convert_from_fft_2_fftpack_format(self, f_signal, o_signal_length): """ Converts a scipy.fft.rfft into the (legacy) scipy.fftpack.rfft format. The fftpack.rfft returns a (roughly) twice as long array as fft.rfft as the latter returns an array of complex numbers wheras the former returns an array with real and imag parts in consecutive spots in the array. Parameters ---------- f_signal : array_like The output of scipy.fft.rfft(signal) o_signal_length : int Size of the original signal (before FT). Returns ------- out The rfft converted to the fftpack.rfft format (roughly twice as long). """ out = np.zeros(o_signal_length, dtype=np.float32) if o_signal_length % 2 == 0: out[0] = f_signal[0].real out[1::2] = f_signal[1:].real out[2::2] = f_signal[1:-1].imag else: out[0] = f_signal[0].real out[1::2] = f_signal[1:].real out[2::2] = f_signal[1:].imag return out def _freq_radio(self, sinos, ifrom, ito): size = (sinos.shape[0] + sinos.shape[0] % 2) // 2 fs = np.empty((size, sinos.shape[1])) for i in range(ifrom, ito): line = sinos[:, i] f_signal = rfft(line) f_signal = self._convert_from_fft_2_fftpack_format(f_signal, line.shape[0]) f = np.abs(f_signal[: (f_signal.size - 1) // 2 + 1]) f2 = np.abs(f_signal[(f_signal.size - 1) // 2 + 1 :][::-1]) if len(f) > len(f2): f[1:] += f2 else: f[0:] += f2 fs[:, i] = f with np.errstate(divide="ignore", invalid="ignore", under="ignore"): fs = np.log(fs) return fs def gaussian(self, p, x): return p[3] + p[2] * np.exp(-((x - p[0]) ** 2) / (2 * p[1] ** 2)) def tukey(self, p, x): pos, std, alpha, height, background = p alpha = np.clip(alpha, 0, 1) pi = np.pi inv_alpha = 1 - alpha width = std / (1 - alpha * 0.5) xx = (np.abs(x - pos) - (width * 0.5 * inv_alpha)) / (width * 0.5 * alpha) xx = np.clip(xx, 0, 1) return (0.5 + np.cos(pi * xx) * 0.5) * height + background def sinlet(self, p, x): std = p[1] * 2.5 lin = np.maximum(0, std - np.abs(p[0] - x)) * 0.5 * np.pi / std return p[3] + p[2] * np.sin(lin) def _px(self, detector_width, abs_pos, near_width, near_std, crop_around_cor, near_step): sym_range = None if abs_pos is not None: if crop_around_cor: sym_range = int(abs_pos - near_std * 2), int(abs_pos + near_std * 2) window = near_width if sym_range is not None: xx_from = max(window, sym_range[0]) xx_to = max(xx_from, min(detector_width - window, sym_range[1])) if xx_from == xx_to: sym_range = None if sym_range is None: xx_from = window xx_to = detector_width - window xx = np.arange(xx_from, xx_to, near_step) return xx def _symmetry_correlation(self, px, array, angles, window, shift_sino): if shift_sino: shift_index = np.argmin(np.abs(angles - np.pi)) - np.argmin(np.abs(angles - 0)) else: shift_index = None px_from = int(px[0]) px_to = int(np.ceil(px[-1])) f_coef = np.empty(len(px)) f_array = self._freq_radio(array, px_from - window, px_to + window) if shift_index is not None: shift_array = np.empty(array.shape, dtype=array.dtype) shift_array[0 : len(shift_array) - shift_index, :] = array[shift_index:, :] shift_array[len(shift_array) - shift_index :, :] = array[:shift_index, :] f_shift_array = self._freq_radio(shift_array, px_from - window, px_to + window) else: f_shift_array = f_array for j, x in enumerate(px): i = int(np.floor(x)) if x - i > 0.4: # TO DO : Specific to near_step = 0.5? f_left = f_array[:, i - window : i] f_right = f_shift_array[:, i + 1 : i + window + 1][:, ::-1] else: f_left = f_array[:, i - window : i] f_right = f_shift_array[:, i : i + window][:, ::-1] with np.errstate(divide="ignore", invalid="ignore"): f_coef[j] = np.sum(np.abs(f_left - f_right)) return f_coef def _cor_correlation(self, px, abs_pos, near_std, signal, near_weight, near_alpha): if abs_pos is not None: if signal == "sinlet": coef = self.sinlet((abs_pos, near_std, -near_weight, 1), px) elif signal == "gaussian": coef = self.gaussian((abs_pos, near_std, -near_weight, 1), px) elif signal == "tukey": coef = self.tukey((abs_pos, near_std * 2, near_alpha, -near_weight, 1), px) else: raise ValueError("Shape unsupported") else: coef = np.ones_like(px) return coef def find_shift( self, sino, angles=None, side="center", near_std=100, near_width=20, shift_sino=True, crop_around_cor=False, signal="tukey", near_weight=0.1, near_alpha=0.5, near_step=0.5, return_relative_to_middle=None, ): detector_width = sino.shape[1] # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationFourierAngles.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- if angles is None: angles = np.linspace(0, 2 * np.pi, sino.shape[0], endpoint=True) increment = np.abs(angles[0] - angles[1]) if np.abs(angles[0] - angles[-1]) < (360 - 0.5) * np.pi / 180 - increment: raise ValueError("Not enough angles, estimator skipped") if is_scalar(side): abs_pos = side # COMPAT. elif side == "near": deprecation_warning( "side='near' is deprecated, please use side=", do_print=True, func_name="fourier_angles_near" ) abs_pos = detector_width // 2 ##. elif side == "center": abs_pos = detector_width // 2 elif side == "left": abs_pos = detector_width // 4 elif side == "right": abs_pos = detector_width * 3 // 4 else: raise ValueError(f"side '{side}' is not handled") px = self._px(detector_width, abs_pos, near_width, near_std, crop_around_cor, near_step) coef_f = self._symmetry_correlation( px, sino, angles, near_width, shift_sino, ) coef_p = self._cor_correlation(px, abs_pos, near_std, signal, near_weight, near_alpha) coef = coef_f * coef_p if len(px) > 0: cor = px[np.argmin(coef)] - (detector_width - 1) / 2 else: # raise ValueError ? cor = None if not (return_relative_to_middle): cor += (detector_width - 1) / 2 return cor __call__ = find_shift class CenterOfRotationVo: """ A wrapper around algotom 'find_center_vo' and 'find_center_360'. Nghia T. Vo, Michael Drakopoulos, Robert C. Atwood, and Christina Reinhard, "Reliable method for calculating the center of rotation in parallel-beam tomography," Opt. Express 22, 19078-19086 (2014) """ default_extra_options = {} def __init__(self, logger=None, verbose=False, extra_options=None): if not (__have_algotom__): raise ImportError("Need the 'algotom' package") self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) def find_shift( self, sino, halftomo=False, is_360=False, win_width=100, side="center", search_width_fraction=0.1, step=0.25, radius=4, ratio=0.5, dsp=True, ncore=None, hor_drop=None, ver_drop=None, denoise=True, norm=True, use_overlap=False, return_relative_to_middle=None, ): # COMPAT. if return_relative_to_middle is None: deprecation_warning( "The current default behavior is to return the shift relative the the middle of the image. In a future release, this function will return the shift relative to the left-most pixel. To keep the current behavior, please use 'return_relative_to_middle=True'.", do_print=True, func_name="CenterOfRotationVo.find_shift", ) return_relative_to_middle = True # the kwarg above will be False by default in a future release # --- if halftomo: side_algotom = {"left": 0, "right": 1}.get(side, None) cor, _, _, _ = find_center_360( sino, win_width, side=side_algotom, denoise=denoise, norm=norm, use_overlap=use_overlap, ncore=ncore ) else: if is_360 and not (halftomo): # Take only one part of the sinogram and use "find_center_vo" - this works better in this case sino = sino[: sino.shape[0] // 2] sino_width = sino.shape[-1] search_width = int(search_width_fraction * sino_width) if side == "left": start, stop = 0, search_width elif side == "center": start, stop = sino_width // 2 - search_width, sino_width // 2 + search_width elif side == "right": start, stop = sino_width - search_width, sino_width elif is_scalar(side): # side is passed as an offset from the middle of detector side = side + (sino.shape[-1] - 1) / 2.0 start, stop = max(0, side - search_width), min(sino_width, side + search_width) else: raise ValueError("Expected 'side' to be 'left', 'center', 'right' or a scalar") cor = find_center_vo( sino, start=start, stop=stop, step=step, radius=radius, ratio=ratio, dsp=dsp, ncore=ncore, hor_drop=hor_drop, ver_drop=ver_drop, ) return cor if not (return_relative_to_middle) else cor - (sino.shape[1] - 1) / 2 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/estimation/distortion.py0000644000175000017500000001113514402565210020136 0ustar00pierrepierreimport numpy as np import scipy.interpolate from .translation import DetectorTranslationAlongBeam from ..misc.filters import correct_spikes from ..resources.logger import LoggerOrPrint def estimate_flat_distortion( flat, image, tile_size=100, interpolation_kind="linear", padding_mode="edge", correction_spike_threshold=None, logger=None, ): """ Estimate the wavefront distortion on a flat image, from another image. Parameters ---------- flat: np.array The flat-field image to be corrected image: np.ndarray The image to correlate the flat against. tile_size: int The wavefront corrections are calculated by correlating the image to the flat, region by region. The regions are tiles of size tile_size interpolation_kind: "linear" or "cubic" The interpolation method used for interpolation padding_mode: string Padding mode. Must be valid for np.pad when wavefront correction is applied, the corrections are first found for the tiles, which gives the shift at the center of each tiled. Then, to interpolate the corrections, at the positions f every pixel, on must add also the border of the extremal tiles. This is done by padding with a width of 1, and using the mode given 'padding_mode'. correction_spike_threshold: float, optional By default it is None and no spike correction is performed on the shifts grid which is found by correlation. If set to a float, a spike removal will be applied using such threshold Returns -------- coordinates: np.ndarray An array having dimensions (flat.shape[0], flat.shape[1], 2) where each coordinates[i,j] contains the coordinates of the position in the image "flat" which correlates to the pixel (i,j) in the image "im2". """ logger = LoggerOrPrint(logger) starts_r = np.array(range(0, image.shape[0] - tile_size, tile_size)) starts_c = np.array(range(0, image.shape[1] - tile_size, tile_size)) cor1 = np.zeros([len(starts_r), len(starts_c)], np.float32) cor2 = np.zeros([len(starts_r), len(starts_c)], np.float32) shift_finder = DetectorTranslationAlongBeam() for ir, r in enumerate(starts_r): for ic, c in enumerate(starts_c): try: coeff_v, coeff_h, shifts_vh_per_img = shift_finder.find_shift( np.array([image[r : r + tile_size, c : c + tile_size], flat[r : r + tile_size, c : c + tile_size]]), np.array([0, 1]), return_shifts=True, low_pass=(1.0, 0.3), high_pass=(tile_size, tile_size * 0.3), ) cor1[ir, ic], cor2[ir, ic] = shifts_vh_per_img[1] except ValueError as e: if "positions are outside" in str(e): logger.debug(str(e)) cor1[ir, ic], cor2[ir, ic] = (0, 0) else: raise cor1[np.isnan(cor1)] = 0 cor2[np.isnan(cor2)] = 0 if correction_spike_threshold is not None: cor1 = correct_spikes(cor1, correction_spike_threshold) cor2 = correct_spikes(cor2, correction_spike_threshold) # TODO implement the previous spikes correction in CCDCorrection - median_clip # spikes_corrector = CCDCorrection(cor1.shape, median_clip_thresh=3, abs_diff=True, preserve_borders=True) # cor1 = spikes_corrector.median_clip_correction(cor1) # cor2 = spikes_corrector.median_clip_correction(cor2) cor1 = np.pad(cor1, ((1, 1), (1, 1)), mode=padding_mode) cor2 = np.pad(cor2, ((1, 1), (1, 1)), mode=padding_mode) hp = np.concatenate([[0.0], starts_c + tile_size * 0.5, [image.shape[1]]]) vp = np.concatenate([[0.0], starts_r + tile_size * 0.5, [image.shape[0]]]) h_ticks = np.arange(image.shape[1]).astype(np.float32) v_ticks = np.arange(image.shape[0]).astype(np.float32) spline_degree = {"linear": 1, "cubic": 3}[interpolation_kind] interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor1, kx=spline_degree, ky=spline_degree) cor1 = interpolator(h_ticks, v_ticks) interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor2, kx=spline_degree, ky=spline_degree) cor2 = interpolator(h_ticks, v_ticks) hh = np.arange(image.shape[1]).astype(np.float32) vv = np.arange(image.shape[0]).astype(np.float32) unshifted_v, unshifted_h = np.meshgrid(vv, hh, indexing="ij") shifted_v = unshifted_v - cor1 shifted_h = unshifted_h - cor2 coordinates = np.transpose(np.array([shifted_v, shifted_h]), axes=[1, 2, 0]) return coordinates ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/estimation/focus.py0000644000175000017500000004315514726604214017075 0ustar00pierrepierreimport numpy as np from scipy.fft import fftn from ..processing.azim import azimuthal_integration_skimage_stack, azimuthal_integration_imagej_stack, __have_skimage__ from .alignment import plt from .cor import CenterOfRotation class CameraFocus(CenterOfRotation): def _check_position_jitter(self, img_pos): pos_diff = np.diff(img_pos) if np.any(pos_diff <= 0): self.logger.warning( "Image position regressed throughout scan! (negative movement for some image positions)" ) @staticmethod def _gradient(x, axes): d = [None] * len(axes) for ii in range(len(axes)): ind = -(ii + 1) padding = [(0, 0)] * len(x.shape) padding[ind] = (0, 1) temp_x = np.pad(x, padding, mode="constant") d[ind] = np.diff(temp_x, n=1, axis=ind) return np.stack(d, axis=0) @staticmethod def _compute_metric_value(data, metric, axes=(-2, -1)): if metric.lower() == "std": return np.std(data, axis=axes) / np.mean(data, axis=axes) elif metric.lower() == "grad": grad_data = CameraFocus._gradient(data, axes=axes) grad_mag = np.sqrt(np.sum(grad_data**2, axis=0)) return np.sum(grad_mag, axis=axes) elif metric.lower() == "psd": f_data = fftn(data, axes=axes, workers=4) f_data = np.fft.fftshift(f_data, axes=(-2, -1)) # octave-fasttomo3 uses |.|^2, probably with scaled FFT (norm="forward" in python), # but tests show that it's less accurate. f_data = np.abs(f_data) ai_func = azimuthal_integration_skimage_stack if __have_skimage__ else azimuthal_integration_imagej_stack az_data = ai_func(f_data, n_threads=4) max_vals = np.max(az_data, axis=0) az_data /= max_vals[None, :] return np.mean(az_data, axis=-1) else: raise ValueError("Unknown metric function %s" % metric) def find_distance( self, img_stack: np.ndarray, img_pos: np.array, metric="std", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Find the focal distance of the camera system. This routine computes the motor position that corresponds to having the scintillator on the focal plain of the camera system. Parameters ---------- img_stack: numpy.ndarray A stack of images at different distances. img_pos: numpy.ndarray Position of the images along the translation axis metric: string, optional The property, whose maximize occurs at the focal position. Defaults to 'std' (standard deviation). All options are: 'std' | 'grad' | 'psd' roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. Returns ------- focus_pos: float Estimated position of the focal plane of the camera system. focus_ind: float Image index of the estimated position of the focal plane of the camera system (starting from 1!). Examples -------- Given the focal stack associated to multiple positions of the camera focus motor called `img_stack`, and the associated positions `img_pos`, the following code computes the highest focus position: >>> focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind = focus_calc.find_distance(img_stack, img_pos) where `focus_pos` is the corresponding motor position, and `focus_ind` is the associated image position (starting from 1). """ self._check_img_stack_size(img_stack, img_pos) self._check_position_jitter(img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_stack = self._prepare_image( img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, low_pass=low_pass, high_pass=high_pass, ) img_resp = self._compute_metric_value(img_stack, metric=metric, axes=(-2, -1)) # assuming images are equispaced! # focus_step = np.mean(np.abs(np.diff(img_pos))) focus_step = (img_pos[-1] - img_pos[0]) / (num_imgs - 1) img_inds = np.arange(num_imgs) (f_vals, f_pos) = self.extract_peak_regions_1d(img_resp, peak_radius=peak_fit_radius, cc_coords=img_inds) focus_ind, img_resp_max = self.refine_max_position_1d(f_vals, return_vertex_val=True, return_all_coeffs=True) focus_ind += f_pos[1, :] focus_pos = img_pos[0] + focus_step * focus_ind focus_ind += 1 if focus_pos.size == 1: focus_pos = focus_pos[0] if focus_ind.size == 1: focus_ind = focus_ind[0] if self.verbose: self.logger.info( "Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind, ) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(img_pos, img_resp) ax.stem(focus_pos, img_resp_max, linefmt="C1-", markerfmt="C1o") ax.set_title("Images response (metric: %s)" % metric) plt.show(block=False) return focus_pos, focus_ind def _check_img_block_size(self, img_shape, regions_number, suggest_new_shape=True): img_shape = np.array(img_shape) new_shape = img_shape if not len(img_shape) == 2: raise ValueError( "Images need to be square 2-dimensional and with shape multiple of the number of assigned regions.\n" " Image shape: %s, regions number: %d" % (img_shape, regions_number) ) if not (img_shape[0] == img_shape[1] and np.all((np.array(img_shape) % regions_number) == 0)): new_shape = (img_shape // regions_number) * regions_number new_shape = np.fmin(new_shape, new_shape.min()) message = ( "Images need to be square 2-dimensional and with shape multiple of the number of assigned regions.\n" " Image shape: %s, regions number: %d. Cropping to image shape: %s" % (img_shape, regions_number, new_shape) ) if suggest_new_shape: self.logger.info(message) else: raise ValueError(message) return new_shape @staticmethod def _fit_plane(f_vals): f_vals_half_shape = (np.array(f_vals.shape) - 1) / 2 fy = np.linspace(-f_vals_half_shape[-2], f_vals_half_shape[-2], f_vals.shape[-2]) fx = np.linspace(-f_vals_half_shape[-1], f_vals_half_shape[-1], f_vals.shape[-1]) fy, fx = np.meshgrid(fy, fx, indexing="ij") coords = np.array([np.ones(f_vals.size), fy.flatten(), fx.flatten()]) return np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0], fy, fx def find_scintillator_tilt( self, img_stack: np.ndarray, img_pos: np.array, regions_number=4, metric="std", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Finds the scintillator tilt and focal distance of the camera system. This routine computes the mounting tilt of the scintillator and the motor position that corresponds to having the scintillator on the focal plain of the camera system. The input is supposed to be a stack of square images, whose sizes are multiples of the `regions_number` parameter. If images with a different size are passed, this function will crop the images. This also generates a warning. To suppress the warning, it is suggested to specify a ROI that satisfies those criteria (see examples). The computed tilts `tilt_vh` are in unit-length per pixel-size. To obtain the tilts it is necessary to divide by the pixel-size: >>> tilt_vh_deg = np.rad2deg(np.arctan(tilt_vh / pixel_size)) The correction to be applied is: >>> tilt_corr_vh_deg = - np.rad2deg(np.arctan(tilt_vh / pixel_size)) The legacy octave macros computed the approximation of these values in radians: >>> tilt_corr_vh_rad = - tilt_vh / pixel_size Note that `pixel_size` should be in the same unit scale as `img_pos`. Parameters ---------- img_stack: numpy.ndarray A stack of images at different distances. img_pos: numpy.ndarray Position of the images along the translation axis regions_number: int, optional The number of regions to subdivide the image into, along each direction. Defaults to 4. metric: string, optional The property, whose maximize occurs at the focal position. Defaults to 'std' (standard deviation). All options are: 'std' | 'grad' | 'psd' roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> auto-suggest correct size. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. Returns ------- focus_pos: float Estimated position of the focal plane of the camera system. focus_ind: float Image index of the estimated position of the focal plane of the camera system (starting from 1!). tilts_vh: tuple(float, float) Estimated scintillator tilts in the vertical and horizontal direction respectively per unit-length per pixel-size. Examples -------- Given the focal stack associated to multiple positions of the camera focus motor called `img_stack`, and the associated positions `img_pos`, the following code computes the highest focus position: >>> focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt(img_stack, img_pos) ... tilt_corr_vh_deg = - np.rad2deg(np.arctan(tilt_vh / pixel_size)) or to keep compatibility with the old octave macros: >>> tilt_corr_vh_rad = - tilt_vh / pixel_size For non square images, or images with sizes that are not multiples of the `regions_number` parameter, and no ROI is being passed, this function will try to crop the image stack to the correct size. If you want to remove the warning message, it is suggested to set a ROI like the following: >>> regions_number = 4 ... img_roi = (np.array(img_stack.shape[1:]) // regions_number) * regions_number ... img_roi = np.fmin(img_roi, img_roi.min()) ... focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt( ... img_stack, img_pos, roi_yxhw=img_roi, regions_number=regions_number) """ self._check_img_stack_size(img_stack, img_pos) self._check_position_jitter(img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] if roi_yxhw is None: # If no roi is being passed, we try to crop the images to the # correct size, if needed roi_yxhw = self._check_img_block_size(img_shape, regions_number, suggest_new_shape=True) roi_yxhw = self._determine_roi(img_shape, roi_yxhw) else: # If a roi is being passed, and the images don't have the correct # shape, we raise an error roi_yxhw = self._determine_roi(img_shape, roi_yxhw) self._check_img_block_size(roi_yxhw[2:], regions_number, suggest_new_shape=False) img_stack = self._prepare_image( img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, low_pass=low_pass, high_pass=high_pass, ) img_shape = img_stack.shape[-2:] block_size = np.array(img_shape) / regions_number block_stack_size = np.array( [num_imgs, regions_number, block_size[-2], regions_number, block_size[-1]], dtype=np.intp, ) img_stack = np.reshape(img_stack, block_stack_size) img_resp = self._compute_metric_value(img_stack, metric=metric, axes=(-3, -1)) img_resp = np.reshape(img_resp, [num_imgs, -1]).transpose() # assuming images are equispaced focus_step = (img_pos[-1] - img_pos[0]) / (num_imgs - 1) img_inds = np.arange(num_imgs) (f_vals, f_pos) = self.extract_peak_regions_1d(img_resp, peak_radius=peak_fit_radius, cc_coords=img_inds) focus_inds = self.refine_max_position_1d(f_vals, return_all_coeffs=True) focus_inds += f_pos[1, :] focus_poss = img_pos[0] + focus_step * focus_inds # Fitting the plane focus_poss = np.reshape(focus_poss, [regions_number, regions_number]) coeffs, fy, fx = self._fit_plane(focus_poss) focus_pos, tg_v, tg_h = coeffs # The angular coefficient along x is the tilt around the y axis and vice-versa tilts_vh = np.array([tg_h, tg_v]) / block_size focus_ind = np.mean(focus_inds) + 1 if self.verbose: self.logger.info( "Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind, ) self.logger.info("Fitted tilts (to be divided by pixel size, and converted to deg): (v, h) %s" % tilts_vh) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") self._add_plot_window(fig, ax=ax) ax.plot_wireframe(fx, fy, focus_poss) regions_half_shape = (regions_number - 1) / 2 base_points = np.linspace(-regions_half_shape, regions_half_shape, regions_number) ax.plot( np.zeros((regions_number,)), base_points, np.polyval([tg_v, focus_pos], base_points), "C2", ) ax.plot( base_points, np.zeros((regions_number,)), np.polyval([tg_h, focus_pos], base_points), "C2", ) ax.scatter(0, 0, focus_pos, marker="o", c="C1") ax.set_title("Images std") plt.show(block=False) return focus_pos, focus_ind, tilts_vh ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/estimation/tests/0000755000175000017500000000000014730277752016546 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/estimation/tests/__init__.py0000644000175000017500000000000014315516747020644 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/estimation/tests/test_alignment.py0000644000175000017500000000467314402565210022130 0ustar00pierrepierreimport numpy as np import pytest from nabu.estimation.alignment import AlignmentBase @pytest.fixture(scope="class") def bootstrap_base(request): cls = request.cls cls.abs_tol = 2.5e-2 @pytest.mark.usefixtures("bootstrap_base") class TestAlignmentBase: def test_peak_fitting_2d_3x3(self): # Fit a 3 x 3 grid fy = np.linspace(-1, 1, 3) fx = np.linspace(-1, 1, 3) yy, xx = np.meshgrid(fy, fx, indexing="ij") peak_pos_yx = np.random.rand(2) * 1.6 - 0.8 f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100) fitted_peak_pos_yx = AlignmentBase.refine_max_position_2d(f_vals, fy, fx) message = ( "Computed peak position: (%f, %f) " % (*fitted_peak_pos_yx,) + " and real peak position (%f, %f) do not coincide." % (*peak_pos_yx,) + " Difference: (%f, %f)," % (*(fitted_peak_pos_yx - peak_pos_yx),) + " tolerance: %f" % self.abs_tol ) assert np.all(np.isclose(peak_pos_yx, fitted_peak_pos_yx, atol=self.abs_tol)), message def test_peak_fitting_2d_error_checking(self): # Fit a 3 x 3 grid fy = np.linspace(-1, 1, 3) fx = np.linspace(-1, 1, 3) yy, xx = np.meshgrid(fy, fx, indexing="ij") peak_pos_yx = np.random.rand(2) + 1.5 f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100) with pytest.raises(ValueError) as ex: AlignmentBase.refine_max_position_2d(f_vals, fy, fx) message = ( "Error should have been raised about the peak being fitted outside margins, " + "other error raised instead:\n%s" % str(ex.value) ) assert "positions are outside the input margins" in str(ex.value), message def test_extract_peak_regions_1d(self): img = np.random.randint(0, 10, size=(8, 8)) peaks_pos = np.argmax(img, axis=-1) peaks_val = np.max(img, axis=-1) cc_coords = np.arange(0, 8) ( found_peaks_val, found_peaks_pos, ) = AlignmentBase.extract_peak_regions_1d(img, axis=-1, cc_coords=cc_coords) message = ( "The found peak positions do not correspond to the expected peak positions:\n Expected: %s\n Found: %s" % ( peaks_pos, found_peaks_pos[1, :], ) ) assert np.all(peaks_val == found_peaks_val[1, :]), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/estimation/tests/test_cor.py0000644000175000017500000005640314726604214020742 0ustar00pierrepierreimport os import numpy as np import pytest import scipy.ndimage import h5py from nabu.testutils import utilstest, __do_long_tests__ from nabu.testutils import get_data from nabu.estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationGrowingWindow, CenterOfRotationSlidingWindow, CenterOfRotationOctaveAccurate, ) from nabu.estimation.cor_sino import SinoCor, CenterOfRotationFourierAngles, CenterOfRotationVo, __have_algotom__ @pytest.fixture(scope="class") def bootstrap_cor(request): cls = request.cls cls.abs_tol = 0.2 cls.data, calib_data = get_cor_data_h5("test_alignment_cor.h5") cls.cor_gl_pix, cls.cor_hl_pix, cls.tilt_deg = calib_data @pytest.fixture(scope="class") def bootstrap_cor_win(request): cls = request.cls cls.abs_tol = 0.2 cls.data_ha_proj, cls.cor_ha_pr_pix = get_cor_win_proj_data_h5("ha_autocor_radios.npz") cls.data_ha_sino, cls.cor_ha_sn_pix = get_cor_win_sino_data_h5("halftomo_1_sino.npz") @pytest.fixture(scope="class") def bootstrap_cor_accurate(request): cls = request.cls cls.abs_tol = 0.2 cls.image_pair_stylo, cls.cor_pos_abs_stylo = get_cor_win_proj_data_h5("stylo_accurate.npz") cls.image_pair_blc12781, cls.cor_pos_abs_blc12781 = get_cor_win_proj_data_h5("blc12781_accurate.npz") @pytest.fixture(scope="class") def bootstrap_cor_fourier(request): cls = request.cls cls.abs_tol = 0.2 dataset_relpath = os.path.join("sino_bamboo_hercules_for_test.npz") dataset_downloaded_path = utilstest.getfile(dataset_relpath) a = np.load(dataset_downloaded_path) cls.sinos = a["sinos"] cls.angles = a["angles"] cls.true_cor = a["true_cor"] def get_cor_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] cor_global_pix = hf["/calibration/alignment/global/x_rotation_axis_pixel_position"][()][0] cor_highlow_pix = hf["/calibration/alignment/highlow/x_rotation_axis_pixel_position"][()][0] tilt_deg = hf["/calibration/alignment/highlow/z_camera_tilt"][()][0] return data, (cor_global_pix, cor_highlow_pix, tilt_deg) def get_cor_win_proj_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) data = np.load(dataset_downloaded_path) radios = np.stack((data["radio1"], data["radio2"]), axis=0) return radios, data["cor_pos"] def get_cor_win_sino_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) data = np.load(dataset_downloaded_path) sino_shape = data["sino"].shape sinos = np.stack((data["sino"][: sino_shape[0] // 2], data["sino"][sino_shape[0] // 2 :]), axis=0) return sinos, data["cor"] - sino_shape[1] / 2 @pytest.mark.usefixtures("bootstrap_cor") class TestCor: def test_cor_posx(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message # testing again with the validity return value cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, return_validity=True, return_relative_to_middle=True ) assert np.isscalar(cor_position) message = ( "returned result_validity is %s " % result_validity + " while it should be unknown because the validity check is not yet implemented" ) assert result_validity == "unknown", message def test_noisy_cor_posx(self): radio1 = np.fmax(self.data[0, :, :], 0) radio2 = np.fmax(self.data[1, :, :], 0) radio1 = np.random.poisson(radio1 * 400) radio2 = np.random.poisson(np.fliplr(radio2) * 400) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3), return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isscalar(cor_position) assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_noisyHF_cor_posx(self): """test with noise at high frequencies""" radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) noise_level = radio1.max() / 16.0 noise_ima1 = np.random.normal(0.0, size=radio1.shape) * noise_level noise_ima2 = np.random.normal(0.0, size=radio2.shape) * noise_level noise_ima1 = noise_ima1 - scipy.ndimage.gaussian_filter(noise_ima1, 2.0) noise_ima2 = noise_ima2 - scipy.ndimage.gaussian_filter(noise_ima2, 2.0) radio1 = radio1 + noise_ima1 radio2 = radio2 + noise_ima2 CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, low_pass=(6.0, 0.3), return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_half_tomo_cor_exp(self): """test the half_tomo algorithm on experimental data""" radios = get_data("ha_autocor_radios.npz") radio1 = radios["radio1"] radio2 = radios["radio2"] cor_pos = radios["cor_pos"] radio2 = np.fliplr(radio2) CoR_calc = CenterOfRotationAdaptiveSearch() cor_position = CoR_calc.find_shift( radio1, radio2, low_pass=1, high_pass=20, filtered_cost=True, return_relative_to_middle=True ) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = ( "Computed CoR %f " % cor_position + " and real CoR %f should coincide when using the halftomo algorithm with half tomo data" % cor_pos ) assert np.isclose(cor_pos, cor_position, atol=self.abs_tol + 0.5), message @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_half_tomo_cor_exp_limited(self): """test the hal_tomo algorithm on experimental data and global search with limits""" radios = get_data("ha_autocor_radios.npz") radio1 = radios["radio1"] radio2 = radios["radio2"] cor_pos = radios["cor_pos"] radio2 = np.fliplr(radio2) CoR_calc = CenterOfRotationAdaptiveSearch() cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, low_pass=1, high_pass=20, margins=(100, 10), filtered_cost=False, return_validity=True, return_relative_to_middle=True, ) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = ( "Computed CoR %f " % cor_position + " and real CoR %f should coincide when using the halftomo algorithm with half tomo data" % cor_pos ) assert np.isclose(cor_pos, cor_position, atol=self.abs_tol + 0.5), message message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_cor_posx_linear(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, padding_mode="edge", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_error_checking_001(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :1:] radio2 = self.data[1, :, :] with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2, return_relative_to_middle=True) message = "Error should have been raised about img #1 shape, other error raised instead:\n%s" % str(ex.value) assert "Images need to be 2-dimensional. Shape of image #1" in str(ex.value), message def test_error_checking_002(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :] radio2 = self.data with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2, return_relative_to_middle=True) message = "Error should have been raised about img #2 shape, other error raised instead:\n%s" % str(ex.value) assert "Images need to be 2-dimensional. Shape of image #2" in str(ex.value), message def test_error_checking_003(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :] radio2 = self.data[1, :, 0:10] with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2, return_relative_to_middle=True) message = ( "Error should have been raised about different image shapes, " + "other error raised instead:\n%s" % str(ex.value) ) assert "Images need to be of the same shape" in str(ex.value), message @pytest.mark.skipif(not (__do_long_tests__), reason="Need NABU_LONG_TESTS=1 for this test") @pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win") class TestCorWindowSlide: def test_proj_center_axis_lft(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift( radio1, radio2, side="left", window_width=round(radio1.shape[-1] / 4.0 * 3.0), return_relative_to_middle=True, ) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, side="left", window_width=round(radio1.shape[-1] / 4.0 * 3.0), return_validity=True, return_relative_to_middle=True, ) message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_proj_center_axis_cen(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="center", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_proj_right_axis_rgt(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="right", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_proj_left_axis_lft(self): radio1 = np.fliplr(self.data_ha_proj[0, :, :]) radio2 = self.data_ha_proj[1, :, :] CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="left", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_sino_right_axis_rgt(self): sino1 = self.data_ha_sino[0, :, :] sino2 = np.fliplr(self.data_ha_sino[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(sino1, sino2, side="right", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 5), message @pytest.mark.skipif(not (__do_long_tests__), reason="need NABU_LONG_TESTS for this test") @pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win") class TestCorWindowGrow: def test_proj_center_axis_cen(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="center", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_proj_right_axis_rgt(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="right", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_proj_left_axis_lft(self): radio1 = np.fliplr(self.data_ha_proj[0, :, :]) radio2 = self.data_ha_proj[1, :, :] CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="left", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, side="left", return_validity=True, return_relative_to_middle=True ) message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_proj_right_axis_all(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="all", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_sino_right_axis_rgt(self): sino1 = self.data_ha_sino[0, :, :] sino2 = np.fliplr(self.data_ha_sino[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(sino1, sino2, side="right", return_relative_to_middle=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 4), message @pytest.mark.usefixtures("bootstrap_cor_win") class TestCoarseToFineSinoCor: def test_coarse_to_fine(self): """ Test nabu.estimation.cor_sino.SinoCor """ sino_halftomo = np.vstack([self.data_ha_sino[0], self.data_ha_sino[1]]) sino_cor = SinoCor(self.data_ha_sino[0], np.fliplr(self.data_ha_sino[1])) cor_coarse = sino_cor.estimate_cor_coarse() assert np.isscalar(cor_coarse), f"cor_position expected to be a scalar, {type(cor_coarse)} returned" cor_fine = sino_cor.estimate_cor_fine() assert np.isscalar(cor_fine), f"cor_position expected to be a scale, {type(cor_fine)} returned" cor_ref = self.cor_ha_sn_pix + sino_halftomo.shape[-1] / 2.0 message = "Computed CoR %f " % cor_fine + " and expected CoR %f do not coincide" % cor_ref assert abs(cor_fine - cor_ref) < self.abs_tol * 2, message @pytest.mark.usefixtures("bootstrap_cor_accurate") class TestCorOctaveAccurate: def test_cor_accurate_positive_shift(self): detector_width = self.image_pair_stylo[0].shape[1] CoR_calc = CenterOfRotationOctaveAccurate() cor_position = CoR_calc.find_shift( self.image_pair_stylo[0], np.fliplr(self.image_pair_stylo[1]), "center", return_relative_to_middle=True ) cor_position = cor_position + detector_width / 2 assert np.isscalar(cor_position), f"cor_position expected to be a scalar, {type(cor_position)} returned" message = f"Computed CoR {cor_position} and expected CoR {self.cor_pos_abs_stylo} do not coincide." assert np.isclose(self.cor_pos_abs_stylo, cor_position, atol=self.abs_tol), message def test_cor_accurate_negative_shift(self): detector_width = self.image_pair_blc12781[0].shape[1] CoR_calc = CenterOfRotationOctaveAccurate() cor_position = CoR_calc.find_shift( self.image_pair_blc12781[0], np.fliplr(self.image_pair_blc12781[1]), "center", return_relative_to_middle=True, ) cor_position = cor_position + detector_width / 2 assert np.isscalar(cor_position), f"cor_position expected to be a scalar, {type(cor_position)} returned" message = f"Computed CoR {cor_position} and expected CoR {self.cor_pos_abs_blc12781} do not coincide." assert np.isclose(self.cor_pos_abs_blc12781, cor_position, atol=self.abs_tol), message @pytest.mark.usefixtures("bootstrap_cor_fourier", "bootstrap_cor_win") class TestCorFourierAngle: @pytest.mark.skip("Broken function") def test_sino_right_axis_with_near_pos(self): sino = np.vstack([self.data_ha_sino[0], self.data_ha_sino[1]]) start_angle = np.pi / 4 angles = np.linspace(start_angle, start_angle + 2 * np.pi, sino.shape[0]) CoR_calc = CenterOfRotationFourierAngles() cor_position = CoR_calc.find_shift( sino, angles, side="right", crop_around_cor=True, return_relative_to_middle=True ) # side=sino.shape[1]/2+740) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 3), message def test_sino_right_axis_with_near_pos_jl(self): CoR_calc = CenterOfRotationFourierAngles() cor_position = CoR_calc.find_shift( self.sinos, self.angles, side="right", crop_around_cor=True, return_relative_to_middle=True ) # side=sino.shape[1]/2+740) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.true_cor assert np.isclose(self.true_cor, cor_position, atol=self.abs_tol * 3), message @pytest.fixture(scope="class") def bootstrap_vo_cor(request): cls = request.cls cls.tol = 1e-2 cls.test_sinograms = {name: get_data("sino_%s.npz" % name) for name in ["pencil", "coffee", "mousebrains"]} sino_bamboo = get_data("sino_bamboo_hercules_for_test.npz") cls.test_sinograms["bamboo_hercules"] = { "data": sino_bamboo["sinos"], # FIXME the test file needs to be re-generated, "true_cor" has an incorrect offset "cor": sino_bamboo["true_cor"] + (2560) / 2, } @pytest.mark.skipif(not (__have_algotom__), reason="need algotom for this test") @pytest.mark.usefixtures("bootstrap_vo_cor") class TestVoCOR: def _test_cor(self, dataset_name, tolerance=1e-2, **cor_options): cor_finder = CenterOfRotationVo() cor = cor_finder.find_shift( self.test_sinograms[dataset_name]["data"], return_relative_to_middle=False, **cor_options ) cor_ref = self.test_sinograms[dataset_name]["cor"] assert ( np.abs(cor - cor_ref) < tolerance ), "CoR estimation failed for %s: expected %.3f, got %.3f (tol = %.2e)" % ( dataset_name, cor_ref, cor, tolerance, ) def test_cor_180(self): self._test_cor("pencil", tolerance=0.6) def test_cor_180_more_complex(self): ... def test_cor_360_halftomo(self): self._test_cor("bamboo_hercules", tolerance=0.1, halftomo=True) def test_cor_360_halftomo_hard(self): # This one is difficult self._test_cor("mousebrains", tolerance=2, halftomo=True) def test_cor_360_not_halftomo(self): self._test_cor("coffee", tolerance=0.5, halftomo=False, is_360=True) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/estimation/tests/test_focus.py0000644000175000017500000001013414726604214021265 0ustar00pierrepierreimport os import numpy as np import pytest import h5py from nabu.testutils import utilstest, __do_long_tests__ from nabu.estimation.focus import CameraFocus @pytest.fixture(scope="class") def bootstrap_fcs(request): cls = request.cls cls.abs_tol_dist = 1e-2 cls.abs_tol_tilt = 2.5e-4 ( cls.data, cls.img_pos, cls.pixel_size, (calib_data_std, calib_data_angle), ) = get_focus_data("test_alignment_focus.h5") ( cls.angle_best_ind, cls.angle_best_pos, cls.angle_tilt_v, cls.angle_tilt_h, ) = calib_data_angle cls.std_best_ind, cls.std_best_pos = calib_data_std def get_focus_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] img_pos = hf["/entry/instrument/detector/distance"][()] pixel_size = np.mean( [ hf["/entry/instrument/detector/x_pixel_size"][()], hf["/entry/instrument/detector/y_pixel_size"][()], ] ) angle_best_ind = hf["/calibration/focus/angle/best_img"][()][0] angle_best_pos = hf["/calibration/focus/angle/best_pos"][()][0] angle_tilt_v = hf["/calibration/focus/angle/tilt_v_rad"][()][0] angle_tilt_h = hf["/calibration/focus/angle/tilt_h_rad"][()][0] std_best_ind = hf["/calibration/focus/std/best_img"][()][0] std_best_pos = hf["/calibration/focus/std/best_pos"][()][0] calib_data_angle = (angle_best_ind, angle_best_pos, angle_tilt_v, angle_tilt_h) calib_data_std = (std_best_ind, std_best_pos) return data, img_pos, pixel_size, (calib_data_std, calib_data_angle) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") @pytest.mark.usefixtures("bootstrap_fcs") class TestFocus: def test_find_distance(self): focus_calc = CameraFocus() focus_pos, focus_ind = focus_calc.find_distance(self.data, self.img_pos) message = ( "Computed focus motor position %f " % focus_pos + " and expected %f do not coincide" % self.std_best_pos ) assert np.isclose(self.std_best_pos, focus_pos, atol=self.abs_tol_dist), message message = "Computed focus image index %f " % focus_ind + " and expected %f do not coincide" % self.std_best_ind assert np.isclose(self.std_best_ind, focus_ind, atol=self.abs_tol_dist), message def test_find_scintillator_tilt(self): focus_calc = CameraFocus() focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt(self.data, self.img_pos) message = ( "Computed focus motor position %f " % focus_pos + " and expected %f do not coincide" % self.angle_best_pos ) assert np.isclose(self.angle_best_pos, focus_pos, atol=self.abs_tol_dist), message message = ( "Computed focus image index %f " % focus_ind + " and expected %f do not coincide" % self.angle_best_ind ) assert np.isclose(self.angle_best_ind, focus_ind, atol=self.abs_tol_dist), message expected_tilts_vh = np.squeeze(np.array([self.angle_tilt_v, self.angle_tilt_h])) computed_tilts_vh = -tilts_vh / (self.pixel_size / 1000) message = "Computed tilts %s and expected %s do not coincide" % ( computed_tilts_vh, expected_tilts_vh, ) assert np.all(np.isclose(computed_tilts_vh, expected_tilts_vh, atol=self.abs_tol_tilt)), message def test_size_determination(self): inp_shape = [2162, 2560] exp_shape = np.array([2160, 2160]) new_shape = CameraFocus()._check_img_block_size(inp_shape, 4, suggest_new_shape=True) message = "New suggested shape: %s and expected: %s do not coincide" % (new_shape, exp_shape) assert np.all(new_shape == exp_shape), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/estimation/tests/test_tilt.py0000644000175000017500000000317214402565210021117 0ustar00pierrepierreimport pytest import numpy as np from nabu.estimation.tilt import CameraTilt from nabu.estimation.tests.test_cor import bootstrap_cor try: import skimage.transform as skt # noqa: F401 __have_skimage__ = True except ImportError: __have_skimage__ = False @pytest.mark.usefixtures("bootstrap_cor") class TestCameraTilt: def test_1dcorrelation(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) tilt_calc = CameraTilt() cor_position, camera_tilt = tilt_calc.compute_angle(radio1, radio2) message = "Computed tilt %f " % camera_tilt + " and real tilt %f do not coincide" % self.tilt_deg assert np.isclose(self.tilt_deg, camera_tilt, atol=self.abs_tol), message message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_hl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message @pytest.mark.skipif(not (__have_skimage__), reason="need scikit-image for this test") def test_fftpolar(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) tilt_calc = CameraTilt() cor_position, camera_tilt = tilt_calc.compute_angle(radio1, radio2, method="fft-polar") message = "Computed tilt %f " % camera_tilt + " and real tilt %f do not coincide" % self.tilt_deg assert np.isclose(self.tilt_deg, camera_tilt, atol=self.abs_tol), message message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_hl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/estimation/tests/test_translation.py0000644000175000017500000000566714402565210022514 0ustar00pierrepierreimport os import numpy as np import pytest import h5py from nabu.testutils import utilstest from nabu.estimation.translation import DetectorTranslationAlongBeam import scipy.ndimage def get_alignxc_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] img_pos = hf["/entry/instrument/detector/distance"][()] unit_length_shifts_vh = [ hf["/calibration/alignxc/y_pixel_shift_unit"][()], hf["/calibration/alignxc/x_pixel_shift_unit"][()], ] all_shifts_vh = hf["/calibration/alignxc/yx_pixel_offsets"][()] return data, img_pos, (unit_length_shifts_vh, all_shifts_vh) @pytest.fixture(scope="class") def bootstrap_dtr(request): cls = request.cls cls.abs_tol = 1e-1 cls.data, cls.img_pos, calib_data = get_alignxc_data("test_alignment_alignxc.h5") cls.expected_shifts_vh, cls.all_shifts_vh = calib_data @pytest.mark.usefixtures("bootstrap_dtr") class TestDetectorTranslation: def test_alignxc(self): T_calc = DetectorTranslationAlongBeam() shifts_v, shifts_h, found_shifts_list = T_calc.find_shift(self.data, self.img_pos, return_shifts=True) message = "Computed shifts coefficients %s and expected %s do not coincide" % ( (shifts_v, shifts_h), self.expected_shifts_vh, ) assert np.all(np.isclose(self.expected_shifts_vh, [shifts_v, shifts_h], atol=self.abs_tol)), message message = "Computed shifts %s and expected %s do not coincide" % ( found_shifts_list, self.all_shifts_vh, ) assert np.all(np.isclose(found_shifts_list, self.all_shifts_vh, atol=self.abs_tol)), message def test_alignxc_synth(self): T_calc = DetectorTranslationAlongBeam() stack = np.zeros([4, 512, 512], "d") for i in range(4): stack[i, 200 - i * 10, 200 - i * 10] = 1 stack = scipy.ndimage.gaussian_filter(stack, [0, 10, 10.0]) * 100 x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2])) for i in range(4): xc = x - (250 + i * 1.234) yc = y - (250 + i * 1.234 * 2) stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5) shifts_v, shifts_h, found_shifts_list = T_calc.find_shift( stack, np.array([0.0, 1, 2, 3]), high_pass=1.0, return_shifts=True ) message = "Found shifts per units %s and reference %s do not coincide" % ( (shifts_v, shifts_h), (-1.234 * 2, -1.234), ) assert np.all(np.isclose((shifts_v, shifts_h), (-1.234 * 2, -1.234), atol=self.abs_tol)), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/estimation/tilt.py0000644000175000017500000002102314712705065016721 0ustar00pierrepierreimport numpy as np from numpy.polynomial.polynomial import Polynomial, polyval from .alignment import medfilt2d, plt from .cor import CenterOfRotation try: import skimage.transform as skt __have_skimage__ = True except ImportError: __have_skimage__ = False class CameraTilt(CenterOfRotation): def compute_angle( self, img_1: np.ndarray, img_2: np.ndarray, method="1d-correlation", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Find the camera tilt, given two opposite images. This method finds the tilt between the camera pixel columns and the rotation axis, by performing a 1-dimensional correlation between two opposite images. The output of this function, allows to compute motor movements for aligning the camera tilt. Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). method: str Tilt angle computation method. Default is "1d-correlation" (traditional). All options are: - "1d-correlation": fastest, but works best for small tilts - "fft-polar": slower, but works well on all ranges of tilts roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- cor_offset_pix: float Estimated center of rotation position from the center of the RoI in pixels. tilt_deg: float Estimated camera tilt angle in degrees. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... tilt_calc = CameraTilt() ... cor_offset, camera_tilt = tilt_calc.compute_angle(radio1, radio2) Or for noisy images: >>> cor_offset, camera_tilt = tilt_calc.compute_angle(radio1, radio2, median_filt_shape=(3, 3)) """ self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) if method.lower() == "1d-correlation": return self._compute_angle_1dcorrelation( img_1, img_2, padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass ) elif method.lower() == "fft-polar": if not __have_skimage__: raise ValueError( 'Camera tilt calculation using "fft-polar" is only available with scikit-image.' " Please install the package to use this option." ) return self._compute_angle_fftpolar( img_1, img_2, padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass ) else: raise ValueError('Invalid method: %s. Valid options are: "1d-correlation" | "fft-polar"' % method) def _compute_angle_1dcorrelation( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None ): cc = self._compute_correlation_fft( img_1, img_2, padding_mode, axes=(-1,), high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape cc_h_coords = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) (f_vals, fh) = self.extract_peak_regions_1d(cc, peak_radius=peak_fit_radius, cc_coords=cc_h_coords) fitted_shifts_h = self.refine_max_position_1d(f_vals, return_all_coeffs=True) fitted_shifts_h += fh[1, :] # Computing tilt fitted_shifts_h = medfilt2d(fitted_shifts_h, kernel_size=3) half_img_size = (img_shape[-2] - 1) / 2 cc_v_coords = np.linspace(-half_img_size, half_img_size, img_shape[-2]) coeffs_h = Polynomial.fit(cc_v_coords, fitted_shifts_h, deg=1).convert().coef tilt_deg = np.rad2deg(-coeffs_h[1] / 2) cor_offset_pix = coeffs_h[0] / 2 if self.verbose: self.logger.info( "Fitted center of rotation (pixels):", cor_offset_pix, "and camera tilt (degrees):", tilt_deg, ) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.plot(cc_v_coords, fitted_shifts_h) ax.plot(cc_v_coords, polyval(cc_v_coords, coeffs_h), "-C1") ax.set_title("Correlation peaks") plt.show(block=self.extra_options["blocking_plots"]) return cor_offset_pix, tilt_deg def _compute_angle_fftpolar( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None ): img_shape = img_2.shape img_fft_1, img_fft_2, filt, _ = self._transform_to_fft( img_1, img_2, padding_mode=padding_mode, axes=(-2, -1), low_pass=low_pass, high_pass=high_pass ) if filt is not None: img_fft_1 *= filt img_fft_2 *= filt # abs removes the translation component img_fft_1 = np.abs(np.fft.fftshift(img_fft_1, axes=(-2, -1))) img_fft_2 = np.abs(np.fft.fftshift(img_fft_2, axes=(-2, -1))) # transform to polar coordinates img_fft_1 = skt.warp_polar(img_fft_1, scaling="linear", output_shape=img_shape) img_fft_2 = skt.warp_polar(img_fft_2, scaling="linear", output_shape=img_shape) # only use half of the fft domain img_fft_1 = img_fft_1[..., : img_fft_1.shape[-2] // 2, :] img_fft_2 = img_fft_2[..., : img_fft_2.shape[-2] // 2, :] tilt_pix = self.find_shift(img_fft_1, img_fft_2, shift_axis=-2, return_relative_to_middle=True) tilt_deg = -(360 / img_shape[0]) * tilt_pix img_1 = skt.rotate(img_1, tilt_deg) img_2 = skt.rotate(img_2, -tilt_deg) cor_offset_pix = self.find_shift( img_1, img_2, padding_mode=padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass, return_relative_to_middle=True, ) if self.verbose: print( "Fitted center of rotation (pixels):", cor_offset_pix, "and camera tilt (degrees):", tilt_deg, ) return cor_offset_pix, tilt_deg ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1726234429.0 nabu-2024.2.1/nabu/estimation/translation.py0000644000175000017500000002060514671037475020317 0ustar00pierrepierreimport numpy as np from numpy.polynomial.polynomial import Polynomial, polyval from .alignment import AlignmentBase, plt class DetectorTranslationAlongBeam(AlignmentBase): def find_shift( self, img_stack: np.ndarray, img_pos: np.array, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_shifts=False, use_adjacent_imgs=False, ): """Find the vertical and horizontal shifts for translations of the detector along the beam direction. These shifts are in pixels-per-unit-translation, and they are due to the misalignment of the translation stage, with respect to the beam propagation direction. To compute the vertical and horizontal tilt angles from the obtained `shift_pix`: >>> tilt_deg = np.rad2deg(np.arctan(shift_pix * pixel_size)) where `pixel_size` and and the input parameter `img_pos` have to be expressed in the same units. Parameters ---------- img_stack: numpy.ndarray A stack of images (usually 4) at different distances img_pos: numpy.ndarray Position of the images along the translation axis roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. return_shifts: boolean, optional Adds a third returned argument, containing the pixel shifts of each image with respect to the first one in the stack. Defaults to False. use_adjacent_imgs: boolean, optional Compute correlation between adjacent images. It can be used when dealing with large shifts, to avoid overflowing the shift. This option allows to replicate the behavior of the reference function `alignxc.m` However, it is detrimental to shift fitting accuracy. Defaults to False. Returns ------- coeff_v: float Estimated vertical shift in pixel per unit-distance of the detector translation. coeff_h: float Estimated horizontal shift in pixel per unit-distance of the detector translation. shifts_vh: list, optional The pixel shifts of each image with respect to the first image in the stack. Activated if return_shifts is True. Examples -------- The following example creates a stack of shifted images, and retrieves the computed shift. Here we use a high-pass filter, due to the presence of some low-frequency noise component. >>> import numpy as np ... import scipy as sp ... import scipy.ndimage ... from nabu.preproc.alignment import DetectorTranslationAlongBeam ... ... tr_calc = DetectorTranslationAlongBeam() ... ... stack = np.zeros([4, 512, 512]) ... ... # Add low frequency spurious component ... for i in range(4): ... stack[i, 200 - i * 10, 200 - i * 10] = 1 ... stack = sp.ndimage.filters.gaussian_filter(stack, [0, 10, 10.0]) * 100 ... ... # Add the feature ... x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2])) ... for i in range(4): ... xc = x - (250 + i * 1.234) ... yc = y - (250 + i * 1.234 * 2) ... stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5) ... ... # Image translation along the beam ... img_pos = np.arange(4) ... ... # Find the shifts from the features ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, high_pass=1.0) ... print(shifts_v, shifts_h) >>> ( -2.47 , -1.236 ) and the following commands convert the shifts in angular tilts: >>> tilt_v_deg = np.rad2deg(np.arctan(shifts_v * pixel_size)) >>> tilt_h_deg = np.rad2deg(np.arctan(shifts_h * pixel_size)) To enable the legacy behavior of `alignxc.m` (correlation between adjacent images): >>> shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, use_adjacent_imgs=True) To plot the correlation shifts and the fitted straight lines for both directions: >>> tr_calc = DetectorTranslationAlongBeam(verbose=True) ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos) """ self._check_img_stack_size(img_stack, img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) # do correlations ccs = [ self._compute_correlation_fft( img_stack[ii - 1 if use_adjacent_imgs else 0, ...], img_stack[ii, ...], padding_mode, high_pass=high_pass, low_pass=low_pass, ) for ii in range(1, num_imgs) ] img_shape = img_stack.shape[-2:] cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2]) cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) shifts_vh = np.zeros((num_imgs, 2)) for ii, cc in enumerate(ccs): (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs) shifts_vh[ii + 1, :] = self.refine_max_position_2d(f_vals, fv, fh) if use_adjacent_imgs: shifts_vh = np.cumsum(shifts_vh, axis=0) # Polynomial.fit is supposed to be more numerically stable than polyfit # (according to numpy) coeffs_v = Polynomial.fit(img_pos, shifts_vh[:, 0], deg=1).convert().coef coeffs_h = Polynomial.fit(img_pos, shifts_vh[:, 1], deg=1).convert().coef # In some cases (singular matrix ?) the output is [0] while in some other its [eps, eps]. if len(coeffs_v) == 1: coeffs_v = np.array([coeffs_v[0], coeffs_v[0]]) if len(coeffs_h) == 1: coeffs_h = np.array([coeffs_h[0], coeffs_h[0]]) if self.verbose: self.logger.info( "Fitted pixel shifts per unit-length: vertical = %f, horizontal = %f" % (coeffs_v[1], coeffs_h[1]) ) f, axs = plt.subplots(1, 2) self._add_plot_window(f, ax=axs) axs[0].scatter(img_pos, shifts_vh[:, 0]) axs[0].plot(img_pos, polyval(img_pos, coeffs_v), "-C1") axs[0].set_title("Vertical shifts") axs[1].scatter(img_pos, shifts_vh[:, 1]) axs[1].plot(img_pos, polyval(img_pos, coeffs_h), "-C1") axs[1].set_title("Horizontal shifts") plt.show(block=False) if return_shifts: return coeffs_v[1], coeffs_h[1], shifts_vh else: return coeffs_v[1], coeffs_h[1] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/estimation/utils.py0000644000175000017500000000053314712705065017110 0ustar00pierrepierreimport numpy as np def is_fullturn_scan(angles_rad, tol=None): """ Return True if the angles correspond to a full-turn (360 degrees) scan. """ angles_rad = np.sort(angles_rad) if tol is None: tol = np.min(np.abs(np.diff(angles_rad))) * 1.1 return np.abs((angles_rad.max() - angles_rad.min()) - (2 * np.pi)) < tol ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/io/0000755000175000017500000000000014730277752013637 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/io/__init__.py0000644000175000017500000000017014654107202015731 0ustar00pierrepierrefrom .reader import NPReader, EDFReader, HDF5File, HDF5Loader, ChunkReader, Readers from .writer import NXProcessWriter ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/io/cast_volume.py0000644000175000017500000004054114654107202016521 0ustar00pierrepierreimport os from nabu.misc.utils import rescale_data from nabu.pipeline.params import files_formats from tomoscan.volumebase import VolumeBase from tomoscan.scanbase import TomoScanBase from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from tomoscan.io import HDF5File from silx.io.utils import get_data from silx.utils.enum import Enum as _Enum import numpy from silx.io.url import DataUrl from typing import Optional import logging _logger = logging.getLogger(__name__) __all__ = ["get_default_output_volume", "cast_volume"] _DEFAULT_OUTPUT_DIR = "vol_cast" RESCALE_MIN_PERCENTILE = 10 RESCALE_MAX_PERCENTILE = 90 def get_default_output_volume( input_volume: VolumeBase, output_type: str, output_dir: str = _DEFAULT_OUTPUT_DIR ) -> VolumeBase: """ For a given input volume and output type return output volume as an instance of VolumeBase :param VolumeBase intput_volume: volume for which we want to get the resulting output volume for a cast :param str output_type: output_type of the volume (edf, tiff, hdf5...) :param str output_dir: output dir to save the cast volume """ if not isinstance(input_volume, VolumeBase): raise TypeError(f"input_volume is expected to be an instance of {VolumeBase}") valid_file_formats = set(files_formats.values()) if not output_type in valid_file_formats: raise ValueError(f"output_type is not a valid value ({output_type}). Valid values are {valid_file_formats}") if isinstance(input_volume, (EDFVolume, TIFFVolume, JP2KVolume)): if output_type == "hdf5": file_path = os.path.join( input_volume.data_url.file_path(), output_dir, input_volume.get_volume_basename() + ".hdf5", ) volume = HDF5Volume( file_path=file_path, data_path="/volume", ) assert volume.get_identifier() is not None, "volume should be able to create an identifier" return volume elif output_type in ("tiff", "edf", "jp2"): if output_type == "tiff": Constructor = TIFFVolume elif output_type == "edf": Constructor = EDFVolume elif output_type == "jp2": Constructor = JP2KVolume return Constructor( # pylint: disable=E0601 folder=os.path.join( os.path.dirname(input_volume.data_url.file_path()), output_dir, ), volume_basename=input_volume.get_volume_basename(), ) else: raise NotImplementedError(f"output volume format {output_type} is not handled") elif isinstance(input_volume, (HDF5Volume, MultiTIFFVolume)): if output_type == "hdf5": data_file_parent_path, data_file_name = os.path.split(input_volume.data_url.file_path()) # replace extension: data_file_name = ".".join( [ os.path.splitext(data_file_name)[0], "hdf5", ] ) if isinstance(input_volume, HDF5Volume): data_data_path = input_volume.data_url.data_path() metadata_data_path = input_volume.metadata_url.data_path() try: data_path = os.path.commonprefix([data_data_path, metadata_data_path]) except Exception: data_path = "volume" else: data_data_path = HDF5Volume.DATA_DATASET_NAME metadata_data_path = HDF5Volume.METADATA_GROUP_NAME file_path = data_file_name data_path = "volume" volume = HDF5Volume( file_path=os.path.join(data_file_parent_path, output_dir, data_file_name), data_path=data_path, ) assert volume.get_identifier() is not None, "volume should be able to create an identifier" return volume elif output_type in ("tiff", "edf", "jp2"): if output_type == "tiff": Constructor = TIFFVolume elif output_type == "edf": Constructor = EDFVolume elif output_type == "jp2": Constructor = JP2KVolume file_parent_path, file_name = os.path.split(input_volume.data_url.file_path()) file_name = os.path.splitext(file_name)[0] return Constructor( folder=os.path.join( file_parent_path, output_dir, os.path.basename(file_name), ) ) else: raise NotImplementedError(f"output volume format {output_type} is not handled") else: raise NotImplementedError(f"input volume format {input_volume} is not handled") def cast_volume( input_volume: VolumeBase, output_volume: VolumeBase, output_data_type: numpy.dtype, data_min=None, data_max=None, scan: Optional[TomoScanBase] = None, rescale_min_percentile=RESCALE_MIN_PERCENTILE, rescale_max_percentile=RESCALE_MAX_PERCENTILE, save=True, store=False, ) -> VolumeBase: """ cast givent volume to output_volume of 'output_data_type' type :param VolumeBase input_volume: :param VolumeBase output_volume: :param numpy.dtype output_data_type: output data type :param number data_min: `data` min value to clamp to new_min. Any lower value will also be clamp to new_min. :param number data_max: `data` max value to clamp to new_max. Any hight value will also be clamp to new_max. :param TomoScanBase scan: source scan that produced input_volume. Can be used to find histogram for example. :param rescale_min_percentile: if `data_min` is None will set data_min to 'rescale_min_percentile' :param rescale_max_percentile: if `data_max` is None will set data_min to 'rescale_max_percentile' :param bool save: if True dump the slice on disk (one by one) :param bool store: if True once the volume is cast then set `output_volume.data` :return: output_volume with data and metadata set .. warning:: the created will volume will not be saved in this processing. If you want to save the cast volume you must do it yourself. .. note:: if you want to tune compression ratios (for jp2k) then please update the `cratios` attributes of the output volume """ if not isinstance(input_volume, VolumeBase): raise TypeError(f"input_volume is expected to be a {VolumeBase}. {type(input_volume)} provided") if not isinstance(output_volume, VolumeBase): raise TypeError(f"output_volume is expected to be a {VolumeBase}. {type(output_volume)} provided") try: output_data_type = numpy.dtype( output_data_type ) # User friendly API in case user provides np.uint16 e.g. (see issue #482) except Exception: pass if not isinstance(output_data_type, numpy.dtype): raise TypeError(f"output_data_type is expected to be a {numpy.dtype}. {type(output_data_type)} provided") # start processing # check for data_min and data_max if data_min is None or data_max is None: found_data_min, found_data_max = _try_to_find_min_max_from_histo( input_volume=input_volume, scan=scan, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) if found_data_min is None or found_data_max is None: _logger.warning("couldn't find histogram, recompute volume min and max values") data_min, data_max = input_volume.get_min_max() _logger.info(f"min and max found ({data_min} ; {data_max})") data_min = data_min if data_min is not None else found_data_min data_max = data_max if data_max is not None else found_data_max data = [] for input_slice, frame_dumper in zip( input_volume.browse_slices(), output_volume.data_file_saver_generator( input_volume.get_volume_shape()[0], data_url=output_volume.data_url, overwrite=output_volume.overwrite, ), ): if numpy.issubdtype(output_data_type, numpy.integer): new_min = numpy.iinfo(output_data_type).min new_max = numpy.iinfo(output_data_type).max output_slice = clamp_and_rescale_data( data=input_slice, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ).astype(output_data_type) else: output_slice = input_slice.astype(output_data_type) if save: frame_dumper[:] = output_slice if store: # only keep data in cache if not dump to disk data.append(output_slice) if store: output_volume.data = numpy.asarray(data) # try also to append some metadata to it try: output_volume.metadata = input_volume.metadata or input_volume.load_metadata() except (OSError, KeyError): # if no metadata provided and or saved in disk or if some key are missing pass return output_volume def clamp_and_rescale_data( data: numpy.ndarray, new_min, new_max, data_min=None, data_max=None, rescale_min_percentile=RESCALE_MIN_PERCENTILE, rescale_max_percentile=RESCALE_MAX_PERCENTILE, ): """ rescale data to 'new_min', 'new_max' :param numpy.ndarray data: data to be rescaled :param dtype output_dtype: output dtype :param new_min: rescaled data new min (clamp min value) :param new_max: rescaled data new max (clamp max value) :param data_min: `data` min value to clamp to new_min. Any lower value will also be clamp to new_min. :param data_max: `data` max value to clamp to new_max. Any hight value will also be clamp to new_max. :param rescale_min_percentile: if `data_min` is None will set data_min to 'rescale_min_percentile' :param rescale_max_percentile: if `data_max` is None will set data_min to 'rescale_max_percentile' """ if data_min is None: data_min = numpy.percentile(data, rescale_min_percentile) if data_max is None: data_max = numpy.percentile(data, rescale_max_percentile) # rescale data rescaled_data = rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max) # clamp data rescaled_data[rescaled_data < new_min] = new_min rescaled_data[rescaled_data > new_max] = new_max return rescaled_data def find_histogram(volume: VolumeBase, scan: Optional[TomoScanBase] = None) -> Optional[DataUrl]: """ Look for histogram of the provided url. If found one return the DataUrl of the nabu histogram """ if not isinstance(volume, VolumeBase): raise TypeError(f"volume is expected to be an instance of {VolumeBase} not {type(volume)}") elif isinstance(volume, HDF5Volume): histogram_file = volume.data_url.file_path() if volume.url is not None: data_path = volume.url.data_path() if data_path.endswith("reconstruction"): data_path = "/".join( [ *data_path.split("/")[:-1], "histogram/results/data", ] ) else: data_path = "/".join((volume.url.data_path(), "histogram/results/data")) else: # TODO: FIXME: in some case (if the users provides the full data_url and if the 'DATA_DATASET_NAME' is not used we # will endup with an invalid data_path. Hope this case will not happen. Anyway this is a case that we can't handle.) # if trouble: check if data_path exists. If not raise an error saying this we can't find an histogram for this volume data_path = volume.data_url.data_path().replace(HDF5Volume.DATA_DATASET_NAME, "histogram/results/data") elif isinstance(volume, (EDFVolume, JP2KVolume, TIFFVolume, MultiTIFFVolume)): if isinstance(volume, (EDFVolume, JP2KVolume, TIFFVolume)): histogram_file = os.path.join( volume.data_url.file_path(), volume.get_volume_basename() + "_histogram.hdf5", ) if not os.path.exists(histogram_file): # legacy location legacy_histogram_file = os.path.join( volume.data_url.file_path(), volume.get_volume_basename() + "histogram.hdf5", ) if os.path.exists(legacy_histogram_file): # only overwrite if exists. Else keep the older one to get a clearer information histogram_file = legacy_histogram_file else: file_path, _ = os.path.splitext(volume.data_url.file_path()) histogram_file = file_path + "_histogram.hdf5" if scan is not None: data_path = getattr(scan, "entry/histogram/results/data", "entry/histogram/results/data") else: def get_file_entries(file_path: str) -> Optional[tuple]: if os.path.exists(file_path): with HDF5File(file_path, mode="r") as h5s: return tuple(h5s.keys()) else: return None # in the case we only know about the volume to cast. # in most of the cast the histogram.hdf5 file will only get a single entry. The exception could be # for HDF5 if the user save volumes into the same file. # we can find back the histogram entries = get_file_entries(histogram_file) if entries is not None and len(entries) == 1: data_path = "/".join((entries[0], "histogram/results/data")) else: # TODO: FIXME: how to get the entry name in every case ? # what to do if the histogram file has more than one entry. # one option could be to request the entry from the user... # or keep as today (in this case it will be recomputed) _logger.info("histogram file found but unable to find relevant histogram") return None else: raise NotImplementedError(f"volume {type(volume)} not handled") if not os.path.exists(histogram_file): _logger.info(f"{histogram_file} not found") return None with HDF5File(histogram_file, mode="r") as h5f: if not data_path in h5f: _logger.info(f"{data_path} in {histogram_file} not found") return None else: _logger.info(f"Found histogram {histogram_file}::/{data_path}") return DataUrl( file_path=histogram_file, data_path=data_path, scheme="silx", ) def _get_hst_saturations(hist, bins, rescale_min_percentile: numpy.float32, rescale_max_percentile: numpy.float32): hist_cum = numpy.cumsum(hist) bin_index_min = numpy.searchsorted(hist_cum, numpy.percentile(hist_cum, rescale_min_percentile)) bin_index_max = numpy.searchsorted(hist_cum, numpy.percentile(hist_cum, rescale_max_percentile)) return bins[bin_index_min], bins[bin_index_max] def _try_to_find_min_max_from_histo( input_volume: VolumeBase, rescale_min_percentile, rescale_max_percentile, scan=None ) -> tuple: """ util to interpret nabu histogram and deduce data_min and data_max to be used for rescaling a volume """ histogram_res_url = find_histogram(input_volume, scan=scan) if histogram_res_url is not None: return _min_max_from_histo( url=histogram_res_url, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) else: return None, None def _min_max_from_histo(url: DataUrl, rescale_min_percentile: int, rescale_max_percentile: int) -> tuple: try: histogram = get_data(url) except Exception as e: _logger.error(f"Fail to load histogram from {url.path()}. Reason is {e}") return None, None else: bins = histogram[1] hist = histogram[0] return _get_hst_saturations( hist, bins, numpy.float32(rescale_min_percentile), numpy.float32(rescale_max_percentile) ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1712153075.0 nabu-2024.2.1/nabu/io/detector_distortion.py0000644000175000017500000002674014603260763020302 0ustar00pierrepierreimport numpy as np from scipy import sparse class DetectorDistortionBase: """ """ def __init__(self, detector_full_shape_vh=(0, 0)): """This is the basis class. A simple identity transformation which has the only merit to show how it works.Reimplement this function to have more parameters for other transformations """ self._build_full_transformation(detector_full_shape_vh) def transform(self, source_data, do_full=False): """performs the transformation""" if do_full: result = self.transformation_matrix_full @ source_data.flat result.shape = source_data.shape else: result = self.transformation_matrix @ source_data.flat result.shape = self.target_shape return result def _build_full_transformation(self, detector_full_shape_vh): """A simple identity. Reimplement this function to have more parameters for other transformations """ self.detector_full_shape_vh = detector_full_shape_vh sz, sx = detector_full_shape_vh # A simple identity matrix in sparse coordinates format self.total_detector_npixs = detector_full_shape_vh[0] * detector_full_shape_vh[1] I_tmp = np.arange(self.total_detector_npixs) J_tmp = np.arange(self.total_detector_npixs) V_tmp = np.ones([self.total_detector_npixs], "f") coo_tmp = sparse.coo_matrix((V_tmp, (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() ## The following arrays are kept for future usage ## when, according to the "sub_region" parameter of the moment, ## they will be used to extract a "slice" of them ## which will map an appropriate data region corresponding to "sub_region_source" ## to the target "sub_region" of the moment self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def get_adapted_subregion(self, sub_region_xz): if sub_region_xz is not None: start_x, end_x, start_z, end_z = sub_region_xz else: start_x = 0 end_x = None start_z = 0 end_z = None (start_x, end_x, start_z, end_z) = self.set_sub_region_transformation((start_x, end_x, start_z, end_z)) return (start_x, end_x, start_z, end_z) def set_sub_region_transformation(self, target_sub_region=None): """must return a source sub_region. It sets internally an object (a practical implementation would be a sparse matrice) which can be reused in further applications of "transform" method for transforming the source sub_region data into the target sub_region """ if target_sub_region is None: target_sub_region = (None, None, 0, None) if self._status is not None and self._status["target_sub_region"] == target_sub_region: return self._status["source_sub_region"] else: self._status = None return self._set_sub_region_transformation(target_sub_region) def set_full_transformation(self): self._set_sub_region_transformation(do_full=True) def get_actual_shapes_source_target(self): if self._status is None: return None, None else: return self._status["source_sub_region"], self._status["target_sub_region"] def _set_sub_region_transformation(self, target_sub_region=None, do_full=False): """to be reimplemented in the derived classes""" if target_sub_region is None or do_full: target_sub_region = (None, None, 0, None) (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionBase only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] indices_offset = self.full_csr_indptr[row_ptr_start] source_offset = target_sub_region[2] * self.detector_full_shape_vh[1] data_tmp = self.full_csr_data[indices_start:indices_end] indices_tmp = self.full_csr_indices[indices_start:indices_end] - source_offset indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - indices_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end - z_start) * self.detector_full_shape_vh[1] tmp_transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) if do_full: self.transformation_matrix_full = tmp_transformation_matrix return None else: self.transformation_matrix = tmp_transformation_matrix self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start, z_end)), } return self._status["source_sub_region"] class DetectorDistortionMapsXZ(DetectorDistortionBase): def __init__(self, map_x, map_z): """ This class implements the distortion correction from the knowledge of two arrays, map_x and map_z. Pixel (i,j) of the corrected image is obtained by interpolating the raw data at position ( map_z(i,j), map_x(i,j) ). Parameters: map_x : float 2D array map_z : float 2D array """ self._build_full_transformation(map_x, map_z) def _build_full_transformation(self, map_x, map_z): """ """ detector_full_shape_vh = map_x.shape if detector_full_shape_vh != map_z.shape: message = f""" map_x and map_z must have the same shape but the dimensions were {map_x.shape} and {map_z.shape} """ raise ValueError(message) coordinates = np.array([map_z, map_x]) # padding sz, sx = detector_full_shape_vh total_detector_npixs = sz * sx xs = np.clip(np.array(coordinates[1].flat), [[0]], [[sx - 1]]) zs = np.clip(np.array(coordinates[0].flat), [[0]], [[sz - 1]]) ix0s = np.floor(xs) ix1s = np.ceil(xs) fx = xs - ix0s iz0s = np.floor(zs) iz1s = np.ceil(zs) fz = zs - iz0s I_tmp = np.empty([4 * sz * sx], np.int64) J_tmp = np.empty([4 * sz * sx], np.int64) V_tmp = np.ones([4 * sz * sx], "f") I_tmp[:] = np.arange(sz * sx * 4) // 4 J_tmp[0::4] = iz0s * sx + ix0s J_tmp[1::4] = iz0s * sx + ix1s J_tmp[2::4] = iz1s * sx + ix0s J_tmp[3::4] = iz1s * sx + ix1s V_tmp[0::4] = (1 - fz) * (1 - fx) V_tmp[1::4] = (1 - fz) * fx V_tmp[2::4] = fz * (1 - fx) V_tmp[3::4] = fz * fx self.detector_full_shape_vh = detector_full_shape_vh coo_tmp = sparse.coo_matrix((V_tmp.astype("f"), (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def _set_sub_region_transformation( self, target_sub_region=( ( None, None, 0, 0, ), ), do_full=False, ): if target_sub_region is None or do_full: target_sub_region = (None, None, 0, None) (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionRotation only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] data_tmp = self.full_csr_data[indices_start:indices_end] target_offset = self.full_csr_indptr[row_ptr_start] indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - target_offset indices_tmp = self.full_csr_indices[indices_start:indices_end] iz_source = (indices_tmp) // self.detector_full_shape_vh[1] z_start_source = iz_source.min() z_end_source = iz_source.max() + 1 source_offset = z_start_source * self.detector_full_shape_vh[1] indices_tmp = indices_tmp - source_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end_source - z_start_source) * self.detector_full_shape_vh[1] tmp_transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) if do_full: self.transformation_matrix_full = tmp_transformation_matrix return None else: self.transformation_matrix = tmp_transformation_matrix self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start_source, z_end_source)), } return self._status["source_sub_region"] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/io/reader.py0000644000175000017500000011710514712705065015451 0ustar00pierrepierreimport os from threading import get_ident from math import ceil from multiprocessing.pool import ThreadPool from posixpath import sep as posix_sep, join as posix_join import numpy as np from silx.io import get_data from silx.io.dictdump import h5todict from tomoscan.io import HDF5File from .utils import get_compacted_dataslices, convert_dict_values, get_first_hdf5_entry from ..misc.binning import binning as image_binning from ..utils import ( check_supported, deprecated, deprecated_class, deprecation_warning, indices_to_slices, compacted_views, merge_slices, subsample_dict, get_3D_subregion, get_num_threads, get_shape_from_sliced_dims, get_size_from_sliced_dimension, safe_format, ) try: from fabio.edfimage import EdfImage except ImportError: EdfImage = None class Reader: """ Abstract class for various file readers. """ def __init__(self, sub_region=None): """ Parameters ---------- sub_region: tuple, optional Coordinates in the form (start_x, end_x, start_y, end_y), to read a subset of each frame. It can be used for Regions of Interest (ROI). Indices start at zero ! """ self._set_default_parameters(sub_region) def _set_default_parameters(self, sub_region): self._set_subregion(sub_region) def _set_subregion(self, sub_region): self.sub_region = sub_region if sub_region is not None: start_x, end_x, start_y, end_y = sub_region self.start_x = start_x self.end_x = end_x self.start_y = start_y self.end_y = end_y else: self.start_x = 0 self.end_x = None self.start_y = 0 self.end_y = None def get_data(self, data_url): """ Get data from a silx.io.url.DataUrl """ raise ValueError("Base class") def release(self): """ Release the file if needed. """ pass class NPReader(Reader): multi_load = True def __init__(self, sub_region=None, mmap=True): """ Reader for NPY/NPZ files. Mostly used for internal development. Please refer to the documentation of nabu.io.reader.Reader """ super().__init__(sub_region=sub_region) self._file_desc = {} self._set_mmap(mmap) def _set_mmap(self, mmap): self.mmap_mode = "r" if mmap else None def _open(self, data_url): file_path = data_url.file_path() file_ext = self._get_file_type(file_path) if file_ext == "npz": if file_path not in self._file_desc: self._file_desc[file_path] = np.load(file_path, mmap_mode=self.mmap_mode) data_ref = self._file_desc[file_path][data_url.data_path()] else: data_ref = np.load(file_path, mmap_mode=self.mmap_mode) return data_ref @staticmethod def _get_file_type(fname): if fname.endswith(".npy"): return "npy" elif fname.endswith(".npz"): return "npz" else: raise ValueError("Not a numpy file: %s" % fname) def get_data(self, data_url): data_ref = self._open(data_url) data_slice = data_url.data_slice() if data_slice is None: res = data_ref[self.start_y : self.end_y, self.start_x : self.end_x] else: res = data_ref[data_slice, self.start_y : self.end_y, self.start_x : self.end_x] return res def release(self): for fname, fdesc in self._file_desc.items(): if fdesc is not None: fdesc.close() self._file_desc[fname] = None def __del__(self): self.release() class EDFReader(Reader): multi_load = False # not implemented def __init__(self, sub_region=None): """ A class for reading series of EDF Files. Multi-frames EDF are not supported. """ if EdfImage is None: raise ImportError("Need fabio to use this reader") super().__init__(sub_region=sub_region) self._reader = EdfImage() self._first_fname = None def read(self, fname): if self._first_fname is None: self._first_fname = fname self._reader.read(fname) if self.sub_region is None: data = self._reader.data else: data = self._reader.fast_read_roi(fname, (slice(self.start_y, self.end_y), slice(self.start_x, self.end_x))) # self._reader.close() return data def get_data(self, data_url): return self.read(data_url.file_path()) class HDF5Reader(Reader): multi_load = True def __init__(self, sub_region=None): """ A class for reading a HDF5 File. """ super().__init__(sub_region=sub_region) self._file_desc = {} def _open(self, file_path): if file_path not in self._file_desc: self._file_desc[file_path] = HDF5File(file_path, "r", swmr=True) def get_data(self, data_url): file_path = data_url.file_path() self._open(file_path) h5dataset = self._file_desc[file_path][data_url.data_path()] data_slice = data_url.data_slice() if data_slice is None: res = h5dataset[self.start_y : self.end_y, self.start_x : self.end_x] else: res = h5dataset[data_slice, self.start_y : self.end_y, self.start_x : self.end_x] return res def release(self): for fname, fdesc in self._file_desc.items(): if fdesc is not None: try: fdesc.close() self._file_desc[fname] = None except Exception as exc: print("Error while closing %s: %s" % (fname, str(exc))) def __del__(self): self.release() class HDF5Loader: """ An alternative class to HDF5Reader where information is first passed at class instantiation """ def __init__(self, fname, data_path, sub_region=None, data_buffer=None, pre_allocate=True, dtype="f"): self.fname = fname self.data_path = data_path self._set_subregion(sub_region) if not ((data_buffer is not None) ^ (pre_allocate is True)): raise ValueError("Please provide either 'data_buffer' or 'pre_allocate'") self.data = data_buffer self._loaded = False self.expected_shape = get_hdf5_dataset_shape(fname, data_path, sub_region=sub_region) if pre_allocate: self.data = np.zeros(self.expected_shape, dtype=dtype) def _set_subregion(self, sub_region): self.sub_region = sub_region if sub_region is not None: start_z, end_z, start_y, end_y, start_x, end_x = sub_region self.start_x, self.end_x = start_x, end_x self.start_y, self.end_y = start_y, end_y self.start_z, self.end_z = start_z, end_z else: self.start_x, self.end_x = None, None self.start_y, self.end_y = None, None self.start_z, self.end_z = None, None def load_data(self, force_load=False, output=None): if self._loaded and not force_load: return self.data output = self.data if output is None else output with HDF5File(self.fname, "r") as fdesc: if output is None: output = fdesc[self.data_path][ self.start_z : self.end_z, self.start_y : self.end_y, self.start_x : self.end_x ] else: output[:] = fdesc[self.data_path][ self.start_z : self.end_z, self.start_y : self.end_y, self.start_x : self.end_x ] self._loaded = True return output @deprecated_class("ChunkReader is deprecated since 2024.2.0 and will be removed in a future version", do_print=True) class ChunkReader: """ A reader of chunk of images. """ def __init__( self, files, sub_region=None, detector_corrector=None, pre_allocate=True, data_buffer=None, convert_float=False, shape=None, dtype=None, binning=None, dataset_subsampling=None, num_threads=None, ): """ Initialize a "ChunkReader". A chunk is a stack of images. Parameters ---------- files: dict Dictionary where the key is the file/data index, and the value is a silx.io.url.DataUrl pointing to the data. The dict must contain only the files which shall be used ! Note: the shape and data type is infered from the first data file. sub_region: tuple, optional If provided, this must be a tuple in the form (start_x, end_x, start_y, end_y). Each image will be cropped to this region. This is used to specify a chunk of files. Each of the parameters can be None, in this case the default start and end are taken in each dimension. pre_allocate: bool Whether to pre-allocate data before reading. data_buffer: array-like, optional If `pre_allocate` is set to False, this parameter has to be provided. It is an array-like object which will hold the data. convert_float: bool Whether to convert data to float32, regardless of the input data type. shape: tuple, optional Shape of each image. If not provided, it is inferred from the first image in the collection. dtype: `numpy.dtype`, optional Data type of each image. If not provided, it is inferred from the first image in the collection. binning: int or tuple of int, optional Whether to bin the data. If multi-dimensional binning is done, the parameter must be in the form (binning_x, binning_y). Each image will be binned by these factors. dataset_subsampling: int or tuple, optional Subsampling factor when reading the images. If an integer `n` is provided, then one image out of `n` will be read. If a tuple of integers (step, begin) is given, the data is read as data[begin::step] num_threads: int, optional Number of threads to use for binning the data. Default is to use all available threads. This parameter has no effect when binning is disabled. Notes ------ The files are provided as a collection of `silx.io.DataURL`. The file type is inferred from the extension. Binning is different from subsampling. Using binning will not speed up the data retrieval (quite the opposite), since the whole (subregion of) data is read and then binning is performed. """ self.detector_corrector = detector_corrector self._get_reader_class(files) self.dataset_subsampling = dataset_subsampling self.num_threads = get_num_threads(num_threads) self._set_files(files) self._get_shape_and_dtype(shape, dtype, binning) self._set_subregion(sub_region) self._init_reader() self._loaded = False self.convert_float = convert_float if convert_float: self.out_dtype = np.float32 else: self.out_dtype = self.dtype if not ((data_buffer is not None) ^ (pre_allocate is True)): raise ValueError("Please provide either 'data_buffer' or 'pre_allocate'") self.files_data = data_buffer if data_buffer is not None: # overwrite out_dtype self.out_dtype = data_buffer.dtype if data_buffer.shape != self.chunk_shape: raise ValueError("Expected shape %s but got %s" % (self.shape, data_buffer.shape)) if pre_allocate: self.files_data = np.zeros(self.chunk_shape, dtype=self.out_dtype) if (self.binning is not None) and (np.dtype(self.out_dtype).kind in ["u", "i"]): raise ValueError( "Output datatype cannot be integer when using binning. Please set the 'convert_float' parameter to True or specify a 'data_buffer'." ) def _set_files(self, files): if len(files) == 0: raise ValueError("Expected at least one data file") self._files_begin_idx = 0 if isinstance(self.dataset_subsampling, (tuple, list)): self._files_begin_idx = self.dataset_subsampling[1] self.dataset_subsampling = self.dataset_subsampling[0] self.n_files = len(files) self.files = files self._sorted_files_indices = sorted(files.keys()) self._fileindex_to_idx = dict.fromkeys(self._sorted_files_indices) self._configure_subsampling() def _infer_file_type(self, files): fname = files[sorted(files.keys())[0]].file_path() ext = os.path.splitext(fname)[-1].replace(".", "") if ext not in Readers: raise ValueError("Unknown file format %s. Supported formats are: %s" % (ext, str(Readers.keys()))) return ext def _get_reader_class(self, files): ext = self._infer_file_type(files) reader_class = Readers[ext] self._reader_class = reader_class def _get_shape_and_dtype(self, shape, dtype, binning): if shape is None or dtype is None: shape, dtype = self._infer_shape_and_dtype() assert len(shape) == 2, "Expected the shape of an image (2-tuple)" self.shape_total = shape self.dtype = dtype self._set_binning(binning) def _configure_subsampling(self): dataset_subsampling = self.dataset_subsampling self.files_subsampled = self.files if dataset_subsampling is not None and dataset_subsampling > 1: self.files_subsampled = subsample_dict(self.files, dataset_subsampling) self.n_files = len(self.files_subsampled) if not (self._reader_class.multi_load): # 3D loading not supported for this reader. # Data is loaded frames by frame, so subsample directly self.files self.files = self.files_subsampled self._sorted_files_indices = sorted(self.files.keys()) self._fileindex_to_idx = dict.fromkeys(self._sorted_files_indices) def _infer_shape_and_dtype(self): self._reader_entire_image = self._reader_class() first_file_dataurl = self.files[self._sorted_files_indices[0]] first_file_data = self._reader_entire_image.get_data(first_file_dataurl) return first_file_data.shape, first_file_data.dtype def _set_subregion(self, sub_region): sub_region = sub_region or (None, None, None, None) start_x, end_x, start_y, end_y = sub_region if start_x is None: start_x = 0 if start_y is None: start_y = 0 if end_x is None: end_x = self.shape_total[1] if end_y is None: end_y = self.shape_total[0] self.sub_region = (start_x, end_x, start_y, end_y) self.shape = (end_y - start_y, end_x - start_x) if self.binning is not None: self.shape = (self.shape[0] // self.binning[1], self.shape[1] // self.binning[0]) self.chunk_shape = (self.n_files,) + self.shape if self.detector_corrector is not None: self.detector_corrector.set_sub_region_transformation(target_sub_region=self.sub_region) def _init_reader(self): # instantiate reader with user params if self.detector_corrector is not None: adapted_subregion = self.detector_corrector.get_adapted_subregion(self.sub_region) else: adapted_subregion = self.sub_region self.file_reader = self._reader_class(sub_region=adapted_subregion) def _set_binning(self, binning): if binning is None: self.binning = None return if np.isscalar(binning): binning = (binning, binning) else: assert len(binning) == 2, "Expected binning in the form (binning_x, binning_y)" if binning[0] == 1 and binning[1] == 1: self.binning = None return for b in binning: if int(b) != b: raise ValueError("Expected an integer number for binning values, but got %s" % binning) self.binning = binning def get_data(self, file_url): """ Get the data associated to a file url. """ arr = self.file_reader.get_data(file_url) if arr.ndim == 2: if self.detector_corrector is not None: arr = self.detector_corrector.transform(arr) if self.binning is not None: arr = image_binning(arr, self.binning[::-1]) else: if self.detector_corrector is not None: if self.detector_corrector is not None: _, ( src_x_start, src_x_end, src_z_start, src_z_end, ) = self.detector_corrector.get_actual_shapes_source_target() arr_target = np.empty([len(arr), src_z_end - src_z_start, src_x_end - src_x_start], "f") def apply_corrector(i_img_tuple): i, img = i_img_tuple arr_target[i] = self.detector_corrector.transform(img) with ThreadPool(self.num_threads) as tp: tp.map(apply_corrector, enumerate(arr)) arr = arr_target if self.binning is not None: nz = arr.shape[0] res = np.zeros((nz,) + image_binning(arr[0], self.binning[::-1]).shape, dtype="f") def apply_binning(img_res_tuple): img, res = img_res_tuple res[:] = image_binning(img, self.binning[::-1]) with ThreadPool(self.num_threads) as tp: tp.map(apply_binning, zip(arr, res)) arr = res return arr def _load_single(self): for i, fileidx in enumerate(self._sorted_files_indices): file_url = self.files[fileidx] self.files_data[i] = self.get_data(file_url) self._fileindex_to_idx[fileidx] = i def _load_multi(self): urls_compacted = get_compacted_dataslices( self.files, subsampling=self.dataset_subsampling, begin=self._files_begin_idx ) loaded = {} start_idx = 0 sorted_files_indices = sorted(urls_compacted.keys()) for idx in sorted_files_indices: url = urls_compacted[idx] url_str = str(url) is_loaded = loaded.get(url_str, False) if is_loaded: continue ds = url.data_slice() delta_z = ds.stop - ds.start if ds.step is not None and ds.step > 1: delta_z = ceil(delta_z / ds.step) end_idx = start_idx + delta_z self.files_data[start_idx:end_idx] = self.get_data(url) start_idx += delta_z loaded[url_str] = True def load_files(self, overwrite: bool = False): """ Load the files whose links was provided at class instantiation. Parameters ----------- overwrite: bool, optional Whether to force reloading the files if already loaded. """ if self._loaded and not (overwrite): raise ValueError("Radios were already loaded. Call load_files(overwrite=True) to force reloading") if self.file_reader.multi_load: self._load_multi() else: self._load_single() self._loaded = True load_data = load_files @property def data(self): return self.files_data class VolReaderBase: """ Base class with common code for data readers (EDFStackReader, NXTomoReader, etc) """ def __init__(self, *args, **kwargs): raise ValueError("Base class") def _set_subregion(self, sub_region): slice_angle, slice_z, slice_x = (None, None, None) if isinstance(sub_region, slice): # Assume selection is done only along dimension 0 slice_angle = sub_region slice_z = None slice_x = None if isinstance(sub_region, (tuple, list)): slice_angle, slice_z, slice_x = sub_region self.sub_region = (slice_angle or slice(None, None), slice_z or slice(None, None), slice_x or slice(None, None)) def _set_processing_function(self, processing_func, processing_func_args, processing_func_kwargs): self.processing_func = processing_func self._processing_func_args = processing_func_args or [] self._processing_func_kwargs = processing_func_kwargs or {} def _get_output(self, array): if array is not None: if array.shape != self.output_shape: raise ValueError("Expected output shape %s but got %s" % (self.output_shape, array.shape)) if array.dtype != self.output_dtype: raise ValueError("Expected output dtype '%s' but got '%s'" % (self.output_dtype, array.dtype)) output = array else: output = np.zeros(self.output_shape, dtype=self.output_dtype) return output def get_frames_indices(self): return np.arange(self.data_shape_total[0])[self._source_selection[0]] class NXTomoReader(VolReaderBase): image_key_path = "instrument/detector/image_key_control" multiple_frames_per_file = True def __init__( self, filename, data_path="{entry}/instrument/detector/data", sub_region=None, image_key=0, output_dtype=np.float32, processing_func=None, processing_func_args=None, processing_func_kwargs=None, ): """ Read a HDF5 file in NXTomo layout. Parameters ---------- filename: str Path to the file to read. data_path: str Path within the HDF5 file, eg. "entry/instrument/detector/data". Default is {entry}/data/data where {entry} is a magic keyword for the first entry. sub_region: slice or tuple, optional Region to select within the data, once the "image key" selection has been done. If None, all the data (corresponding to image_key) is selected. If slice(start, stop) is provided, the selection is done along dimension 0. Otherwise, it must be a 3-tuple of slices in the form (slice(start_angle, end_angle, step_angle), slice(start_z, end_z, step_z), slice(start_x, end_x, step_x)) Each of the parameters can be None, in this case the default start and end are taken in each dimension. output_dtype: numpy.dtype, optional Output data type if the data memory is allocated by this class. Default is float32. image_key: int, or None, optional Image type to read (see NXTomo documentation). 0 for projections, 1 for flat-field, 2 for dark field. If set to None, all the data will be read. processing_func: callable, optional Function to be called on each loaded stack of images. If provided, this function first argument must be the source buffer (3D array: stack of raw images), and the second argument must be the destination buffer (3D array, stack of output images). It can be None. Other parameters ---------------- The other parameters are passed to "processing_func" if this parameter is not None. """ self.filename = filename self.data_path = safe_format(data_path or "", entry=get_first_hdf5_entry(filename)) self._set_image_key(image_key) self._set_subregion(sub_region) self._get_shape() self._set_processing_function(processing_func, processing_func_args, processing_func_kwargs) self._get_source_selection() self._init_output(output_dtype) def _get_input_dtype(self): return get_hdf5_dataset_dtype(self.filename, self.data_path) def _init_output(self, output_dtype): output_shape = get_shape_from_sliced_dims(self.data_shape_total, self._source_selection) self.output_dtype = output_dtype self._output_shape_no_processing = output_shape if self.processing_func is not None: test_subvolume = np.zeros((1,) + output_shape[1:], dtype=self._get_input_dtype()) out = self.processing_func( test_subvolume, None, *self._processing_func_args, **self._processing_func_kwargs ) output_image_shape = out.shape[1:] output_shape = (output_shape[0],) + output_image_shape self.output_shape = output_shape self._tmp_dst_buffer = None def _set_image_key(self, image_key): self.image_key = image_key entry = get_entry_from_h5_path(self.data_path) image_key_path = posix_join(entry, self.image_key_path) with HDF5File(self.filename, "r") as f: image_key_val = f[image_key_path][()] idx = np.where(image_key_val == image_key)[0] if len(idx) == 0: raise FileNotFoundError("No frames found with image key = %d" % image_key) self._image_key_slices = indices_to_slices(idx) def _get_shape(self): # Shape of the total HDF5-NXTomo data (including darks and flats) self.data_shape_total = get_hdf5_dataset_shape(self.filename, self.data_path) # Shape of the data once the "image key" is selected n_imgs = self.data_shape_total[0] self.data_shape_imagekey = ( sum([get_size_from_sliced_dimension(n_imgs, slice_) for slice_ in self._image_key_slices]), ) + self.data_shape_total[1:] # Shape of the data after sub-regions are selected self.data_shape_subregion = get_shape_from_sliced_dims(self.data_shape_imagekey, self.sub_region) def _get_source_selection(self): if len(self._image_key_slices) == 1 and self.processing_func is None: # Simple case: # - One single chunk to load, i.e len(self._image_key_slices) == 1 # - No on-the-fly processing (binning, distortion correction, ...) # In this case, we can use h5py read_direct() to avoid extraneous memory consumption image_key_slice = self._image_key_slices[0] # merge image key selection and user selection (if any) self._source_selection = ( merge_slices(image_key_slice, self.sub_region[0] or slice(None, None)), ) + self.sub_region[1:] else: user_selection_dim0 = self.sub_region[0] indices = np.arange(self.data_shape_total[0]) data_selection_indices_axis0 = np.hstack( [indices[image_key_slice][user_selection_dim0] for image_key_slice in self._image_key_slices] ) self._source_selection = (data_selection_indices_axis0,) + self.sub_region[1:] def _get_temporary_buffer(self, convert_after_reading): if self._tmp_dst_buffer is None: shape = self._output_shape_no_processing dtype = self.output_dtype if not (convert_after_reading) else self._get_input_dtype() self._tmp_dst_buffer = np.zeros(shape, dtype=dtype) return self._tmp_dst_buffer def load_data(self, output=None, convert_after_reading=True): """ Read data. Parameters ----------- output: array-like, optional Destination 3D array that will hold the data. If provided, use this memory buffer instead of allocating the memory. Its shape must be compatible with the selection of 'sub_region' and 'image_key'. conver_after_reading: bool, optional Whether to do the dtype conversion (if any, eg. uint16 to float32) after data reading. With using h5py's read_direct(), reading from uint16 to float32 is extremely slow, so data type conversion should be done after reading. The drawback is that it requires more memory. """ output = self._get_output(output) convert_after_reading &= np.dtype(self.output_dtype) != np.dtype(self._get_input_dtype()) if convert_after_reading or self.processing_func is not None: dst_buffer = self._get_temporary_buffer(convert_after_reading) else: dst_buffer = output with HDF5File(self.filename, "r") as f: dptr = f[self.data_path] dptr.read_direct( dst_buffer, source_sel=self._source_selection, dest_sel=None, ) if self.processing_func is not None: self.processing_func(dst_buffer, output, *self._processing_func_args, **self._processing_func_kwargs) elif dst_buffer.ctypes.data != output.ctypes.data: output[:] = dst_buffer[:] # cast return output class NXDarksFlats: _reduce_func = { "median": np.median, "mean": np.mean, } def __init__(self, filename, **nxtomoreader_kwargs): nxtomoreader_kwargs.pop("image_key", None) self.darks_reader = NXTomoReader(filename, image_key=2, **nxtomoreader_kwargs) self.flats_reader = NXTomoReader(filename, image_key=1, **nxtomoreader_kwargs) self._raw_darks = None self._raw_flats = None def _get_raw_frames(self, what, force_reload=False, as_multiple_array=True): check_supported(what, ["darks", "flats"], "frame type") loaded_frames = getattr(self, "_raw_%s" % what) reader = getattr(self, "%s_reader" % what) if force_reload or loaded_frames is None: loaded_frames = reader.load_data() setattr(self, "_raw_%s" % what, loaded_frames) res = loaded_frames if as_multiple_array: slices_ = compacted_views(reader._image_key_slices) return [res[slice_] for slice_ in slices_] return res def _get_reduced_frames(self, what, method="mean", force_reload=False, as_dict=False): raw_frames = self._get_raw_frames(what, force_reload=force_reload, as_multiple_array=True) reduced_frames = [self._reduce_func[method](frames, axis=0) for frames in raw_frames] reader = getattr(self, "%s_reader" % what) if as_dict: return {k: v for k, v in zip([s.start for s in reader._image_key_slices], reduced_frames)} return reduced_frames def get_raw_darks(self, force_reload=False, as_multiple_array=True): return self._get_raw_frames("darks", force_reload=force_reload, as_multiple_array=as_multiple_array) def get_raw_flats(self, force_reload=False, as_multiple_array=True): return self._get_raw_frames("flats", force_reload=force_reload, as_multiple_array=as_multiple_array) def get_reduced_darks(self, method="mean", force_reload=False, as_dict=False): return self._get_reduced_frames("darks", method=method, force_reload=force_reload, as_dict=as_dict) def get_reduced_flats(self, method="median", force_reload=False, as_dict=False): return self._get_reduced_frames("flats", method=method, force_reload=force_reload, as_dict=as_dict) def get_raw_current(self, h5_path="{entry}/control/data"): h5_path = safe_format(h5_path, entry=self.flats_reader.data_path.split(posix_sep)[0]) with HDF5File(self.flats_reader.filename, "r") as f: current = f[h5_path][()] return {sl.start: current[sl] for sl in self.flats_reader._image_key_slices} def get_reduced_current(self, h5_path="{entry}/control/data", method="median"): current = self.get_raw_current(h5_path=h5_path) return {k: self._reduce_func[method](current[k]) for k in current.keys()} class EDFStackReader(VolReaderBase): multiple_frames_per_file = False def __init__( self, filenames, sub_region=None, output_dtype=np.float32, n_reading_threads=1, processing_func=None, processing_func_args=None, processing_func_kwargs=None, ): self.filenames = filenames self.n_reading_threads = n_reading_threads self._set_subregion(sub_region) self._get_shape() self._set_processing_function(processing_func, processing_func_args, processing_func_kwargs) self._get_source_selection() self._init_output(output_dtype) def _get_input_dtype(self): return EDFReader().read(self.filenames[0]).dtype def _get_shape(self): first_filename = self.filenames[0] # Shape of the total data (without subregion selection) reader_all = EDFReader() first_frame_full = reader_all.read(first_filename) self.data_shape_total = (len(self.filenames),) + first_frame_full.shape self.input_dtype = first_frame_full.dtype self.data_shape_subregion = get_shape_from_sliced_dims( self.data_shape_total, self.sub_region ) # might fail if sub_region is not a slice ? def _init_output(self, output_dtype): output_shape = get_shape_from_sliced_dims(self.data_shape_total, self._source_selection) self.output_dtype = output_dtype if self.processing_func is not None: test_image = np.zeros(output_shape[1:], dtype=self._get_input_dtype()) out = self.processing_func(test_image, *self._processing_func_args, **self._processing_func_kwargs) output_image_shape = out.shape output_shape = (output_shape[0],) + output_image_shape self.output_shape = output_shape def _get_source_selection(self): self._source_selection = self.sub_region self._sub_region_xy_for_edf_reader = ( self.sub_region[2].start or 0, self.sub_region[2].stop or self.data_shape_total[2], self.sub_region[1].start or 0, self.sub_region[1].stop or self.data_shape_total[1], ) self.filenames_subsampled = self.filenames[self.sub_region[0]] def load_data(self, output=None): output = self._get_output(output) readers = {} def _init_reader_thread(): readers[get_ident()] = EDFReader(self._sub_region_xy_for_edf_reader) def _get_data(i_fname): i, fname = i_fname reader = readers[get_ident()] frame = reader.read(fname) if self.processing_func is not None: frame = self.processing_func(frame, *self._processing_func_args, **self._processing_func_kwargs) output[i] = frame with ThreadPool(self.n_reading_threads, initializer=_init_reader_thread) as tp: tp.map( _get_data, zip( range(len(self.filenames_subsampled)), self.filenames_subsampled, ), ) return output Readers = { "edf": EDFReader, "hdf5": HDF5Reader, "h5": HDF5Reader, "nx": HDF5Reader, "npz": NPReader, "npy": NPReader, } @deprecated( "Function load_images_from_dataurl_dict is deprecated and will be removed in a fugure version", do_print=True ) def load_images_from_dataurl_dict(data_url_dict, sub_region=None, dtype="f", binning=None): """ Load a dictionary of dataurl into numpy arrays. Parameters ---------- data_url_dict: dict A dictionary where the keys are integers (the index of each image in the dataset), and the values are numpy.ndarray (data_url_dict). sub_region: tuple, optional Tuple in the form (y_subregion, x_subregion) where xy_subregion is a tuple in the form slice(start, stop, step) Returns -------- res: dict A dictionary where the keys are the same as `data_url_dict`, and the values are numpy arrays. Notes ----- This function is used to load flats/darks. Usually, these are reduced flats/darks, meaning that 'data_url_dict' is a collection of a handful of files (less than 10). To load more frames, it would be best to use NXTomoReader / EDFStackReader. """ res = {} if sub_region is not None and not isinstance(sub_region[0], slice): if len(sub_region) == 4: # (start_y, end_y, start_x, end_x) deprecation_warning( "The parameter 'sub_region' was passed as (start_x, end_x, start_y, end_y). This is deprecated and will yield an error in the future. Please use the syntax ((start_z, end_z), (start_x, end_x))", do_print=True, func_name="load_images_from_dataurl_dict", ) sub_region = (slice(sub_region[2], sub_region[3]), slice(sub_region[0], sub_region[1])) else: # ((start_z, end_z), (start_x, end_x)) sub_region = tuple(slice(s[0], s[1]) for s in sub_region) for idx, data_url in data_url_dict.items(): frame = get_data(data_url) if sub_region is not None: frame = frame[sub_region[0], sub_region[1]] if binning is not None: frame = image_binning(frame, binning, out_dtype=dtype) res[idx] = frame return res def load_images_stack_from_hdf5(fname, h5_data_path, sub_region=None): """ Load a 3D dataset from a HDF5 file. Parameters ----------- fname: str File path h5_data_path: str Data path within the HDF5 file sub_region: tuple, optional Tuple indicating which sub-volume to load, in the form (xmin, xmax, ymin, ymax, zmin, zmax) where the 3D dataset has the python shape (N_Z, N_Y, N_X). This means that the data will be loaded as `data[zmin:zmax, ymin:ymax, xmin:xmax]`. """ xmin, xmax, ymin, ymax, zmin, zmax = get_3D_subregion(sub_region) with HDF5File(fname, "r") as f: d_ptr = f[h5_data_path] data = d_ptr[zmin:zmax, ymin:ymax, xmin:xmax] return data def get_hdf5_dataset_shape(fname, h5_data_path, sub_region=None): zmin, zmax, ymin, ymax, xmin, xmax = get_3D_subregion(sub_region) with HDF5File(fname, "r") as f: d_ptr = f[h5_data_path] shape = d_ptr.shape n_z, n_y, n_x = shape # perhaps there is more elegant res_shape = [] for n, bounds in zip([n_z, n_y, n_x], ((zmin, zmax), (ymin, ymax), (xmin, xmax))): res_shape.append(np.arange(n)[bounds[0] : bounds[1]].size) return tuple(res_shape) def get_hdf5_dataset_dtype(fname, h5_data_path): with HDF5File(fname, "r") as f: d_ptr = f[h5_data_path] dtype = d_ptr.dtype return dtype def get_entry_from_h5_path(h5_path): v = h5_path.split(posix_sep) return v[0] or v[1] def check_virtual_sources_exist(fname, data_path): with HDF5File(fname, "r") as f: if data_path not in f: print("No dataset %s in file %s" % (data_path, fname)) return False dptr = f[data_path] if not dptr.is_virtual: return True for vsource in dptr.virtual_sources(): vsource_fname = os.path.join(os.path.dirname(dptr.file.filename), vsource.file_name) if not os.path.isfile(vsource_fname): print("No such file: %s" % vsource_fname) return False elif not check_virtual_sources_exist(vsource_fname, vsource.dset_name): print("Error with virtual source %s" % vsource_fname) return False return True def import_h5_to_dict(h5file, h5path, asarray=False): """ Wrapper on top of silx.io.dictdump.dicttoh5 replacing "None" with None Parameters ----------- h5file: str File name h5path: str Path in the HDF5 file asarray: bool, optional Whether to convert each numeric value to an 0D array. Default is False. """ dic = h5todict(h5file, path=h5path, asarray=asarray) modified_dic = convert_dict_values(dic, {"None": None}, bytes_tostring=True) return modified_dic ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/io/reader_helical.py0000644000175000017500000001057514402565210017125 0ustar00pierrepierrefrom .reader import * class ChunkReaderHelical(ChunkReader): """implements reading by projections subsets""" def _set_subregion(self, sub_region): super()._set_subregion(sub_region) ########### # undo the chun_size setting of the base class # to avoid allocation of Tera bytes in the helical case self.chunk_shape = (1,) + self.shape def set_data_buffer(self, data_buffer, pre_allocate=False): if data_buffer is not None: # overwrite out_dtype self.files_data = data_buffer self.out_dtype = data_buffer.dtype if data_buffer.shape[1:] != self.shape: raise ValueError("Expected shape %s but got %s" % (self.shape, data_buffer.shape)) if pre_allocate: self.files_data = np.zeros(self.chunk_shape, dtype=self.out_dtype) if (self.binning is not None) and (np.dtype(self.out_dtype).kind in ["u", "i"]): raise ValueError( "Output datatype cannot be integer when using binning. Please set the 'convert_float' parameter to True or specify a 'data_buffer'." ) def get_binning(self): if self.binning is None: return 1, 1 else: return self.binning def _load_single(self, sub_total_prange_slice=slice(None, None)): if sub_total_prange_slice == slice(None, None): sorted_files_indices = self._sorted_files_indices else: sorted_files_indices = self._sorted_files_indices[sub_total_prange_slice] for i, fileidx in enumerate(sorted_files_indices): file_url = self.files[fileidx] self.files_data[i] = self.get_data(file_url) self._fileindex_to_idx[fileidx] = i def _apply_subsample_fact(self, t): if t is not None: t = t * self.dataset_subsampling return t def _load_multi(self, sub_total_prange_slice=slice(None, None)): if sub_total_prange_slice == slice(None, None): files_to_be_compacted_dict = self.files sorted_files_indices = self._sorted_files_indices else: if self.dataset_subsampling > 1: start, stop, step = list( map( self._apply_subsample_fact, [sub_total_prange_slice.start, sub_total_prange_slice.stop, sub_total_prange_slice.step], ) ) sub_total_prange_slice = slice(start, stop, step) sorted_files_indices = self._sorted_files_indices[sub_total_prange_slice] files_to_be_compacted_dict = dict( zip(sorted_files_indices, [self.files[idx] for idx in sorted_files_indices]) ) urls_compacted = get_compacted_dataslices(files_to_be_compacted_dict, subsampling=self.dataset_subsampling) loaded = {} start_idx = 0 for idx in sorted_files_indices: url = urls_compacted[idx] url_str = str(url) is_loaded = loaded.get(url_str, False) if is_loaded: continue ds = url.data_slice() delta_z = ds.stop - ds.start if ds.step is not None and ds.step > 1: delta_z //= ds.step end_idx = start_idx + delta_z self.files_data[start_idx:end_idx] = self.get_data(url) start_idx += delta_z loaded[url_str] = True def load_files(self, overwrite: bool = False, sub_total_prange_slice=slice(None, None)): """ Load the files whose links was provided at class instantiation. Parameters ----------- overwrite: bool, optional Whether to force reloading the files if already loaded. """ if self._loaded and not (overwrite): raise ValueError("Radios were already loaded. Call load_files(overwrite=True) to force reloading") if self.file_reader.multi_load: self._load_multi(sub_total_prange_slice) else: if self.dataset_subsampling > 1: assert ( False ), " in helica pipeline, load file _load_single has not yet been adapted to angular subsampling " self._load_single(sub_total_prange_slice) self._loaded = True load_data = load_files @property def data(self): return self.files_data ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5047567 nabu-2024.2.1/nabu/io/tests/0000755000175000017500000000000014730277752015001 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/io/tests/__init__.py0000644000175000017500000000000014315516747017077 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1713526109.0 nabu-2024.2.1/nabu/io/tests/test_cast_volume.py0000644000175000017500000002476714610452535020742 0ustar00pierrepierreimport numpy from nabu.io.cast_volume import ( cast_volume, clamp_and_rescale_data, find_histogram, get_default_output_volume, ) from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from nabu.io.writer import __have_jp2k__ from tomoscan.esrf.scan.edfscan import EDFTomoScan from tomoscan.esrf.scan.nxtomoscan import NXtomoScan import pytest import h5py import os from silx.io.url import DataUrl @pytest.mark.skipif(not __have_jp2k__, reason="need jp2k (glymur) for this test") def test_get_default_output_volume(): """ insure nabu.io.cast_volume is working properly """ with pytest.raises(TypeError): # test input_volume type get_default_output_volume(input_volume="dsad/dsad/", output_type="jp2") with pytest.raises(ValueError): # test output value get_default_output_volume(input_volume=EDFVolume(folder="test"), output_type="toto") # test edf to jp2 input_volume = EDFVolume( folder="/path/to/my_folder", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="jp2", ) assert isinstance(output_volume, JP2KVolume) assert output_volume.data_url.file_path() == "/path/to/vol_cast" assert output_volume.get_volume_basename() == "my_folder" # test hdf5 to tiff input_volume = HDF5Volume( file_path="/path/to/my_file.hdf5", data_path="entry0012", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="tiff", ) assert isinstance(output_volume, TIFFVolume) assert output_volume.data_url.file_path() == "/path/to/vol_cast/my_file" assert output_volume.get_volume_basename() == "my_file" # test Multitiff to hdf5 input_volume = MultiTIFFVolume( file_path="my_file.tiff", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="hdf5", ) assert isinstance(output_volume, HDF5Volume) assert output_volume.data_url.file_path() == "vol_cast/my_file.hdf5" assert output_volume.data_url.data_path() == "volume/" + HDF5Volume.DATA_DATASET_NAME assert output_volume.metadata_url.file_path() == "vol_cast/my_file.hdf5" assert output_volume.metadata_url.data_path() == "volume/" + HDF5Volume.METADATA_GROUP_NAME # test jp2 to hdf5 input_volume = JP2KVolume( folder="folder", volume_basename="basename", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="hdf5", ) assert isinstance(output_volume, HDF5Volume) assert output_volume.data_url.file_path() == "folder/vol_cast/basename.hdf5" assert output_volume.data_url.data_path() == f"/volume/{HDF5Volume.DATA_DATASET_NAME}" assert output_volume.metadata_url.file_path() == "folder/vol_cast/basename.hdf5" assert output_volume.metadata_url.data_path() == f"/volume/{HDF5Volume.METADATA_GROUP_NAME}" def test_find_histogram_hdf5_volume(tmp_path): """ test find_histogram function with hdf5 volume """ h5_file = os.path.join(tmp_path, "test_file") with h5py.File(h5_file, mode="w") as h5f: h5f.require_group("myentry/histogram/results/data") # if volume url provided then can find it assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="myentry")) == DataUrl( file_path=h5_file, data_path="myentry/histogram/results/data", scheme="silx", ) assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="entry")) == None def test_find_histogram_single_frame_volume(tmp_path): """ test find_histogram function with single frame volume TODO: improve: for now histogram file are created manually. If this can be more coupled with the "real" histogram generation it would be way better """ # create volume and histogram volume = EDFVolume( folder=tmp_path, volume_basename="volume", ) histogram_file = os.path.join(tmp_path, "volume_histogram.hdf5") with h5py.File(histogram_file, mode="w") as h5f: h5f.require_group("entry/histogram/results/data") # check behavior assert find_histogram(volume=volume) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) assert find_histogram( volume=volume, scan=EDFTomoScan(scan=str(tmp_path)), ) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) assert find_histogram( volume=volume, scan=NXtomoScan(scan=str(tmp_path), entry="entry"), ) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) def test_find_histogram_multi_tiff_volume(tmp_path): """ test find_histogram function with multi tiff frame volume TODO: improve: for now histogram file are created manually. If this can be more coupled with the "real" histogram generation it would be way better """ # create volume and histogram tiff_file = os.path.join(tmp_path, "my_tiff.tif") volume = MultiTIFFVolume( file_path=tiff_file, ) histogram_file = os.path.join(tmp_path, "my_tiff_histogram.hdf5") with h5py.File(histogram_file, mode="w") as h5f: h5f.require_group("entry/histogram/results/data") # check behavior assert find_histogram(volume=volume) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) assert find_histogram( volume=volume, scan=EDFTomoScan(scan=str(tmp_path)), ) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) assert find_histogram( volume=volume, scan=NXtomoScan(scan=str(tmp_path), entry="entry"), ) == DataUrl( file_path=histogram_file, data_path="entry/histogram/results/data", scheme="silx", ) @pytest.mark.parametrize("input_dtype", (numpy.float32, numpy.float64, numpy.uint8, numpy.uint16)) def test_clamp_and_rescale_data(input_dtype): """ test 'rescale_data' function """ array = numpy.linspace( start=1, stop=100, num=100, endpoint=True, dtype=input_dtype, ).reshape(10, 10) rescaled_array = clamp_and_rescale_data( data=array, new_min=10, new_max=90, rescale_min_percentile=20, # provided to insure they will be ignored rescale_max_percentile=80, # provided to insure they will be ignored ) assert rescaled_array.min() == 10 assert rescaled_array.max() == 90 numpy.testing.assert_equal(rescaled_array.flatten()[0:10], numpy.array([10] * 10)) numpy.testing.assert_equal(rescaled_array.flatten()[90:100], numpy.array([90] * 10)) def test_cast_volume(tmp_path): """ test cast_volume """ raw_data = numpy.linspace( start=1, stop=100, num=100, endpoint=True, dtype=numpy.float64, ).reshape(1, 10, 10) volume_hdf5_file_path = os.path.join(tmp_path, "myvolume.hdf5") volume_hdf5 = HDF5Volume( file_path=volume_hdf5_file_path, data_path="myentry", data=raw_data, ) volume_edf = EDFVolume( folder=os.path.join(tmp_path, "volume_folder"), ) # test when no histogram existing cast_volume( input_volume=volume_hdf5, output_volume=volume_edf, output_data_type=numpy.dtype(numpy.uint16), rescale_min_percentile=10, rescale_max_percentile=90, save=True, store=True, ) # if percentiles 10 and 90 provided, no data_min and data_max then they will be computed from data min / max # append histogram with h5py.File(volume_hdf5_file_path, mode="a") as h5s: hist = numpy.array([20, 20, 20, 20, 20, 20]) bins = numpy.array([0, 20, 40, 60, 80, 100]) h5s["myentry/histogram/results/data"] = numpy.vstack((hist, bins)) # and test it again volume_edf.overwrite = True cast_volume( input_volume=volume_hdf5, output_volume=volume_edf, output_data_type=numpy.dtype(numpy.uint16), rescale_min_percentile=20, rescale_max_percentile=60, save=True, store=True, ) # test to cast the already cast volumes volume_tif = EDFVolume( folder=os.path.join(tmp_path, "second_volume_folder"), ) volume_tif.overwrite = True cast_volume( input_volume=volume_edf, output_volume=volume_tif, output_data_type=numpy.dtype(numpy.uint8), save=True, store=True, ) assert volume_tif.data.dtype == numpy.uint8 volume_tif.overwrite = False with pytest.raises(OSError): cast_volume( input_volume=volume_edf, output_volume=volume_tif, output_data_type=numpy.dtype(numpy.uint8), save=True, store=True, ) @pytest.mark.skipif(not __have_jp2k__, reason="need jp2k (glymur) for this test") def test_jp2k_compression_ratios(tmp_path): """ simple test to make sure the compression ratios are handled """ import glymur raw_data = numpy.random.random( size=(1, 2048, 2048), ) raw_data *= 2048.000005 volume_hdf5_file_path = os.path.join(tmp_path, "myvolume.hdf5") volume_hdf5 = HDF5Volume( file_path=volume_hdf5_file_path, data_path="myentry", data=raw_data, ) volume_jp2k_ratios_0 = JP2KVolume( folder=os.path.join(tmp_path, "volume_folder"), cratios=(100, 10), ) volume_jp2k_ratios_1 = JP2KVolume( folder=os.path.join(tmp_path, "volume_folder_2"), cratios=(1000, 100), ) # test when no histogram existing cast_volume( input_volume=volume_hdf5, output_volume=volume_jp2k_ratios_0, output_data_type=numpy.dtype(numpy.uint16), save=True, store=True, ) cast_volume( input_volume=volume_hdf5, output_volume=volume_jp2k_ratios_1, output_data_type=numpy.dtype(numpy.uint16), save=True, store=True, ) # make sure the ratio have been taking into account frame_0 = glymur.Jp2k(next(volume_jp2k_ratios_0.browse_data_files())) frame_0.layer = 0 frame_1 = glymur.Jp2k(next(volume_jp2k_ratios_1.browse_data_files())) frame_1.layer = 0 assert not numpy.array_equal(frame_0, frame_1) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/io/tests/test_detector_distortion.py0000644000175000017500000001430514550227307022473 0ustar00pierrepierreimport pytest import numpy as np import scipy.ndimage from scipy import sparse from nabu.io.detector_distortion import DetectorDistortionBase from nabu.processing.rotation import Rotation, __have__skimage__ if __have__skimage__: import skimage @pytest.mark.skipif(not (__have__skimage__), reason="Need scikit-image for rotation") def test_detector_distortion(): image = scipy.ndimage.gaussian_filter(np.random.random([379, 1357]), 3.0) center_xz = ((image.shape[1] - 1) / 2, (image.shape[0] - 1) / 2) part_to_be_retrieved = image[100:279] rotated_image = skimage.transform.rotate(image, angle=5.0, center=center_xz[::]) corrector = DetectorDistortionRotation(detector_full_shape_vh=image.shape, center_xz=center_xz, angle_deg=-5.0) start_x, end_x, start_z, end_z = corrector.set_sub_region_transformation( target_sub_region=( None, None, 100, 279, ) ) source = rotated_image[start_z:end_z, start_x:end_x] retrieved = corrector.transform(source) diff = (retrieved - part_to_be_retrieved)[:, 20:-20] assert abs(diff).std() < 1e-3 class DetectorDistortionRotation(DetectorDistortionBase): """ """ def __init__(self, detector_full_shape_vh=(0, 0), center_xz=(0, 0), angle_deg=0.0): """This is the basis class. A simple identity transformation which has the only merit to show how it works.Reimplement this function to have more parameters for other transformations """ self._build_full_transformation(detector_full_shape_vh, center_xz, angle_deg) def _build_full_transformation(self, detector_full_shape_vh, center_xz, angle_deg): """A simple identity. Reimplement this function to have more parameters for other transformations """ indices = np.indices(detector_full_shape_vh) center_x, center_z = center_xz coordinates = (indices.T - np.array([center_z, center_x])).T c = np.cos(np.deg2rad(angle_deg)) s = np.sin(np.deg2rad(angle_deg)) rot_mat = np.array([[c, s], [-s, c]]) coordinates = np.tensordot(rot_mat, coordinates, axes=[1, 0]) # padding sz, sx = detector_full_shape_vh total_detector_npixs = sz * sx xs = np.clip(np.array(coordinates[1].flat) + center_x, [[0]], [[sx - 1]]) zs = np.clip(np.array(coordinates[0].flat) + center_z, [[0]], [[sz - 1]]) ix0s = np.floor(xs) ix1s = np.ceil(xs) fx = xs - ix0s iz0s = np.floor(zs) iz1s = np.ceil(zs) fz = zs - iz0s I_tmp = np.empty([4 * sz * sx], np.int64) J_tmp = np.empty([4 * sz * sx], np.int64) V_tmp = np.ones([4 * sz * sx], "f") I_tmp[:] = np.arange(sz * sx * 4) // 4 J_tmp[0::4] = iz0s * sx + ix0s J_tmp[1::4] = iz0s * sx + ix1s J_tmp[2::4] = iz1s * sx + ix0s J_tmp[3::4] = iz1s * sx + ix1s V_tmp[0::4] = (1 - fz) * (1 - fx) V_tmp[1::4] = (1 - fz) * fx V_tmp[2::4] = fz * (1 - fx) V_tmp[3::4] = fz * fx self.detector_full_shape_vh = detector_full_shape_vh coo_tmp = sparse.coo_matrix((V_tmp.astype("f"), (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def _set_sub_region_transformation( self, target_sub_region=( ( None, None, 0, 0, ), ), ): (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionRotation only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] data_tmp = self.full_csr_data[indices_start:indices_end] target_offset = self.full_csr_indptr[row_ptr_start] indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - target_offset indices_tmp = self.full_csr_indices[indices_start:indices_end] iz_source = (indices_tmp) // self.detector_full_shape_vh[1] z_start_source = iz_source.min() z_end_source = iz_source.max() + 1 source_offset = z_start_source * self.detector_full_shape_vh[1] indices_tmp = indices_tmp - source_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end_source - z_start_source) * self.detector_full_shape_vh[1] self.transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start_source, z_end_source)), } return self._status["source_sub_region"] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/io/tests/test_readers.py0000644000175000017500000004326214656662150020043 0ustar00pierrepierrefrom math import ceil from tempfile import TemporaryDirectory from dataclasses import dataclass from tomoscan.io import HDF5File import pytest import numpy as np from nxtomo.application.nxtomo import ImageKey from tomoscan.esrf import EDFVolume from nabu.pipeline.reader import NXTomoReaderBinning from nabu.testutils import utilstest, __do_long_tests__, get_file from nabu.utils import indices_to_slices, merge_slices from nabu.io.reader import EDFStackReader, NXTomoReader, NXDarksFlats @dataclass class SimpleNXTomoDescription: n_darks: int = 0 n_flats1: int = 0 n_projs: int = 0 n_flats2: int = 0 n_align: int = 0 frame_shape: tuple = None dtype: np.dtype = np.uint16 @pytest.fixture(scope="class") def bootstrap_nx_reader(request): cls = request.cls cls.nx_fname = utilstest.getfile("dummy_nxtomo.nx") cls.nx_data_path = "entry/instrument/detector/data" cls.data_desc = SimpleNXTomoDescription( n_darks=10, n_flats1=11, n_projs=100, n_flats2=11, n_align=12, frame_shape=(11, 10), dtype=np.uint16 ) cls.projs_vals = np.arange(cls.data_desc.n_projs) + cls.data_desc.n_flats1 + cls.data_desc.n_darks cls.darks_vals = np.arange(cls.data_desc.n_darks) cls.flats1_vals = np.arange(cls.data_desc.n_darks, cls.data_desc.n_darks + cls.data_desc.n_flats1) cls.flats2_vals = np.arange(cls.data_desc.n_darks, cls.data_desc.n_darks + cls.data_desc.n_flats2) yield # teardown @pytest.mark.usefixtures("bootstrap_nx_reader") class TestNXReader: def test_incorrect_path(self): with pytest.raises(FileNotFoundError): reader = NXTomoReader("/invalid/path", self.nx_data_path) with pytest.raises(KeyError): reader = NXTomoReader(self.nx_fname, "/bad/data/path") def test_simple_reads(self): """ Test NXTomoReader with simplest settings """ reader1 = NXTomoReader(self.nx_fname, self.nx_data_path) data1 = reader1.load_data() assert data1.shape == (self.data_desc.n_projs,) + self.data_desc.frame_shape assert np.allclose(data1[:, 0, 0], self.projs_vals) def test_image_key(self): """ Test the data selection using "image_key". """ reader_projs = NXTomoReader(self.nx_fname, self.nx_data_path, image_key=ImageKey.PROJECTION.value) data = reader_projs.load_data() assert np.allclose(data[:, 0, 0], self.projs_vals) reader_darks = NXTomoReader(self.nx_fname, self.nx_data_path, image_key=ImageKey.DARK_FIELD.value) data_darks = reader_darks.load_data() assert np.allclose(data_darks[:, 0, 0], self.darks_vals) reader_flats = NXTomoReader(self.nx_fname, self.nx_data_path, image_key=ImageKey.FLAT_FIELD.value) data_flats = reader_flats.load_data() assert np.allclose(data_flats[:, 0, 0], np.concatenate([self.flats1_vals, self.flats2_vals])) def test_data_buffer_and_subregion(self): """ Test the "data_buffer" and "sub_region" parameters """ data_desc = self.data_desc def _check_correct_shape_succeeds(shape, sub_region, test_description=""): err_msg = "Something wrong with the following test:" + test_description data_buffer = np.zeros(shape, dtype="f") reader1 = NXTomoReader(self.nx_fname, self.nx_data_path, sub_region=sub_region) data1 = reader1.load_data(output=data_buffer) assert id(data1) == id(data_buffer), err_msg reader2 = NXTomoReader(self.nx_fname, self.nx_data_path, sub_region=sub_region) data2 = reader2.load_data() assert np.allclose(data1, data2), err_msg test_cases = [ { "description": "In the projections, read everything into the provided data buffer", "sub_region": None, "correct_shape": (data_desc.n_projs,) + data_desc.frame_shape, "wrong_shapes": [ (data_desc.n_projs - 1,) + data_desc.frame_shape, (data_desc.n_projs - 1,) + (999, 998), (data_desc.n_projs,) + (999, 998), ], }, { "description": "In the projections, select a subset along dimension 0 (i.e take only several full frames). The correct output shape is: data_total[image_key==0][slice(10, 30)].shape", "sub_region": slice(10, 30), "correct_shape": (20,) + data_desc.frame_shape, "wrong_shapes": [ (data_desc.n_projs,) + data_desc.frame_shape, (19,) + data_desc.frame_shape, ], }, { "description": "In the projections, read several rows of all images, i.e extract several sinograms. The correct output shape is: data_total[image_key==0][:, slice(start_z, end_z), :].shape", "sub_region": (None, slice(3, 7), None), "correct_shape": (data_desc.n_projs, 4, data_desc.frame_shape[-1]), "wrong_shapes": [], }, ] for test_case in test_cases: for wrong_shape in test_case["wrong_shapes"]: with pytest.raises(ValueError): data_buffer_wrong_shape = np.zeros(wrong_shape, dtype="f") reader = NXTomoReader( self.nx_fname, self.nx_data_path, sub_region=test_case["sub_region"], ) reader.load_data(output=data_buffer_wrong_shape) _check_correct_shape_succeeds(test_case["correct_shape"], test_case["sub_region"], test_case["description"]) def test_subregion_and_subsampling(self): data_desc = self.data_desc test_cases = [ { # Read one full image out of two in all projections "sub_region": (slice(None, None, 2), None, None), "expected_shape": (self.projs_vals[::2].size,) + data_desc.frame_shape, "expected_values": self.projs_vals[::2], }, { # Read one image fragment (several rows) out of two in all projections "sub_region": (slice(None, None, 2), slice(5, 8), None), "expected_shape": (self.projs_vals[::2].size, 3, data_desc.frame_shape[-1]), "expected_values": self.projs_vals[::2], }, ] for test_case in test_cases: reader = NXTomoReader(self.nx_fname, self.nx_data_path, sub_region=test_case["sub_region"]) data = reader.load_data() assert data.shape == test_case["expected_shape"] assert np.allclose(data[:, 0, 0], test_case["expected_values"]) def test_reading_with_binning_(self): from nabu.pipeline.reader import NXTomoReaderBinning reader_with_binning = NXTomoReaderBinning((2, 2), self.nx_fname, self.nx_data_path) data = reader_with_binning.load_data() assert data.shape == (self.data_desc.n_projs,) + tuple(n // 2 for n in self.data_desc.frame_shape) def test_reading_with_distortion_correction(self): from nabu.io.detector_distortion import DetectorDistortionBase from nabu.pipeline.reader import NXTomoReaderDistortionCorrection data_desc = self.data_desc # (start_x, end_x, start_y, end_y) sub_region_xy = (None, None, 1, 6) distortion_corrector = DetectorDistortionBase(detector_full_shape_vh=data_desc.frame_shape) distortion_corrector.set_sub_region_transformation(target_sub_region=sub_region_xy) adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy) sub_region = (slice(None, None), slice(*sub_region_xy[2:]), slice(*sub_region_xy[:2])) reader_distortion_corr = NXTomoReaderDistortionCorrection( distortion_corrector, self.nx_fname, self.nx_data_path, sub_region=sub_region, ) reader_distortion_corr.load_data() @pytest.mark.skipif(not (__do_long_tests__), reason="Need NABU_LONG_TESTS=1") def test_other_load_patterns(self): """ Other data read patterns that are sometimes used by ChunkedPipeline Test cases already done in check_correct_shape_succeeds(): - Read all frames in a provided buffer - Read a subset of all (full) projections - Read several rows of all projections (extract sinograms) """ data_desc = self.data_desc test_cases = [ { "description": "Select a subset along all dimensions. The correct output shape is data_total[image_key==0][slice_dim0, slice_dim1, slice_dim2].shape", "sub_region": (slice(10, 72, 2), slice(4, None), slice(2, 8)), "expected_shape": (31, 7, 6), "expected_values": self.projs_vals[slice(10, 72, 2)], }, { "description": "Select several rows in all images (i.e extract sinograms), with binning", "sub_region": (slice(None, None), slice(3, 7), slice(None, None)), "binning": (2, 2), "expected_shape": (data_desc.n_projs, 4 // 2, data_desc.frame_shape[-1] // 2), "expected_values": self.projs_vals[:], }, { "description": "Extract sinograms with binning + subsampling", "sub_region": (slice(None, None, 2), slice(1, 8), slice(None, None)), "binning": (2, 2), "expected_shape": (ceil(data_desc.n_projs / 2), 7 // 2, data_desc.frame_shape[-1] // 2), "expected_values": self.projs_vals[::2], }, ] for test_case in test_cases: binning = test_case.get("binning", None) reader_cls = NXTomoReader init_args = [self.nx_fname, self.nx_data_path] init_kwargs = {"sub_region": test_case["sub_region"]} if binning is not None: reader_cls = NXTomoReaderBinning init_args = [binning] + init_args reader = reader_cls(*init_args, **init_kwargs) data = reader.load_data() err_msg = "Something wrong with test: " + test_case["description"] assert data.shape == test_case["expected_shape"], err_msg assert np.allclose(data[:, 0, 0], test_case["expected_values"]), err_msg @pytest.fixture(scope="class") def bootstrap_edf_reader(request): cls = request.cls test_dir = utilstest.data_home cls._tmpdir = TemporaryDirectory(prefix="test_edf_stack_", dir=test_dir) cls.edf_dir = cls._tmpdir.name cls.n_projs = 100 cls.frame_shape = (11, 12) cls.projs_vals = np.arange(cls.n_projs, dtype=np.uint16) + 10 edf_vol = EDFVolume(folder=cls.edf_dir, volume_basename="edf_stack", overwrite=True) data_shape = (cls.n_projs,) + cls.frame_shape edf_vol.data = np.ones(data_shape, dtype=np.uint16) * cls.projs_vals.reshape(cls.n_projs, 1, 1) edf_vol.save_data() cls.filenames = list(edf_vol.browse_data_files()) yield cls._tmpdir.cleanup() @pytest.mark.usefixtures("bootstrap_edf_reader") class TestEDFReader: def test_read_all_frames(self): """ Simple test, read all the frames """ reader = EDFStackReader(self.filenames) data = reader.load_data() expected_shape = (self.n_projs,) + self.frame_shape assert data.shape == expected_shape assert np.allclose(data[:, 0, 0], self.projs_vals) buffer_correct = np.zeros(expected_shape, dtype=np.float32) reader.load_data(output=buffer_correct) buffer_incorrect_1 = np.zeros((99, 11, 12), dtype=np.float32) with pytest.raises(ValueError): reader.load_data(output=buffer_incorrect_1) buffer_incorrect_2 = np.zeros((100, 11, 12), dtype=np.uint16) with pytest.raises(ValueError): reader.load_data(output=buffer_incorrect_2) def test_subregions_1(self): test_cases = [ { "name": "read a handful of full frames", "sub_region": (slice(0, 48), slice(None, None), slice(None, None)), "expected_shape": (48,) + self.frame_shape, "expected_values": self.projs_vals[:48], }, { "name": "read several lines of all frames (i.e extract a singoram)", "sub_region": (slice(None, None), slice(0, 6), slice(None, None)), "expected_shape": (self.n_projs, 6, self.frame_shape[-1]), "expected_values": self.projs_vals, }, { "name": "read several lines of all frames (i.e extract a singoram), and a X-ROI", "sub_region": (slice(None, None), slice(3, 7), slice(2, 5)), "expected_shape": (self.n_projs, 4, 3), "expected_values": self.projs_vals, }, { "name": "read several lines of all frames (i.e extract a singoram), with angular subsampling", "sub_region": (slice(None, None, 2), slice(3, 7), slice(2, 5)), "expected_shape": (ceil(self.n_projs / 2), 4, 3), "expected_values": self.projs_vals[::2], }, ] for test_case in test_cases: reader = EDFStackReader(self.filenames, sub_region=test_case["sub_region"]) data = reader.load_data() err_msg = "Something wrong with test: %s" % (test_case["name"]) assert data.shape == test_case["expected_shape"], err_msg assert np.allclose(data[:, 0, 0], test_case["expected_values"]), err_msg @pytest.mark.skipif(not (__do_long_tests__), reason="Need NABU_LONG_TESTS=1") def test_reading_with_binning(self): from nabu.pipeline.reader import EDFStackReaderBinning reader_with_binning = EDFStackReaderBinning((2, 2), self.filenames) data = reader_with_binning.load_data() assert data.shape == (self.n_projs,) + tuple(n // 2 for n in self.frame_shape) @pytest.mark.skipif(not (__do_long_tests__), reason="Need NABU_LONG_TESTS=1") def test_reading_with_distortion_correction(self): from nabu.io.detector_distortion import DetectorDistortionBase from nabu.pipeline.reader import EDFStackReaderDistortionCorrection # (start_x, end_x, start_y, end_y) sub_region_xy = (None, None, 1, 6) distortion_corrector = DetectorDistortionBase(detector_full_shape_vh=self.frame_shape) distortion_corrector.set_sub_region_transformation(target_sub_region=sub_region_xy) adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy) sub_region = (slice(None, None), slice(*sub_region_xy[2:]), slice(*sub_region_xy[:2])) reader_distortion_corr = EDFStackReaderDistortionCorrection( distortion_corrector, self.filenames, sub_region=sub_region, ) reader_distortion_corr.load_data() def test_indices_to_slices(): slices1 = [slice(0, 4)] slices2 = [slice(11, 16)] slices3 = [slice(3, 5), slice(8, 20)] slices4 = [slice(2, 7), slice(18, 28), slice(182, 845)] idx = np.arange(1000) for slices in [slices1, slices2, slices3, slices4]: indices = np.hstack([idx[sl] for sl in slices]) slices_calculated = indices_to_slices(indices) assert slices_calculated == slices, "Expected indices_to_slices() to return %s, but got %s" % ( str(slices), str(slices_calculated), ) def test_merge_slices(): idx = np.arange(10000) rnd = lambda x: np.random.randint(1, high=x) n_tests = 10 for i in range(n_tests): start1 = rnd(1000) stop1 = start1 + rnd(1000) start2 = rnd(1000) stop2 = start2 + rnd(1000) step1 = rnd(4) step2 = rnd(4) slice1 = slice(start1, stop1, step1) slice2 = slice(start2, stop2, step2) assert np.allclose(idx[slice1][slice2], idx[merge_slices(slice1, slice2)]) @pytest.fixture(scope="class") def bootstrap_nxdkrf(request): cls = request.cls cls.nx_file_path = get_file("bamboo_reduced.nx") yield # teardown @pytest.mark.usefixtures("bootstrap_nxdkrf") class TestDKRFReader: def test_darks(self): dkrf_reader = NXDarksFlats(self.nx_file_path) darks = dkrf_reader.get_raw_darks(as_multiple_array=True) reduced_darks = dkrf_reader.get_reduced_darks(method="mean") actual_darks = [] with HDF5File(self.nx_file_path, "r") as f: actual_darks.append(f["entry0000/data/data"][slice(0, 1)]) assert len(darks) == len(actual_darks) for i in range(len(darks)): assert np.allclose(darks[i], actual_darks[i]) actual_reduced_darks = np.mean(actual_darks[i], axis=0) assert np.allclose(reduced_darks[i], actual_reduced_darks) assert np.allclose(list(dkrf_reader.get_reduced_darks(as_dict=True).keys()), [0]) def test_flats(self): dkrf_reader = NXDarksFlats(self.nx_file_path) flats = dkrf_reader.get_raw_flats(as_multiple_array=True) reduced_flats = dkrf_reader.get_reduced_flats(method="median") actual_flats = [] with HDF5File(self.nx_file_path, "r") as f: actual_flats.append(f["entry0000/data/data"][slice(1, 25 + 1)]) actual_flats.append(f["entry0000/data/data"][slice(526, 550 + 1)]) assert len(flats) == len(actual_flats) for i in range(len(flats)): assert np.allclose(flats[i], actual_flats[i]) actual_reduced_flats = np.median(actual_flats[i], axis=0) assert np.allclose(reduced_flats[i], actual_reduced_flats) assert np.allclose(list(dkrf_reader.get_reduced_flats(as_dict=True).keys()), [1, 526]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/io/tests/test_writers.py0000644000175000017500000000610514654107202020076 0ustar00pierrepierrefrom os import path from tempfile import TemporaryDirectory import pytest import numpy as np from nabu.io.writer import NXProcessWriter from nabu.io.reader import import_h5_to_dict from nabu.testutils import get_data @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls._tmpdir = TemporaryDirectory(prefix="nabu_") cls.tempdir = cls._tmpdir.name cls.sino_data = get_data("mri_sino500.npz")["data"].astype(np.uint16) cls.data = cls.sino_data yield # teardown cls._tmpdir.cleanup() @pytest.fixture(scope="class") def bootstrap_h5(request): cls = request.cls cls._tmpdir = TemporaryDirectory(prefix="nabu_") cls.tempdir = cls._tmpdir.name cls.data = get_data("mri_sino500.npz")["data"].astype(np.uint16) cls.h5_config = { "key1": "value1", "some_int": 1, "some_float": 1.0, "some_dict": { "numpy_array": np.ones((5, 6), dtype="f"), "key2": "value2", }, } yield # teardown cls._tmpdir.cleanup() @pytest.mark.usefixtures("bootstrap_h5") class TestNXWriter: def test_write_simple(self): fname = path.join(self.tempdir, "sino500.h5") writer = NXProcessWriter(fname, entry="entry0000") writer.write(self.data, "test_write_simple") def test_write_with_config(self): fname = path.join(self.tempdir, "sino500_cfg.h5") writer = NXProcessWriter(fname, entry="entry0000") writer.write(self.data, "test_write_with_config", config=self.h5_config) def test_overwrite(self): fname = path.join(self.tempdir, "sino500_overwrite.h5") writer = NXProcessWriter(fname, entry="entry0000", overwrite=True) writer.write(self.data, "test_overwrite", config=self.h5_config) writer2 = NXProcessWriter(fname, entry="entry0001", overwrite=True) writer2.write(self.data, "test_overwrite", config=self.h5_config) # overwrite entry0000 writer3 = NXProcessWriter(fname, entry="entry0000", overwrite=True) new_data = self.data.copy() new_data += 1 new_config = self.h5_config.copy() new_config["key1"] = "modified value" writer3.write(new_data, "test_overwrite", config=new_config) res = import_h5_to_dict(fname, "/") assert "entry0000" in res assert "entry0001" in res assert np.allclose(res["entry0000"]["test_overwrite"]["results"]["data"], self.data + 1) rec_cfg = res["entry0000"]["test_overwrite"]["configuration"] assert rec_cfg["key1"] == "modified value" def test_no_overwrite(self): fname = path.join(self.tempdir, "sino500_no_overwrite.h5") writer = NXProcessWriter(fname, entry="entry0000", overwrite=False) writer.write(self.data, "test_no_overwrite") writer2 = NXProcessWriter(fname, entry="entry0000", overwrite=False) with pytest.raises((RuntimeError, OSError)) as ex: writer2.write(self.data, "test_no_overwrite") message = "Error should have been raised for trying to overwrite, but got the following: %s" % str(ex.value) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1712301455.0 nabu-2024.2.1/nabu/io/utils.py0000644000175000017500000002200714603722617015344 0ustar00pierrepierreimport os from typing import Optional import contextlib import h5py import numpy as np from silx.io.url import DataUrl from tomoscan.volumebase import VolumeBase from tomoscan.esrf import EDFVolume, HDF5Volume, TIFFVolume, JP2KVolume, MultiTIFFVolume from tomoscan.io import HDF5File # This function might be moved elsewhere def get_compacted_dataslices(urls, subsampling=None, begin=0): """ Regroup urls to get the data more efficiently. Build a structure mapping files indices to information on how to load the data: `{indices_set: data_location}` where `data_location` contains contiguous indices. Parameters ----------- urls: dict Dictionary where the key is an integer and the value is a silx `DataUrl`. subsampling: int, optional Subsampling factor when reading the frames. If an integer `n` is provided, then one frame out of `n` will be read. Returns -------- merged_urls: dict Dictionary with the same keys as the `urls` parameter, and where the values are the corresponding `silx.io.url.DataUrl` with merged data_slice. """ subsampling = subsampling or 1 def _convert_to_slice(idx): if np.isscalar(idx): return slice(idx, idx + 1) # otherwise, assume already slice object return idx def is_contiguous_slice(slice1, slice2, step=1): if np.isscalar(slice1): slice1 = slice(slice1, slice1 + step) if np.isscalar(slice2): slice2 = slice(slice2, slice2 + step) return slice2.start == slice1.stop def merge_slices(slice1, slice2, step=1): return slice(slice1.start, slice2.stop, step) if len(urls) == 0: return urls sorted_files_indices = sorted(urls.keys()) # if begin > 0: # sorted_files_indices = sorted_files_indices[begin:] idx0 = sorted_files_indices[begin] first_url = urls[idx0] merged_indices = [[idx0]] # location = (file_path, data_path, slice) data_location = [[first_url.file_path(), first_url.data_path(), _convert_to_slice(first_url.data_slice())]] pos = 0 curr_fp, curr_dp, curr_slice = data_location[pos] skip_next = 0 for idx in sorted_files_indices[begin + 1 :]: if skip_next > 1: skip_next -= 1 continue url = urls[idx] next_slice = _convert_to_slice(url.data_slice()) if ( (url.file_path() == curr_fp) and (url.data_path() == curr_dp) and is_contiguous_slice(curr_slice, next_slice, step=subsampling) ): merged_indices[pos].append(idx) merged_slices = merge_slices(curr_slice, next_slice, step=subsampling) data_location[pos][-1] = merged_slices curr_slice = merged_slices skip_next = 0 else: # "jump" if begin > 0 and skip_next == 0: # Skip the "begin" next urls (first of a new block) skip_next = begin continue pos += 1 merged_indices.append([idx]) data_location.append([url.file_path(), url.data_path(), _convert_to_slice(url.data_slice())]) curr_fp, curr_dp, curr_slice = data_location[pos] # Format result res = {} for ind, dl in zip(merged_indices, data_location): res.update(dict.fromkeys(ind, DataUrl(file_path=dl[0], data_path=dl[1], data_slice=dl[2]))) return res def get_first_hdf5_entry(fname): with HDF5File(fname, "r") as fid: entry = list(fid.keys())[0] return entry def hdf5_entry_exists(fname, entry): with HDF5File(fname, "r") as fid: res = fid.get(entry, None) is not None return res def get_h5_value(fname, h5_path, default_ret=None): with HDF5File(fname, "r") as fid: try: val_ptr = fid[h5_path][()] except KeyError: val_ptr = default_ret return val_ptr def get_h5_str_value(dataset_ptr): """ Get a HDF5 field which can be bytes or str (depending on h5py version !). """ data = dataset_ptr[()] if isinstance(data, str): return data else: return bytes.decode(data) def create_dict_of_indices(images_stack, images_indices): """ From an image stack with the images indices, create a dictionary where each index is the image index, and the value is the corresponding image. Parameters ---------- images_stack: numpy.ndarray A 3D numpy array in the layout (n_images, n_y, n_x) images_indices: array or list of int Array containing the indices of images in the stack Examples -------- Given a simple array stack: >>> images_stack = np.arange(3*4*5).reshape((3,4,5)) ... images_indices = [2, 7, 1] ... create_dict_of_indices(images_stack, images_indices) ... # returns {2: array1, 7: array2, 1: array3} """ if images_stack.ndim != 3: raise ValueError("Expected a 3D array") if len(images_indices) != images_stack.shape[0]: raise ValueError("images_stack must have as many images as the length of images_indices") res = {} for i in range(len(images_indices)): res[images_indices[i]] = images_stack[i] return res def convert_dict_values(dic, val_replacements, bytes_tostring=False): """ Modify a dictionary to be able to export it with silx.io.dicttoh5 """ modified_dic = {} for key, value in dic.items(): if isinstance(key, int): # np.isscalar ? key = str(key) if isinstance(value, bytes) and bytes_tostring: value = bytes.decode(value.tostring()) if isinstance(value, dict): value = convert_dict_values(value, val_replacements, bytes_tostring=bytes_tostring) else: if isinstance(value, DataUrl): value = value.path() elif value.__hash__ is not None and value in val_replacements: value = val_replacements[value] modified_dic[key] = value return modified_dic class _BaseReader(contextlib.AbstractContextManager): def __init__(self, url: DataUrl): if not isinstance(url, DataUrl): raise TypeError("url should be an instance of DataUrl") if url.scheme() not in ("silx", "h5py"): raise ValueError("Valid scheme are silx and h5py") if url.data_slice() is not None: raise ValueError("Data slices are not managed. Data path should " "point to a bliss node (h5py.Group)") self._url = url self._file_handler = None def __exit__(self, *exc): return self._file_handler.close() class EntryReader(_BaseReader): """Context manager used to read a bliss node""" def __enter__(self): self._file_handler = HDF5File(self._url.file_path(), mode="r") if self._url.data_path() == "": entry = self._file_handler else: entry = self._file_handler[self._url.data_path()] if not isinstance(entry, h5py.Group): raise ValueError("Data path should point to a bliss node (h5py.Group)") return entry class DatasetReader(_BaseReader): """Context manager used to read a bliss node""" def __enter__(self): self._file_handler = HDF5File(self._url.file_path(), mode="r") entry = self._file_handler[self._url.data_path()] if not isinstance(entry, h5py.Dataset): raise ValueError("Data path ({}) should point to a dataset (h5py.Dataset)".format(self._url.path())) return entry # TODO: require some utils function to deduce type. And insure homogeneity. Might be moved in tomoscan ? def file_format_is_edf(file_format: str): return file_format.lower().lstrip(".") == "edf" def file_format_is_jp2k(file_format: str): return file_format.lower().lstrip(".") in ("jp2k", "jp2") def file_format_is_tiff(file_format: str): return file_format.lower().lstrip(".") in ("tiff", "tif") def file_format_is_hdf5(file_format: str): return file_format.lower().lstrip(".") in ("hdf5", "hdf", "nx", "nexus") def get_output_volume(location: str, file_prefix: Optional[str], file_format: str, multitiff=False) -> VolumeBase: # TODO: see strategy. what if user provide a .nx ... ? # this function should be more generic location, extension = os.path.splitext(location) if extension == "": extension = file_format if file_format_is_edf(extension): return EDFVolume(folder=location, volume_basename=file_prefix) elif file_format_is_jp2k(extension): return JP2KVolume(folder=location, volume_basename=file_prefix) elif file_format_is_hdf5(file_format=extension): if extension is None: if file_prefix is None: location = ".".join([location, extension]) else: location = os.path.join(location, ".".join([file_prefix, extension])) return HDF5Volume(file_path=location) elif file_format_is_tiff(extension): if multitiff: return MultiTIFFVolume(file_path=location) else: return TIFFVolume(folder=location, volume_basename=file_prefix) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906647.0 nabu-2024.2.1/nabu/io/writer.py0000644000175000017500000003702714712705027015525 0ustar00pierrepierrefrom glob import glob from pathlib import Path as pathlib_Path from os import path, getcwd, chdir from posixpath import join as posix_join from datetime import datetime import numpy as np from h5py import VirtualSource, VirtualLayout from silx.io.dictdump import dicttoh5 from silx.io.url import DataUrl try: from tomoscan.io import HDF5File except: from h5py import File as HDF5File from tomoscan.esrf import RawVolume from tomoscan.esrf.volume.jp2kvolume import has_glymur as __have_jp2k__ from .. import version as nabu_version from ..utils import merged_shape from .utils import convert_dict_values def get_datetime(): """ Function used by some writers to indicate the current date. """ return datetime.now().replace(microsecond=0).isoformat() class Writer: """ Base class for all writers. """ def __init__(self, fname): self.fname = fname def get_filename(self): return self.fname ################################################################################################### ## Nabu original code for NXProcessWriter - also works for non-3D data, does not depend on tomoscan ################################################################################################### def h5_write_object(h5group, key, value, overwrite=False, default_val=None): existing_val = h5group.get(key, default_val) if existing_val is not default_val: if not overwrite: raise OSError("Unable to create link (name already exists): %s" % h5group.name) else: h5group.pop(key) h5group[key] = value class NXProcessWriter(Writer): """ A class to write Nexus file with a processing result. """ def __init__(self, fname, entry=None, filemode="a", overwrite=False): """ Initialize a NXProcessWriter. Parameters ----------- fname: str Path to the HDF5 file. entry: str, optional Entry in the HDF5 file. Default is "entry" """ super().__init__(fname) self._set_entry(entry) self._filemode = filemode self.overwrite = overwrite def _set_entry(self, entry): self.entry = entry or "entry" data_path = posix_join("/", self.entry) self.data_path = data_path def write( self, result, process_name, processing_index=0, config=None, data_name="data", is_frames_stack=True, direct_access=True, ): """ Write the result in the current NXProcess group. Parameters ---------- result: numpy.ndarray Array containing the processing result process_name: str Name of the processing processing_index: int Index of the processing (in a pipeline) config: dict, optional Dictionary containing the configuration. """ swmr = self._filemode == "r" with HDF5File(self.fname, self._filemode, swmr=swmr) as fid: nx_entry = fid.require_group(self.data_path) if "NX_class" not in nx_entry.attrs: nx_entry.attrs["NX_class"] = "NXentry" nx_process = nx_entry.require_group(process_name) nx_process.attrs["NX_class"] = "NXprocess" metadata = { "program": "nabu", "version": nabu_version, "date": get_datetime(), "sequence_index": np.int32(processing_index), } for key, val in metadata.items(): h5_write_object(nx_process, key, val, overwrite=self.overwrite) if config is not None: export_dict_to_h5( config, self.fname, posix_join(nx_process.name, "configuration"), overwrite_data=True, mode="a" ) nx_process["configuration"].attrs["NX_class"] = "NXcollection" if isinstance(result, dict): results_path = posix_join(nx_process.name, "results") export_dict_to_h5(result, self.fname, results_path, overwrite_data=self.overwrite, mode="a") else: nx_data = nx_process.require_group("results") results_path = nx_data.name nx_data.attrs["NX_class"] = "NXdata" nx_data.attrs["signal"] = data_name results_data_path = posix_join(results_path, data_name) if self.overwrite and results_data_path in fid: del fid[results_data_path] if isinstance(result, VirtualLayout): nx_data.create_virtual_dataset(data_name, result) else: # assuming array-like nx_data[data_name] = result if is_frames_stack: nx_data[data_name].attrs["interpretation"] = "image" nx_data.attrs["signal"] = data_name # prepare the direct access plots if direct_access: nx_process.attrs["default"] = "results" if "default" not in nx_entry.attrs: nx_entry.attrs["default"] = posix_join(nx_process.name, "results") # Return the internal path to "results" return results_path class NXVolVolume(NXProcessWriter): """ An interface to NXProcessWriter with the same API than tomoscan.esrf.volume. NX files are written in two ways: 1. Partial files containing sub-volumes 2. Final volume: master file with virtual dataset pointing to partial files This class handles the first one, therefore expects the "start_index" parameter. In the case of HDF5, a sub-directory is creating to contain the partial files. In other words, if file_prefix="recons" and output_dir="/path/to/out": /path/to/out/recons.h5 # final master file /path/to/out/recons/ /path/to/out/recons/recons_00000.h5 /path/to/out/recons/recons_00100.h5 ... """ def __init__(self, **kwargs): # get parameters from kwargs passed to tomoscan XXVolume() folder = output_dir = kwargs.get("folder", None) volume_basename = file_prefix = kwargs.get("volume_basename", None) start_index = kwargs.get("start_index", None) overwrite = kwargs.get("overwrite", False) data_path = entry = kwargs.get("data_path", None) self._process_name = kwargs.get("process_name", "reconstruction") if any([param is None for param in [folder, volume_basename, start_index, entry]]): raise ValueError("Need the following parameters: folder, volume_basename, start_index, data_path") # # By default, a sub-folder is created so that partial volumes will be one folder below the master file # (see example above in class documentation) if kwargs.get("create_subfolder", True): output_dir = path.join(output_dir, file_prefix) if path.exists(output_dir): if not (path.isdir(output_dir)): raise ValueError("Unable to create directory %s: already exists and is not a directory" % output_dir) else: pathlib_Path(output_dir).mkdir(parents=True, exist_ok=True) # file_prefix += str("_%05d" % start_index) fname = path.join(output_dir, file_prefix + ".hdf5") super().__init__(fname, entry=entry, filemode="a", overwrite=overwrite) self.data = None self.metadata = None self.file_path = fname def save(self): if self.data is None: raise ValueError("Must set data first") self.write(self.data, self._process_name, config=self.metadata) def save_metadata(self): pass # already done def browse_data_files(self): return [self.fname] # COMPAT. LegacyNXProcessWriter = NXProcessWriter # ######################################################################################## ######################################################################################## ######################################################################################## def export_dict_to_h5(dic, h5file, h5path, overwrite_data=True, mode="a"): """ Wrapper on top of silx.io.dictdump.dicttoh5 replacing None with "None" Parameters ----------- dic: dict Dictionary containing the options h5file: str File name h5path: str Path in the HDF5 file overwrite_data: bool, optional Whether to overwrite data when writing HDF5. Default is True mode: str, optional File mode. Default is "a" (append). """ modified_dic = convert_dict_values( dic, {None: "None"}, ) update_mode = {True: "modify", False: "add"}[bool(overwrite_data)] return dicttoh5(modified_dic, h5file=h5file, h5path=h5path, update_mode=update_mode, mode=mode) def create_virtual_layout(files_or_pattern, h5_path, base_dir=None, axis=0, dtype="f"): """ Create a HDF5 virtual layout. Parameters ---------- files_or_pattern: str or list A list of file names, or a wildcard pattern. If a list is provided, it will not be sorted! This will have to be done before calling this function. h5_path: str Path inside the HDF5 input file(s) base_dir: str, optional Base directory when using relative file names. axis: int, optional Data axis to merge. Default is 0. """ prev_cwd = None if base_dir is not None: prev_cwd = getcwd() chdir(base_dir) if isinstance(files_or_pattern, str): files_list = glob(files_or_pattern) files_list.sort() else: # list files_list = files_or_pattern if files_list == []: raise ValueError("Nothing found as pattern %s" % files_or_pattern) virtual_sources = [] shapes = [] for fname in files_list: with HDF5File(fname, "r", swmr=True) as fid: shape = fid[h5_path].shape vsource = VirtualSource(fname, name=h5_path, shape=shape) virtual_sources.append(vsource) shapes.append(shape) total_shape = merged_shape(shapes, axis=axis) virtual_layout = VirtualLayout(shape=total_shape, dtype=dtype) start_idx = 0 for vsource, shape in zip(virtual_sources, shapes): n_imgs = shape[axis] # Perhaps there is more elegant if axis == 0: virtual_layout[start_idx : start_idx + n_imgs] = vsource elif axis == 1: virtual_layout[:, start_idx : start_idx + n_imgs, :] = vsource elif axis == 2: virtual_layout[:, :, start_idx : start_idx + n_imgs] = vsource else: raise ValueError("Only axis 0,1,2 are supported") # start_idx += n_imgs if base_dir is not None: chdir(prev_cwd) return virtual_layout def merge_hdf5_files( files_or_pattern, h5_path, output_file, process_name, output_entry=None, output_filemode="a", data_name="data", processing_index=0, config=None, base_dir=None, axis=0, overwrite=False, dtype="f", ): """ Parameters ----------- files_or_pattern: str or list A list of file names, or a wildcard pattern. If a list is provided, it will not be sorted! This will have to be done before calling this function. h5_path: str Path inside the HDF5 input file(s) output_file: str Path of the output file process_name: str Name of the process output_entry: str, optional Output HDF5 root entry (default is "/entry") output_filemode: str, optional File mode for output file. Default is "a" (append) processing_index: int, optional Processing index for the output file. Default is 0. config: dict, optional Dictionary describing the configuration needed to get the results. base_dir: str, optional Base directory when using relative file names. axis: int, optional Data axis to merge. Default is 0. overwrite: bool, optional Whether to overwrite already existing data in the final file. Default is False. """ if base_dir is not None: prev_cwd = getcwd() virtual_layout = create_virtual_layout(files_or_pattern, h5_path, base_dir=base_dir, axis=axis, dtype=dtype) nx_file = NXProcessWriter(output_file, entry=output_entry, filemode=output_filemode, overwrite=overwrite) nx_file.write( virtual_layout, process_name, processing_index=processing_index, config=config, data_name=data_name, is_frames_stack=True, ) # pylint: disable=E0606 if base_dir is not None and prev_cwd != getcwd(): chdir(prev_cwd) class HSTVolWriter(Writer): """ A writer to mimic PyHST2 ".vol" files """ def __init__(self, fname, append=False, **kwargs): super().__init__(fname) self.append = append self._vol_writer = RawVolume(fname, overwrite=True, append=append) self._hst_metadata = kwargs.get("hst_metadata", {}) def generate_metadata(self, data, **kwargs): n_z, n_y, n_x = data.shape metadata = { "NUM_X": n_x, "NUM_Y": n_y, "NUM_Z": n_z, "voxelSize": 40.0, "BYTEORDER": "LOWBYTEFIRST", "ValMin": kwargs.get("ValMin", 0.0), "ValMax": kwargs.get("ValMin", 1.0), "s1": 0.0, "s2": 0.0, "S1": 0.0, "S2": 0.0, } for key, default_val in metadata.items(): metadata[key] = kwargs.get(key, None) or self._hst_metadata.get(key, None) or default_val return metadata @staticmethod def sanitize_metadata(metadata): # To be fixed in RawVolume for what in ["NUM_X", "NUM_Y", "NUM_Z"]: metadata[what] = int(metadata[what]) for what in ["voxelSize", "ValMin", "ValMax", "s1", "s2", "S1", "S2"]: metadata[what] = float(metadata[what]) def write(self, data, *args, config=None, **kwargs): existing_metadata = self._vol_writer.load_metadata() new_metadata = self.generate_metadata(data) if len(existing_metadata) == 0 or not (self.append): # first write or append==False metadata = new_metadata else: # append write ; update metadata metadata = existing_metadata.copy() self.sanitize_metadata(metadata) metadata["NUM_Z"] += new_metadata["NUM_Z"] self._vol_writer.data = data self._vol_writer.metadata = metadata self._vol_writer.save() # Also save .xml self._vol_writer.save_metadata( url=DataUrl( scheme="lxml", file_path=self._vol_writer.metadata_url.file_path().replace(".info", ".xml"), ) ) class HSTVolVolume(HSTVolWriter): """ An interface to HSTVolWriter with the same API than tomoscan.esrf.volume. This is really not ideal, see nabu:#381 """ def __init__(self, **kwargs): file_path = kwargs.get("file_path", None) if file_path is None: raise ValueError("Missing mandatory 'file_path' parameter") super().__init__(file_path, append=kwargs.pop("append", False), **kwargs) self.data = None self.metadata = None self.data_url = self._vol_writer.data_url def save(self): if self.data is None: raise ValueError("Must set data first") self.write(self.data) def save_metadata(self): pass # already done for HST part - proper metadata is not supported def browse_data_files(self): return [self.fname] ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/misc/0000755000175000017500000000000014730277752014163 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/misc/__init__.py0000644000175000017500000000000014315516747016261 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/binning.py0000644000175000017500000000555014402565210016147 0ustar00pierrepierreimport numpy as np from ..utils import deprecated def binning(img, bin_factor, out_dtype=np.float32): """ Bin an image by a factor of "bin_factor". Parameters ---------- bin_factor: tuple of int Binning factor in each axis. out_dtype: dtype, optional Output data type. Default is float32. Notes ----- If the image original size is not a multiple of the binning factor, the last items (in the considered axis) will be dropped. The resulting shape is (img.shape[0] // bin_factor[0], img.shape[1] // bin_factor[1]) """ s = img.shape n0, n1 = bin_factor shp = (s[0] - (s[0] % n0), s[1] - (s[1] % n1)) sub_img = img[: shp[0], : shp[1]] out_shp = (shp[0] // n0, shp[1] // n1) res = np.zeros(out_shp, dtype=out_dtype) for i in range(n0): for j in range(n1): res[:] += sub_img[i::n0, j::n1] res /= n0 * n1 return res def binning_n_alt(img, bin_factor, out_dtype=np.float32): """ Alternate, "clever" but slower implementation """ n0, n1 = bin_factor new_shape = tuple(s - (s % n) for s, n in zip(img.shape, bin_factor)) sub_img = img[: new_shape[0], : new_shape[1]] img_view_4d = sub_img.reshape((new_shape[0] // n0, n0, new_shape[1] // n1, n1)) return img_view_4d.astype(out_dtype).mean(axis=1).mean(axis=-1) # # COMPAT. # @deprecated("Please use binning()", do_print=True) def binning2(img, out_dtype=np.float32): return binning(img, (2, 2), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning2_horiz(img, out_dtype=np.float32): return binning(img, (1, 2), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning2_vertic(img, out_dtype=np.float32): return binning(img, (2, 1), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3(img, out_dtype=np.float32): return binning(img, (3, 3), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3_horiz(img, out_dtype=np.float32): return binning(img, (1, 3), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3_vertic(img, out_dtype=np.float32): return binning(img, (3, 1), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def get_binning_function(binning_factor): """ Determine the binning function to use. """ binning_functions = { (2, 2): binning2, (2, 1): binning2_vertic, (1, 2): binning2_horiz, (3, 3): binning3, (3, 1): binning3_vertic, (1, 3): binning3_horiz, (2, 3): None, # was a limitation (3, 2): None, # was a limitation } if binning_factor not in binning_functions: raise ValueError("Could not get a function for binning factor %s" % binning_factor) return binning_functions[binning_factor] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/fftshift.py0000644000175000017500000000031514550227307016340 0ustar00pierrepierrefrom ..processing.fftshift import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.fftshift has been moved to nabu.processing.fftshift", do_print=True, func_name="fftshift" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/filters.py0000644000175000017500000000113414402565210016165 0ustar00pierrepierreimport numpy as np import scipy.signal def correct_spikes(image, threshold): """ Perform a conditional median filtering The filtering is done in-place, meaning that the array content is modified. Parameters ---------- image: numpy.ndarray Image to filter threshold: float Median filter threshold """ m_im = scipy.signal.medfilt2d(image) fixed_part = np.array(image[[0, 0, -1, -1], [0, -1, 0, -1]]) where = abs(image - m_im) > threshold image[where] = m_im[where] image[[0, 0, -1, -1], [0, -1, 0, -1]] = fixed_part return image ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/fourier_filters.py0000644000175000017500000001307314402565210017725 0ustar00pierrepierre# -*- coding: utf-8 -*- """ Fourier filters. """ from functools import lru_cache import numpy as np import scipy.special as spspe @lru_cache(maxsize=10) def get_lowpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64): """Computes a low pass filter using the erfc function. Parameters ---------- img_shape: tuple Shape of the image cutoff_par: float or sequence of two floats Position of the cut off in pixels, if a sequence is given the second float expresses the width of the transition region which is given as a fraction of the cutoff frequency. When only one float is given for this argument a gaussian is applied whose sigma is the parameter. When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff frequency while a smooth erfc transition to zero is done use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ if cutoff_par is None: return 1 elif isinstance(cutoff_par, (int, float)): cutoff_pix = cutoff_par cutoff_trans_fact = None else: try: cutoff_pix, cutoff_trans_fact = cutoff_par except ValueError: raise ValueError( "Argument cutoff_par (which specifies the pass filter shape) must be either a scalar or a" " sequence of two scalars" ) if (not isinstance(cutoff_pix, (int, float))) or (not isinstance(cutoff_trans_fact, (int, float))): raise ValueError( "Argument cutoff_par (which specifies the pass filter shape) must be one number or a sequence" "of two numbers" ) coords = [np.fft.fftfreq(s, 1) for s in img_shape] coords = np.meshgrid(*coords, indexing="ij") r = np.sqrt(np.sum(np.array(coords, dtype=data_type) ** 2, axis=0)) if cutoff_trans_fact is not None: k_cut = 0.5 / cutoff_pix k_cut_width = k_cut * cutoff_trans_fact k_pos_rescaled = (r - k_cut) / k_cut_width res = spspe.erfc(k_pos_rescaled) / 2 else: res = np.exp(-(np.pi**2) * (r**2) * (cutoff_pix**2) * 2) # Making sure to force result to chosen data type res = res.astype(data_type) if use_rfft: slicelist = [slice(None)] * (len(res.shape) - 1) + [slice(0, res.shape[-1] // 2 + 1)] return res[tuple(slicelist)] else: return res def get_highpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64): """Computes a high pass filter using the erfc function. Parameters ---------- img_shape: tuple Shape of the image cutoff_par: float or sequence of two floats Position of the cut off in pixels, if a sequence is given the second float expresses the width of the transition region which is given as a fraction of the cutoff frequency. When only one float is given for this argument a gaussian is applied whose sigma is the parameter, and the result is subtracted from 1 to obtain the high pass filter When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff frequency and then a smooth transition to zero is done for smaller frequency use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ if cutoff_par is None: return 1 else: return 1 - get_lowpass_filter(img_shape, cutoff_par, use_rfft=use_rfft, data_type=data_type) def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None, use_rfft=False, data_type=np.float64): """Computes a band pass filter using the erfc function. The cutoff structures should be formed as follows: - tuple of two floats: the first indicates the cutoff frequency, the second \ determines the width of the transition region, as fraction of the cutoff frequency. - one float -> it represents the sigma of a gaussian which acts as a filter or anti-filter (1 - filter). Parameters ---------- img_shape: tuple Shape of the image cutoff_lowpass: float or sequence of two floats Cutoff parameters for the low-pass filter cutoff_highpass: float or sequence of two floats Cutoff parameters for the high-pass filter use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ return get_lowpass_filter( img_shape, cutoff_par=cutoff_lowpass, use_rfft=use_rfft, data_type=data_type ) * get_highpass_filter(img_shape, cutoff_par=cutoff_highpass, use_rfft=use_rfft, data_type=data_type) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/histogram.py0000644000175000017500000000032114550227307016515 0ustar00pierrepierrefrom ..processing.histogram import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.histogram has been moved to nabu.processing.histogram", do_print=True, func_name="histogram" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/histogram_cuda.py0000644000175000017500000000035614550227307017521 0ustar00pierrepierrefrom ..processing.histogram_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.histogram_cuda has been moved to nabu.processing.histogram_cuda", do_print=True, func_name="histogram_cuda", ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/kernel_base.py0000644000175000017500000000030214550227307016771 0ustar00pierrepierrefrom nabu.processing.kernel_base import KernelBase from ..utils import deprecated_class KernelBase = deprecated_class("KernelBase has been moved to nabu.processing", do_print=True)(KernelBase) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/padding.py0000644000175000017500000000523114402565210016125 0ustar00pierrepierreimport math import numpy as np def pad_interpolate(im, padded_img_shape_vh, translation_vh=None, padding_mode="reflect"): """ This function produces a centered padded image and , optionally if translation_vh is set, performs a Fourier shift of the whole image. In case of translation, the image is first padded to a larger extent, that encompasses the final padded with plus a translation margin, and then translated, and final recut to the required padded width. The values are translated: if a feature appear at x in the original image it will appear at pad+translation+x in the final image. Parameters ------------ im: np.ndaray the input image translation_vh: a sequence of two float the vertical and horizontal shifts """ if translation_vh is not None: pad_extra = 2 ** (1 + np.ceil(np.log2(np.maximum(1, np.ceil(abs(np.array(translation_vh)))))).astype(np.int32)) else: pad_extra = [0, 0] origy, origx = im.shape rety = padded_img_shape_vh[0] + pad_extra[0] retx = padded_img_shape_vh[1] + pad_extra[1] xpad = [0, 0] xpad[0] = math.ceil((retx - origx) / 2) xpad[1] = retx - origx - xpad[0] ypad = [0, 0] ypad[0] = math.ceil((rety - origy) / 2) ypad[1] = rety - origy - ypad[0] y2 = origy - ypad[1] x2 = origx - xpad[1] if ypad[0] + 1 > origy or xpad[0] + 1 > origx or y2 < 1 or x2 < 1: raise ValueError("Too large padding for this reflect padding type") padded_im = np.pad(im, pad_width=((ypad[0], ypad[1]), (xpad[0], xpad[1])), mode=padding_mode) if translation_vh is not None: freqs_list = list(map(np.fft.fftfreq, padded_im.shape)) shifts_list = [np.exp(-2.0j * np.pi * freqs * trans) for (freqs, trans) in zip(freqs_list, translation_vh)] shifts_2D = shifts_list[0][:, None] * shifts_list[1][None, :] padded_im = np.fft.ifft2(np.fft.fft2(padded_im) * shifts_2D).real padded_im = recut(padded_im, padded_img_shape_vh) return padded_im def recut(im, new_shape_vh): """ This method implements a centered cut which reverts the centered padding applied in the present class. Parameters ----------- im: np.ndarray A 2D image. new_shape_vh: tuple The shape of the cutted image. Returns -------- The image cutted to new_shape_vh. """ new_shape_vh = np.array(new_shape_vh) old_shape_vh = np.array(im.shape) center_vh = (old_shape_vh - 1) / 2 start_vh = np.round(0.5 + center_vh - new_shape_vh / 2).astype(np.int32) end_vh = start_vh + new_shape_vh return im[start_vh[0] : end_vh[0], start_vh[1] : end_vh[1]] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/padding_base.py0000644000175000017500000000033014550227307017120 0ustar00pierrepierrefrom ..processing.padding_base import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.padding has been moved to nabu.processing.padding_base", do_print=True, func_name="padding_base" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/processing_base.py0000644000175000017500000000032614550227307017673 0ustar00pierrepierrefrom nabu.processing.processing_base import ProcessingBase from ..utils import deprecated_class ProcessingBase = deprecated_class("ProcessingBase has been moved to nabu.processing", do_print=True)(ProcessingBase) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/rotation.py0000644000175000017500000000031514550227307016362 0ustar00pierrepierrefrom ..processing.rotation import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.rotation has been moved to nabu.processing.rotation", do_print=True, func_name="rotation" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/rotation_cuda.py0000644000175000017500000000034114550227307017355 0ustar00pierrepierrefrom ..processing.rotation_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.rotation_cuda has been moved to nabu.processing.rotation_cuda", do_print=True, func_name="rotation_cuda" ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/misc/tests/0000755000175000017500000000000014730277752015325 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/misc/tests/__init__.py0000644000175000017500000000000114315516747017424 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/tests/test_binning.py0000644000175000017500000000353414402565210020350 0ustar00pierrepierrefrom itertools import product import numpy as np import pytest from nabu.misc.binning import * @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = np.arange(100 * 99, dtype=np.uint16).reshape(100, 99) cls.tol = 1e-5 @pytest.mark.usefixtures("bootstrap") class TestBinning: def testBinning(self): """ Test the general-purpose binning function with an image defined by its indices. The test "image" is an array where entry [i, j] equals i * Nx + j (where Nx = array.shape[1]). Let (b_i, b_j) be the binning factor along each dimension, then the binned array at position [p_i, p_j] is equal to 1/(b_i*b_j) * Sum(Sum(Nx*i + j, (j, p_j, p_j+b_j-1)), (i, p_i, p_i+b_i-1)) which happens to be equal to (Nx * b_i + 2*Nx * p_i - Nx + b_j + 2*p_j - 1) /2 """ def get_reference_binned_image(img_shape, bin_factor): # reference[p_i, p_j] = 0.5 * (Nx * b_i + 2*Nx * p_i - Nx + b_j + 2*p_j - 1) Ny, Nx = img_shape img_shape_reduced = tuple(s - (s % b) for s, b in zip(img_shape, bin_factor)) b_i, b_j = bin_factor inds_i, inds_j = np.indices(img_shape_reduced) p_i = inds_i[::b0, ::b1] p_j = inds_j[::b0, ::b1] return 0.5 * (Nx * b_i + 2 * Nx * p_i - Nx + b_j + 2 * p_j - 1) # Various test settings binning_factors = [2, 3, 4, 5, 6, 8, 10] n_items = [63, 64, 65, 66, 125, 128, 130] # Yep, that's 2401 tests... params = product(n_items, n_items, binning_factors, binning_factors) for s0, s1, b0, b1 in params: img = np.arange(s0 * s1).reshape((s0, s1)) ref = get_reference_binned_image(img.shape, (b0, b1)) res = binning(img, (b0, b1)) assert np.allclose(res, ref) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/tests/test_interpolation.py0000644000175000017500000000452514550227307021622 0ustar00pierrepierreimport numpy as np import pytest from scipy.interpolate import interp1d from nabu.testutils import generate_tests_scenarios, get_data from nabu.utils import get_cuda_srcfile, updiv from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: import pycuda.gpuarray as garray from nabu.cuda.kernel import CudaKernel img0 = get_data("brain_phantom.npz")["data"] scenarios = generate_tests_scenarios( { "image": [img0, img0[:, :511], img0[:511, :]], "x_bounds": [(180, 360), (0, 180), (50, 50 + 180)], "x_to_x_new": [0.1, -0.2, 0.3], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tol = 1e-4 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestInterpolation: def _get_reference_interpolation(self, img, x, x_new): interpolator = interp1d(x, img, kind="linear", axis=0, fill_value="extrapolate", copy=True) ref = interpolator(x_new) return ref def _compare(self, res, img, x, x_new): ref = self._get_reference_interpolation(img, x, x_new) mae = np.max(np.abs(res - ref)) return mae # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.skipif(not (__has_pycuda__), reason="need pycuda for this test") @pytest.mark.parametrize("config", scenarios) def test_cuda_interpolation(self, config): img = config["image"] Ny, Nx = img.shape xmin, xmax = config["x_bounds"] x = np.linspace(xmin, xmax, num=img.shape[0], endpoint=False, dtype="f") x_new = x + config["x_to_x_new"] d_img = garray.to_gpu(img) d_out = garray.zeros_like(d_img) d_x = garray.to_gpu(x) d_x_new = garray.to_gpu(x_new) cuda_interpolator = CudaKernel( "linear_interp_vertical", get_cuda_srcfile("interpolation.cu"), signature="PPiiPP" ) cuda_interpolator(d_img, d_out, Nx, Ny, d_x, d_x_new, grid=(updiv(Nx, 16), updiv(Ny, 16), 1), block=(16, 16, 1)) err = self._compare(d_out.get(), img, x, x_new) err_msg = str("Max error is too high for this configuration: %s" % str(config)) assert err < self.tol, err_msg ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/transpose.py0000644000175000017500000000032114550227307016536 0ustar00pierrepierrefrom ..processing.transpose import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.transpose has been moved to nabu.processing.transpose", do_print=True, func_name="transpose" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/unsharp.py0000644000175000017500000000030314550227307016200 0ustar00pierrepierrefrom ..processing.unsharp import * from ..utils import deprecation_warning deprecation_warning("nabu.misc.unsharp has been moved to nabu.processing.unsharp", do_print=True, func_name="unsharp") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/unsharp_cuda.py0000644000175000017500000000033514550227307017201 0ustar00pierrepierrefrom ..processing.unsharp_cuda import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.unsharp_cuda has been moved to nabu.processing.unsharp_cuda", do_print=True, func_name="unsharp_cuda" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/misc/unsharp_opencl.py0000644000175000017500000000035614550227307017550 0ustar00pierrepierrefrom ..processing.unsharp_opencl import * from ..utils import deprecation_warning deprecation_warning( "nabu.misc.unsharp_opencl has been moved to nabu.processing.unsharp_opencl", do_print=True, func_name="unsharp_opencl", ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/misc/utils.py0000644000175000017500000000750214402565210015662 0ustar00pierrepierreimport numpy as np def rescale_data(data, new_min, new_max, data_min=None, data_max=None): if data_min is None: data_min = np.min(data) if data_max is None: data_max = np.max(data) return (new_max - new_min) / (data_max - data_min) * (data - data_min) + new_min def get_dtype_range(dtype, normalize_floats=False): if np.dtype(dtype).kind in ["u", "i"]: dtype_range = (np.iinfo(dtype).min, np.iinfo(dtype).max) else: if normalize_floats: dtype_range = (-1.0, 1.0) else: dtype_range = (np.finfo(dtype).min, np.finfo(dtype).max) return dtype_range def psnr(img1, img2): if img1.dtype != img2.dtype: raise ValueError("both images should have the same data type") dtype_range = get_dtype_range(img1.dtype, normalize_floats=True) dtype_range = dtype_range[-1] - dtype_range[0] if np.dtype(img1.dtype).kind in ["f", "c"]: img1 = rescale_data(img1, -1.0, 1.0) img2 = rescale_data(img2, -1.0, 1.0) mse = np.mean((img1.astype(np.float64) - img2) ** 2) return 10 * np.log10((dtype_range**2) / mse) # # silx.opencl.utils.ConvolutionInfos cannot be used as long as # silx.opencl instantiates the "ocl" singleton in __init__, # leaving opencl contexts all over the place in some cases. # # so for now: copypasta # class ConvolutionInfos(object): allowed_axes = { "1D": [None], "separable_2D_1D_2D": [None, (0, 1), (1, 0)], "batched_1D_2D": [(0,), (1,)], "separable_3D_1D_3D": [None, (0, 1, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0), (1, 0, 2), (0, 2, 1)], "batched_1D_3D": [(0,), (1,), (2,)], "batched_separable_2D_1D_3D": [(0,), (1,), (2,)], # unsupported (?) "2D": [None], "batched_2D_3D": [(0,), (1,), (2,)], "separable_3D_2D_3D": [ (1, 0), (0, 1), (2, 0), (0, 2), (1, 2), (2, 1), ], "3D": [None], } use_cases = { (1, 1): { "1D": { "name": "1D convolution on 1D data", "kernels": ["convol_1D_X"], }, }, (2, 2): { "2D": { "name": "2D convolution on 2D data", "kernels": ["convol_2D_XY"], }, }, (3, 3): { "3D": { "name": "3D convolution on 3D data", "kernels": ["convol_3D_XYZ"], }, }, (2, 1): { "separable_2D_1D_2D": { "name": "Separable (2D->1D) convolution on 2D data", "kernels": ["convol_1D_X", "convol_1D_Y"], }, "batched_1D_2D": { "name": "Batched 1D convolution on 2D data", "kernels": ["convol_1D_X", "convol_1D_Y"], }, }, (3, 1): { "separable_3D_1D_3D": { "name": "Separable (3D->1D) convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, "batched_1D_3D": { "name": "Batched 1D convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, "batched_separable_2D_1D_3D": { "name": "Batched separable (2D->1D) convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, }, (3, 2): { "separable_3D_2D_3D": { "name": "Separable (3D->2D) convolution on 3D data", "kernels": ["convol_2D_XY", "convol_2D_XZ", "convol_2D_YZ"], }, "batched_2D_3D": { "name": "Batched 2D convolution on 3D data", "kernels": ["convol_2D_XY", "convol_2D_XZ", "convol_2D_YZ"], }, }, } ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/opencl/0000755000175000017500000000000014730277752014510 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/opencl/__init__.py0000644000175000017500000000000014315516747016606 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/fft.py0000644000175000017500000000032014550227307015623 0ustar00pierrepierrefrom ..processing.fft_opencl import * from ..utils import deprecation_warning deprecation_warning( "nabu.opencl.fft has been moved to nabu.processing.fft_opencl", do_print=True, func_name="fft_opencl" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/opencl/kernel.py0000644000175000017500000001200514654107202016323 0ustar00pierrepierreimport pyopencl.array as parray from pyopencl import Program, CommandQueue, kernel_work_group_info from ..utils import ( deprecation_warning, catch_warnings, ) # TODO use warnings.catch_warnings once python < 3.11 is dropped from ..processing.kernel_base import KernelBase class OpenCLKernel(KernelBase): """ Helper class that wraps OpenCL kernel through pyopencl. Parameters ----------- kernel_name: str Name of the OpenCL kernel. ctx: pyopencl.Context OpenCL context to use. queue: pyopencl.CommandQueue OpenCL queue to use. If provided, will use this queue's context instead of 'ctx' filename: str, optional Path to the file name containing kernels definitions src: str, optional Source code of kernels definitions automation_params: dict, optional Automation parameters, see below build_kwargs: optional Extra arguments to provide to pyopencl.Program.build(), """ def __init__( self, kernel_name, ctx, queue=None, filename=None, src=None, automation_params=None, silent_compilation_warnings=False, **build_kwargs, ): super().__init__( kernel_name, filename=filename, src=src, automation_params=automation_params, silent_compilation_warnings=silent_compilation_warnings, ) if queue is not None: self.ctx = queue.context self.queue = queue else: self.ctx = ctx self.queue = None self.compile_kernel_source(kernel_name, build_kwargs) self.get_kernel() def compile_kernel_source(self, kernel_name, build_kwargs): self.build_kwargs = build_kwargs self.kernel_name = kernel_name with catch_warnings(action=("ignore" if self.silent_compilation_warnings else None)): # pylint: disable=E1123 self.program = Program(self.ctx, self.src).build(**self.build_kwargs) def get_kernel(self): self.kernel = None for kern in self.program.all_kernels(): if kern.function_name == self.kernel_name: self.kernel = kern if self.kernel is None: raise ValueError( "Could not find a kernel with function name '%s'. Available are: %s" % (self.kernel_name, self.program.kernel_names) ) # overwrite parent method def guess_block_size(self, shape): device = self.ctx.devices[0] wg_max = device.max_work_group_size wg_multiple = self.kernel.get_work_group_info(kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE, device) ndim = len(shape) # Try to have workgroup relatively well-balanced in all dimensions, # with more work items in x > y > z if ndim == 1: wg = (wg_max, 1, 1) else: w = (wg_max // wg_multiple, wg_multiple) wg = w if w[0] > w[1] else w[::-1] wg = wg + (1,) if ndim == 3: (wg[0] // 2, wg[1] // 4, 8) return wg def get_block_grid(self, *args, **kwargs): local_size = None global_size = block = None # COMPAT. block = kwargs.pop("block", None) if block is not None: deprecation_warning("Please use 'local_size' instead of 'block'") grid = kwargs.pop("grid", None) if grid is not None: deprecation_warning("Please use 'global_size' instead of 'grid'") global_size = tuple(g * b for g, b in zip(grid, block)) # global_size = kwargs.pop("global_size", global_size) local_size = kwargs.pop("local_size", block) if global_size is None: raise ValueError("Need to define global_size for kernel '%s'" % self.kernel_name) if len(global_size) == 2 and local_size is not None and len(local_size) == 3: local_size = local_size[:-1] # TODO check that last dim is 1 self.last_block_size = local_size self.last_grid_size = global_size return local_size, global_size def follow_device_arr(self, args): args = list(args) for i, arg in enumerate(args): if isinstance(arg, parray.Array): args[i] = arg.data return tuple(args) def call(self, *args, **kwargs): if not isinstance(args[0], CommandQueue): queue = self.queue if queue is None: raise ValueError( "First argument must be a pyopencl queue - otherwise provide OpenCLKernel(..., queue=queue)" ) else: queue = args[0] args = args[1:] global_size, local_size, args, kwargs = self._prepare_call(*args, **kwargs) kwargs.pop("global_size", None) kwargs.pop("local_size", None) kwargs.pop("grid", None) kwargs.pop("block", None) return self.kernel(queue, global_size, local_size, *args, **kwargs) __call__ = call ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/memcpy.py0000644000175000017500000000235714550227307016352 0ustar00pierrepierreimport numpy as np from ..utils import get_opencl_srcfile from .kernel import OpenCLKernel from .processing import OpenCLProcessing class OpenCLMemcpy2D(OpenCLProcessing): """ A class for performing rectangular memory copies between pyopencl arrays. It will only work for float32 arrays! It was written as pyopencl.enqueue_copy is too cumbersome to use for buffers. """ def __init__(self, ctx=None, device_type="GPU", queue=None, **kwargs): super().__init__(ctx=ctx, device_type=device_type, queue=queue, **kwargs) self.memcpy2D = OpenCLKernel("cpy2d", self.ctx, filename=get_opencl_srcfile("ElementOp.cl")) def __call__(self, dst, src, transfer_shape_xy, dst_offset_xy=None, src_offset_xy=None, wait=True): if dst_offset_xy is None: dst_offset_xy = (0, 0) if src_offset_xy is None: src_offset_xy = (0, 0) evt = self.memcpy2D( self.queue, dst, src, np.int32(dst.shape[-1]), np.int32(src.shape[-1]), np.int32(dst_offset_xy), np.int32(src_offset_xy), np.int32(transfer_shape_xy), global_size=transfer_shape_xy, ) if wait: evt.wait() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/padding.py0000644000175000017500000000034014550227307016454 0ustar00pierrepierrefrom ..processing.padding_opencl import * from ..utils import deprecation_warning deprecation_warning( "nabu.opencl.padding has been moved to nabu.processing.padding_opencl", do_print=True, func_name="padding_opencl" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/opencl/processing.py0000644000175000017500000000456014726604214017233 0ustar00pierrepierrefrom ..processing.processing_base import ProcessingBase from ..utils import MissingComponentError from .utils import get_opencl_context, __has_pyopencl__ if __has_pyopencl__: from .kernel import OpenCLKernel import pyopencl as cl import pyopencl.array as parray from pyopencl.tools import dtype_to_ctype OpenCLArray = parray.Array else: OpenCLArray = MissingComponentError("pyopencl") dtype_to_ctype = MissingComponentError("pyopencl") # pylint: disable=E0606 class OpenCLProcessing(ProcessingBase): array_class = OpenCLArray dtype_to_ctype = dtype_to_ctype def __init__(self, ctx=None, device_type="all", queue=None, profile=False, **kwargs): """ Initialie a OpenCLProcessing instance. Parameters ---------- ctx: pycuda.driver.Context, optional Existing context to use. If provided, do not create a new context. cleanup_at_exit: bool, optional Whether to clean-up the context at exit. Ignored if ctx is not None. """ super().__init__() if queue is not None: # re-use an existing queue. In this case the this instance is mostly for convenience ctx = queue.context if ctx is None: self.ctx = get_opencl_context(device_type=device_type, **kwargs) else: self.ctx = ctx if queue is None: queue_init_kwargs = {} if profile: queue_init_kwargs = {"properties": cl.command_queue_properties.PROFILING_ENABLE} queue = cl.CommandQueue(self.ctx, **queue_init_kwargs) self.queue = queue dev_types = { cl.device_type.CPU: "cpu", cl.device_type.GPU: "gpu", cl.device_type.ACCELERATOR: "accelerator", -1: "unknown", } self.device_type = dev_types.get(self.ctx.devices[0].type, "unknown") # TODO push_context, pop_context ? def _allocate_array_mem(self, shape, dtype): return parray.zeros(self.queue, shape, dtype) def kernel(self, kernel_name, filename=None, src=None, automation_params=None, **build_kwargs): return OpenCLKernel( kernel_name, None, queue=self.queue, filename=filename, src=src, automation_params=automation_params, **build_kwargs, ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/opencl/src/0000755000175000017500000000000014730277752015277 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/opencl/src/ElementOp.cl0000644000175000017500000000235214422670612017477 0ustar00pierrepierre#include typedef cfloat_t complex; __kernel void cpy2d( __global float* dst, __global float* src, int dst_width, int src_width, int2 dst_offset, int2 src_offset, int2 transfer_shape) { int gidx = get_global_id(0), gidy = get_global_id(1); if (gidx < transfer_shape.x && gidy < transfer_shape.y) { dst[(dst_offset.y + gidy)*dst_width + (dst_offset.x + gidx)] = src[(src_offset.y + gidy)*src_width + (src_offset.x + gidx)]; } } // arr2D *= arr1D (line by line, i.e along fast dim) __kernel void inplace_complex_mul_2Dby1D(__global complex* arr2D, __global complex* arr1D, int width, int height) { int x = get_global_id(0); int y = get_global_id(1); if ((x >= width) || (y >= height)) return; size_t i = y*width + x; arr2D[i] = cfloat_mul(arr2D[i], arr1D[x]); } // arr3D *= arr1D (along fast dim) __kernel void inplace_complex_mul_3Dby1D(__global complex* arr3D, __global complex* arr1D, int width, int height, int depth) { int x = get_global_id(0); int y = get_global_id(1); int z = get_global_id(2); if ((x >= width) || (y >= height) || (z >= depth)) return; size_t i = (z*height + y)*width + x; arr3D[i] = cfloat_mul(arr3D[i], arr1D[x]); } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/backproj.cl0000644000175000017500000001266714550227307017415 0ustar00pierrepierre#ifndef SHARED_SIZE #define SHARED_SIZE 256 #endif #ifdef CLIP_OUTER_CIRCLE static inline int is_in_circle(float x, float y, float center_x, float center_y, int radius2) { return (((x - center_x)*(x - center_x) + (y - center_y)*(y - center_y)) <= radius2); } #endif /* Linear interpolation on a 2D array, horizontally. This will return arr[y][x] where y is an int (exact access) and x is a float (linear interp horizontally) */ static inline float linear_interpolation(global float* arr, int Nx, float x, int y) { if (x < 0 || x >= Nx) return 0.0f; // texture address mode CLAMP_TO_EDGE int xm = (int) floor(x); int xp = (int) ceil(x); if ((xm == xp) || (xp >= Nx)) return arr[y*Nx+xm]; else return (arr[y*Nx+xm] * (xp - x)) + (arr[y*Nx+xp] * (x - xm)); } kernel void backproj( global float* d_slice, #ifdef USE_TEXTURES read_only image2d_t d_sino, #else global float* d_sino, #endif int num_projs, int num_bins, float axis_position, int n_x, int n_y, float offset_x, float offset_y, global float* d_cos, global float* d_msin, #ifdef DO_AXIS_CORRECTION global float* d_axis_corr, #endif float scale_factor, local float* shared2 // local mem ) { int x = get_global_id(0); int y = get_global_id(1); uint Gx = get_global_size(0); uint Gy = get_global_size(1); #ifdef USE_TEXTURES const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_LINEAR; #endif // (xr, yr) (xrp, yr) // (xr, yrp) (xrp, yrp) float xr = (x + offset_x) - axis_position, yr = (y + offset_y) - axis_position; float xrp = xr + Gx, yrp = yr + Gy; local float s_cos[SHARED_SIZE]; local float s_msin[SHARED_SIZE]; #ifdef DO_AXIS_CORRECTION local float s_axis[SHARED_SIZE]; float axcorr; #endif int next_fetch = 0; int tid = get_local_id(1) * get_local_size(0) + get_local_id(0); float costheta, msintheta; float h1, h2, h3, h4; float sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, sum4 = 0.0f; for (int proj = 0; proj < num_projs; proj++) { if (proj == next_fetch) { // Fetch SHARED_SIZE values to shared memory barrier(CLK_LOCAL_MEM_FENCE); if (next_fetch + tid < num_projs) { s_cos[tid] = d_cos[next_fetch + tid]; s_msin[tid] = d_msin[next_fetch + tid]; #ifdef DO_AXIS_CORRECTION s_axis[tid] = d_axis_corr[next_fetch + tid]; #endif } next_fetch += SHARED_SIZE; barrier(CLK_LOCAL_MEM_FENCE); } costheta = s_cos[proj - (next_fetch - SHARED_SIZE)]; msintheta = s_msin[proj - (next_fetch - SHARED_SIZE)]; #ifdef DO_AXIS_CORRECTION axcorr = s_axis[proj - (next_fetch - SHARED_SIZE)]; #endif float c1 = fma(costheta, xr, axis_position); // cos(theta)*xr + axis_pos float c2 = fma(costheta, xrp, axis_position); // cos(theta)*(xr + Gx) + axis_pos float s1 = fma(msintheta, yr, 0.0f); // -sin(theta)*yr float s2 = fma(msintheta, yrp, 0.0f); // -sin(theta)*(yr + Gy) h1 = c1 + s1; h2 = c2 + s1; h3 = c1 + s2; h4 = c2 + s2; #ifdef DO_AXIS_CORRECTION h1 += axcorr; h2 += axcorr; h3 += axcorr; h4 += axcorr; #endif #ifdef USE_TEXTURES if (h1 >= 0 && h1 < num_bins) sum1 += read_imagef(d_sino, sampler, (float2) (h1 +0.5f,proj +0.5f)).x; if (h2 >= 0 && h2 < num_bins) sum2 += read_imagef(d_sino, sampler, (float2) (h2 +0.5f,proj +0.5f)).x; if (h3 >= 0 && h3 < num_bins) sum3 += read_imagef(d_sino, sampler, (float2) (h3 +0.5f,proj +0.5f)).x; if (h4 >= 0 && h4 < num_bins) sum4 += read_imagef(d_sino, sampler, (float2) (h4 +0.5f,proj +0.5f)).x; #else if (h1 >= 0 && h1 < num_bins) sum1 += linear_interpolation(d_sino, num_bins, h1, proj); if (h2 >= 0 && h2 < num_bins) sum2 += linear_interpolation(d_sino, num_bins, h2, proj); if (h3 >= 0 && h3 < num_bins) sum3 += linear_interpolation(d_sino, num_bins, h3, proj); if (h4 >= 0 && h4 < num_bins) sum4 += linear_interpolation(d_sino, num_bins, h4, proj); #endif } int write_topleft = 1, write_topright = 1, write_botleft = 1, write_botright = 1; #ifdef CLIP_OUTER_CIRCLE float center_x = (n_x - 1)/2.0f, center_y = (n_y - 1)/2.0f; int radius2 = min(n_x/2, n_y/2); radius2 *= radius2; write_topleft = is_in_circle(x, y, center_x, center_y, radius2); write_topright = is_in_circle(x + Gx, y, center_x, center_y, radius2); write_botleft = is_in_circle(x, y + Gy, center_x, center_y, radius2); write_botright = is_in_circle(x + Gy, y + Gy, center_x, center_y, radius2); #endif // useful only if n_x < blocksize_x or n_y < blocksize_y if (x >= n_x) return; if (y >= n_y) return; // Pixels in top-left quadrant if (write_topleft) d_slice[y*(n_x) + x] = sum1 * scale_factor; // Pixels in top-right quadrant if ((Gx + x < n_x) && (write_topright)) { d_slice[y*(n_x) + Gx + x] = sum2 * scale_factor; } if (Gy + y < n_y) { // Pixels in bottom-left quadrant if (write_botleft) d_slice[(y+Gy)*(n_x) + x] = sum3 * scale_factor; // Pixels in bottom-right quadrant if ((Gx + x < n_x) && (write_botright)) d_slice[(y+Gy)*(n_x) + Gx + x] = sum4 * scale_factor; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/fftshift.cl0000644000175000017500000000327614550227307017433 0ustar00pierrepierre#include #ifndef DTYPE #define DTYPE float #endif static inline void swap(global DTYPE* arr, size_t idx, size_t idx2) { DTYPE tmp = arr[idx]; arr[idx] = arr[idx2]; arr[idx2] = tmp; } /* In-place one-dimensional fftshift, along horizontal dimension. The array can be 1D or 2D. direction > 0 means fftshift, direction < 0 means ifftshift. It works for even-sized arrays. Odd-sized arrays need an additional step (see roll.cl: roll_forward_x) */ __kernel void fftshift_x_inplace( __global DTYPE* arr, int Nx, int Ny, int direction ) { int x = get_global_id(0), y = get_global_id(1); int shift = Nx / 2; if (x >= shift) return; // (i)fftshift on odd-sized arrays cannot be done in-place in one step - need another kernel after this one if ((Nx & 1) && (direction > 0)) shift++; size_t idx = y * Nx + x; size_t idx_out = y * Nx + ((x + shift) % Nx); swap(arr, idx, idx_out); } #ifdef DTYPE_OUT /* Out-of-place fftshift, possibly with type casting - useful for eg. fft(ifftshift(array)) */ __kernel void fftshift_x(global DTYPE* arr, global DTYPE_OUT* dst, int Nx, int Ny, int direction) { int x = get_global_id(0), y = get_global_id(1); if (x >= Nx || y >= Ny) return; int shift = Nx / 2; if ((Nx & 1) && (direction < 0)) shift++; size_t idx = y * Nx + x; size_t idx_out = y * Nx + ((x + shift) % Nx); DTYPE_OUT out_item; #ifdef CAST_TO_COMPLEX out_item = cfloat_new(arr[idx], 0); #else #ifdef CAST_TO_REAL out_item = cfloat_real(arr[idx]); #else out_item = (DTYPE_OUT) arr[idx]; #endif #endif dst[idx_out] = out_item; } #endif ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/halftomo.cl0000644000175000017500000000255214550227307017423 0ustar00pierrepierre/* Multiply in-place a 360 degrees sinogram with weights. This kernel is used to prepare a sinogram to be backprojected using half-tomography geometry. One of the sides (left or right) is multiplied with weights. For example, if "r" is the center of rotation near the right side: sinogram[:, -overlap_width:] *= weights where overlap_width = 2*(n_x - 1 - r) This can still be improved when the geometry has horizontal translations. In this case, we should have "start_x" and "end_x" as arrays of size n_angles, i.e one varying (start_x, end_x) per angle. Parameters ----------- * sinogram: array of size (n_angles, n_x): 360 degrees sinogram * weights: array of size (n_angles,): weights to apply on one side of the sinogram * n_angles: int: number of angles * n_x: int: horizontal size (number of pixels) of the sinogram * start_x: int: start x-position for applying the weights * end_x: int: end x-position for applying the weights (included!) */ kernel void halftomo_prepare_sinogram( global float* sinogram, global float* weights, int n_angles, int n_x, int start_x, int end_x ) { uint x = get_global_id(0); uint i_angle = get_global_id(1); if (x < start_x || x > end_x || i_angle >= n_angles) return; sinogram[i_angle * n_x + x] *= weights[x - start_x]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/padding.cl0000644000175000017500000000071414550227307017216 0ustar00pierrepierre __kernel void coordinate_transform( __global float* array_in, __global float* array_out, __global int* cols_inds, __global int* rows_inds, int Nx, int Nx_padded, int Ny_padded ) { uint x = get_global_id(0); uint y = get_global_id(1); if ((x >= Nx_padded) || (y >= Ny_padded)) return; uint idx = y*Nx_padded + x; int x2 = cols_inds[x]; int y2 = rows_inds[y]; array_out[idx] = array_in[y2*Nx + x2]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/roll.cl0000644000175000017500000000466614550227307016572 0ustar00pierrepierre#include static inline void swap(global DTYPE* arr, size_t idx, size_t idx2) { DTYPE tmp = arr[idx]; arr[idx] = arr[idx2]; arr[idx2] = tmp; } /* This code should work but it not used yet. The first intent was to have an in-place fftshift for odd-sized arrays: fftshift_odd = fftshift_even followed by roll(-1) on second half of the array ifft_odd = fftshift_even followed by roll(1) on second half of the array Roll elements (as in numpy.roll(arr, 1)) of an array, in-place. Needs to be launched with a large horizontal work group. */ __kernel void roll_forward_x( __global DTYPE* array, int Nx, int Ny, int offset_x, __local DTYPE* shmem ) { int Nx_tot = Nx; if (offset_x > 0) { Nx_tot = Nx; Nx -= offset_x; } int x = get_global_id(0), y = get_global_id(1); if ((x >= Nx / 2) || (y >= Ny)) return; __global DTYPE* arr = array + y * Nx_tot + offset_x; int lid = get_local_id(0); int wg_size = get_local_size(0); int n_steps = (int) ceil((Nx - (Nx & 1)) * 1.0f / (2*wg_size)); DTYPE previous, current, write_on_first; int offset = 0; for (int step = 0; step < n_steps; step++) { int idx = 2*lid + 1; if (offset + idx >= Nx) break; previous = arr[offset + idx - 1]; current = arr[offset + idx]; arr[offset + idx] = previous; if ((step == n_steps - 1) && (offset + idx + 1 >= Nx - 1)) { if (Nx & 1) write_on_first = arr[offset + idx + 1]; else write_on_first = current; } barrier(CLK_LOCAL_MEM_FENCE); if ((step > 0) && (lid == 0)) arr[offset + idx - 1] = shmem[0]; if ((lid == wg_size - 1) && (step < n_steps - 1)) shmem[0] = current; else if (offset + idx + 1 <= Nx - 1) arr[offset + idx + 1] = current; if ((step == n_steps - 1) && (offset + idx + 1 >= Nx - 1)) arr[0] = write_on_first; barrier(CLK_LOCAL_MEM_FENCE); offset += 2 * wg_size; } } __kernel void revert_array_x( __global DTYPE* array, int Nx, int Ny, int offset_x ) { int x = get_global_id(0), y = get_global_id(1); int Nx_tot = Nx; if (offset_x > 0) { Nx_tot = Nx; Nx -= offset_x; } if ((x >= Nx / 2) || (y >= Ny)) return; size_t idx = y * Nx_tot + offset_x + x; size_t idx2 = y * Nx_tot + offset_x + (Nx - 1 - x); // Nx ? swap(array, idx, idx2); }././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/src/transpose.cl0000644000175000017500000000070214550227307017623 0ustar00pierrepierre#ifndef SRC_DTYPE #define SRC_DTYPE float #endif #ifndef DST_DTYPE #define DST_DTYPE float #endif #include __kernel void transpose(__global SRC_DTYPE* src, __global DST_DTYPE* dst, int src_width, int src_height) { // coordinates for "dst" uint x = get_global_id(0); uint y = get_global_id(1); if ((x >= src_height) || (y >= src_width)) return; dst[y*src_height + x] = (DST_DTYPE) src[x*src_width + y]; }././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/opencl/tests/0000755000175000017500000000000014730277752015652 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/opencl/tests/__init__.py0000644000175000017500000000000014422670612017737 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/opencl/utils.py0000644000175000017500000002363514550227307016222 0ustar00pierrepierreimport numpy as np from ..utils import check_supported try: import pyopencl as cl import pyopencl.array as parray __has_pyopencl__ = True __pyopencl_error_msg__ = None except ImportError as err: __has_pyopencl__ = False __pyopencl_error_msg__ = str(err) from ..resources.gpu import GPUDescription def get_opencl_devices( device_type, vendor=None, name=None, order_by="global_mem_size", exclude_platforms=None, exclude_vendors=None, prefer_GPU=True, ): """ List available OpenCL devices. Parameters ---------- device_type: str Type of device, can be "CPU", "GPU" or "all". vendor: str, optional Filter devices by vendor, eg. "NVIDIA" name: Filter devices by names, eg. "GeForce RTX 3080" order_by: str, optional Order results in decreasing order of this value. Default is to sort by global memory size. exclude_platforms: str, optional Platforms to exclude, eg. "Portable Computing Language" exclude_vendors: str, optional Vendors to be excluded prefer_GPU: bool, optional Whether to put GPUs on top of the returned list, regardless of the "order_by" parameter. This can be useful when sorting by global memory size, as CPUs often have a bigger memory size. Returns ------- devices: list of pyopencl.Device List of OpenCL devices matching the criteria, and ordered by the 'order_by' parameter. The list may be empty. """ exclude_platforms = exclude_platforms or [] exclude_vendors = exclude_vendors or [] dev_type = { "cpu": cl.device_type.CPU, "gpu": cl.device_type.GPU, "accelerator": cl.device_type.ACCELERATOR, "all": cl.device_type.ALL, "any": cl.device_type.ALL, } device_type = device_type.lower() check_supported(device_type, dev_type.keys(), "device_type") devices = [] for platform in cl.get_platforms(): if vendor is not None and vendor.lower() not in platform.vendor.lower(): continue if any(excluded_platform.lower() in platform.name.lower() for excluded_platform in exclude_platforms): continue if any(excluded_vendor.lower() in platform.vendor.lower() for excluded_vendor in exclude_vendors): continue for device in platform.get_devices(): if device.type & dev_type[device_type] == 0: continue if name is not None and name.lower() not in device.name.lower(): continue devices.append(device) if order_by is not None: devices = sorted(devices, key=lambda dev: getattr(dev, order_by), reverse=True) if prefer_GPU: # put GPUs devices on top of the list devices = [dev for dev in devices if dev.type & dev_type["gpu"] > 0] + [ dev for dev in devices if dev.type & dev_type["gpu"] == 0 ] return devices def usable_opencl_devices(): """ Test the available OpenCL platforms/devices. Returns -------- platforms: dict Dictionary where the key is the platform name, and the value is a list of `silx.opencl.common.Device` object. """ platforms = {} for platform in cl.get_platforms(): platforms[platform.name] = platform.get_devices() return platforms def detect_opencl_gpus(): """ Get the available OpenCL-compatible GPUs. Returns -------- gpus: dict Nested dictionary where the keys are OpenCL platform names, values are dictionary of GPU IDs and `silx.opencl.common.Device` object. error_msg: str In the case where there is an error, the message is returned in this item. Otherwise, it is a None object. """ gpus = {} error_msg = None if not (__has_pyopencl__): return {}, __pyopencl_error_msg__ try: platforms = usable_opencl_devices() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg for platform_name, devices in platforms.items(): for d_id, device in enumerate(devices): if device.type == cl.device_type.GPU: # and bool(device.available): if platform_name not in gpus: gpus[platform_name] = {} gpus[platform_name][d_id] = device return gpus, None def collect_opencl_gpus(): """ Return a dictionary of platforms and brief description of each OpenCL-compatible GPU with a few fields """ gpus, error_msg = detect_opencl_gpus() if error_msg is not None: return None opencl_gpus = {} for platform, gpus in gpus.items(): for gpu_id, gpu in gpus.items(): if platform not in opencl_gpus: opencl_gpus[platform] = {} opencl_gpus[platform][gpu_id] = GPUDescription(gpu, device_id=gpu_id).get_dict() opencl_gpus[platform][gpu_id]["platform"] = platform return opencl_gpus def collect_opencl_cpus(): """ Return a dictionary of platforms and brief description of each OpenCL-compatible CPU with a few fields """ opencl_cpus = {} platforms = usable_opencl_devices() for platform, devices in platforms.items(): if "cuda" in platform.lower(): continue opencl_cpus[platform] = {} for device_id, device in enumerate(devices): # device_id might be inaccurate if device.type != cl.device_type.CPU: continue opencl_cpus[platform][device_id] = GPUDescription(device).get_dict() opencl_cpus[platform][device_id]["platform"] = platform return opencl_cpus def get_opencl_context(device_type, **kwargs): """ Create an OpenCL context. Please refer to 'get_opencl_devices' documentation """ devices = get_opencl_devices(device_type, **kwargs) if devices == []: raise RuntimeError("No OpenCL device found for device_type='%s' and %s" % (device_type, str(kwargs))) return cl.Context([devices[0]]) def replace_array_memory(arr, new_shape): """ Replace the underlying buffer data of a `pyopencl.array.Array`. This function is dangerous ! It should merely be used to clear memory, the array should not be used afterwise. """ arr.data.release() arr.base_data = cl.Buffer(arr.context, cl.mem_flags.READ_WRITE, np.prod(new_shape) * arr.dtype.itemsize) arr.shape = new_shape # strides seems to be updated by pyopencl return arr def pick_opencl_cpu_platform(opencl_cpus): """ Pick the best OpenCL implementation for the opencl cpu. This function assume that there is only one opencl-enabled CPU on the current machine, but there might be several OpenCL implementations/vendors. Parameters ---------- opencl_cpus: dict Dictionary with the available opencl-enabled CPUs. Usually obtained with collect_opencl_cpus(). Returns ------- cpu: dict A dictionary describing the CPU. """ if len(opencl_cpus) == 0: raise ValueError("No CPU to pick") name2device = {} for platform, devices in opencl_cpus.items(): for device_id, device_desc in devices.items(): name2device.setdefault(device_desc["name"], []) name2device[device_desc["name"]].append(platform) if len(name2device) > 1: raise ValueError("Expected at most one CPU but got %d: %s" % (len(name2device), list(name2device.keys()))) cpu_name = list(name2device.keys())[0] platforms = name2device[cpu_name] # Several platforms for the same CPU res = opencl_cpus[platforms[0]] if len(platforms) > 1: if "intel" in cpu_name.lower(): for platform in platforms: if "intel" in platform.lower(): res = opencl_cpus[platform] # return res[list(res.keys())[0]] def allocate_texture(ctx, shape, support_1D=False): """ Allocate an OpenCL image ("texture"). Parameters ---------- ctx: OpenCL context OpenCL context shape: tuple of int Shape of the image. Note that pyopencl and OpenCL < 1.2 do not support 1D images, so 1D images are handled as 2D with one row support_1D: bool, optional force the image to be 1D if the shape has only one dim """ if len(shape) == 1 and not (support_1D): shape = (1,) + shape return cl.Image( ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.USE_HOST_PTR, cl.ImageFormat(cl.channel_order.INTENSITY, cl.channel_type.FLOAT), hostbuf=np.zeros(shape[::-1], dtype=np.float32), ) def check_textures_availability(ctx): """ Check whether textures are supported on the current OpenCL context. """ try: dummy_texture = allocate_texture(ctx, (16, 16)) # Need to further access some attributes (pocl) dummy_height = dummy_texture.height textures_available = True del dummy_texture, dummy_height except (cl.RuntimeError, cl.LogicError): textures_available = False # Nvidia Fermi GPUs (compute capability 2.X) do not support opencl read_imagef # There is no way to detect this until a kernel is compiled try: cc = ctx.devices[0].compute_capability_major_nv textures_available &= cc >= 3 except (cl.LogicError, AttributeError): # probably not a Nvidia GPU pass # return textures_available def copy_to_texture(queue, dst_texture, src_array, dtype=np.float32): shape = src_array.shape if isinstance(src_array, parray.Array): return cl.enqueue_copy(queue, dst_texture, src_array.data, offset=0, origin=(0, 0), region=shape[::-1]) elif isinstance(src_array, np.ndarray): if not (src_array.flags["C_CONTIGUOUS"] and src_array.dtype == dtype): src_array = np.ascontiguousarray(src_array, dtype=dtype) return cl.enqueue_copy(queue, dst_texture, src_array, origin=(0, 0), region=shape[::-1]) else: raise ValueError("Unknown source array type") ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5087566 nabu-2024.2.1/nabu/pipeline/0000755000175000017500000000000014730277752015035 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/pipeline/__init__.py0000644000175000017500000000000014315516747017133 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/pipeline/config.py0000644000175000017500000002200514550227307016642 0ustar00pierrepierrefrom os import linesep from configparser import ConfigParser from ..utils import check_supported, deprecated # # option "type": # - required: always visible, user must provide a valid value # - optional: visible, but might be left blank # - advanced: optional and not visible by default # - unsupported: hidden (not implemented yet) _options_levels = { "required": 0, "optional": 1, "advanced": 2, "unsupported": 10, } class NabuConfigParser: @deprecated( "The class 'NabuConfigParser' is deprecated and will be removed in a future version. Please use parse_nabu_config_file instead", do_print=True, ) def __init__(self, fname): """ Nabu configuration file parser. Parameters ---------- fname: str File name of the configuration file """ parser = ConfigParser(inline_comment_prefixes=("#",)) # allow in-line comments with open(fname) as fid: file_content = fid.read() parser.read_string(file_content) self.parser = parser self.get_dict() self.file_content = file_content.split(linesep) def get_dict(self): # Is there an officially supported way to do this ? self.conf_dict = self.parser._sections return self.conf_dict def __str__(self): return self.conf_dict.__str__() def __repr__(self): return self.conf_dict.__repr__() def __getitem__(self, key): return self.conf_dict[key] def parse_nabu_config_file(fname, allow_no_value=False): """ Parse a configuration file and returns a dictionary. Parameters ---------- fname: str File name of the configuration file Returns ------- conf_dict: dict Dictionary with the configuration """ parser = ConfigParser( inline_comment_prefixes=("#",), # allow in-line comments allow_no_value=allow_no_value, ) with open(fname) as fid: file_content = fid.read() parser.read_string(file_content) conf_dict = parser._sections # Is there an officially supported way to do this ? return conf_dict def generate_nabu_configfile( fname, default_config, config=None, sections=None, sections_comments=None, comments=True, options_level=None, prefilled_values=None, ): """ Generate a nabu configuration file. Parameters ----------- fname: str Output file path. config: dict Configuration to save. If section and / or key missing will store the default value sections: list of str, optional Sections which should be included in the configuration file comments: bool, optional Whether to include comments in the configuration file options_level: str, optional Which "level" of options to embed in the file. Can be "required", "optional", "advanced". Default is "optional". """ if options_level is None: options_level = "optional" if prefilled_values is None: prefilled_values = {} check_supported(options_level, list(_options_levels.keys()), "options_level") options_level = _options_levels[options_level] if config is None: config = {} if sections is None: sections = default_config.keys() def dump_help(fid, help_sequence): for help_line in help_sequence.split(linesep): content = "# %s" % (help_line) if help_line.strip() != "" else "" content = content + linesep fid.write(content) with open(fname, "w") as fid: for section, section_content in default_config.items(): if section not in sections: continue if section != "dataset": fid.write("%s%s" % (linesep, linesep)) fid.write("[%s]%s" % (section, linesep)) if sections_comments is not None and section in sections_comments: dump_help(fid, sections_comments[section]) for key, values in section_content.items(): if options_level < _options_levels[values["type"]]: continue if comments and values["help"].strip() != "": dump_help(fid, values["help"]) value = values["default"] if section in prefilled_values and key in prefilled_values[section]: value = prefilled_values[section][key] if section in config and key in config[section]: value = config[section][key] fid.write("%s = %s%s" % (key, value, linesep)) def _extract_nabuconfig_section(section, default_config): res = {} for key, val in default_config[section].items(): res[key] = val["default"] return res def _extract_nabuconfig_keyvals(default_config): res = {} for section in default_config.keys(): res[section] = _extract_nabuconfig_section(section, default_config) return res def get_default_nabu_config(default_config): """ Return a dictionary with the default nabu configuration. """ return _extract_nabuconfig_keyvals(default_config) def _handle_modified_key(key, val, section, default_config, renamed_keys): if val is not None: return key, val, section if key in renamed_keys and renamed_keys[key]["section"] == section: info = renamed_keys[key] print(info["message"]) print("This is deprecated since version %s and will result in an error in futures versions" % (info["since"])) section = info.get("new_section", section) if info["new_name"] == "": return None, None, section # deleted key val = default_config[section].get(info["new_name"], None) return info["new_name"], val, section else: return key, None, section # unhandled renamed/deleted key def validate_config(config, default_config, renamed_keys, errors="warn"): """ Validate a configuration dictionary against a "default" configuration dict. Parameters ---------- config: dict configuration dict to be validated default_config: dict Reference configuration. Missing keys/sections from 'config' will be updated with keys from this dictionary. errors: str, optional What to do when an unknonw key/section is encountered. Possible actions are: - "warn": throw a warning, continue the validation - "raise": raise an error and exit """ def error(msg): if errors == "raise": raise ValueError(msg) else: print("Error: %s" % msg) res_config = {} for section, section_content in config.items(): # Ignore the "other" section if section.lower() == "other": continue if section not in default_config: error("Unknown section [%s]" % section) continue res_config[section] = _extract_nabuconfig_section(section, default_config) res_config[section].update(section_content) for key, value in res_config[section].items(): opt = default_config[section].get(key, None) key, opt, section_updated = _handle_modified_key(key, opt, section, default_config, renamed_keys) if key is None: continue # deleted key if opt is None: error("Unknown option '%s' in section [%s]" % (key, section_updated)) continue validator = default_config[section_updated][key]["validator"] if section_updated not in res_config: # missing section - handled later continue res_config[section_updated][key] = validator(section_updated, key, value) # Handle sections missing in config for section in set(default_config.keys()) - set(res_config.keys()): res_config[section] = _extract_nabuconfig_section(section, default_config) for key, value in res_config[section].items(): validator = default_config[section][key]["validator"] res_config[section][key] = validator(section, key, value) return res_config validate_nabu_config = deprecated("validate_nabu_config is renamed validate_config", do_print=True)(validate_config) def overwrite_config(conf, overwritten_params): """ Overwrite a (validated) configuration with a new parameters dict. Parameters ---------- conf: dict Configuration dictionary, usually output from validate_config() overwritten_params: dict Configuration dictionary with the same layout, containing parameters to overwrite """ overwritten_params = overwritten_params or {} for section, params in overwritten_params.items(): if section not in conf: raise ValueError("Unknown section %s" % section) current_section = conf[section] for key in params.keys(): if key not in current_section: raise ValueError("Unknown parameter '%s' in section '%s'" % (key, section)) conf[section][key] = overwritten_params[section][key] # --- return conf ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/pipeline/config_validators.py0000644000175000017500000003774114720040151021074 0ustar00pierrepierreimport os path = os.path from ..utils import check_supported, is_writeable from .params import * """ A validator is a function with - input: a value - output: the input value, or a modified input value - possibly raising exceptions in case of invalid value. """ # ------------------------------------------------------------------------------ # ---------------------------- Utils ------------------------------------------- # ------------------------------------------------------------------------------ def raise_error(section, key, msg=""): raise ValueError("Invalid value for %s/%s: %s" % (section, key, msg)) def validator(func): """ Common decorator for all validator functions. It modifies the signature of the decorated functions ! """ def wrapper(section, key, value): try: res = func(value) except AssertionError as e: raise_error(section, key, e) return res return wrapper def convert_to_int(val): val_int = 0 try: val_int = int(val) conversion_error = None except ValueError as exc: conversion_error = exc return val_int, conversion_error def convert_to_float(val): val_float = 0.0 try: val_float = float(val) conversion_error = None except ValueError as exc: conversion_error = exc return val_float, conversion_error def convert_to_bool(val): val_int, error = convert_to_int(val) res = None if not error: res = val_int > 0 else: if val.lower() in ["yes", "true", "y"]: res = True error = None if val.lower() in ["no", "false", "n"]: res = False error = None return res, error def str2bool(val): """This is an interface to convert_to_bool and it is meant to work as a class: in argparse interface the type argument can be set to float, int .. in general to a class. The argument value is then created, at parsing time, by typecasting the input string to the given class. A possibly occuring exception then trigger, in case, the display explanation provided by the argparse library. All what this methods does is simply trying to convert an argument into a bool, and return it, or generate an exception if there is a problem """ import argparse res, error = convert_to_bool(val) if error: raise argparse.ArgumentTypeError(error) else: return res def convert_to_bool_noerr(val): res, err = convert_to_bool(val) if err is not None: raise ValueError("Could not convert to boolean: %s" % str(val)) return res def name_range_checker(name, valid_names, descr, replacements=None): name = name.strip().lower() if replacements is not None and name in replacements: name = replacements[name] valid = name in valid_names assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, str(valid_names)) return name # ------------------------------------------------------------------------------ # ---------------------------- Validators -------------------------------------- # ------------------------------------------------------------------------------ @validator def optional_string_validator(val): if len(val.strip()) == 0: return None return val @validator def file_name_validator(name): assert len(name) >= 1, "Name should be non-empty" return name @validator def file_location_validator(location): assert path.isfile(location), "location must be a file" return os.path.abspath(location) @validator def optional_file_location_validator(location): if len(location.strip()) > 0: assert path.isfile(location), "location must be a file" return os.path.abspath(location) return None @validator def optional_values_file_validator(location): if len(location.strip()) == 0: return None if path.splitext(location)[-1].strip() == "": # Assume path to h5 dataset. Validation is done later. if "://" not in location: location = "silx://" + os.path.abspath(location) else: # Assume plaintext file assert path.isfile(location), "Invalid file path" location = os.path.abspath(location) return location @validator def directory_location_validator(location): assert path.isdir(location), "location must be a directory" return os.path.abspath(location) @validator def optional_directory_location_validator(location): if len(location.strip()) > 0: assert is_writeable(location), "Directory must be writeable" return os.path.abspath(location) return None @validator def dataset_location_validator(location): if not (path.isdir(location)): assert ( path.isfile(location) and path.splitext(location)[-1].split(".")[-1].lower() in files_formats ), "Dataset location must be a directory or a HDF5 file" return os.path.abspath(location) @validator def directory_writeable_validator(location): assert is_writeable(location), "Directory must be writeable" return os.path.abspath(location) @validator def optional_output_directory_validator(location): if len(location.strip()) > 0: return directory_writeable_validator(location) return None @validator def optional_output_file_path_validator(location): if len(location.strip()) > 0: dirname, fname = path.split(location) assert os.access(dirname, os.W_OK), "Directory must be writeable" return os.path.abspath(location) return None @validator def integer_validator(val): val_int, error = convert_to_int(val) assert error is None, "number must be an integer" return val_int @validator def nonnegative_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int >= 0, "number must be a non-negative integer" return val_int @validator def positive_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int > 0, "number must be a positive integer" return val_int @validator def optional_positive_integer_validator(val): if len(val.strip()) == 0: return None val_int, error = convert_to_int(val) assert error is None and val_int > 0, "number must be a positive integer" return val_int @validator def nonzero_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int != 0, "number must be a non-zero integer" return val_int @validator def binning_validator(val): if val == "": val = "1" val_int, error = convert_to_int(val) assert error is None and val_int >= 0, "number must be a non-negative integer" return max(1, val_int) @validator def projections_subsampling_validator(val): val = val.strip() err_msg = "projections_subsampling: expected one positive integer or two integers in the format step:begin" if ":" not in val: val += ":0" step, begin = val.split(":") step_int, error1 = convert_to_int(step) begin_int, error2 = convert_to_int(begin) if error1 is not None or error2 is not None or step_int <= 0 or begin_int < 0: raise ValueError(err_msg) return step_int, begin_int @validator def optional_file_name_validator(val): if len(val) > 0: assert len(val) >= 1, "Name should be non-empty" assert path.basename(val) == val, "File name should not be a path (no '/')" return val return None @validator def boolean_validator(val): res, error = convert_to_bool(val) assert error is None, "Invalid boolean value" return res @validator def boolean_or_auto_validator(val): res, error = convert_to_bool(val) if error is not None: assert val.lower() == "auto", "Valid values are 0, 1 and auto" return val return res @validator def float_validator(val): val_float, error = convert_to_float(val) assert error is None, "Invalid number" return val_float @validator def optional_float_validator(val): if isinstance(val, float): return val elif len(val.strip()) >= 1: val_float, error = convert_to_float(val) assert error is None, "Invalid number" else: val_float = None return val_float @validator def optional_nonzero_float_validator(val): if isinstance(val, float): val_float = val elif len(val.strip()) >= 1: val_float, error = convert_to_float(val) assert error is None, "Invalid number" else: val_float = None if val_float is not None: if abs(val_float) < 1e-6: val_float = None return val_float @validator def optional_tuple_of_floats_validator(val): if len(val.strip()) == 0: return None err_msg = "Expected a tuple of two numbers, but got %s" % val try: res = tuple(float(x) for x in val.strip("()").split(",")) except Exception as exc: raise ValueError(err_msg) if len(res) != 2: raise ValueError(err_msg) return res @validator def cor_validator(val): val_float, error = convert_to_float(val) if error is None: return val_float if len(val.strip()) == 0: return None val = name_range_checker( val.lower(), set(cor_methods.values()), "center of rotation estimation method", replacements=cor_methods ) return val @validator def tilt_validator(val): val_float, error = convert_to_float(val) if error is None: return val_float if len(val.strip()) == 0: return None val = name_range_checker( val.lower(), set(tilt_methods.values()), "automatic detector tilt estimation method", replacements=tilt_methods ) return val @validator def slice_num_validator(val): val_int, error = convert_to_int(val) if error is None: return val_int else: assert val in [ "first", "middle", "last", ], "Expected start_z and end_z to be either a number or first, middle or last" return val @validator def generic_options_validator(val): if len(val.strip()) == 0: return None return val cor_options_validator = generic_options_validator @validator def cor_slice_validator(val): if len(val) == 0: return None val_int, error = convert_to_int(val) if error: supported = ["top", "first", "bottom", "last", "middle"] assert val in supported, "Invalid value, must be a number or one of %s" % supported return val else: return val_int @validator def flatfield_enabled_validator(val): return name_range_checker(val, set(flatfield_modes.values()), "flatfield mode", replacements=flatfield_modes) @validator def phase_method_validator(val): return name_range_checker( val, set(phase_retrieval_methods.values()), "phase retrieval method", replacements=phase_retrieval_methods ) @validator def detector_distortion_correction_validator(val): return name_range_checker( val, set(detector_distortion_correction_methods.values()), "detector_distortion_correction_methods", replacements=detector_distortion_correction_methods, ) @validator def unsharp_method_validator(val): return name_range_checker( val, set(unsharp_methods.values()), "unsharp mask method", replacements=phase_retrieval_methods ) @validator def padding_mode_validator(val): return name_range_checker(val, set(padding_modes.values()), "padding mode", replacements=padding_modes) @validator def reconstruction_method_validator(val): return name_range_checker( val, set(reconstruction_methods.values()), "reconstruction method", replacements=reconstruction_methods ) @validator def fbp_filter_name_validator(val): return name_range_checker( val, set(fbp_filters.values()), "FBP filter", replacements=fbp_filters, ) @validator def reconstruction_implementation_validator(val): return name_range_checker( val, set(reco_implementations.values()), "Reconstruction method implementation", replacements=reco_implementations, ) @validator def optimization_algorithm_name_validator(val): return name_range_checker( val, set(optim_algorithms.values()), "optimization algorithm name", replacements=iterative_methods ) @validator def output_file_format_validator(val): return name_range_checker(val, set(files_formats.values()), "output file format", replacements=files_formats) @validator def distribution_method_validator(val): val = name_range_checker( val, set(distribution_methods.values()), "workload distribution method", replacements=distribution_methods ) # TEMP. if val != "local": raise NotImplementedError("Computation method '%s' is not implemented yet" % val) # -- return val @validator def sino_normalization_validator(val): val = name_range_checker( val, set(sino_normalizations.values()), "sinogram normalization method", replacements=sino_normalizations ) return val @validator def sino_deringer_methods(val): val = name_range_checker( val, set(rings_methods.values()), "sinogram rings artefacts correction method", replacements=rings_methods, ) return val @validator def list_of_int_validator(val): ids = val.replace(",", " ").split() res = list(map(convert_to_int, ids)) err = list(filter(lambda x: x[1] is not None or x[0] < 0, res)) if err != []: raise ValueError("Could not convert to a list of GPU IDs: %s" % val) return list(set(map(lambda x: x[0], res))) @validator def list_of_shift_validator(values): ids = values.replace(" ", "").split(",") return [int(val) if val not in ("auto", "'auto'", '"auto"') else "auto" for val in ids] @validator def list_of_tomoscan_identifier(val): # TODO: insure those are valid tomoscan identifier return val @validator def resources_validator(val): val = val.strip() is_percentage = False if "%" in val: is_percentage = True val = val.replace("%", "") val_float, conversion_error = convert_to_float(val) assert conversion_error is None, str("Error while converting %s to float" % val) return (val_float, is_percentage) @validator def walltime_validator(val): # HH:mm:ss vals = val.strip().split(":") error_msg = "Invalid walltime format, expected HH:mm:ss" assert len(vals) == 3, error_msg hours, mins, secs = vals hours, err1 = convert_to_int(hours) mins, err2 = convert_to_int(mins) secs, err3 = convert_to_int(secs) assert err1 is None and err2 is None and err3 is None, error_msg err = hours < 0 or mins < 0 or mins > 59 or secs < 0 or secs > 59 assert err is False, error_msg return hours, mins, secs @validator def nonempty_string_validator(val): assert val != "", "Value cannot be empty" return val @validator def logging_validator(val): return name_range_checker(val, set(log_levels.values()), "logging level", replacements=log_levels) @validator def exclude_projections_validator(val): val = val.strip() if val == "": return None if path.isfile(val): # previous/default behavior return {"type": "indices", "file": val} if "=" not in val: raise ValueError( "exclude_projections: expected either 'angles=angles_file.txt' or 'indices=indices_file.txt' or 'angular_range=[a,b]'" ) excl_type, excl_val = val.split("=") excl_type = excl_type.strip() excl_val = excl_val.strip() check_supported(excl_type, exclude_projections_type.keys(), "exclude_projections type") if excl_type == "angular_range": def _get_range(range_val): for c in ["(", ")", "[", "]"]: range_val = range_val.replace(c, "") r_min, r_max = range_val.split(",") return (float(r_min), float(r_max)) return {"type": "angular_range", "range": _get_range(excl_val)} else: return {"type": excl_type, "file": excl_val} @validator def no_validator(val): return val ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/pipeline/datadump.py0000644000175000017500000001550614656662150017212 0ustar00pierrepierrefrom os import path from ..resources.logger import LoggerOrPrint from .utils import get_subregion from .writer import WriterManager from ..io.reader import get_hdf5_dataset_shape try: import pycuda.gpuarray as garray __has_pycuda__ = True except: __has_pycuda__ = False class DataDumpManager: """ A helper class for managing data dumps, with the aim of saving/resuming the processing from a given step. """ def __init__(self, process_config, sub_region, margin=None, logger=None): """ Initialize a DataDump object. Parameters ----------- process_config: ProcessConfig ProcessConfig object sub_region: tuple of int Series of integers defining the sub-region being processed. The form is ((start_angle, end_angle), (start_z, end_z), (start_x, end_x)) margin: tuple of int, optional Margin, used when processing data, in the form ((up, down), (left, right)). Each item can be None. Using a margin means that a given chunk of data will eventually be cropped as `data[:, up:-down, left:-right]` logger: Logger, optional Logging object """ self.process_config = process_config self.processing_steps = process_config.processing_steps self.processing_options = process_config.processing_options self.dataset_info = process_config.dataset_info self._set_subregion_and_margin(sub_region, margin) self.logger = LoggerOrPrint(logger) self._configure_data_dumps() def _set_subregion_and_margin(self, sub_region, margin): self.sub_region = get_subregion(sub_region) self._z_sub_region = self.sub_region[1] self.z_min = self._z_sub_region[0] self.margin = get_subregion(margin, ndim=2) # ((U, D), (L, R)) self.margin_up = self.margin[0][0] or 0 self.start_index = self.z_min + self.margin_up self.delta_z = self._z_sub_region[-1] - self._z_sub_region[-2] self._grouped_processing = False iangle1, iangle2 = self.sub_region[0] if iangle1 != 0 or iangle2 < len(self.process_config.rotation_angles(subsampling=False)): self._grouped_processing = True self.start_index = self.sub_region[0][0] def _configure_dump(self, step_name, force_dump_to_fname=None): if force_dump_to_fname is not None: # Shortcut fname_full = force_dump_to_fname elif step_name in self.processing_steps: # Standard case if not self.processing_options[step_name].get("save", False): return fname_full = self.processing_options[step_name]["save_steps_file"] elif step_name == "sinogram" and self.process_config.dump_sinogram: # "sinogram" is a special keyword fname_full = self.process_config.dump_sinogram_file else: return # "fname_full" is the path to the final master file. # We also need to create partial files (in a sub-directory) fname, ext = path.splitext(fname_full) dirname, file_prefix = path.split(fname) self.data_dump[step_name] = WriterManager( dirname, file_prefix, file_format="hdf5", overwrite=True, start_index=self.start_index, logger=self.logger, metadata={ "process_name": step_name, "processing_index": 0, "config": { "processing_options": self.processing_options, # slow! "nabu_config": self.process_config.nabu_config, }, "entry": getattr(self.dataset_info.dataset_scanner, "entry", "entry"), }, ) def _configure_data_dumps(self): self.data_dump = {} for step_name in self.processing_steps: self._configure_dump(step_name) # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation if self.process_config.dump_sinogram: self._configure_dump("sinogram") def get_data_dump(self, step_name): """ Get information on where to write a given processing step. Parameters ---------- step_name: str Name of the processing step Returns ------- writer_configurator: WriterConfigurator An object with information on where to write the data for the given processing step. """ return self.data_dump.get(step_name, None) def get_read_dump_subregion(self): read_opts = self.processing_options["read_chunk"] if read_opts.get("process_file", None) is None: return None dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"] relative_start_z = self.z_min - dump_start_z relative_end_z = relative_start_z + self.delta_z # When using binning, every step after "read" results in smaller-sized data. # Therefore dumped data has shape (ceil(n_angles/subsampling), n_z//binning_z, n_x//binning_x) relative_start_z //= self.process_config.binning_z relative_end_z //= self.process_config.binning_z # (n_angles, n_z, n_x) subregion = (None, None, relative_start_z, relative_end_z, None, None) return subregion def _check_resume_from_step(self): read_opts = self.processing_options["read_chunk"] expected_radios_shape = get_hdf5_dataset_shape( read_opts["process_file"], read_opts["process_h5_path"], sub_region=self.get_read_dump_subregion(), ) # TODO check def dump_data_to_file(self, step_name, data, crop_margin=False): if step_name not in self.data_dump: return writer = self.data_dump[step_name] self.logger.info("Dumping data to %s" % writer.fname) if __has_pycuda__: if isinstance(data, garray.GPUArray): data = data.get() margin_up = self.margin[0][0] or None margin_down = self.margin[0][1] or None margin_down = -margin_down if margin_down is not None else None # pylint: disable=E1130 if crop_margin and (margin_up is not None or margin_down is not None): data = data[:, margin_up:margin_down, :] metadata = {"dump_sub_region": {"sub_region": self.sub_region, "margin": self.margin}} writer.write_data(data, metadata=metadata) def __repr__(self): res = "%s(%s, margin=%s)" % (self.__class__.__name__, str(self.sub_region), str(self.margin)) if len(self.data_dump) > 0: for step_name, writer_configurator in self.data_dump.items(): res += "\n- Dump %s to %s" % (step_name, writer_configurator.fname) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1713526109.0 nabu-2024.2.1/nabu/pipeline/dataset_validator.py0000644000175000017500000002205714610452535021076 0ustar00pierrepierreimport os from ..resources.logger import LoggerOrPrint from ..utils import copy_dict_items from ..reconstruction.sinogram import get_extended_sinogram_width class DatasetValidatorBase: # this in the helical derived class will be False _check_also_z = True def __init__(self, nabu_config, dataset_info, logger=None): """ Perform a coupled validation of nabu configuration against dataset information. Check the consistency of these two structures, and modify them in-place. Parameters ---------- nabu_config: dict Dictionary containing the nabu configuration, usually got from `nabu.pipeline.config.validate_config()` It will be modified ! dataset_info: `DatasetAnalyzer` instance Structure containing information on the dataset to process. It will be modified ! """ self.nabu_config = nabu_config self.dataset_info = dataset_info self.logger = LoggerOrPrint(logger) self.rec_params = copy_dict_items(self.nabu_config["reconstruction"], self.nabu_config["reconstruction"].keys()) self._validate() def _validate(self): raise ValueError("Base class") @property def is_halftomo(self): do_halftomo = self.nabu_config["reconstruction"].get("enable_halftomo", False) if do_halftomo == "auto": do_halftomo = self.dataset_info.is_halftomo if do_halftomo is None: raise ValueError( "'enable_halftomo' was set to 'auto' but unable to get the information on field of view" ) return do_halftomo def _check_not_empty(self): if len(self.dataset_info.projections) == 0: msg = "Dataset seems to be empty (no projections)" self.logger.fatal(msg) raise ValueError(msg) if self.dataset_info.n_angles is None: msg = "Could not determine the number of projections. Please check the .info or HDF5 file" self.logger.fatal(msg) raise ValueError(msg) for dim_name, n in zip(["dim_1", "dim_2"], self.dataset_info.radio_dims): if n is None: msg = "Could not determine %s. Please check the .info file or HDF5 file" % dim_name self.logger.fatal(msg) raise ValueError(msg) @staticmethod def _convert_negative_idx(idx, last_idx): res = idx if idx < 0: res = last_idx + idx return res def _get_nx_ny(self, binning_factor=1): nx = self.dataset_info.radio_dims[0] // binning_factor if self.is_halftomo: cor = self._get_cor(binning_factor=binning_factor) nx = get_extended_sinogram_width(nx, cor) ny = nx return nx, ny def _get_cor(self, binning_factor=1): cor = self.dataset_info.axis_position if binning_factor >= 1: # Backprojector uses middle of pixel for coordinate indices. # This means that the leftmost edge of the leftmost pixel has coordinate -0.5. # When using binning with a factor 'b', the CoR has to adapted as # cor_binned = (cor + 0.5)/b - 0.5 cor = (cor + 0.5) / binning_factor - 0.5 return cor def _convert_negative_indices(self): """ Convert any negative index to the corresponding positive index. """ nx, nz = self.dataset_info.radio_dims ny = nx if self.is_halftomo: if self.dataset_info.axis_position is None: raise ValueError( "Cannot use rotation axis position in the middle of the detector when half tomo is enabled" ) nx, ny = self._get_nx_ny() what = ( ("start_x", nx), ("end_x", nx), ("start_y", ny), ("end_y", ny), ) if self._check_also_z: what = what + ( ("start_z", nz), ("end_z", nz), ) for key, upper_bound in what: val = self.rec_params[key] if isinstance(val, str): idx_mapping = { "first": 0, "middle": upper_bound // 2, # works on both start_ and end_ since the end_ index is included "last": upper_bound - 1, # upper bound is included in the user interface (contrarily to python) } res = idx_mapping[val] else: res = self._convert_negative_idx(self.rec_params[key], upper_bound) self.rec_params[key] = res self.rec_region = copy_dict_items(self.rec_params, [w[0] for w in what]) def _get_output_filename(self): # This function modifies nabu_config ! opts = self.nabu_config["output"] dataset_path = self.nabu_config["dataset"]["location"] if opts["location"] == "" or opts["location"] is None: opts["location"] = os.path.dirname(dataset_path) if opts["file_prefix"] == "" or opts["file_prefix"] is None: if os.path.isfile(dataset_path): # hdf5 file_prefix = os.path.basename(dataset_path).split(".")[0] elif os.path.isdir(dataset_path): file_prefix = os.path.basename(dataset_path) else: raise ValueError("dataset location %s is neither a file or directory" % dataset_path) file_prefix += "_rec" # avoid overwriting dataset opts["file_prefix"] = file_prefix @staticmethod def _check_start_end_idx(start, end, n_elements, start_name="start_x", end_name="end_x"): assert start >= 0 and start < n_elements, "Invalid value %d for %s, must be >= 0 and < %d" % ( start, start_name, n_elements, ) assert end >= 0 and end < n_elements, "Invalid value for %d %s, must be >= 0 and < %d" % ( end, end_name, n_elements, ) assert start <= end, "Must have %s <= %s" % (start_name, end_name) def _handle_binning(self): """ Modify the dataset description/process config to handle binning and projections subsampling """ dataset_cfg = self.nabu_config["dataset"] self.binning = (dataset_cfg["binning"], dataset_cfg["binning_z"]) subsampling_factor, subsampling_start = dataset_cfg["projections_subsampling"] self.subsampling_factor = subsampling_factor or 1 self.subsampling_start = subsampling_start or 0 if self.binning != (1, 1): bin_x, bin_z = self.binning rec_cfg = self.rec_params # Update "start_xyz" rec_cfg["start_x"] //= bin_x rec_cfg["start_y"] //= bin_x rec_cfg["start_z"] //= bin_z # Update "end_xyz". Things are a little bit more complicated for several reasons: # - In the user interface (configuration file), end_xyz index is INCLUDED (contrarily to python). So there are +1, -1 all over the place. # - When using half tomography, n_x and n_y are less straightforward : 2*CoR(binning) instead of 2*CoR//binning # - delta = end - start [+1] should be a multiple of binning factor. This makes things much easier for processing pipeline. def ensure_multiple_of_binning(end, start, binning_factor): """ Update "end" so that end-start is a multiple of "binning_factor" Note that "end" is INCLUDED here (comes from user configuration) """ return end - ((end - start + 1) % binning_factor) end_z = ensure_multiple_of_binning(rec_cfg["end_z"], rec_cfg["start_z"], bin_z) rec_cfg["end_z"] = (end_z + 1) // bin_z - 1 nx_binned, ny_binned = self._get_nx_ny(binning_factor=bin_x) end_y = ensure_multiple_of_binning(rec_cfg["end_y"], rec_cfg["start_y"], bin_x) rec_cfg["end_y"] = min((end_y + 1) // bin_x - 1, ny_binned - 1) end_x = ensure_multiple_of_binning(rec_cfg["end_x"], rec_cfg["start_x"], bin_x) rec_cfg["end_x"] = min((end_x + 1) // bin_x - 1, nx_binned - 1) def _check_output_file(self): out_cfg = self.nabu_config["output"] out_fname = os.path.join(out_cfg["location"], out_cfg["file_prefix"] + out_cfg["file_format"]) if os.path.exists(out_fname): raise ValueError("File %s already exists" % out_fname) def _handle_processing_mode(self): mode = self.nabu_config["resources"]["method"] if mode == "preview": print( "Warning: the method 'preview' was selected. This means that the data volume will be binned so that everything fits in memory." ) # TODO automatically compute binning/subsampling factors as a function of lowest memory (GPU) self.nabu_config["dataset"]["binning"] = 2 self.nabu_config["dataset"]["binning_z"] = 2 self.nabu_config["dataset"]["projections_subsampling"] = 2, 0 # TODO handle other modes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/pipeline/detector_distortion_provider.py0000644000175000017500000000163014402565210023371 0ustar00pierrepierrefrom ..resources.utils import extract_parameters from ..io.detector_distortion import DetectorDistortionBase, DetectorDistortionMapsXZ import silx.io def DetectorDistortionProvider(detector_full_shape_vh=(0, 0), correction_type="", options=""): if correction_type == "identity": return DetectorDistortionBase(detector_full_shape_vh=detector_full_shape_vh) elif correction_type == "map_xz": options = options.replace("path=", "path_eq") user_params = extract_parameters(options) print(user_params, options) map_x = silx.io.get_data(user_params["map_x"].replace("path_eq", "path=")) map_z = silx.io.get_data(user_params["map_z"].replace("path_eq", "path=")) return DetectorDistortionMapsXZ(map_x=map_x, map_z=map_z) else: message = f""" Unknown correction type: {correction_type} requested """ raise ValueError(message) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/pipeline/estimators.py0000644000175000017500000011655214726604214017603 0ustar00pierrepierre""" nabu.pipeline.estimators: helper classes/functions to estimate parameters of a dataset (center of rotation, detector tilt, etc). """ import inspect import numpy as np import scipy.fft # pylint: disable=E0611 from silx.io import get_data import math from scipy import ndimage as nd from ..preproc.flatfield import FlatField from ..estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow, CenterOfRotationOctaveAccurate, ) from ..estimation.cor_sino import SinoCorInterface, CenterOfRotationFourierAngles, CenterOfRotationVo from ..estimation.tilt import CameraTilt from ..estimation.utils import is_fullturn_scan from ..resources.logger import LoggerOrPrint from ..resources.utils import extract_parameters from ..utils import check_supported, deprecation_warning, get_num_threads, is_int, is_scalar from ..resources.dataset_analyzer import get_radio_pair from ..processing.rotation import Rotation from ..preproc.ccd import Log, CCDFilter from ..misc import fourier_filters from .params import cor_methods, tilt_methods def estimate_cor(method, dataset_info, do_flatfield=True, cor_options=None, logger=None): """ High level function to compute the center of rotation (COR) Parameters ---------- method: name of the method to be used for computing the center of rotation dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer` Dataset information structure do_flatfield: If True apply flat field to compute the center of rotation cor_options: optional dictionary that can contain the following keys: * slice_idx: index of the slice to use for computing the sinogram (for sinogram based algorithms) * subsampling subsampling * radio_angles: angles of the radios to use (for radio based algorithms) logger: logging object """ logger = LoggerOrPrint(logger) cor_options = cor_options or {} check_supported(method, list(cor_methods.keys()), "COR estimation method") method = cor_methods[method] # Extract CoR parameters from configuration file if isinstance(cor_options, str): try: cor_options = extract_parameters(cor_options, sep=";") except Exception as exc: msg = "Could not extract parameters from cor_options: %s" % (str(exc)) logger.fatal(msg) raise ValueError(msg) elif isinstance(cor_options, dict): pass else: raise TypeError(f"cor_options_str is expected to be a dict or a str. {type(cor_options)} provided") # Dispatch. COR estimation is always expressed in absolute number of pixels (i.e. from the center of the first pixel column) if method in CORFinder.search_methods: cor_finder = CORFinder( method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, radio_angles=cor_options.get("radio_angles", (0.0, np.pi)), logger=logger, ) estimated_cor = cor_finder.find_cor() elif method in SinoCORFinder.search_methods: cor_finder = SinoCORFinder( method, dataset_info, slice_idx=cor_options.get("slice_idx", "middle"), subsampling=cor_options.get("subsampling", 10), do_flatfield=do_flatfield, take_log=cor_options.get("take_log", True), cor_options=cor_options, logger=logger, ) estimated_cor = cor_finder.find_cor() else: composite_options = update_func_kwargs(CompositeCORFinder, cor_options) for what in ["cor_options", "logger"]: composite_options.pop(what, None) cor_finder = CompositeCORFinder( dataset_info, cor_options=cor_options, logger=logger, **composite_options, ) estimated_cor = cor_finder.find_cor() return estimated_cor class CORFinderBase: """ A base class for CoR estimators. It does common tasks like data reading, flatfield, etc. """ search_methods = {} def __init__(self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None): """ Initialize a CORFinder object. Parameters ---------- dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer` Dataset information structure """ check_supported(method, self.search_methods, "CoR estimation method") self.method = method self.cor_options = cor_options or {} self.logger = LoggerOrPrint(logger) self.dataset_info = dataset_info self.do_flatfield = do_flatfield self.shape = dataset_info.radio_dims[::-1] self._get_lookup_side() self._init_cor_finder() def _get_lookup_side(self): """ Get the "initial guess" where the center-of-rotation (CoR) should be estimated. For example 'center' means that CoR search will be done near the middle of the detector, i.e center column. """ lookup_side = self.cor_options.get("side", None) self._lookup_side = lookup_side # User-provided scalar if not (isinstance(lookup_side, str)) and np.isscalar(lookup_side): return default_lookup_side = "right" if self.dataset_info.is_halftomo else "center" # By default in nabu config, side='from_file' meaning that we inspect the dataset information for CoR metadata if lookup_side == "from_file": initial_cor_pos = self.dataset_info.dataset_scanner.x_rotation_axis_pixel_position # relative pos in pixels if initial_cor_pos is None or initial_cor_pos == 0: self.logger.warning("Could not get an initial estimate for center of rotation in data file") lookup_side = default_lookup_side else: lookup_side = initial_cor_pos self._lookup_side = initial_cor_pos def _init_cor_finder(self): cor_finder_cls = self.search_methods[self.method]["class"] self.cor_finder = cor_finder_cls(verbose=False, logger=self.logger, extra_options=None) class CORFinder(CORFinderBase): """ Find the Center of Rotation with methods based on two (180-degrees opposed) radios. """ search_methods = { "centered": { "class": CenterOfRotation, }, "global": { "class": CenterOfRotationAdaptiveSearch, "default_kwargs": {"low_pass": 1, "high_pass": 20}, }, "sliding-window": { "class": CenterOfRotationSlidingWindow, }, "growing-window": { "class": CenterOfRotationGrowingWindow, }, "octave-accurate": { "class": CenterOfRotationOctaveAccurate, }, } def __init__( self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None, radio_angles=(0.0, np.pi) ): """ Initialize a CORFinder object. Parameters ---------- dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer` Dataset information structure radio_angles: angles to use to find the cor """ super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger) self._radio_angles = radio_angles self._init_radios() self._apply_flatfield() self._apply_tilt() # octave-accurate does not support half-acquisition scans, # but information on field of view is only known here with the "dataset_info" object. # Do the check here. if self.dataset_info.is_halftomo and method == "octave-accurate": raise ValueError("The CoR estimator 'octave-accurate' does not support half-acquisition scans") # def _init_radios(self): self.radios, self._radios_indices = get_radio_pair( self.dataset_info, radio_angles=self._radio_angles, return_indices=True ) def _apply_flatfield(self): if not (self.do_flatfield): return self.flatfield = FlatField( self.radios.shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=self._radios_indices, interpolation="linear", ) self.flatfield.normalize_radios(self.radios) def _apply_tilt(self): tilt = self.dataset_info.detector_tilt if tilt is None: return self.logger.debug("COREstimator: applying detector tilt correction of %f degrees" % tilt) rot = Rotation(self.shape, tilt) for i in range(self.radios.shape[0]): self.radios[i] = rot.rotate(self.radios[i]) def find_cor(self): """ Find the center of rotation. Returns ------- cor: float The estimated center of rotation for the current dataset. """ self.logger.info("Estimating center of rotation") # All find_shift() methods in self.search_methods have the same API with "img_1" and "img_2" cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options) cor_exec_kwargs["return_relative_to_middle"] = False # ----- FIXME ----- # 'self.cor_options' can contain 'side="from_file"', and we should not modify it directly # because it's entered by the user. # Either make a copy of self.cor_options, or change the inspect() mechanism if cor_exec_kwargs.get("side", None) == "from_file": cor_exec_kwargs["side"] = self._lookup_side or "center" # ------ if self._lookup_side is not None: cor_exec_kwargs["side"] = self._lookup_side self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(cor_exec_kwargs))) shift = self.cor_finder.find_shift(self.radios[0], np.fliplr(self.radios[1]), **cor_exec_kwargs) return shift # alias COREstimator = CORFinder class SinoCORFinder(CORFinderBase): """ A class for finding Center of Rotation based on 360 degrees sinograms. This class handles the steps of building the sinogram from raw radios. """ search_methods = { "sino-coarse-to-fine": { "class": SinoCorInterface, }, "sino-sliding-window": { "class": CenterOfRotationSlidingWindow, }, "sino-growing-window": { "class": CenterOfRotationGrowingWindow, }, "fourier-angles": {"class": CenterOfRotationFourierAngles}, "vo": { "class": CenterOfRotationVo, }, } def __init__( self, method, dataset_info, do_flatfield=True, take_log=True, cor_options=None, logger=None, slice_idx="middle", subsampling=10, ): """ Initialize a SinoCORFinder object. Other parameters ---------------- The following keys can be set in cor_options. slice_idx: int or str Which slice index to take for building the sinogram. For example slice_idx=0 means that we extract the first line of each projection. Value can also be "first", "top", "middle", "last", "bottom". subsampling: int, float subsampling strategy when building sinograms. As building the complete sinogram from raw projections might be tedious, the reading is done with subsampling. A positive integer value means the subsampling step (i.e `projections[::subsampling]`). """ super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger) self._set_slice_idx(slice_idx) self._set_subsampling(subsampling) self._load_raw_sinogram() self._flatfield(do_flatfield) self._get_sinogram(take_log) def _check_360(self): if not is_fullturn_scan(self.dataset_info.rotation_angles): raise ValueError("Sinogram-based Center of Rotation estimation can only be used for 360 degrees scans") def _set_slice_idx(self, slice_idx): n_z = self.dataset_info.radio_dims[1] if isinstance(slice_idx, str): str_to_idx = {"top": 0, "first": 0, "middle": n_z // 2, "bottom": n_z - 1, "last": n_z - 1} check_supported(slice_idx, str_to_idx.keys(), "slice location") slice_idx = str_to_idx[slice_idx] self.slice_idx = slice_idx def _set_subsampling(self, subsampling): projs_idx = sorted(self.dataset_info.projections.keys()) self.subsampling = None if is_int(subsampling): if subsampling < 0: # Total number of angles raise NotImplementedError else: self.projs_indices = projs_idx[::subsampling] self.angles = self.dataset_info.rotation_angles[::subsampling] self.subsampling = subsampling else: # Angular step raise NotImplementedError() def _load_raw_sinogram(self): if self.slice_idx is None: raise ValueError("Unknow slice index") reader_kwargs = { "sub_region": (slice(None, None, self.subsampling), slice(self.slice_idx, self.slice_idx + 1), slice(None)) } if self.dataset_info.kind == "edf": reader_kwargs = {"n_reading_threads": get_num_threads()} self.data_reader = self.dataset_info.get_reader(**reader_kwargs) self._radios = self.data_reader.load_data() def _flatfield(self, do_flatfield): self.do_flatfield = bool(do_flatfield) if not self.do_flatfield: return flats = {k: arr[self.slice_idx : self.slice_idx + 1, :] for k, arr in self.dataset_info.flats.items()} darks = {k: arr[self.slice_idx : self.slice_idx + 1, :] for k, arr in self.dataset_info.darks.items()} flatfield = FlatField( self._radios.shape, flats, darks, radios_indices=self.projs_indices, ) flatfield.normalize_radios(self._radios) def _get_sinogram(self, take_log): sinogram = self._radios[:, 0, :].copy() if take_log: log = Log(self._radios.shape, clip_min=1e-6, clip_max=10.0) log.take_logarithm(sinogram) self.sinogram = sinogram @staticmethod def _split_sinogram(sinogram): n_a_2 = sinogram.shape[0] // 2 img_1, img_2 = sinogram[:n_a_2], sinogram[n_a_2:] # "Handle" odd number of projections if img_2.shape[0] > img_1.shape[0]: img_2 = img_2[:-1, :] # return img_1, img_2 def find_cor(self): self.logger.info("Estimating center of rotation") cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options) cor_exec_kwargs["return_relative_to_middle"] = False # FIXME # 'self.cor_options' can contain 'side="from_file"', and we should not modify it directly # because it's entered by the user. # Either make a copy of self.cor_options, or change the inspect() mechanism if cor_exec_kwargs["side"] == "from_file": cor_exec_kwargs["side"] = self._lookup_side or "center" # if self._lookup_side is not None: cor_exec_kwargs["side"] = self._lookup_side if self.method == "fourier-angles": cor_exec_args = [self.sinogram] cor_exec_kwargs["angles"] = self.dataset_info.rotation_angles elif self.method == "vo": cor_exec_args = [self.sinogram] cor_exec_kwargs["halftomo"] = self.dataset_info.is_halftomo cor_exec_kwargs["is_360"] = is_fullturn_scan(self.dataset_info.rotation_angles) else: # For these methods relying on find_shift() with two images, the sinogram needs to be split in two img_1, img_2 = self._split_sinogram(self.sinogram) cor_exec_args = [img_1, np.fliplr(img_2)] self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(cor_exec_kwargs))) shift = self.cor_finder.find_shift(*cor_exec_args, **cor_exec_kwargs) return shift # alias SinoCOREstimator = SinoCORFinder class CompositeCORFinder(CORFinderBase): """ Class and method to prepare sinogram and calculate COR The pseudo sinogram is built with shrinked radios taken every theta_interval degres Compared to first writing by Christian Nemoz: - gives the same result of the original octave script on the dataset sofar tested - The meaning of parameter n_subsampling_y (alias subsampling_y)is now the number of lines which are taken from every radio. This is more meaningful in terms of amout of collected information because it does not depend on the radio size. Moreover this is what was done in the octave script - The spike_threshold has been added with default to 0.04 - The angular sampling is every 5 degree by default, as it is now the case also in the octave script - The finding of the optimal overlap is doing by looping over the possible overlap, according to the overlap. After a first testing phase, this part, which is the time consuming part, can be accelerated by several order of magnitude without modifing the final result """ search_methods = { "composite-coarse-to-fine": { "class": CenterOfRotation, # Hack. Not used. Everything is done in the find_cor() func. } } _default_cor_options = {"low_pass": 0.4, "high_pass": 10, "side": "near", "near_pos": 0, "near_width": 40} def __init__( self, dataset_info, oversampling=4, theta_interval=5, n_subsampling_y=40, take_log=True, cor_options=None, spike_threshold=0.04, logger=None, norm_order=1, ): super().__init__( "composite-coarse-to-fine", dataset_info, do_flatfield=True, cor_options=cor_options, logger=logger ) if norm_order not in [1, 2]: raise ValueError( f""" the norm order (nom_order parameter) must be either 1 or 2. You passed {norm_order} """ ) self.norm_order = norm_order self.dataset_info = dataset_info self.logger = LoggerOrPrint(logger) self.sx, self.sy = self.dataset_info.radio_dims default_cor_options = self._default_cor_options.copy() default_cor_options.update(self.cor_options) self.cor_options = default_cor_options # the algorithm can work for angular ranges larger than 1.2*pi # up to an arbitrarily number of turns as it is the case in helical scans self.spike_threshold = spike_threshold # the following line is necessary for multi-turns scan because the encoders is always # in the interval 0-360 self.unwrapped_rotation_angles = np.unwrap(self.dataset_info.rotation_angles) self.angle_min = self.unwrapped_rotation_angles.min() self.angle_max = self.unwrapped_rotation_angles.max() if (self.angle_max - self.angle_min) < 1.2 * np.pi: useful_span = None raise ValueError( f"""Sinogram-based Center of Rotation estimation can only be used for scans over more than 180 degrees. Your angular span was barely above 180 degrees, it was in fact {((self.angle_max - self.angle_min)/np.pi):.2f} x 180 and it is not considered to be enough by the discriminating condition which requires at least 1.2 half-turns """ ) else: useful_span = min(np.pi, (self.angle_max - self.angle_min) - np.pi) # readapt theta_interval accordingly if the span is smaller than pi if useful_span < np.pi: theta_interval = theta_interval * useful_span / np.pi self.take_log = take_log self.ovs = oversampling self.theta_interval = theta_interval target_sampling_y = np.round(np.linspace(0, self.sy - 1, n_subsampling_y + 2)).astype(int)[1:-1] if self.spike_threshold is not None: # take also one line below and on above for each line # to provide appropriate margin self.sampling_y = np.zeros([3 * len(target_sampling_y)], "i") self.sampling_y[0::3] = np.maximum(0, target_sampling_y - 1) self.sampling_y[2::3] = np.minimum(self.sy - 1, target_sampling_y + 1) self.sampling_y[1::3] = target_sampling_y self.ccd_correction = CCDFilter((len(self.sampling_y), self.sx), median_clip_thresh=self.spike_threshold) else: self.sampling_y = target_sampling_y self.nproj = self.dataset_info.n_angles my_condition = np.less(self.unwrapped_rotation_angles + np.pi, self.angle_max) * np.less( self.unwrapped_rotation_angles, self.angle_min + useful_span ) possibly_probed_angles = self.unwrapped_rotation_angles[my_condition] possibly_probed_indices = np.arange(len(self.unwrapped_rotation_angles))[my_condition] self.dproj = round(len(possibly_probed_angles) / np.rad2deg(useful_span) * self.theta_interval) self.probed_angles = possibly_probed_angles[:: self.dproj] self.probed_indices = possibly_probed_indices[:: self.dproj] self.absolute_indices = sorted(self.dataset_info.projections.keys()) my_flats = self.dataset_info.flats if my_flats is not None and len(list(my_flats.keys())): self.use_flat = True self.flatfield = FlatField( (len(self.absolute_indices), self.sy, self.sx), self.dataset_info.flats, self.dataset_info.darks, radios_indices=self.absolute_indices, ) else: self.use_flat = False self.sx, self.sy = self.dataset_info.radio_dims self.mlog = Log((1,) + (self.sy, self.sx), clip_min=1e-6, clip_max=10.0) self.rcor_abs = round(self.sx / 2.0) self.cor_acc = round(self.sx / 2.0) self.nprobed = len(self.probed_angles) # initialize sinograms and radios arrays self.sino = np.zeros([2 * self.nprobed * n_subsampling_y, (self.sx - 1) * self.ovs + 1], "f") self._loaded = False self.high_pass = self.cor_options["high_pass"] img_filter = fourier_filters.get_bandpass_filter( (self.sino.shape[0] // 2, self.sino.shape[1]), cutoff_lowpass=self.cor_options["low_pass"] * self.ovs, cutoff_highpass=self.high_pass * self.ovs, use_rfft=False, # rfft changes the image dimensions lenghts to even if odd data_type=np.float64, ) # we are interested in filtering only along the x dimension only img_filter[:] = img_filter[0] self.img_filter = img_filter def _oversample(self, radio): """oversampling in the horizontal direction""" if self.ovs == 1: return radio else: ovs_2D = [1, self.ovs] return oversample(radio, ovs_2D) def _get_cor_options(self, cor_options): default_dict = self._default_cor_options.copy() if self.dataset_info.is_halftomo: default_dict["side"] = "right" if cor_options is None or cor_options == "": cor_options = {} if isinstance(cor_options, str): try: cor_options = extract_parameters(cor_options, sep=";") except Exception as exc: msg = "Could not extract parameters from cor_options: %s" % (str(exc)) self.logger.fatal(msg) raise ValueError(msg) default_dict.update(cor_options) cor_options = default_dict self.cor_options = cor_options def get_radio(self, image_num): # radio_dataset_idx = self.absolute_indices[image_num] radio_dataset_idx = image_num data_url = self.dataset_info.projections[radio_dataset_idx] radio = get_data(data_url).astype(np.float64) if self.use_flat: self.flatfield.normalize_single_radio(radio, radio_dataset_idx, dtype=radio.dtype) if self.take_log: self.mlog.take_logarithm(radio) radio = radio[self.sampling_y] if self.spike_threshold is not None: self.ccd_correction.median_clip_correction(radio, output=radio) radio = radio[1::3] return radio def get_sino(self, reload=False): """ Build sinogram (composite image) from the radio files """ if self._loaded and not reload: return self.sino sorting_indexes = np.argsort(self.unwrapped_rotation_angles) sorted_all_angles = self.unwrapped_rotation_angles[sorting_indexes] sorted_angle_indexes = np.arange(len(self.unwrapped_rotation_angles))[sorting_indexes] irad = 0 for prob_a, prob_i in zip(self.probed_angles, self.probed_indices): radio1 = self.get_radio(self.absolute_indices[prob_i]) other_angle = prob_a + np.pi insertion_point = np.searchsorted(sorted_all_angles, other_angle) if insertion_point > 0 and insertion_point < len(sorted_all_angles): other_i_l = sorted_angle_indexes[insertion_point - 1] other_i_h = sorted_angle_indexes[insertion_point] radio_l = self.get_radio(self.absolute_indices[other_i_l]) radio_h = self.get_radio(self.absolute_indices[other_i_h]) f = (other_angle - sorted_all_angles[insertion_point - 1]) / ( sorted_all_angles[insertion_point] - sorted_all_angles[insertion_point - 1] ) radio2 = (1 - f) * radio_l + f * radio_h else: if insertion_point == 0: other_i = sorted_angle_indexes[0] elif insertion_point == len(sorted_all_angles): other_i = sorted_angle_indexes[insertion_point - 1] radio2 = self.get_radio(self.absolute_indices[other_i]) # pylint: disable=E0606 self.sino[irad : irad + radio1.shape[0], :] = self._oversample(radio1) self.sino[ irad + self.nprobed * radio1.shape[0] : irad + self.nprobed * radio1.shape[0] + radio1.shape[0], : ] = self._oversample(radio2) irad = irad + radio1.shape[0] self.sino[np.isnan(self.sino)] = 0.0001 # ? return self.sino def find_cor(self, reload=False): self.logger.info("Estimating center of rotation") self.logger.debug("%s.find_shift(%s)" % (self.__class__.__name__, self.cor_options)) self.sinogram = self.get_sino(reload=reload) dim_v, dim_h = self.sinogram.shape assert dim_v % 2 == 0, " this should not happen " dim_v = dim_v // 2 radio1 = self.sinogram[:dim_v] radio2 = self.sinogram[dim_v:] orig_sy, orig_ovsd_sx = radio1.shape radio1 = scipy.fft.ifftn( scipy.fft.fftn(radio1, axes=(-2, -1)) * self.img_filter, axes=(-2, -1) ).real # TODO: convolute only along x radio2 = scipy.fft.ifftn( scipy.fft.fftn(radio2, axes=(-2, -1)) * self.img_filter, axes=(-2, -1) ).real # TODO: convolute only along x tmp_sy, ovsd_sx = radio1.shape assert orig_sy == tmp_sy and orig_ovsd_sx == ovsd_sx, "this should not happen" cor_side = self.cor_options["side"] if cor_side == "center": overlap_min = max(round(ovsd_sx - ovsd_sx / 3), 4) overlap_max = min(round(ovsd_sx + ovsd_sx / 3), 2 * ovsd_sx - 4) elif cor_side == "right": overlap_min = max(4, self.ovs * self.high_pass * 3) overlap_max = ovsd_sx elif cor_side == "left": overlap_min = ovsd_sx overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3) elif cor_side == "all": overlap_min = max(4, self.ovs * self.high_pass * 3) overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3) elif is_scalar(cor_side): near_pos = cor_side near_width = self.cor_options["near_width"] overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width)) overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width)) # COMPAT. elif cor_side == "near": deprecation_warning( "using side='near' is deprecated, use side= instead", do_print=True, func_name="composite_near_pos", ) near_pos = self.cor_options["near_pos"] near_width = self.cor_options["near_width"] overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width)) overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width)) # --- else: raise ValueError("Invalid option 'side=%s'" % self.cor_options["side"]) if overlap_min > overlap_max: message = f""" There is no safe search range in find_cor once the margins corresponding to the high_pass filter are discarded. Try reducing the low_pass parameter in cor_options """ raise ValueError(message) self.logger.info( "looking for overlap from min %.2f and max %.2f\n" % (overlap_min / self.ovs, overlap_max / self.ovs) ) best_overlap = overlap_min best_error = np.inf blurred_radio1 = nd.gaussian_filter(abs(radio1), [0, self.high_pass]) blurred_radio2 = nd.gaussian_filter(abs(radio2), [0, self.high_pass]) for z in range(int(overlap_min), int(overlap_max) + 1): if z <= ovsd_sx: my_z = z my_radio1 = radio1 my_radio2 = radio2 my_blurred_radio1 = blurred_radio1 my_blurred_radio2 = blurred_radio2 else: my_z = ovsd_sx - (z - ovsd_sx) my_radio1 = np.fliplr(radio1) my_radio2 = np.fliplr(radio2) my_blurred_radio1 = np.fliplr(blurred_radio1) my_blurred_radio2 = np.fliplr(blurred_radio2) common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : -int(math.ceil(self.ovs * self.high_pass * 2))] # adopt a 'safe' margin considering high_pass value (possibly float) common_right = my_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))] common_blurred_left = np.fliplr(my_blurred_radio1[:, ovsd_sx - my_z :])[ :, : -int(math.ceil(self.ovs * self.high_pass * 2)) ] # adopt a 'safe' margin considering high_pass value (possibly float) common_blurred_right = my_blurred_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))] if common_right.size == 0: continue error = self.error_metric(common_right, common_left, common_blurred_right, common_blurred_left) min_error = min(best_error, error) if min_error == error: best_overlap = z best_error = min_error # self.logger.debug( # "testing an overlap of %.2f pixels, actual best overlap is %.2f pixels over %d\r" # % (z / self.ovs, best_overlap / self.ovs, ovsd_sx / self.ovs), # ) offset = (ovsd_sx - best_overlap) / self.ovs / 2 cor_abs = (self.sx - 1) / 2 + offset return cor_abs def error_metric(self, common_right, common_left, common_blurred_right, common_blurred_left): if self.norm_order == 2: return self.error_metric_l2(common_right, common_left) elif self.norm_order == 1: return self.error_metric_l1(common_right, common_left, common_blurred_right, common_blurred_left) else: assert False, "this cannot happen" def error_metric_l2(self, common_right, common_left): common = common_right - common_left tmp = np.linalg.norm(common) norm_diff2 = tmp * tmp norm_right = np.linalg.norm(common_right) norm_left = np.linalg.norm(common_left) res = norm_diff2 / (norm_right * norm_left) return res def error_metric_l1(self, common_right, common_left, common_blurred_right, common_blurred_left): common = (common_right - common_left) / (common_blurred_right + common_blurred_left) res = abs(common).mean() return res def oversample(radio, ovs_s): """oversampling an image in arbitrary directions. The first and last point of each axis will still remain as extremal points of the new axis. """ result = np.zeros([(radio.shape[0] - 1) * ovs_s[0] + 1, (radio.shape[1] - 1) * ovs_s[1] + 1], "f") # Pre-initialisation: The original data falls exactly on the following strided positions in the new data array. result[:: ovs_s[0], :: ovs_s[1]] = radio for k in range(0, ovs_s[0]): # interpolation coefficient for axis 0 g = k / ovs_s[0] for i in range(0, ovs_s[1]): if i == 0 and k == 0: # this case subset was already exactly matched from before the present double loop, # in the pre-initialisation line. continue # interpolation coefficent for axis 1 f = i / ovs_s[1] # stop just a bit before cause we are not extending beyond the limits. # If we are exacly on a vertical or horizontal original line, then no shift will be applied, # and we will exploit the equality f+(1-f)=g+(1-g)=1 adding twice the same contribution with # interpolation factors which become dummies pour le coup. stop0 = -ovs_s[0] if k else None stop1 = -ovs_s[1] if i else None # Once again, we exploit the g+(1-g)=1 equality start0 = ovs_s[0] if k else 0 start1 = ovs_s[1] if i else 0 # and what is done below makes clear the corundum above. result[k :: ovs_s[0], i :: ovs_s[1]] = (1 - g) * ( (1 - f) * result[0 : stop0 : ovs_s[0], 0 : stop1 : ovs_s[1]] + f * result[0 : stop0 : ovs_s[0], start1 :: ovs_s[1]] ) + g * ( (1 - f) * result[start0 :: ovs_s[0], 0 : stop1 : ovs_s[1]] + f * result[start0 :: ovs_s[0], start1 :: ovs_s[1]] ) return result # alias CompositeCOREstimator = CompositeCORFinder # Some heavily inelegant things going on here def get_default_kwargs(func): params = inspect.signature(func).parameters res = {} for param_name, param in params.items(): if param.default != inspect._empty: res[param_name] = param.default return res def update_func_kwargs(func, options): res_options = get_default_kwargs(func) for option_name, option_val in options.items(): if option_name in res_options: res_options[option_name] = option_val return res_options def get_class_name(class_object): return str(class_object).split(".")[-1].strip(">").strip("'").strip('"') class DetectorTiltEstimator: """ Helper class for detector tilt estimation. It automatically chooses the right radios and performs flat-field. """ default_tilt_method = "1d-correlation" # Given a tilt angle "a", the maximum deviation caused by the tilt (in pixels) is # N/2 * |sin(a)| where N is the number of pixels # We ignore tilts causing less than 0.25 pixel deviation: N/2*|sin(a)| < tilt_threshold tilt_threshold = 0.25 def __init__(self, dataset_info, do_flatfield=True, logger=None, autotilt_options=None): """ Initialize a detector tilt estimator helper. Parameters ---------- dataset_info: `dataset_info` object Data structure with the dataset information. do_flatfield: bool, optional Whether to perform flat field on radios. logger: `Logger` object, optional Logger object autotilt_options: dict, optional named arguments to pass to the detector tilt estimator class. """ self._set_params(dataset_info, do_flatfield, logger, autotilt_options) self.radios, self.radios_indices = get_radio_pair(dataset_info, radio_angles=(0.0, np.pi), return_indices=True) self._init_flatfield() self._apply_flatfield() def _set_params(self, dataset_info, do_flatfield, logger, autotilt_options): self.dataset_info = dataset_info self.do_flatfield = bool(do_flatfield) self.logger = LoggerOrPrint(logger) self._get_autotilt_options(autotilt_options) def _init_flatfield(self): if not (self.do_flatfield): return self.flatfield = FlatField( self.radios.shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=self.radios_indices, interpolation="linear", ) def _apply_flatfield(self): if not (self.do_flatfield): return self.flatfield.normalize_radios(self.radios) def _get_autotilt_options(self, autotilt_options): if autotilt_options is None: self.autotilt_options = None return try: autotilt_options = extract_parameters(autotilt_options) except Exception as exc: msg = "Could not extract parameters from autotilt_options: %s" % (str(exc)) self.logger.fatal(msg) raise ValueError(msg) self.autotilt_options = autotilt_options if "threshold" in autotilt_options: self.tilt_threshold = autotilt_options.pop("threshold") def find_tilt(self, tilt_method=None): """ Find the detector tilt. Parameters ---------- tilt_method: str, optional Which tilt estimation method to use. """ if tilt_method is None: tilt_method = self.default_tilt_method check_supported(tilt_method, set(tilt_methods.values()), "tilt estimation method") self.logger.info("Estimating detector tilt angle") autotilt_params = { "roi_yxhw": None, "median_filt_shape": None, "padding_mode": None, "peak_fit_radius": 1, "high_pass": None, "low_pass": None, } autotilt_params.update(self.autotilt_options or {}) self.logger.debug("%s(%s)" % ("CameraTilt", str(autotilt_params))) tilt_calc = CameraTilt() tilt_cor_position, camera_tilt = tilt_calc.compute_angle( self.radios[0], np.fliplr(self.radios[1]), method=tilt_method, **autotilt_params ) self.logger.info("Estimated detector tilt angle: %f degrees" % camera_tilt) # Ignore too small tilts max_deviation = np.max(self.dataset_info.radio_dims) * np.abs(np.sin(np.deg2rad(camera_tilt))) if self.dataset_info.is_halftomo: max_deviation *= 2 if max_deviation < self.tilt_threshold: self.logger.info( "Estimated tilt angle (%.3f degrees) results in %.2f maximum pixels shift, which is below threshold (%.2f pixel). Ignoring the tilt, no correction will be done." % (camera_tilt, max_deviation, self.tilt_threshold) ) camera_tilt = None return camera_tilt # alias TiltFinder = DetectorTiltEstimator ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/pipeline/fullfield/0000755000175000017500000000000014730277752017003 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/pipeline/fullfield/__init__.py0000644000175000017500000000000014315516747021101 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/pipeline/fullfield/chunked.py0000644000175000017500000011714514726604214020777 0ustar00pierrepierrefrom os import path from time import time from math import ceil import numpy as np from silx.io.url import DataUrl from ...utils import get_num_threads, remove_items_from_list from ...resources.logger import LoggerOrPrint from ...resources.utils import extract_parameters from ...misc.binning import binning as image_binning from ...io.reader import EDFStackReader, HDF5Loader, NXTomoReader from ...preproc.ccd import Log, CCDFilter from ...preproc.flatfield import FlatField from ...preproc.distortion import DistortionCorrection from ...preproc.shift import VerticalShift from ...preproc.double_flatfield import DoubleFlatField from ...preproc.phase import PaganinPhaseRetrieval from ...preproc.ctf import CTFPhaseRetrieval, GeoPars from ...reconstruction.sinogram import SinoNormalization from ...reconstruction.filtering import SinoFilter from ...reconstruction.mlem import __have_corrct__, MLEMReconstructor from ...processing.rotation import Rotation from ...reconstruction.rings import MunchDeringer, SinoMeanDeringer, VoDeringer from ...processing.unsharp import UnsharpMask from ...processing.histogram import PartialHistogram, hist_as_2Darray from ..utils import use_options, pipeline_step, get_subregion from ..reader import bin_image_stack, load_darks_flats from ..datadump import DataDumpManager from ..writer import WriterManager # For now we don't have a plain python/numpy backend for reconstruction try: from ...reconstruction.fbp_opencl import OpenCLBackprojector as Backprojector except: Backprojector = None class ChunkedPipeline: """ Pipeline for "regular" full-field tomography. Data is processed by chunks. A chunk consists in K contiguous lines of all the radios. In parallel geometry, a chunk of K radios lines gives K sinograms, and equivalently K reconstructed slices. """ backend = "numpy" FlatFieldClass = FlatField DoubleFlatFieldClass = DoubleFlatField CCDCorrectionClass = CCDFilter PaganinPhaseRetrievalClass = PaganinPhaseRetrieval CTFPhaseRetrievalClass = CTFPhaseRetrieval UnsharpMaskClass = UnsharpMask ImageRotationClass = Rotation VerticalShiftClass = VerticalShift MunchDeringerClass = MunchDeringer SinoMeanDeringerClass = SinoMeanDeringer VoDeringerClass = VoDeringer MLogClass = Log SinoNormalizationClass = SinoNormalization SinoFilterClass = SinoFilter FBPClass = Backprojector ConebeamClass = None # unsupported on CPU MLEMClass = MLEMReconstructor HBPClass = None # unsupported on CPU HistogramClass = PartialHistogram _default_extra_options = {} # These steps are skipped if the reconstruction is done in two stages. # The first stage will skip these steps, and the second stage will do these stages after merging sinograms. _reconstruction_steps = ["sino_rings_correction", "reconstruction", "save", "histogram"] def __init__( self, process_config, chunk_shape, margin=None, logger=None, use_grouped_mode=False, extra_options=None ): """ Initialize a "Chunked" pipeline. Parameters ---------- processing_config: `ProcessConfig` Process configuration. chunk_shape: tuple Shape of the chunk of data to process, in the form (n_angles, n_z, n_x). It has to account for possible cropping of the data, eg. [:, start_z:end_z, start_x:end_x] where start_xz and/or end_xz can be other than None. margin: tuple, optional Margin to use, in the form ((up, down), (left, right)). It is used for example when performing phase retrieval or a convolution-like operation: some extra data is kept to avoid boundaries issues. These boundaries are then discarded: the data volume is eventually cropped as `data[U:D, L:R]` where `((U, D), (L, R)) = margin` If not provided, no margin is applied. logger: `nabu.app.logger.Logger`, optional Logger class extra_options: dict, optional Advanced extra options. Notes ------ Using `margin` results in a lesser number of reconstructed slices. More specifically, if `margin = (V, H)`, then there will be `delta_z - 2*V` reconstructed slices (if the sub-region is in the middle of the volume) or `delta_z - V` reconstructed slices (if the sub-region is on top or bottom of the volume). """ self.logger = LoggerOrPrint(logger) self._set_params(process_config, chunk_shape, extra_options, margin, use_grouped_mode) self._init_pipeline() def _set_params(self, process_config, chunk_shape, extra_options, margin, use_grouped_mode): self.process_config = process_config self.dataset_info = self.process_config.dataset_info self.processing_steps = self.process_config.processing_steps.copy() self.processing_options = self.process_config.processing_options self._set_chunk_shape(chunk_shape, use_grouped_mode) self.set_subregion(None) self._set_margin(margin) self._set_extra_options(extra_options) self._callbacks = {} self._steps_name2component = {} self._steps_component2name = {} def _set_chunk_shape(self, chunk_shape, use_grouped_mode): if len(chunk_shape) != 3: raise ValueError("Expected chunk_shape to be a tuple of length 3 in the form (n_z, n_y, n_x)") self.chunk_shape = tuple(int(c) for c in chunk_shape) # cast to int, as numpy.int64 can make pycuda crash # TODO: sanity check (eg. compare to size of radios in dataset_info) ? # (n_a, n_z, n_x) self.radios_shape = ( ceil(self.chunk_shape[0] / self.process_config.subsampling_factor), self.chunk_shape[1] // self.process_config.binning[1], self.chunk_shape[2] // self.process_config.binning[0], ) self.n_angles = self.radios_shape[0] self.n_slices = self.radios_shape[1] self._grouped_processing = False if use_grouped_mode or self.chunk_shape[0] < len(self.process_config.rotation_angles(subsampling=False)): # TODO allow a certain tolerance in this case ? # Reconstruction is still possible (albeit less accurate) if delta is small self._grouped_processing = True self.logger.debug("Only a subset of angles is processed - Reconstruction will be skipped") self.processing_steps, _ = remove_items_from_list(self.processing_steps, self._reconstruction_steps) def _set_margin(self, margin): if margin is None: U, D, L, R = None, None, None, None else: ((U, D), (L, R)) = get_subregion(margin, ndim=2) # Replace "None" with zeros U, D, L, R = U or 0, D or 0, L or 0, R or 0 self.margin = ((U, D), (L, R)) self._margin_up = U self._margin_down = D self._margin_left = L self._margin_right = R self.use_margin = (U + D + L + R) > 0 self.n_recs = self.chunk_shape[1] - sum(self.margin[0]) self.radios_cropped_shape = (self.radios_shape[0], self.radios_shape[1] - U - D, self.radios_shape[2] - L - R) if self.use_margin: self.n_slices -= sum(self.margin[0]) def set_subregion(self, sub_region): """ Set the data volume sub-region to process. Note that processing margin, if any, is contained within the sub-region. Parameters ----------- sub_region: tuple Data volume sub-region, in the form ((start_a, end_a), (start_z, end_z), (start_x, end_x)) where the data volume has a layout (angles, Z, X) """ n_angles = self.dataset_info.n_angles n_x, n_z = self.dataset_info.radio_dims c_a, c_z, c_x = self.chunk_shape if sub_region is None: # By default, take the sub-region around central slice sub_region = ( (0, c_a), (n_z // 2 - c_z // 2, n_z // 2 - c_z // 2 + c_z), (n_x // 2 - c_x // 2, n_x // 2 - c_x // 2 + c_x), ) else: sub_region = get_subregion(sub_region, ndim=3) # check sub-region for i, start_end in enumerate(sub_region): start, end = start_end if start is not None and end is not None: if end - start != self.chunk_shape[i]: raise ValueError( "Invalid (start, end)=(%d, %d) for sub-region (dimension %d): chunk shape is %s, but %d-%d=%d != %d" % (start, end, i, str(self.chunk_shape), end, start, end - start, self.chunk_shape[i]) ) # self.logger.debug("Set sub-region to %s" % (str(sub_region))) self.sub_region = sub_region self._sub_region_xz = sub_region[2] + sub_region[1] self._radios_were_cropped = False def _set_extra_options(self, extra_options): self.extra_options = self._default_extra_options.copy() self.extra_options.update(extra_options or {}) # # Callbacks # def register_callback(self, step_name, callback): """ Register a callback for a pipeline processing step. Parameters ---------- step_name: str processing step name callback: callable A function. It will be executed once the processing step `step_name` is finished. The function takes only one argument: the class instance. """ if step_name not in self.processing_steps: raise ValueError("'%s' is not in processing steps %s" % (step_name, self.processing_steps)) if step_name in self._callbacks: self._callbacks[step_name].append(callback) else: self._callbacks[step_name] = [callback] # # Memory management # def _allocate_array(self, shape, dtype, name=None): return np.zeros(shape, dtype=dtype) def _allocate_recs(self, ny, nx, n_slices=None): n_slices = n_slices or self.n_slices self.recs = self._allocate_array((n_slices, ny, nx), "f", name="recs") # # Runtime attributes # @property def sub_region_xz(self): """ Return the currently processed sub-region in the form (start_x, end_x, start_z, end_z) """ return self._sub_region_xz @property def z_min(self): return self._sub_region_xz[2] @property def sino_shape(self): return self.process_config.sino_shape(binning=True, subsampling=True) @property def sinos_shape(self): return (self.n_slices,) + self.sino_shape def get_slice_start_index(self): return self.z_min + self._margin_up # # Pipeline initialization # def _init_pipeline(self): self._allocate_radios() self._init_data_dump() self._init_reader() self._init_flatfield() self._init_double_flatfield() self._init_ccd_corrections() self._init_radios_rotation() self._init_phase() self._init_unsharp() self._init_radios_movements() self._init_mlog() self._init_sino_normalization() self._init_sino_rings_correction() self._init_reconstruction() self._init_histogram() self._init_writer() def _allocate_radios(self): self.radios = np.zeros(self.radios_shape, dtype=np.float32) self.data = self.radios # alias def _init_data_dump(self): self._resume_from_step = self.processing_options["read_chunk"].get("step_name", None) self.datadump_manager = DataDumpManager( self.process_config, self.sub_region, margin=self.margin, logger=self.logger ) # When using "grouped processing", sinogram has to be dumped. # If it was not specified by user, force sinogram dump # Perhaps these lines should be moved directly to DataDumpManager. if self._grouped_processing and not self.process_config.dump_sinogram: sino_dump_fname = self.process_config.get_save_steps_file("sinogram") self.datadump_manager._configure_dump("sinogram", force_dump_to_fname=sino_dump_fname) self.logger.debug("Will dump sinogram to %s" % self.datadump_manager.data_dump["sinogram"].fname) def _init_reading_processing_function(self): # Some processing may be applied directly when reading data (eg. distortion correction, binning, ...) # Configure it here self._reader_processing_function = None self._reader_processing_function_args = None self._reader_processing_function_kwargs = None self._ff_processing_function = None self._ff_processing_function_args = None if self.process_config.binning is None or self.process_config.binning == (1, 1): return if self.dataset_info.kind == "nx": self._reader_processing_function = bin_image_stack self._reader_processing_function_kwargs = { "binning_factor": self.process_config.binning[::-1], "num_threads": get_num_threads(), } else: self._reader_processing_function = image_binning self._reader_processing_function_args = [self.process_config.binning[::-1]] # flat-field is read image-wise self._ff_processing_function = image_binning self._ff_processing_function_args = [self.process_config.binning[::-1]] @use_options("read_chunk", "chunk_reader") def _init_reader(self): options = self.processing_options["read_chunk"] process_file = options.get("process_file", None) if process_file is None: # Standard case - start pipeline from raw data self._init_reading_processing_function() subs_angles = None subs_z = None subs_x = None if self.process_config.subsampling_factor: subs_angles = self.process_config.subsampling_factor reader_sub_region = ( slice(*(self.sub_region[0]) + ((subs_angles,) if subs_angles else ())), slice(*(self.sub_region[1]) + ((subs_z,) if subs_z else ())), slice(*(self.sub_region[2]) + ((subs_x,) if subs_x else ())), ) other_reader_kwargs = { "output_dtype": np.float32, "processing_func": self._reader_processing_function, "processing_func_args": self._reader_processing_function_args, "processing_func_kwargs": self._reader_processing_function_kwargs, } if self.dataset_info.kind == "nx": self.chunk_reader = NXTomoReader( self.dataset_info.dataset_hdf5_url.file_path(), self.dataset_info.dataset_hdf5_url.data_path(), sub_region=reader_sub_region, image_key=0, **other_reader_kwargs, ) elif self.dataset_info.kind == "edf": files = [ self.dataset_info.projections[k].file_path() for k in sorted(self.dataset_info.projections.keys()) ] self.chunk_reader = EDFStackReader( files, sub_region=reader_sub_region, n_reading_threads=max(1, get_num_threads() // 2), **other_reader_kwargs, ) else: # Resume pipeline from dumped intermediate step self.chunk_reader = HDF5Loader( process_file, options["process_h5_path"], sub_region=self.datadump_manager.get_read_dump_subregion(), data_buffer=self.radios, pre_allocate=False, ) self._resume_from_step = options["step_name"] self.logger.debug( "Load subregion %s from file %s" % (str(self.chunk_reader.sub_region), self.chunk_reader.fname) ) @use_options("flatfield", "flatfield") def _init_flatfield(self): self._ff_options = self.processing_options["flatfield"].copy() # This won't work when resuming from a step (i.e before FF), because we rely on H5Loader() # which re-compacts the data. When data is re-compacted, we have to know the original radios positions. # These positions can be saved in the "file_dump" metadata, but it is not loaded for now # (the process_config object is re-built from scratch every time) self._ff_options["projs_indices"] = self.chunk_reader.get_frames_indices() if self._ff_options.get("normalize_srcurrent", False): a_start_idx, a_end_idx = self.sub_region[0] subs = self.process_config.subsampling_factor self._ff_options["radios_srcurrent"] = self._ff_options["radios_srcurrent"][a_start_idx:a_end_idx:subs] distortion_correction = None if self._ff_options["do_flat_distortion"]: self.logger.info("Flats distortion correction will be applied") self.FlatFieldClass = FlatField # no GPU implementation available, force this backend estimation_kwargs = {} estimation_kwargs.update(self._ff_options["flat_distortion_params"]) estimation_kwargs["logger"] = self.logger distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn" ) # Reduced darks/flats are loaded, but we have to crop them on the current sub-region # and possibly do apply some pre-processing (binning, distortion correction, ...) darks_flats = load_darks_flats( self.dataset_info, self.sub_region[1:], processing_func=self._ff_processing_function, processing_func_args=self._ff_processing_function_args, ) # FlatField parameter "radios_indices" must account for subsampling self.flatfield = self.FlatFieldClass( self.radios_shape, flats=darks_flats["flats"], darks=darks_flats["darks"], radios_indices=self._ff_options["projs_indices"], interpolation="linear", distortion_correction=distortion_correction, radios_srcurrent=self._ff_options["radios_srcurrent"], flats_srcurrent=self._ff_options["flats_srcurrent"], ) @use_options("double_flatfield", "double_flatfield") def _init_double_flatfield(self): options = self.processing_options["double_flatfield"] avg_is_on_log = options["sigma"] is not None result_url = None if options["processes_file"] not in (None, ""): result_url = DataUrl( file_path=options["processes_file"], data_path=(self.dataset_info.hdf5_entry or "entry") + "/double_flatfield/results/data", ) self.logger.info("Loading double flatfield from %s" % result_url.file_path()) if (self.n_angles < self.process_config.n_angles(subsampling=True)) and result_url is None: raise ValueError( "Cannot use double-flatfield when processing subset of radios. Please use the 'nabu-double-flatfield' command" ) self.double_flatfield = self.DoubleFlatFieldClass( self.radios_shape, result_url=result_url, sub_region=self.sub_region[1:], input_is_mlog=False, output_is_mlog=False, average_is_on_log=avg_is_on_log, sigma_filter=options["sigma"], log_clip_min=options["log_min_clip"], log_clip_max=options["log_max_clip"], ) @use_options("ccd_correction", "ccd_correction") def _init_ccd_corrections(self): options = self.processing_options["ccd_correction"] self.ccd_correction = self.CCDCorrectionClass( self.radios_shape[1:], median_clip_thresh=options["median_clip_thresh"] ) @use_options("tilt_correction", "projs_rot") def _init_radios_rotation(self): options = self.processing_options["tilt_correction"] center = options["center"] if center is None: nz, nx = self.radios_shape[1:] # after binning center_x = self.process_config.rotation_axis_position(binning=True) center_z = nz / 2 - 0.5 center = (center_x, center_z) center = (center[0], center[1] - self.z_min) self.projs_rot = self.ImageRotationClass( self.radios_shape[1:], options["angle"], center=center, mode="edge", reshape=False ) self._tmp_rotated_radio = self._allocate_array(self.radios_shape[1:], "f", name="tmp_rotated_radio") @use_options("radios_movements", "radios_movements") def _init_radios_movements(self): options = self.processing_options["radios_movements"] self._vertical_shifts = options["translation_movements"][:, 1] self.radios_movements = self.VerticalShiftClass(self.radios.shape, self._vertical_shifts) @use_options("phase", "phase_retrieval") def _init_phase(self): options = self.processing_options["phase"] if options["method"] == "CTF": translations_vh = getattr(self.dataset_info, "ctf_translations", None) geo_pars_params = options["ctf_geo_pars"].copy() geo_pars_params["logger"] = self.logger geo_pars = GeoPars(**geo_pars_params) self.phase_retrieval = self.CTFPhaseRetrievalClass( self.radios_shape[1:], geo_pars, options["delta_beta"], lim1=options["ctf_lim1"], lim2=options["ctf_lim2"], logger=self.logger, fft_num_threads=None, # TODO tune in advanced params of nabu config file use_rfft=True, normalize_by_mean=options["ctf_normalize_by_mean"], translation_vh=translations_vh, ) else: self.phase_retrieval = self.PaganinPhaseRetrievalClass( self.radios_shape[1:], distance=options["distance_m"], energy=options["energy_kev"], delta_beta=options["delta_beta"], pixel_size=options["pixel_size_m"], padding=options["padding_type"], # TODO tune in advanced params of nabu config file fft_num_threads=None, ) @use_options("unsharp_mask", "unsharp_mask") def _init_unsharp(self): options = self.processing_options["unsharp_mask"] self.unsharp_mask = self.UnsharpMaskClass( self.radios_shape[1:], options["unsharp_sigma"], options["unsharp_coeff"], mode="reflect", method=options["unsharp_method"], ) @use_options("take_log", "mlog") def _init_mlog(self): options = self.processing_options["take_log"] self.mlog = self.MLogClass( self.radios_shape, clip_min=options["log_min_clip"], clip_max=options["log_max_clip"] ) @use_options("sino_normalization", "sino_normalization") def _init_sino_normalization(self): options = self.processing_options["sino_normalization"] self.sino_normalization = self.SinoNormalizationClass( kind=options["method"], radios_shape=self.radios_cropped_shape, normalization_array=options["normalization_array"], ) @use_options("sino_rings_correction", "sino_deringer") def _init_sino_rings_correction(self): n_a, n_z, n_x = self.radios_cropped_shape sinos_shape = (n_z, n_a, n_x) options = self.processing_options["sino_rings_correction"] destriper_params = extract_parameters(options["user_options"]) if options["method"] == "munch": # TODO MunchDeringer does not have an API consistent with the other deringers fw_sigma = destriper_params.pop("sigma", 1.0) self.sino_deringer = self.MunchDeringerClass(fw_sigma, sinos_shape, **destriper_params) elif options["method"] == "vo": self.sino_deringer = self.VoDeringerClass(sinos_shape, **destriper_params) elif options["method"] == "mean-subtraction": self.sino_deringer = self.SinoMeanDeringerClass( sinos_shape, mode="subtract", fft_num_threads=None, **destriper_params ) elif options["method"] == "mean-division": self.sino_deringer = self.SinoMeanDeringerClass( sinos_shape, mode="divide", fft_num_threads=None, **destriper_params ) @use_options("reconstruction", "reconstruction") def _init_reconstruction(self): options = self.processing_options["reconstruction"] if options["method"] == "FBP" and self.FBPClass is None: raise ValueError("No usable FBP module was found") if options["method"] == "cone" and self.ConebeamClass is None: raise ValueError("No usable cone-beam module was found") if options["method"] == "mlem" and self.MLEMClass is None: raise ValueError("No usable MLEM module was found.") n_slices = self.n_slices if options["method"] in ["FBP", "HBP"]: # both have the same API rec_cls = self.HBPClass if options["method"] == "HBP" else self.FBPClass self.reconstruction = rec_cls( self.sinos_shape[1:], angles=options["angles"], rot_center=options["rotation_axis_position"], filter_name=options["fbp_filter_type"] or "none", halftomo=options["enable_halftomo"], slice_roi=self.process_config.rec_roi, padding_mode=options["padding_type"], extra_options={ "scale_factor": 1.0 / options["voxel_size_cm"][0], "axis_correction": options["axis_correction"], "centered_axis": options["centered_axis"], "clip_outer_circle": options["clip_outer_circle"], "outer_circle_value": options["outer_circle_value"], "filter_cutoff": options["fbp_filter_cutoff"], "hbp_legs": options["hbp_legs"], "hbp_reduction_steps": options["hbp_reduction_steps"], }, ) if options["method"] == "cone": n_slices = self.n_slices + sum(self.margin[0]) # For numerical stability, normalize all lengths with respect to detector pixel size pixel_size_m = self.dataset_info.pixel_size * 1e-6 source_sample_dist = options["source_sample_dist"] / pixel_size_m sample_detector_dist = options["sample_detector_dist"] / pixel_size_m self.reconstruction = self.ConebeamClass( # pylint: disable=E1102 (self.radios_shape[1],) + self.sino_shape, source_sample_dist, sample_detector_dist, angles=-options["angles"], rot_center=options["rotation_axis_position"], pixel_size=1, padding_mode=options["padding_type"], slice_roi=self.process_config.rec_roi, extra_options={ "scale_factor": 1.0 / options["voxel_size_cm"][0], "axis_correction": -options["axis_correction"] if options["axis_correction"] is not None else None, "clip_outer_circle": options["clip_outer_circle"], "outer_circle_value": options["outer_circle_value"], "filter_cutoff": options["fbp_filter_cutoff"], }, ) if options["method"] == "mlem" and options["implementation"] in (None, "corrct"): self.reconstruction = self.MLEMClass( # pylint: disable=E1102 (self.radios_shape[1],) + self.sino_shape, angles_rad=-options["angles"], # WARNING: mind the sign... shifts_uv=self.dataset_info.translations, # In config file, one line per proj, each line is (tu,tv). Corrct expects one col per proj and (tv,tu). cor=options["rotation_axis_position"], n_iterations=options["iterations"], extra_options={ "compute_shifts": False, "tomo_consistency": False, "v_min_for_v_shifts": 0, "v_max_for_v_shifts": None, "v_min_for_u_shifts": 0, "v_max_for_u_shifts": None, }, ) self._allocate_recs(*self.process_config.rec_shape, n_slices=n_slices) n_a, _, n_x = self.radios_cropped_shape self._tmp_sino = self._allocate_array((n_a, n_x), "f", name="tmp_sino") @use_options("histogram", "histogram") def _init_histogram(self): options = self.processing_options["histogram"] self.histogram = self.HistogramClass(method="fixed_bins_number", num_bins=options["histogram_bins"]) @use_options("save", "writer") def _init_writer(self, **extra_options): options = self.processing_options["save"] metadata = { "process_name": "reconstruction", "processing_index": 0, # TODO this one takes too much time to write, not useful for partial files # "processing_options": self.processing_options, # "nabu_config": self.process_config.nabu_config, "entry": getattr(self.dataset_info.dataset_scanner, "entry", "entry"), } writer_extra_options = { "jpeg2000_compression_ratio": options["jpeg2000_compression_ratio"], "float_clip_values": options["float_clip_values"], "tiff_single_file": options.get("tiff_single_file", False), "single_output_file_initialized": getattr( self.process_config, "single_output_file_initialized", False ), # COMPAT. "writer_initialized": getattr(self.process_config, "_writer_initialized", False), "raw_vol_metadata": {"voxelSize": self.dataset_info.pixel_size}, # legacy... } writer_extra_options.update(extra_options) self.writer = WriterManager( options["location"], options["file_prefix"], file_format=options["file_format"], overwrite=options["overwrite"], start_index=self.get_slice_start_index(), logger=self.logger, metadata=metadata, histogram=("histogram" in self.processing_steps), extra_options=writer_extra_options, ) # # Pipeline execution # @pipeline_step("chunk_reader", "Reading data") def _read_data(self): self.logger.debug("Region = %s" % str(self.sub_region)) t0 = time() self.chunk_reader.load_data(output=self.radios) el = time() - t0 self.logger.info("Read subvolume %s in %.2f s" % (str(self.radios.shape), el)) @pipeline_step("flatfield", "Applying flat-field") def _flatfield(self): self.flatfield.normalize_radios(self.radios) @pipeline_step("double_flatfield", "Applying double flat-field") def _double_flatfield(self, radios=None): if radios is None: radios = self.radios self.double_flatfield.apply_double_flatfield(radios) @pipeline_step("ccd_correction", "Applying CCD corrections") def _ccd_corrections(self, radios=None): if radios is None: radios = self.radios _tmp_radio = self._allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio") for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] @pipeline_step("projs_rot", "Rotating projections") def _rotate_projections(self, radios=None): if radios is None: radios = self.radios tmp_radio = self._tmp_rotated_radio for i in range(radios.shape[0]): self.projs_rot.rotate(radios[i], output=tmp_radio) radios[i][:] = tmp_radio[:] @pipeline_step("phase_retrieval", "Performing phase retrieval") def _retrieve_phase(self): for i in range(self.radios.shape[0]): self.phase_retrieval.retrieve_phase(self.radios[i], output=self.radios[i]) @pipeline_step("unsharp_mask", "Performing unsharp mask") def _apply_unsharp(self): for i in range(self.radios.shape[0]): self.radios[i] = self.unsharp_mask.unsharp(self.radios[i]) @pipeline_step("mlog", "Taking logarithm") def _take_log(self): self.mlog.take_logarithm(self.radios) @pipeline_step("radios_movements", "Applying radios movements") def _radios_movements(self, radios=None): if radios is None: radios = self.radios self.radios_movements.apply_vertical_shifts(radios, list(range(radios.shape[0]))) def _crop_radios(self): if self.use_margin: self._orig_radios = self.radios if self.processing_options.get("reconstruction", {}).get("method", None) in ("cone",): return ((U, D), (L, R)) = self.margin self.logger.debug( "Cropping radios from %s to %s" % (str(self.radios_shape), str(self.radios_cropped_shape)) ) U, D, L, R = U or None, -D or None, L or None, -R or None self.radios = self.radios[:, U:D, L:R] # view self._radios_were_cropped = True @pipeline_step("sino_normalization", "Normalizing sinograms") def _normalize_sinos(self, radios=None): if radios is None: radios = self.radios sinos = radios.transpose((1, 0, 2)) self.sino_normalization.normalize(sinos) def _dump_sinogram(self): if self.datadump_manager is not None: self.datadump_manager.dump_data_to_file("sinogram", self.radios) @pipeline_step("sino_deringer", "Removing rings on sinograms") def _destripe_sinos(self): sinos = np.rollaxis(self.radios, 1, 0) # view self.sino_deringer.remove_rings(sinos) # TODO check it works with non-contiguous view @pipeline_step("reconstruction", "Reconstruction") def _reconstruct(self): """ Reconstruction for parallel geometry. For each target slice: get the corresponding sinogram, apply some processing, then reconstruct """ options = self.processing_options["reconstruction"] if options["method"] == "cone": self._reconstruct_cone() return if options["method"] == "mlem": self.recs = self._reconstruct_mlem() return for i in range(self.n_slices): self._tmp_sino[:] = self.radios[:, i, :] # copy into contiguous array self.reconstruction.fbp(self._tmp_sino, output=self.recs[i]) def _reconstruct_cone(self): """ This reconstructs the entire sinograms stack at once """ n_angles, n_z, n_x = self.radios.shape # FIXME # can't do a discontiguous single copy... sinos_contig = self._allocate_array((n_z, n_angles, n_x), np.float32, "sinos_cone") for i in range(n_z): sinos_contig[i] = self.radios[:, i, :] # --- # In principle radios are not cropped at this stage, # so self.sub_region[2][0] can be used instead of self.get_slice_start_index() instead of self.sub_region[2][0] z_min, z_max = self.sub_region_xz[2:] n_z_tot = self.process_config.radio_shape(binning=True)[0] self.reconstruction.reconstruct( # pylint: disable=E1101 sinos_contig, output=self.recs, relative_z_position=((z_min + z_max) / self.process_config.binning_z / 2) - n_z_tot / 2, ) def _reconstruct_mlem(self): """ This reconstructs the entire sinograms stack at once """ n_angles, n_z, n_x = self.radios.shape # FIXME # can't do a discontiguous single copy... # Initially done for Astra CB recons. But happens that MLEM Corrct also expects # data with this order (nb_rows, nb_angles, nb_cols) data_vwu = self._allocate_array((n_z, n_angles, n_x), np.float32, "sinos_mlem") for i in range(n_z): data_vwu[i] = self.radios[:, i, :] # --- return self.reconstruction.reconstruct( # pylint: disable=E1101 data_vwu, ) @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None): if data is None: data = self.recs self.recs_histogram = self.histogram.compute_histogram(data) @pipeline_step("writer", "Saving data") def _write_data(self, data=None): if data is None and self.reconstruction is not None: data = self.recs if data is None: self.logger.info("No data to write") return self.writer.write_data(data) self.logger.info("Wrote %s" % self.writer.fname) self._write_histogram() self.process_config.single_output_file_initialized = True # COMPAT. self.process_config._writer_initialized = True def _write_histogram(self): if "histogram" not in self.processing_steps: return self.logger.info("Saving histogram") self.writer.write_histogram( hist_as_2Darray(self.recs_histogram), processing_index=1, config={ "file": path.basename(self.writer.fname), "bins": self.processing_options["histogram"]["histogram_bins"], }, ) def _process_finalize(self): if self.use_margin: self.radios = self._orig_radios def __repr__(self): res = "%s(%s, margin=%s)" % (self.__class__.__name__, str(self.chunk_shape), str(self.margin)) binning = self.process_config.binning subsampling = self.process_config.subsampling_factor if binning != (1, 1) or subsampling > 1: if binning != (1, 1): res += "\nImages binning: %s" % (str(binning)) if subsampling: res += "\nAngles subsampling: %d" % subsampling res += "\nRadios chunk: %s ---> %s" % (self.chunk_shape, self.radios_shape) if self.use_margin: res += "\nMargin: %s" % (str(self.margin)) res += "\nRadios chunk: %s ---> %s" % (str(self.radios_shape), str(self.radios_cropped_shape)) res += "\nCurrent subregion: %s" % (str(self.sub_region)) for step_name in self.processing_steps: res += "\n- %s" % (step_name) return res def _process_chunk(self): self._flatfield() self._double_flatfield() self._ccd_corrections() self._rotate_projections() self._retrieve_phase() self._apply_unsharp() self._take_log() self._radios_movements() self._crop_radios() self._normalize_sinos() self._destripe_sinos() self._dump_sinogram() self._reconstruct() self._compute_histogram() self._write_data() self._process_finalize() def _reset_reader_subregion(self): if self._resume_from_step is not None: self.chunk_reader._set_subregion(self.datadump_manager.get_read_dump_subregion()) self._init_data_dump() self._init_reader() def _reset_sub_region(self, sub_region): self.set_subregion(sub_region) self._reset_reader_subregion() self._init_flatfield() # reset flatfield self._init_writer() self._init_double_flatfield() self._init_data_dump() def process_chunk(self, sub_region): self._reset_sub_region(sub_region) self._read_data() self._process_chunk() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/pipeline/fullfield/chunked_cuda.py0000644000175000017500000001333014720040151021746 0ustar00pierrepierrefrom ...preproc.ccd_cuda import CudaLog, CudaCCDFilter from ...preproc.flatfield_cuda import CudaFlatField from ...preproc.shift_cuda import CudaVerticalShift from ...preproc.double_flatfield_cuda import CudaDoubleFlatField from ...preproc.phase_cuda import CudaPaganinPhaseRetrieval from ...preproc.ctf_cuda import CudaCTFPhaseRetrieval from ...reconstruction.sinogram_cuda import CudaSinoBuilder, CudaSinoNormalization from ...reconstruction.filtering_cuda import CudaSinoFilter from ...reconstruction.rings_cuda import CudaMunchDeringer, CudaSinoMeanDeringer, CudaVoDeringer from ...processing.unsharp_cuda import CudaUnsharpMask from ...processing.rotation_cuda import CudaRotation from ...processing.histogram_cuda import CudaPartialHistogram from ...reconstruction.fbp import Backprojector from ...reconstruction.hbp import HierarchicalBackprojector from ...reconstruction.cone import __have_astra__, ConebeamReconstructor from ...cuda.utils import get_cuda_context, __has_pycuda__, __pycuda_error_msg__ from ..utils import pipeline_step from .chunked import ChunkedPipeline if __has_pycuda__: import pycuda.gpuarray as garray if not (__have_astra__): ConebeamReconstructor = None class CudaChunkedPipeline(ChunkedPipeline): """ Cuda backend of ChunkedPipeline """ backend = "cuda" FlatFieldClass = CudaFlatField DoubleFlatFieldClass = CudaDoubleFlatField CCDCorrectionClass = CudaCCDFilter PaganinPhaseRetrievalClass = CudaPaganinPhaseRetrieval CTFPhaseRetrievalClass = CudaCTFPhaseRetrieval UnsharpMaskClass = CudaUnsharpMask ImageRotationClass = CudaRotation VerticalShiftClass = CudaVerticalShift MunchDeringerClass = CudaMunchDeringer SinoMeanDeringerClass = CudaSinoMeanDeringer VoDeringerClass = CudaVoDeringer MLogClass = CudaLog SinoBuilderClass = CudaSinoBuilder SinoNormalizationClass = CudaSinoNormalization SinoFilterClass = CudaSinoFilter FBPClass = Backprojector ConebeamClass = ConebeamReconstructor HBPClass = HierarchicalBackprojector HistogramClass = CudaPartialHistogram def __init__( self, process_config, chunk_shape, logger=None, extra_options=None, margin=None, use_grouped_mode=False, cuda_options=None, ): self._init_cuda(cuda_options) super().__init__( process_config, chunk_shape, logger=logger, extra_options=extra_options, use_grouped_mode=use_grouped_mode, margin=margin, ) self._allocate_array(self.radios.shape, "f", name="radios") self._determine_when_to_transfer_data_on_gpu() def _determine_when_to_transfer_data_on_gpu(self): # Decide when to transfer data to GPU. Normally it's right after reading the data, # But sometimes a part of the processing is done on CPU. self._when_to_transfer_radios_on_gpu = "read_data" if self.flatfield is not None and self.flatfield.distortion_correction is not None: self._when_to_transfer_radios_on_gpu = "flatfield" def _init_cuda(self, cuda_options): if not (__has_pycuda__): raise ImportError(__pycuda_error_msg__) cuda_options = cuda_options or {} self.ctx = get_cuda_context(**cuda_options) self._d_radios = None self._d_sinos = None self._d_recs = None def _allocate_array(self, shape, dtype, name=None): name = name or "tmp" # should be mandatory d_name = "_d_" + name d_arr = getattr(self, d_name, None) if d_arr is None: self.logger.debug("Allocating %s: %s" % (name, str(shape))) d_arr = garray.zeros(shape, dtype) # pylint: disable=E0606 setattr(self, d_name, d_arr) return d_arr def _transfer_radios_to_gpu(self): self.logger.debug("Transfering radios to GPU") self._d_radios.set(self.radios) self._h_radios = self.radios self.radios = self._d_radios def _process_finalize(self): self.radios = self._h_radios # # Pipeline execution (class specialization) # def _read_data(self): super()._read_data() if self._when_to_transfer_radios_on_gpu == "read_data": self._transfer_radios_to_gpu() def _flatfield(self): super()._flatfield() if self._when_to_transfer_radios_on_gpu == "flatfield": self._transfer_radios_to_gpu() def _reconstruct(self): super()._reconstruct() if "reconstruction" not in self.processing_steps: return if self.processing_options["reconstruction"]["method"] == "cone": ((U, D), (L, R)) = self.margin U, D = U or None, -D or None # not sure why slicing can't be done before get() self.recs = self.recs.get()[U:D, ...] elif self.processing_options["reconstruction"]["method"] == "mlem": pass else: self.recs = self.recs.get() def _write_data(self, data=None): super()._write_data(data=data) if "reconstruction" in self.processing_steps: self.recs = self._d_recs self.radios = self._h_radios @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None): if data is None: data = self._d_recs self.recs_histogram = self.histogram.compute_histogram(data) def _dump_data_to_file(self, step_name, data=None): if data is None: data = self.radios if step_name not in self.datadump_manager.data_dump: return if isinstance(data, garray.GPUArray): data = data.get() self.datadump_manager.dump_data_to_file(step_name, data=data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723189827.0 nabu-2024.2.1/nabu/pipeline/fullfield/computations.py0000644000175000017500000002336214655345103022100 0ustar00pierrepierrefrom math import ceil from silx.image.tomography import get_next_power from ...utils import check_supported def estimate_required_memory( process_config, delta_z=None, delta_a=None, max_mem_allocation_GB=None, fft_plans=True, debug=False ): """ Estimate the memory (RAM) in Bytes needed for a reconstruction. Parameters ----------- process_config: `ProcessConfig` object Data structure with the processing configuration delta_z: int, optional How many lines are to be loaded in the each projection image. Default is to load lines [start_z:end_z] (where start_z and end_z come from the user configuration file) delta_a: int, optional How many (partial) projection images to load at the same time. Default is to load all the projection images. max_mem_allocation_GB: float, optional Maximum amount of memory in GB for one single array. Returns ------- required_memory: float Total required memory (in bytes). Raises ------ ValueError if one single-array allocation exceeds "max_mem_allocation_GB" Notes ----- pycuda <= 2022.1 cannot use arrays with more than 2**32 items (i.e 17.18 GB for float32). This was solved in more recent versions. """ def check_memory_limit(mem_GB, name): if max_mem_allocation_GB is None: return if mem_GB > max_mem_allocation_GB: raise ValueError( "Cannot allocate array '%s' %.3f GB > max_mem_allocation_GB = %.3f GB" % (name, mem_GB, max_mem_allocation_GB) ) dataset = process_config.dataset_info processing_steps = process_config.processing_steps # The "x" dimension (last axis) is always the image width, because # - Processing is never "cut" along this axis (we either split along frames, or images lines) # - Even if we want to reconstruct a vertical slice (i.e end_x - start_x == 1), # the tomography reconstruction will need the full information along this axis. # The only case where reading a x-subregion would be useful is to crop the initial data # (i.e knowing that a part of each image will be completely unused). This is not supported yet. Nx = process_config.radio_shape(binning=True)[-1] if delta_z is not None: Nz = delta_z // process_config.binning_z else: Nz = process_config.rec_delta_z # accounting for binning if delta_a is not None: Na = ceil(delta_a / process_config.subsampling_factor) else: Na = process_config.n_angles(subsampling=True) total_memory_needed = 0 # Read data # ---------- data_image_size = Nx * Nz * 4 data_volume_size = Na * data_image_size check_memory_limit(data_volume_size / 1e9, "projections") total_memory_needed += data_volume_size # CCD processing # --------------- if "flatfield" in processing_steps: # Flat-field is done in-place, but still need to load darks/flats n_darks = len(dataset.darks) n_flats = len(dataset.flats) total_memory_needed += (n_darks + n_flats) * data_image_size if "ccd_correction" in processing_steps: # CCD filter is "batched 2D" total_memory_needed += data_image_size # Phase retrieval # --------------- if "phase" in processing_steps: # Phase retrieval is done image-wise, so near in-place, but needs to allocate some memory: # filter with padded shape, radio_padded, radio_padded_fourier, and possibly FFT plan. # CTF phase retrieval uses "2 filters" (num and denom) but let's neglect this. Nx_p = get_next_power(2 * Nx) Nz_p = get_next_power(2 * Nz) img_size_real = Nx_p * Nz_p * 4 img_size_cplx = ((Nx_p * Nz_p) // 2 + 1) * 8 # assuming RFFT factor = 1 if fft_plans: factor = 2 total_memory_needed += (2 * img_size_real + img_size_cplx) * factor # Sinogram de-ringing # ------------------- if "sino_rings_correction" in processing_steps: method = process_config.processing_options["sino_rings_correction"]["method"] if method == "munch": # Process is done image-wise. # Needs one Discrete Wavelets transform and one FFT/IFFT plan for each scale factor = 2 if not (fft_plans) else 5.5 # approx! total_memory_needed += (Nx * Na * 4) * factor elif method == "vo": # cupy-based implementation makes many calls to "scipy-like" functions, where the memory usage is not under control # TODO try to estimate this pass # Reconstruction # --------------- reconstructed_volume_size = 0 if "reconstruction" in processing_steps: rec_config = process_config.processing_options["reconstruction"] Nx_rec = rec_config["end_x"] - rec_config["start_x"] + 1 Ny_rec = rec_config["end_y"] - rec_config["start_y"] + 1 reconstructed_volume_size = Nz * Nx_rec * Ny_rec * 4 check_memory_limit(reconstructed_volume_size / 1e9, "reconstructions") total_memory_needed += reconstructed_volume_size if process_config.rec_params["method"] == "cone": # In cone-beam reconstruction, need both sinograms and reconstruction inside GPU. # That's big! total_memory_needed += 2 * data_volume_size if debug: print( "Mem for (delta_z=%s, delta_a=%s) ==> (Na=%d, Nz=%d, Nx=%d) : %.3f GB" % (delta_z, delta_a, Na, Nz, Nx, total_memory_needed / 1e9) ) return total_memory_needed def estimate_max_chunk_size( available_memory_GB, process_config, pipeline_part="all", n_rows=None, step=10, max_mem_allocation_GB=None, fft_plans=True, debug=False, ): """ Estimate the maximum size of the data chunk that can be loaded in memory. Parameters ---------- available_memory_GB: float available memory in Giga Bytes (GB - not GiB !). process_config: ProcessConfig ProcessConfig object pipeline_part: str Which pipeline part to consider. Possible options are: - "full": Account for all the pipeline steps (reading data all the way to reconstruction). - "radios": Consider only the processing steps on projection images (ignore sinogram-based steps and reconstruction) - "sinogram": Consider only the processing steps related to sinograms and reconstruction n_rows: int, optional How many lines to load in each projection. Only accounted for pipeline_part="radios". step: int, optional Step size when doing the iterative memory estimation max_mem_allocation_GB: float, optional Maximum size (in GB) for one single array. Returns ------- n_max: int If pipeline_par is "full" or "sinos": return the maximum number of lines that can be loaded in all the projections while fitting memory, i.e `data[:, 0:n_max, :]` If pipeline_part is "radios", return the maximum number of (partial) images that can be loaded while fitting memory, i.e `data[:, zmin:zmax, 0:n_max]` Notes ----- pycuda <= 2022.1 cannot use arrays with more than 2**32 items (i.e 17.18 GB for float32). This was solved in more recent versions. """ supported_pipeline_parts = ["all", "radios", "sinos"] check_supported(pipeline_part, supported_pipeline_parts, "pipeline_part") processing_steps_bak = process_config.processing_steps.copy() reconstruction_steps = ["sino_rings_correction", "reconstruction"] if pipeline_part == "all": # load lines from all the projections delta_a = None delta_z = 0 if pipeline_part == "radios": # order should not matter process_config.processing_steps = list(set(process_config.processing_steps) - set(reconstruction_steps)) # load lines from only a subset of projections delta_a = 0 delta_z = n_rows if pipeline_part == "sinos": process_config.processing_steps = [ step for step in process_config.processing_steps if step in reconstruction_steps ] # load lines from all the projections delta_a = None delta_z = 0 mem = 0 # pylint: disable=E0606, E0601 last_valid_delta_a = delta_a last_valid_delta_z = delta_z while True: try: mem = estimate_required_memory( process_config, delta_z=delta_z, delta_a=delta_a, max_mem_allocation_GB=max_mem_allocation_GB, fft_plans=fft_plans, debug=debug, ) except ValueError: # For very big dataset this function might return "0". # Either start at 1, or use a smaller step... break if mem / 1e9 > available_memory_GB: break if delta_a is not None and delta_a > process_config.n_angles(): break if delta_z is not None and delta_z > process_config.radio_shape()[0]: break last_valid_delta_a, last_valid_delta_z = delta_a, delta_z if pipeline_part == "radios": delta_a += step else: delta_z += step process_config.processing_steps = processing_steps_bak if pipeline_part != "radios": if mem / 1e9 < available_memory_GB: res = min(delta_z, process_config.radio_shape()[0]) else: res = last_valid_delta_z else: if mem / 1e9 < available_memory_GB: res = min(delta_a, process_config.n_angles()) else: res = last_valid_delta_a # Really not ideal. For very large dataset, "step" should be very small. # Otherwise we go from 0 -> OK to 10 -> not OK, and then retain 0... if res == 0: res = 1 # return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/pipeline/fullfield/dataset_validator.py0000644000175000017500000000557714656662150023062 0ustar00pierrepierreimport os from ..dataset_validator import DatasetValidatorBase class FullFieldDatasetValidator(DatasetValidatorBase): def _validate(self): self._check_not_empty() self._convert_negative_indices() self._get_output_filename() self._check_can_do_flatfield() self._check_can_do_phase() self._check_can_do_reconstruction() self._check_slice_indices() self._handle_processing_mode() self._handle_binning() self._check_output_file() def _check_can_do_flatfield(self): if self.nabu_config["preproc"]["flatfield"]: darks = self.dataset_info.darks assert len(darks) > 0, "Need at least one dark to perform flat-field correction" flats = self.dataset_info.flats assert len(flats) > 0, "Need at least one flat to perform flat-field correction" def _check_slice_indices(self): nx, nz = self.dataset_info.radio_dims rec_params = self.rec_params if self.is_halftomo: ny, nx = self._get_nx_ny() what = (("start_x", "end_x", nx), ("start_y", "end_y", nx), ("start_z", "end_z", nz)) for start_name, end_name, numels in what: self._check_start_end_idx( rec_params[start_name], rec_params[end_name], numels, start_name=start_name, end_name=end_name ) def _check_can_do_phase(self): if self.nabu_config["phase"]["method"] is None: return self.dataset_info.check_defined_attribute("distance") self.dataset_info.check_defined_attribute("pixel_size") def _check_can_do_reconstruction(self): rec_options = self.nabu_config["reconstruction"] if rec_options["method"] is None: return self.dataset_info.check_defined_attribute("pixel_size") if rec_options["method"] == "cone": if rec_options["source_sample_dist"] is None: err_msg = "In cone-beam reconstruction, you have to provide 'source_sample_dist' in [reconstruction]" self.logger.fatal(err_msg) raise ValueError(err_msg) if rec_options["sample_detector_dist"] is None: if self.dataset_info.distance is None: err_msg = "Cone-beam reconstruction: 'sample_detector_dist' was not provided but could not be found in the dataset metadata either. Please provide 'sample_detector_dist'" self.logger.fatal(err_msg) raise ValueError(err_msg) self.logger.warning( "Cone-beam reconstruction: 'sample_detector_dist' not provided, will use the one in dataset metadata" ) if self.is_halftomo: err_msg = "Cone-beam reconstruction with half-acquisition is not supported yet" self.logger.fatal(err_msg) raise NotImplementedError(err_msg) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/pipeline/fullfield/nabu_config.py0000644000175000017500000007620514720040151021615 0ustar00pierrepierrefrom ..config_validators import * nabu_config = { "dataset": { "location": { "default": "", "help": "Dataset location, either a directory or a HDF5-Nexus file.", "validator": dataset_location_validator, "type": "required", }, "hdf5_entry": { "default": "", "help": "Which entry to process in the data HDF5 file. Default is the first entry. It can be a comma-separated list of entries, and/or a wildcard (* for all entries, or things like entry???1).", "validator": optional_string_validator, "type": "advanced", }, "nexus_version": { "default": "1.0", "help": "Nexus version to use when browsing the HDF5 dataset. Default is 1.0.", "validator": float_validator, "type": "advanced", }, "darks_flats_dir": { "default": "", "help": "Path to a directory where XXX_flats.h5 and XXX_darks.h5 are to be found, where 'XXX' denotes the dataset basename. If these files are found, then reduced flats/darks will be loaded from them. Otherwise, reduced flats/darks will be saved to there once computed, either in the .nx directory, or in the output directory. Mind that the HDF5 entry corresponds to the one of the dataset.", "validator": optional_directory_location_validator, "type": "optional", }, "binning": { "default": "1", "help": "Binning factor in the horizontal dimension when reading the data.\nThe final slices dimensions will be divided by this factor.", "validator": binning_validator, "type": "advanced", }, "binning_z": { "default": "1", "help": "Binning factor in the vertical dimension when reading the data.\nThis results in a lesser number of reconstructed slices.", "validator": binning_validator, "type": "advanced", }, "projections_subsampling": { "default": "1", "help": "Projections subsampling factor: take one projection out of 'projection_subsampling'. The format can be an integer (take 1 projection out of N), or N:M (take 1 projection out of N, start with the projection number M)\nFor example: 2 (or 2:0) to reconstruct from even projections, 2:1 to reconstruct from odd projections.", "validator": projections_subsampling_validator, "type": "advanced", }, "exclude_projections": { "default": "", "help": "Projection to exclude from the reconstruction. It can be:\n-indices = exclude_projections_indices.txt : Path to a text file with one integer per line. Each corresponding projection INDEX will be ignored.\n-angles = exclude_projections_angles.txt : Path to a text file with angle in DEGREES, one per line. The corresponding angles will be ignored\n-angular_range = [a, b] : ignore angles belonging to angular range [a, b] in degrees, with b included.", "validator": exclude_projections_validator, "type": "advanced", }, "overwrite_metadata": { "default": "", "help": "Which metadata to overwrite, separated by a semicolon, and with units. Example: 'energy = 19 kev; pixel_size = 1.6 um'", "validator": no_validator, "type": "advanced", }, }, "preproc": { "flatfield": { "default": "1", "help": "How to perform flat-field normalization. The parameter value can be:\n - 1 or True: enabled.\n - 0 or False: disabled\n - forced or force-load: perform flatfield regardless of the dataset by attempting to load darks/flats\n - force-compute: perform flatfield, ignore all .h5 files containing already computed darks/flats.", "validator": flatfield_enabled_validator, "type": "required", }, "flat_distortion_correction_enabled": { "default": "0", "help": "Whether to correct for flat distortion. If activated, each radio is correlated with its corresponding flat, in order to determine and correct the flat distortion.", "validator": boolean_validator, "type": "advanced", }, "flat_distortion_params": { "default": "tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None", "help": "Advanced parameters for flat distortion correction", "validator": optional_string_validator, "type": "advanced", }, "normalize_srcurrent": { "default": "1", "help": "Whether to normalize frames with Synchrotron Current. This can correct the effect of a beam refill not taken into account by flats.", "validator": boolean_validator, "type": "advanced", }, "ccd_filter_enabled": { "default": "0", "help": "Whether to enable the CCD hotspots correction.", "validator": boolean_validator, "type": "optional", }, "ccd_filter_threshold": { "default": "0.04", "help": "If ccd_filter_enabled = 1, a median filter is applied on the 3X3 neighborhood\nof every pixel. If a pixel value exceeds the median value more than this parameter,\nthen the pixel value is replaced with the median value.", "validator": float_validator, "type": "optional", }, "detector_distortion_correction": { "default": "", "help": "Apply coordinate transformation on the raw data, at the reading stage. Default (empty) is None. Available are: None, identity(for testing the pipeline), map_xz. This latter method requires two URLs being passed by detector_distortion_correction_options: map_x and map_z pointing to two 2D arrays containing the position where each pixel can be interpolated at in the raw data", "validator": detector_distortion_correction_validator, "type": "advanced", }, "detector_distortion_correction_options": { "default": "", "help": """Options for detector_distortion_correction. Example, for mapx_xz: detector_distortion_correction_options=map_x="silx:./dm.h5?path=/coords_source_x" ; map_z="silx:./dm.h5?path=/coords_source_z" Mind the semicolon separator (;). """, "validator": generic_options_validator, "type": "advanced", }, "double_flatfield_enabled": { "default": "0", "help": "Whether to enable the 'double flat-field' filetering for correcting rings artefacts.", "validator": boolean_validator, "type": "optional", }, "dff_sigma": { "default": "", "help": "Enable high-pass filtering on double flatfield with this value of 'sigma'", "validator": optional_float_validator, "type": "advanced", }, "take_logarithm": { "default": "1", "help": "Whether to take logarithm after flat-field and phase retrieval.", "validator": boolean_validator, "type": "required", }, "log_min_clip": { "default": "1e-6", "help": "After division by the FF, and before the logarithm, the is clipped to this minimum. Enabled only if take_logarithm=1", "validator": float_validator, "type": "advanced", }, "log_max_clip": { "default": "10.0", "help": "After division by the FF, and before the logarithm, the is clipped to this maximum. Enabled only if take_logarithm=1", "validator": float_validator, "type": "advanced", }, "sino_normalization": { "default": "", "help": "Sinogram normalization method. Available methods are: chebyshev, subtraction, division, none. Default is none (no normalization)", "validator": sino_normalization_validator, "type": "advanced", }, "sino_normalization_file": { "default": "", "help": "Path to the file when sino_normalization is either 'subtraction' or 'division'. To specify the path within a HDF5 file, the syntax is /path/to/file?path=/entry/data", "validator": no_validator, "type": "advanced", }, "processes_file": { "default": "", "help": "Path to the file where some operations should be stored for later use. By default it is 'xxx_nabu_processes.h5'", "validator": optional_output_file_path_validator, "type": "advanced", }, "sino_rings_correction": { "default": "", "help": "Sinogram rings removal method. Default (empty) is None. Available are: None, munch, vo, mean-subtraction, mean-division. See also: sino_rings_options", "validator": sino_deringer_methods, "type": "optional", }, "sino_rings_options": { "default": "", "help": "Options for sinogram rings correction methods. The parameters are separated by commas and passed as 'name=value'. Mind the semicolon separator (;). The default options are the following:\n-For munch: sigma=1.0 ; levels=10 ; padding=False\n-For vo: snr=3.0; la_size=51; sm_size=21; dim=1\n-For mean-subtraction and mean-division: filter_cutoff=(0, 30)", "validator": generic_options_validator, "type": "advanced", }, "rotate_projections_center": { "default": "", "help": "Center of rotation when 'tilt_correction' is non-empty. By default the center of rotation is the middle of each radio, i.e ((Nx-1)/2.0, (Ny-1)/2.0).", "validator": optional_tuple_of_floats_validator, "type": "advanced", }, "tilt_correction": { "default": "", "help": "Detector tilt correction. Default (empty) means no tilt correction.\nThe following values can be provided for automatic tilt estimation, in this case, the projection images are rotated by the found tilt value:\n - A scalar value: tilt correction angle in degrees\n - 1d-correlation: auto-detect tilt with the 1D correlation method (fastest, but works best for small tilts)\n - fft-polar: auto-detect tilt with polar FFT method (slower, but works well on all ranges of tilts)", "validator": tilt_validator, "type": "advanced", }, "autotilt_options": { "default": "", "help": "Options for methods computing automatically the detector tilt. The parameters are separated by commas and passed as 'name=value', for example: low_pass=1; high_pass=20. Mind the semicolon separator (;). Use 'value' ('') for values that are strings", "validator": generic_options_validator, "type": "advanced", }, }, "phase": { "method": { "default": "none", "help": "Phase retrieval method. Available are: Paganin, CTF, None", "validator": phase_method_validator, "type": "required", }, "delta_beta": { "default": "100.0", "help": "Single-distance phase retrieval related parameters\n----------------------------\ndelta/beta ratio for the Paganin/CTF method", "validator": float_validator, "type": "required", }, "unsharp_coeff": { "default": "0", "help": "Unsharp mask strength. The unsharped image is equal to\n UnsharpedImage = (1 + coeff)*originalPaganinImage - coeff * ConvolvedImage. Setting this coefficient to zero means that no unsharp mask will be applied.", "validator": float_validator, "type": "optional", }, "unsharp_sigma": { "default": "0", "help": "Standard deviation of the Gaussian filter when applying an unsharp mask\nafter the phase filtering. Disabled if set to 0.", "validator": float_validator, "type": "optional", }, "unsharp_method": { "default": "gaussian", "help": "Which type of unsharp mask filter to use. Available values are gaussian, laplacian and imagej. Default is gaussian.", "validator": unsharp_method_validator, "type": "optional", }, "padding_type": { "default": "edge", "help": "Padding type for the filtering step in Paganin/CTF. Available are: mirror, edge, zeros", "validator": padding_mode_validator, "type": "advanced", }, "ctf_geometry": { "default": "z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True", "help": "Geometric parameters for CTF phase retrieval. Length units are in meters.", "validator": optional_string_validator, "type": "optional", }, "ctf_advanced_params": { "default": "length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True", "help": "Advanced parameters for CTF phase retrieval.", "validator": optional_string_validator, "type": "advanced", }, }, "reconstruction": { "method": { "default": "FBP", "help": "Reconstruction method. Possible values: FBP, HBP, cone, MLEM, none. If value is 'none', no reconstruction will be done.", "validator": reconstruction_method_validator, "type": "required", }, "implementation": { "default": "", "help": "Reconstruction method implementation. The same method can have several implementations. Can be 'nabu', 'corrct', 'astra'", "validator": reconstruction_implementation_validator, "type": "advanced", }, "angles_file": { "default": "", "help": "In the case you want to override the angles found in the files metadata. The angles are in degree.", "validator": optional_file_location_validator, "type": "optional", }, "rotation_axis_position": { "default": "sliding-window", "help": "Rotation axis position. It can be a number or the name of an estimation method (empty value means the middle of the detector).\nThe following methods are available to find automatically the Center of Rotation (CoR):\n - centered : a fast and simple auto-CoR method. It only works when the CoR is not far from the middle of the detector. It does not work for half-tomography.\n - global : a slow but robust auto-CoR.\n - sliding-window : semi-automatically find the CoR with a sliding window. You have to specify on which side the CoR is (left, center, right). Please see the 'cor_options' parameter.\n - growing-window : automatically find the CoR with a sliding-and-growing window. You can tune the option with the parameter 'cor_options'.\n - sino-coarse-to-fine: Estimate CoR from sinogram. Only works for 360 degrees scans.\n - composite-coarse-to-fine: Estimate CoR from composite multi-angle images. Only works for 360 degrees scans.\n - fourier-angles: Estimate CoR from sino based on an angular correlation analysis. You can tune the option with the parameter 'cor_options'.\n - octave-accurate: Legacy from octave accurate COR estimation algorithm. It first estimates the COR with global fourier-based correlation, then refines this estimation with local correlation based on the variance of the difference patches. You can tune the option with the parameter 'cor_options'.\n - vo: Method from Nghia Vo, based on double-wedge in sinogram Fourier transform (needs algotom python package)", "validator": cor_validator, "type": "required", }, "cor_options": { "default": "side='from_file'", "help": "Options for methods finding automatically the rotation axis position. The parameters are separated by commas and passed as 'name=value'.\nFor example: low_pass=1; high_pass=20. Mind the semicolon separator (;) and the '' for string values that are strings.\nIf 'side' is set, it is expected to be either:\n - 'from_file' (to pick the value in the NX file.)\n - or an relative CoR position in pixels (if so, it overrides the value in the NX file), \n or any of 'left', 'center', 'right', 'all', 'near'.\n The default value for 'side' is 'from_file'.", "validator": generic_options_validator, "type": "advanced", }, "cor_slice": { "default": "", "help": "Which slice to use for estimating the Center of Rotation (CoR). This parameter can be an integer or 'top', 'middle', 'bottom'.\nIf provided, the CoR will be estimated from the correspondig sinogram, and 'cor_options' can contain the parameter 'subsampling'.", "validator": cor_slice_validator, "type": "advanced", }, "axis_correction_file": { "default": "", "help": "In the case where the axis position is specified for each angle", "validator": optional_values_file_validator, "type": "advanced", }, "translation_movements_file": { "default": "", "help": "A file where each line describes the horizontal and vertical translations of the sample (or detector). The order is 'horizontal, vertical'.\nIt can be created from a numpy array saved with 'numpy.savetxt'", "validator": optional_values_file_validator, "type": "advanced", }, "angle_offset": { "default": "0", "help": "Use this if you want to obtain a rotated reconstructed slice. The angle is in degrees.", "validator": float_validator, "type": "advanced", }, "fbp_filter_type": { "default": "ramlak", "help": "Filter type for FBP method. Available are: none, ramlak, shepp-logan, cosine, hamming, hann, tukey, lanczos, hilbert", "validator": fbp_filter_name_validator, "type": "advanced", }, "fbp_filter_cutoff": { "default": "1.", "help": "Cut-off frequency for Fourier filter used in FBP, in normalized units. Default is the Nyquist frequency 1.0", "validator": float_validator, "type": "advanced", }, "source_sample_dist": { "default": "", "help": "In cone-beam geometry, distance (in meters) between the X-ray source and the center of the sample. Default is infinity.", "validator": optional_float_validator, "type": "advanced", }, "sample_detector_dist": { "default": "", "help": "In cone-beam geometry, distance (in meters) between the center of the sample and the detector. Default is read from the input dataset.", "validator": optional_float_validator, "type": "advanced", }, "padding_type": { "default": "edges", "help": "Padding type for FBP. Available are: zeros, edges", "validator": padding_mode_validator, "type": "optional", # put "advanced" with default value "edges" ? }, "enable_halftomo": { "default": "auto", "help": "Whether to enable half-acquisition. Default is auto. You can enable/disable it manually by setting 1 or 0.", "validator": boolean_or_auto_validator, "type": "optional", }, "clip_outer_circle": { "default": "0", "help": "Whether to mask voxels falling outside of the reconstruction region", "validator": boolean_validator, "type": "optional", }, "outer_circle_value": { "default": "0", "help": "If 'clip_outer_circle' is enabled, value of the voxels falling outside of the reconstruction region.", "validator": float_validator, "type": "optional", }, "centered_axis": { "default": "1", "help": "If set to true, the reconstructed region is centered on the rotation axis, i.e the center of the image will be the rotation axis position.", "validator": boolean_validator, "type": "optional", }, "hbp_reduction_steps": { "default": "2", "help": "How many reduction steps will be taken. At least 2. A Higher number may increase speed but may also increase the interpolation errors", "validator": nonnegative_integer_validator, "type": "advanced", }, "hbp_legs": { "default": "4", "help": "Increasing this parameter help matching the GPU memory size for big slices. Reconstruction by fragments of the whole images. For very large slices it can be useful to increase this number to fit the memory", "validator": nonnegative_integer_validator, "type": "advanced", }, "start_x": { "default": "0", "help": "\nParameters for sub-volume reconstruction. Indices start at 0, and upper bounds are INCLUDED!\n----------------------------------------------------------------\n(x, y) are the dimension of a slice, and (z) is the 'vertical' axis\nBy default, all the volume is reconstructed slice by slice, along the axis 'z'.", "validator": nonnegative_integer_validator, "type": "optional", }, "end_x": { "default": "-1", "help": "", "validator": integer_validator, "type": "optional", }, "start_y": { "default": "0", "help": "", "validator": nonnegative_integer_validator, "type": "optional", }, "end_y": { "default": "-1", "help": "", "validator": integer_validator, "type": "optional", }, "start_z": { "default": "0", "help": "", "validator": slice_num_validator, "type": "optional", }, "end_z": { "default": "-1", "help": "", "validator": slice_num_validator, "type": "optional", }, "iterations": { "default": "200", "help": "\nParameters for iterative algorithms\n------------------------------------\nNumber of iterations", "validator": nonnegative_integer_validator, "type": "advanced", }, "optim_algorithm": { "default": "chambolle-pock", "help": "Optimization algorithm for iterative methods", "validator": optimization_algorithm_name_validator, "type": "unsupported", }, "weight_tv": { "default": "1.0e-2", "help": "Total Variation regularization parameter for iterative methods", "validator": float_validator, "type": "unsupported", }, "preconditioning_filter": { "default": "1", "help": "Whether to enable 'filter preconditioning' for iterative methods", "validator": boolean_validator, "type": "unsupported", }, "positivity_constraint": { "default": "1", "help": "Whether to enforce a positivity constraint in the reconstruction.", "validator": boolean_validator, "type": "unsupported", }, }, "output": { "location": { "default": "", "help": "Directory where the output reconstruction is stored.", "validator": optional_directory_location_validator, "type": "required", }, "file_prefix": { "default": "", "help": "File prefix. Optional, by default it is inferred from the scanned dataset.", "validator": optional_file_name_validator, "type": "optional", }, "file_format": { "default": "hdf5", "help": "Output file format. Available are: hdf5, tiff, jp2, edf, vol", "validator": output_file_format_validator, "type": "optional", }, "overwrite_results": { "default": "1", "help": "What to do in the case where the output file exists.\nBy default, the output data is never overwritten and the process is interrupted if the file already exists.\nSet this option to 1 if you want to overwrite the output files.", "validator": boolean_validator, "type": "required", }, "tiff_single_file": { "default": "0", "help": "Whether to create a single large tiff file for the reconstructed volume.", "validator": boolean_validator, "type": "advanced", }, "jpeg2000_compression_ratio": { "default": "", "help": "Compression ratio for Jpeg2000 output.", "validator": optional_positive_integer_validator, "type": "advanced", }, "float_clip_values": { "default": "", "help": "Lower and upper bounds to use when converting from float32 to int. Floating point values are clipped to these (min, max) values before being cast to integer.", "validator": optional_tuple_of_floats_validator, "type": "advanced", }, }, "postproc": { "output_histogram": { "default": "0", "help": "Whether to compute a histogram of the volume.", "validator": boolean_validator, "type": "optional", }, "histogram_bins": { "default": "1000000", "help": "Number of bins for the output histogram. Default is one million. ", "validator": nonnegative_integer_validator, "type": "advanced", }, }, "resources": { "method": { "default": "local", "help": "Computations distribution method. It can be:\n - local: run the computations on the local machine\n - slurm: run the computations through SLURM\n - preview: reconstruct the slices/volume as quickly as possible, possibly doing some binning.", "validator": distribution_method_validator, "type": "required", }, "gpus": { "default": "1", "help": "Number of GPUs to use.", "validator": nonnegative_integer_validator, "type": "advanced", }, "gpu_id": { "default": "", "help": "For method = local only. List of GPU IDs to use. This parameter overwrites 'gpus'.\nIf left blank, exactly one GPU will be used, and the best one will be picked.", "validator": list_of_int_validator, "type": "advanced", }, "cpu_workers": { "default": "0", "help": "Number of 'CPU workers' for each GPU worker. It is discouraged to set this number to more than one. A value of -1 means exactly one CPU worker.", "validator": integer_validator, "type": "unsupported", }, "memory_per_node": { "default": "90%", "help": "RAM memory per computing node, either in GB or in percent of the AVAILABLE (!= total) node memory.\nIf several workers share the same node, their combined memory usage will not exceed this number.", "validator": resources_validator, "type": "unsupported", }, "threads_per_node": { "default": "100%", "help": "Number of threads to allocate on each node, either a number or a percentage of the available threads", "validator": resources_validator, "type": "unsupported", }, "queue": { "default": "gpu", "help": "\nParameters exclusive to the 'slurm' distribution method\n------------------------------------------------------\nName of the SLURM partition ('queue'). Full list is obtained with 'scontrol show partition'", "validator": nonempty_string_validator, "type": "unsupported", }, "walltime": { "default": "01:00:00", "help": "Time limit for the SLURM resource allocation, in the format Hours:Minutes:Seconds", "validator": walltime_validator, "type": "unsupported", }, }, "pipeline": { "save_steps": { "default": "", "help": "Save intermediate results. This is a list of comma-separated processing steps, for ex: flatfield, phase, sinogram.\nEach step generates a HDF5 file in the form name_file_prefix.hdf5 (ex. 'sinogram_file_prefix.hdf5')", "validator": optional_string_validator, "type": "optional", }, "resume_from_step": { "default": "", "help": "Resume the processing from a previously saved processing step. The corresponding file must exist in the output directory.", "validator": optional_string_validator, "type": "optional", }, "steps_file": { "default": "", "help": "File where the intermediate processing steps are written. By default it is empty, and intermediate processing steps are written in the same directory as the reconstructions, with a file prefix, ex. sinogram_mydataset.hdf5.", "validator": optional_output_file_path_validator, "type": "advanced", }, "verbosity": { "default": "2", "help": "Level of verbosity of the processing. 0 = terse, 3 = much information.", "validator": logging_validator, "type": "optional", }, }, # This section will be removed in the future (for now it is deprecated) "about": {}, } renamed_keys = { "marge": { "section": "phase", "new_name": "margin", "since": "2020.2.0", "message": "Option 'marge' has been renamed 'margin' in [phase]", }, "overwrite_results": { "section": "about", "new_name": "overwrite_results", "new_section": "output", "since": "2020.3.0", "message": "Option 'overwrite_results' was moved from section [about] to section [output]", }, "nabu_config_version": { "section": "about", "new_name": "", "new_section": "about", "since": "2020.3.1", "message": "Option 'nabu_config_version' was removed.", }, "nabu_version": { "section": "about", "new_name": "", "new_section": "about", "since": "2021.1.0", "message": "Option 'nabu_config' was removed.", }, "verbosity": { "section": "about", "new_name": "verbosity", "new_section": "pipeline", "since": "2021.1.0", "message": "Option 'verbosity' was moved from section [about] to section [pipeline]", }, "flatfield_enabled": { "section": "preproc", "new_name": "flatfield", "since": "2021.2.0", "message": "Option 'flatfield_enabled' has been renamed 'flatfield' in [preproc]", }, "rotate_projections": { "section": "preproc", "new_name": "", "since": "2024.2.0", "message": "Option 'rotate_projections' removed as it was duplicate of 'tilt_correction'. Please use the latter with a scalar value.", }, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442905.0 nabu-2024.2.1/nabu/pipeline/fullfield/processconfig.py0000644000175000017500000010771714730277631022232 0ustar00pierrepierreimport os import posixpath import numpy as np from silx.io import get_data from silx.io.url import DataUrl from ...utils import copy_dict_items, compare_dicts from ...io.utils import hdf5_entry_exists, get_h5_value from ...io.reader import import_h5_to_dict from ...resources.utils import extract_parameters, get_values_from_file from ...resources.nxflatfield import update_dataset_info_flats_darks from ...resources.utils import get_quantities_and_units from ..estimators import estimate_cor from ..processconfig import ProcessConfigBase from .nabu_config import nabu_config, renamed_keys from .dataset_validator import FullFieldDatasetValidator from nxtomo.nxobject.nxdetector import ImageKey class ProcessConfig(ProcessConfigBase): """ A ProcessConfig object has these main fields: - dataset_info: information about the current dataset - nabu_config: configuration from the user side - processing_steps/processing_options: configuration "ready-to use" for underlying classes It is built from the following steps. (1a) parse config: (conf_fname or conf_dict) --> "nabu_config" (1b) browse dataset: (nabu_config or existing dataset_info) --> dataset_info (2) update_dataset_info_with_user_config - Update flats/darks - CoR (value or estimation method) # no estimation yet - rotation angles - translations files - user sino normalization (eg. subtraction etc) (3) estimations - tilt - CoR (4) coupled validation (5) build processing steps (6) configure checkpoints (save/resume) """ default_nabu_config = nabu_config config_renamed_keys = renamed_keys _use_horizontal_translations = True _all_processing_steps = [ "read_chunk", "flatfield", "ccd_correction", "double_flatfield", "tilt_correction", "phase", "unsharp_mask", "take_log", "radios_movements", # radios are cropped after this step, if needed "sino_normalization", "sino_rings_correction", "reconstruction", "histogram", "save", ] def _update_dataset_info_with_user_config(self): """ Update the 'dataset_info' (DatasetAnalyzer class instance) data structure with options from user configuration. """ self.logger.debug("Updating dataset information with user configuration") if self.dataset_info.kind == "nx": update_dataset_info_flats_darks( self.dataset_info, self.nabu_config["preproc"]["flatfield"], output_dir=self.nabu_config["output"]["location"], darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"], ) elif self.dataset_info.kind == "edf": self.dataset_info.flats = self.dataset_info.get_reduced_flats() self.dataset_info.darks = self.dataset_info.get_reduced_darks() self.rec_params = self.nabu_config["reconstruction"] subsampling_factor, subsampling_start = self.nabu_config["dataset"]["projections_subsampling"] self.subsampling_factor = subsampling_factor or 1 self.subsampling_start = subsampling_start or 0 self._update_dataset_with_user_overwrites() self._get_rotation_axis_position() self._update_rotation_angles() self._get_translation_file("reconstruction", "translation_movements_file", "translations") self._get_user_sino_normalization() def _update_dataset_with_user_overwrites(self): user_overwrites = self.nabu_config["dataset"]["overwrite_metadata"].strip() if user_overwrites in ("", None): return possible_overwrites = {"pixel_size": 1e6, "distance": 1.0, "energy": 1.0} try: overwrites = get_quantities_and_units(user_overwrites) except ValueError: msg = ( "Something wrong in config file in 'overwrite_metadata': could not get quantities/units from '%s'. Please check that separators are ';' and that units are provided (separated by a space)" % user_overwrites ) self.logger.fatal(msg) raise ValueError(msg) for quantity, conversion_factor in possible_overwrites.items(): user_value = overwrites.pop(quantity, None) if user_value is not None: self.logger.info("Overwriting %s = %s" % (quantity, user_value)) user_value *= conversion_factor setattr(self.dataset_info, quantity, user_value) def _get_translation_file(self, config_section, config_key, dataset_info_attr, last_dim=2): transl_file = self.nabu_config[config_section][config_key] if transl_file in (None, ""): return translations = None if transl_file is not None and "://" not in transl_file: try: translations = get_values_from_file( transl_file, shape=(self.n_angles(subsampling=False), last_dim), any_size=True ).astype(np.float32) translations = translations[self.subsampling_start :: self.subsampling_factor] except ValueError: print("Something wrong with translation_movements_file %s" % transl_file) raise else: try: translations = get_data(transl_file) except: print("Something wrong with translation_movements_file %s" % transl_file) raise setattr(self.dataset_info, dataset_info_attr, translations) if self._use_horizontal_translations and translations is not None: # Horizontal translations are handled by "axis_correction" in backprojector horizontal_translations = translations[:, 0] if np.max(np.abs(horizontal_translations)) > 1e-3: self.dataset_info.axis_correction = horizontal_translations def _get_rotation_axis_position(self): super()._get_rotation_axis_position() rec_params = self.nabu_config["reconstruction"] axis_correction_file = rec_params["axis_correction_file"] axis_correction = None if axis_correction_file is not None: try: axis_correction = get_values_from_file( axis_correction_file, n_values=self.n_angles(subsampling=False), any_size=True, ).astype(np.float32) axis_correction = axis_correction[self.subsampling_start :: self.subsampling_factor] except ValueError: print("Something wrong with axis correction file %s" % axis_correction_file) raise self.dataset_info.axis_correction = axis_correction def _update_rotation_angles(self): rec_params = self.nabu_config["reconstruction"] n_angles = self.dataset_info.n_angles angles_file = rec_params["angles_file"] if angles_file is not None: try: angles = get_values_from_file(angles_file, n_values=n_angles, any_size=True) angles = np.deg2rad(angles) except ValueError: self.logger.fatal("Something wrong with angle file %s" % angles_file) raise self.dataset_info.rotation_angles = angles elif self.dataset_info.rotation_angles is None: angles_range_txt = "[0, 180[ degrees" if rec_params["enable_halftomo"]: angles_range_txt = "[0, 360] degrees" angles = np.linspace(0, 2 * np.pi, n_angles, True) else: angles = np.linspace(0, np.pi, n_angles, False) self.logger.warning( "No information was found on rotation angles. Using default %s (half tomo is %s)" % (angles_range_txt, {0: "OFF", 1: "ON"}[int(self.do_halftomo)]) ) self.dataset_info.rotation_angles = angles def _get_cor(self): cor = self.nabu_config["reconstruction"]["rotation_axis_position"] if isinstance(cor, str): # auto-CoR cor = estimate_cor( cor, self.dataset_info, do_flatfield=(self.nabu_config["preproc"]["flatfield"] is not False), cor_options=self.nabu_config["reconstruction"]["cor_options"], logger=self.logger, ) self.logger.info("Estimated center of rotation: %.3f" % cor) self.dataset_info.axis_position = cor def _get_user_sino_normalization(self): self._sino_normalization_arr = None norm = self.nabu_config["preproc"]["sino_normalization"] if norm not in ["subtraction", "division"]: return norm_path = "silx://" + self.nabu_config["preproc"]["sino_normalization_file"].strip() url = DataUrl(norm_path) try: norm_array = get_data(url) self._sino_normalization_arr = norm_array.astype("f") except (ValueError, OSError) as exc: error_msg = "Could not load sino_normalization_file %s. The error was: %s" % (norm_path, str(exc)) self.logger.error(error_msg) raise ValueError(error_msg) @property def do_halftomo(self): """ Return True if the current dataset is to be reconstructed using 'half-acquisition' setting. """ enable_halftomo = self.nabu_config["reconstruction"]["enable_halftomo"] is_halftomo_dataset = self.dataset_info.is_halftomo if enable_halftomo == "auto": if is_halftomo_dataset is None: raise ValueError( "enable_halftomo was set to 'auto', but information on field of view was not found. Please set either 0 or 1 for enable_halftomo" ) return is_halftomo_dataset return enable_halftomo def _coupled_validation(self): self.logger.debug("Doing coupled validation") self._dataset_validator = FullFieldDatasetValidator(self.nabu_config, self.dataset_info) # Not so ideal to propagate fields like this for what in ["rec_params", "rec_region", "binning"]: setattr(self, what, getattr(self._dataset_validator, what)) # # Attributes that combine dataset information and user overwrites (eg. binning). # Must be accessed after __init__() is done # @property def binning_x(self): return getattr(self, "binning", (1, 1))[0] @property def binning_z(self): return getattr(self, "binning", (1, 1))[1] @property def subsampling(self): return getattr(self, "subsampling_factor", None) def radio_shape(self, binning=False): n_x, n_z = self.dataset_info.radio_dims if binning: n_z //= self.binning_z n_x //= self.binning_x return (n_z, n_x) def n_angles(self, subsampling=False): rot_angles = self.dataset_info.rotation_angles if subsampling: rot_angles = rot_angles[self.subsampling_start :: self.subsampling_factor] return len(rot_angles) def radios_shape(self, binning=False, subsampling=False): n_z, n_x = self.radio_shape(binning=binning) n_a = self.n_angles(subsampling=subsampling) return (n_a, n_z, n_x) def rotation_axis_position(self, binning=False): cor = self.dataset_info.axis_position # might be None (default to the middle of detector) if cor is None and self.do_halftomo: raise ValueError("No information on center of rotation, cannot use half-tomography reconstruction") if cor is not None and binning: # Backprojector uses middle of pixel for coordinate indices. # This means that the leftmost edge of the leftmost pixel has coordinate -0.5. # When using binning with a factor 'b', the CoR has to adapted as # cor_binned = (cor + 0.5)/b - 0.5 cor = (cor + 0.5) / self.binning_x - 0.5 return cor def sino_shape(self, binning=False, subsampling=False): """ Return the shape of a sinogram image. Parameter --------- binning: bool Whether to account for image binning subsampling: bool Whether to account for projections subsampling """ n_a, _, n_x = self.radios_shape(binning=binning, subsampling=subsampling) return (n_a, n_x) def sinos_shape(self, binning=False, subsampling=False): n_z, _ = self.radio_shape(binning=binning) return (n_z,) + self.sino_shape(binning=binning, subsampling=subsampling) def projs_indices(self, subsampling=False): step = 1 if subsampling: step = self.subsampling or 1 return sorted(self.dataset_info.projections.keys())[::step] def rotation_angles(self, subsampling=False): start = 0 step = 1 if subsampling: start = self.subsampling_start step = self.subsampling_factor return self.dataset_info.rotation_angles[start::step] @property def rec_roi(self): """ Returns the reconstruction region of interest (ROI), accounting for binning in both dimensions. """ rec_params = self.rec_params # accounts for binning x_s, x_e = rec_params["start_x"], rec_params["end_x"] y_s, y_e = rec_params["start_y"], rec_params["end_y"] # Upper bound (end_xy) is INCLUDED in nabu config, hence the +1 here return (x_s, x_e + 1, y_s, y_e + 1) @property def rec_shape(self): # Accounts for binning! return tuple(np.diff(self.rec_roi)[::-2]) @property def rec_delta_z(self): # Accounts for binning! z_s, z_e = self.rec_params["start_z"], self.rec_params["end_z"] # Upper bound (end_xy) is INCLUDED in nabu config, hence the +1 here return z_e + 1 - z_s def is_before_radios_cropping(self, step): """ Return true if a given processing step happens before radios cropping """ if step == "sinogram": return False if step not in self._all_processing_steps: raise ValueError("Unknown step: '%s'. Available are: %s" % (step, self._all_processing_steps)) # sino_normalization return self._all_processing_steps.index(step) <= self._all_processing_steps.index("radios_movements") # # Build processing steps # # TODO update behavior and remove this function once GroupedPipeline cuda backend is implemented def get_radios_rotation_mode(self): """ Determine whether projections are to be rotated, and if so, when they are to be rotated. Returns ------- method: str or None Rotation method: one of the values of `nabu.resources.params.radios_rotation_mode` """ tilt = self.dataset_info.detector_tilt phase_method = self.nabu_config["phase"]["method"] do_ctf = phase_method == "CTF" do_pag = phase_method == "paganin" do_unsharp = ( self.nabu_config["phase"]["unsharp_method"] is not None and self.nabu_config["phase"]["unsharp_coeff"] > 0 ) if tilt is None: return None if do_ctf: return "full" # TODO "chunked" rotation is done only when using a "processing margin" # For now the processing margin is enabled only if phase or unsharp is enabled. # We can either # - Enable processing margin if rotating projections is needed (more complicated to implement) # - Always do "full" rotation (simpler to implement, at the expense of performances) if do_pag or do_unsharp: return "chunk" else: return "full" def _build_processing_steps(self): nabu_config = self.nabu_config dataset_info = self.dataset_info binning = (nabu_config["dataset"]["binning"], nabu_config["dataset"]["binning_z"]) tasks = [] options = {} # # Dataset / Get data # # First thing to do is to get the data (radios or sinograms) # For now data is assumed to be on disk (see issue #66). tasks.append("read_chunk") options["read_chunk"] = { "sub_region": None, "binning": binning, "dataset_subsampling": nabu_config["dataset"]["projections_subsampling"], } # # Flat-field # if nabu_config["preproc"]["flatfield"]: tasks.append("flatfield") options["flatfield"] = { # Data reader handles binning/subsampling by itself, # but FlatField needs "real" indices (after binning/subsampling) "projs_indices": self.projs_indices(subsampling=False), "binning": binning, "do_flat_distortion": nabu_config["preproc"]["flat_distortion_correction_enabled"], "flat_distortion_params": extract_parameters(nabu_config["preproc"]["flat_distortion_params"]), } normalize_srcurrent = nabu_config["preproc"]["normalize_srcurrent"] radios_srcurrent = None flats_srcurrent = None if normalize_srcurrent: if ( dataset_info.projections_srcurrent is None or dataset_info.flats_srcurrent is None or len(dataset_info.flats_srcurrent) == 0 ): self.logger.error("Cannot do SRCurrent normalization: missing flats and/or projections SRCurrent") normalize_srcurrent = False else: radios_srcurrent = dataset_info.projections_srcurrent flats_srcurrent = dataset_info.flats_srcurrent options["flatfield"].update( { "normalize_srcurrent": normalize_srcurrent, "radios_srcurrent": radios_srcurrent, "flats_srcurrent": flats_srcurrent, } ) if len(dataset_info.darks) > 1: self.logger.warning("Cannot do flat-field with more than one reduced dark. Taking the first one.") dataset_info.darks = dataset_info.darks[sorted(dataset_info.darks.keys())[0]] # # Spikes filter # if nabu_config["preproc"]["ccd_filter_enabled"]: tasks.append("ccd_correction") options["ccd_correction"] = { "type": "median_clip", # only one available for now "median_clip_thresh": nabu_config["preproc"]["ccd_filter_threshold"], } # # Double flat field # if nabu_config["preproc"]["double_flatfield_enabled"]: tasks.append("double_flatfield") options["double_flatfield"] = { "sigma": nabu_config["preproc"]["dff_sigma"], "processes_file": nabu_config["preproc"]["processes_file"], "log_min_clip": nabu_config["preproc"]["log_min_clip"], "log_max_clip": nabu_config["preproc"]["log_max_clip"], } # # Radios rotation (do it here if possible) # if self.get_radios_rotation_mode() == "chunk": tasks.append("tilt_correction") options["tilt_correction"] = { "angle": dataset_info.detector_tilt, "center": nabu_config["preproc"]["rotate_projections_center"], "mode": "chunk", } # # # Phase retrieval # if nabu_config["phase"]["method"] is not None: tasks.append("phase") options["phase"] = copy_dict_items(nabu_config["phase"], ["method", "delta_beta", "padding_type"]) options["phase"].update( { "energy_kev": dataset_info.energy, "distance_cm": dataset_info.distance * 1e2, "distance_m": dataset_info.distance, "pixel_size_microns": dataset_info.pixel_size, "pixel_size_m": dataset_info.pixel_size * 1e-6, } ) if binning != (1, 1): options["phase"]["delta_beta"] /= binning[0] * binning[1] if options["phase"]["method"] == "CTF": self._get_ctf_parameters(options["phase"]) # # Unsharp # if ( nabu_config["phase"]["unsharp_method"] is not None and nabu_config["phase"]["unsharp_coeff"] > 0 and nabu_config["phase"]["unsharp_sigma"] > 0 ): tasks.append("unsharp_mask") options["unsharp_mask"] = copy_dict_items( nabu_config["phase"], ["unsharp_coeff", "unsharp_sigma", "unsharp_method"] ) # # -logarithm # if nabu_config["preproc"]["take_logarithm"]: tasks.append("take_log") options["take_log"] = copy_dict_items(nabu_config["preproc"], ["log_min_clip", "log_max_clip"]) # # Radios rotation (do it here if mode=="full") # if self.get_radios_rotation_mode() == "full": tasks.append("tilt_correction") options["tilt_correction"] = { "angle": dataset_info.detector_tilt, "center": nabu_config["preproc"]["rotate_projections_center"], "mode": "full", } # # Translation movements # translations = dataset_info.translations if translations is not None: if np.max(np.abs(translations[:, 1])) < 1e-5: self.logger.warning("No vertical translation greater than 1e-5 - disabling vertical shifts") # horizontal movements are handled in backprojector else: tasks.append("radios_movements") options["radios_movements"] = {"translation_movements": dataset_info.translations} # # Sinogram normalization (before half-tomo) # if nabu_config["preproc"]["sino_normalization"] is not None: tasks.append("sino_normalization") options["sino_normalization"] = { "method": nabu_config["preproc"]["sino_normalization"], "normalization_array": self._sino_normalization_arr, } # # Sinogram-based rings artefacts removal # if nabu_config["preproc"]["sino_rings_correction"]: tasks.append("sino_rings_correction") options["sino_rings_correction"] = { "method": nabu_config["preproc"]["sino_rings_correction"], "user_options": nabu_config["preproc"]["sino_rings_options"], } # # Reconstruction # if nabu_config["reconstruction"]["method"] is not None: tasks.append("reconstruction") # Iterative is not supported through configuration file for now. options["reconstruction"] = copy_dict_items( self.rec_params, [ "method", "implementation", "fbp_filter_type", "fbp_filter_cutoff", "padding_type", "start_x", "end_x", "start_y", "end_y", "start_z", "end_z", "centered_axis", "clip_outer_circle", "outer_circle_value", "source_sample_dist", "sample_detector_dist", "hbp_legs", "hbp_reduction_steps", ], ) rec_options = options["reconstruction"] rec_options["rotation_axis_position"] = self.rotation_axis_position(binning=True) rec_options["enable_halftomo"] = self.do_halftomo rec_options["axis_correction"] = dataset_info.axis_correction if dataset_info.axis_correction is not None: rec_options["axis_correction"] = rec_options["axis_correction"] rec_options["angles"] = np.array(self.rotation_angles(subsampling=True)) rec_options["angles"] += np.deg2rad(nabu_config["reconstruction"]["angle_offset"]) voxel_size = dataset_info.pixel_size * 1e-4 rec_options["voxel_size_cm"] = ( voxel_size, voxel_size, voxel_size, ) # pix size is in microns in dataset_info rec_options["iterations"] = nabu_config["reconstruction"]["iterations"] # x/y/z position information def get_mean_pos(position_array): if position_array is None: return None else: position_array = np.array(position_array) # filter only projections. Avoid getting noise position_array = position_array[ np.asarray(dataset_info.dataset_scanner.image_key_control) == ImageKey.PROJECTION.value ] return float(np.mean(position_array)) mean_positions_xyz = ( get_mean_pos(dataset_info.dataset_scanner.z_translation), get_mean_pos(dataset_info.dataset_scanner.y_translation), get_mean_pos(dataset_info.dataset_scanner.x_translation), ) if all([m is not None for m in mean_positions_xyz]): rec_options["position"] = mean_positions_xyz if rec_options["method"] == "cone" and rec_options["sample_detector_dist"] is None: rec_options["sample_detector_dist"] = self.dataset_info.distance # was checked to be not None earlier # New key rec_options["cor_estimated_auto"] = isinstance(nabu_config["reconstruction"]["rotation_axis_position"], str) # # Histogram # if nabu_config["postproc"]["output_histogram"]: tasks.append("histogram") options["histogram"] = copy_dict_items(nabu_config["postproc"], ["histogram_bins"]) # # Save # if nabu_config["output"]["location"] is not None: tasks.append("save") options["save"] = copy_dict_items(nabu_config["output"], list(nabu_config["output"].keys())) options["save"]["overwrite"] = nabu_config["output"]["overwrite_results"] self.processing_steps = tasks self.processing_options = options if set(self.processing_steps) != set(self.processing_options.keys()): raise ValueError("Something wrong with process_config: options do not correspond to steps") # Add check if set(self.processing_steps) != set(self.processing_options.keys()): raise ValueError("Something wrong when building processing steps") def _get_ctf_parameters(self, phase_options): dataset_info = self.dataset_info user_phase_options = self.nabu_config["phase"] ctf_geom = extract_parameters(user_phase_options["ctf_geometry"]) ctf_advanced_params = extract_parameters(user_phase_options["ctf_advanced_params"]) # z1_vh z1_v = ctf_geom["z1_v"] z1_h = ctf_geom["z1_h"] z1_vh = None if z1_h is None and z1_v is None: # parallel beam z1_vh = None elif (z1_v is None) ^ (z1_h is None): # only one is provided: source-sample distance z1_vh = z1_v or z1_h if z1_h is not None and z1_v is not None: # distance of the vertically focused source (horizontal line) and the horizontaly focused source (vertical line) # for KB mirrors z1_vh = (z1_v, z1_h) # pix_size_det pix_size_det = ctf_geom["detec_pixel_size"] or dataset_info.pixel_size * 1e-6 # wavelength wavelength = 1.23984199e-9 / dataset_info.energy phase_options["ctf_geo_pars"] = { "z1_vh": z1_vh, "z2": phase_options["distance_m"], "pix_size_det": pix_size_det, "wavelength": wavelength, "magnification": bool(ctf_geom["magnification"]), "length_scale": ctf_advanced_params["length_scale"], } phase_options["ctf_lim1"] = ctf_advanced_params["lim1"] phase_options["ctf_lim2"] = ctf_advanced_params["lim2"] phase_options["ctf_normalize_by_mean"] = ctf_advanced_params["normalize_by_mean"] def _configure_save_steps(self, steps_to_save=None): self.steps_to_save = [] self.dump_sinogram = False if steps_to_save is None: steps_to_save = self.nabu_config["pipeline"]["save_steps"] if steps_to_save in (None, ""): return steps_to_save = [s.strip() for s in steps_to_save.split(",")] for step in self.processing_steps: step = step.strip() if step in steps_to_save: self.processing_options[step]["save"] = True self.processing_options[step]["save_steps_file"] = self.get_save_steps_file(step_name=step) # "sinogram" is a special keyword, not explicitly in the processing steps if "sinogram" in steps_to_save: self.dump_sinogram = True self.dump_sinogram_file = self.get_save_steps_file(step_name="sinogram") self.steps_to_save = steps_to_save def _get_dump_file_and_h5_path(self): resume_from = self.resume_from_step process_file = self.get_save_steps_file(step_name=resume_from) if not os.path.isfile(process_file): self.logger.error("Cannot resume processing from step '%s': no such file %s" % (resume_from, process_file)) return None, None h5_entry = self.dataset_info.hdf5_entry or "entry" process_h5_path = posixpath.join(h5_entry, resume_from, "results/data") if not hdf5_entry_exists(process_file, process_h5_path): self.logger.error("Could not find data in %s in file %s" % (process_h5_path, process_file)) process_h5_path = None return process_file, process_h5_path def _configure_resume(self, resume_from=None): self.resume_from_step = None if resume_from is None: resume_from = self.nabu_config["pipeline"]["resume_from_step"] if resume_from in (None, ""): return resume_from = resume_from.strip(" ,;") self.resume_from_step = resume_from processing_steps = self.processing_steps # special case: resume from sinogram if resume_from == "sinogram": # disable up to 'reconstruction', not included if "sino_rings_correction" in processing_steps: # Sinogram destriping is done before building the half tomo sino. # Not sure if this is needed (i.e can we do before building the extended sino ?) # TODO find a more elegant way idx = processing_steps.index("sino_rings_correction") else: idx = processing_steps.index("reconstruction") # elif resume_from in processing_steps: idx = processing_steps.index(resume_from) + 1 # disable up to resume_from, included else: msg = "Cannot resume processing from step '%s': no such step in the current configuration" % resume_from self.logger.error(msg) self.resume_from_step = None return # Get corresponding file and h5 path process_file, process_h5_path = self._get_dump_file_and_h5_path() if process_file is None or process_h5_path is None: self.resume_from_step = None return dump_info = self._check_dump_file(process_file, raise_on_error=False) if dump_info is None: self.logger.error("Cannot resume from step %s: cannot use file %s" % (resume_from, process_file)) self.resume_from_step = None return dump_start_z, dump_end_z = dump_info # Disable steps steps_to_disable = processing_steps[1:idx] self.logger.debug("Disabling steps %s" % str(steps_to_disable)) for step_name in steps_to_disable: processing_steps.remove(step_name) self.processing_options.pop(step_name) # Update configuration self.logger.info("Processing will be resumed from step '%s' using file %s" % (resume_from, process_file)) self._old_read_chunk = self.processing_options["read_chunk"] self.processing_options["read_chunk"] = { "process_file": process_file, "process_h5_path": process_h5_path, "step_name": resume_from, "dump_start_z": dump_start_z, "dump_end_z": dump_end_z, } # Dont dump a step if we resume from this step if resume_from in self.steps_to_save: self.logger.warning( "Processing is resumed from step '%s'. This step won't be dumped to a file" % resume_from ) self.steps_to_save.remove(resume_from) if resume_from == "sinogram": self.dump_sinogram = False else: if resume_from in self.processing_options: # should not happen self.processing_options[resume_from].pop("save") def _check_dump_file(self, process_file, raise_on_error=False): """ Return (start_z, end_z) on success Return None on failure """ # Ensure data in the file correspond to what is currently asked if self.resume_from_step is None: return None # Check dataset shape/start_z/end_z rec_cfg_h5_path = posixpath.join( self.dataset_info.hdf5_entry or "entry", self.resume_from_step, "configuration/processing_options/reconstruction", ) dump_start_z = get_h5_value(process_file, posixpath.join(rec_cfg_h5_path, "start_z")) dump_end_z = get_h5_value(process_file, posixpath.join(rec_cfg_h5_path, "end_z")) if dump_end_z < 0: dump_end_z += self.radio_shape(binning=False)[0] start_z, end_z = ( self.processing_options["reconstruction"]["start_z"], self.processing_options["reconstruction"]["end_z"], ) if not (dump_start_z <= start_z and end_z <= dump_end_z): msg = ( "File %s was built with start_z=%d, end_z=%d but current configuration asks for start_z=%d, end_z=%d" % (process_file, dump_start_z, dump_end_z, start_z, end_z) ) if not raise_on_error: self.logger.error(msg) return None self.logger.fatal(msg) raise ValueError(msg) # Check parameters other than reconstruction filedump_nabu_config = import_h5_to_dict( process_file, posixpath.join(self.dataset_info.hdf5_entry or "entry", self.resume_from_step, "configuration/nabu_config"), ) sections_to_ignore = ["reconstruction", "output", "pipeline"] for section in sections_to_ignore: filedump_nabu_config[section] = self.nabu_config[section] # special case of the "save_steps process" # filedump_nabu_config["pipeline"]["save_steps"] = self.nabu_config["pipeline"]["save_steps"] diff = compare_dicts(filedump_nabu_config, self.nabu_config) if diff is not None: msg = "Nabu configuration in file %s differ from the current one: %s" % (process_file, diff) if not raise_on_error: self.logger.error(msg) return None self.logger.fatal(msg) raise ValueError(msg) # return (dump_start_z, dump_end_z) def get_save_steps_file(self, step_name=None): if self.nabu_config["pipeline"]["steps_file"] not in (None, ""): return self.nabu_config["pipeline"]["steps_file"] nabu_save_options = self.nabu_config["output"] output_dir = nabu_save_options["location"] file_prefix = step_name + "_" + nabu_save_options["file_prefix"] fname = os.path.join(output_dir, file_prefix) + ".hdf5" return fname ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442905.0 nabu-2024.2.1/nabu/pipeline/fullfield/reconstruction.py0000644000175000017500000011105614730277631022436 0ustar00pierrepierrefrom os import environ from os.path import join, isfile, basename, dirname from math import ceil import gc from psutil import virtual_memory from silx.io import get_data from silx.io.url import DataUrl from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase from ... import version as nabu_version from ...utils import check_supported, subdivide_into_overlapping_segment from ...resources.logger import LoggerOrPrint from ...resources.utils import is_hdf5_extension from ...io.writer import merge_hdf5_files, NXProcessWriter from ...cuda.utils import collect_cuda_gpus, __has_pycuda__ from ...preproc.phase import compute_paganin_margin from ...processing.histogram import PartialHistogram, add_last_bin, hist_as_2Darray from .chunked import ChunkedPipeline from .computations import estimate_max_chunk_size if __has_pycuda__: from .chunked_cuda import CudaChunkedPipeline def variable_idxlen_sort(fname): return int(fname.split("_")[-1].split(".")[0]) class FullFieldReconstructor: """ A reconstructor spawns and manages Pipeline objects, depending on the current chunk/group size. """ _available_backends = ["cuda", "numpy"] _process_name = "reconstruction" default_advanced_options = { "gpu_mem_fraction": 0.9, "cpu_mem_fraction": 0.9, "chunk_size": None, "margin": None, "force_grouped_mode": False, } def __init__(self, process_config, logger=None, backend="cuda", extra_options=None, cuda_options=None): """ Initialize a Reconstructor object. This class is used for managing pipelines. Parameters ---------- process_config: ProcessConfig object Data structure with process configuration logger: Logger, optional logging object backend: str, optional Which backend to use. Available are: "cuda", "numpy". extra_options: dict, optional Dictionary with advanced options. Please see 'Other parameters' below cuda_options: dict, optional Dictionary with cuda options passed to `nabu.cuda.processing.CudaProcessing` Other parameters ----------------- Advanced options can be passed in the 'extra_options' dictionary. These can be: - cpu_mem_fraction: 0.9, - gpu_mem_fraction: 0.9, - chunk_size: None, - margin: None, - force_grouped_mode: False """ self.logger = LoggerOrPrint(logger) self.process_config = process_config self._set_extra_options(extra_options) self._get_reconstruction_range() self._get_resources() self._get_backend(backend, cuda_options) self._compute_margin() self._compute_max_chunk_size() self._get_pipeline_mode() self._build_tasks() self._do_histograms = self.process_config.nabu_config["postproc"]["output_histogram"] self._reconstruction_output_format_is_hdf5 = is_hdf5_extension( self.process_config.nabu_config["output"]["file_format"] ) self._histogram_merged = False self.pipeline = None self._histogram_merged = False # # Initialization # def _set_extra_options(self, extra_options): self.extra_options = self.default_advanced_options.copy() self.extra_options.update(extra_options or {}) self.gpu_mem_fraction = self.extra_options["gpu_mem_fraction"] self.cpu_mem_fraction = self.extra_options["cpu_mem_fraction"] def _get_reconstruction_range(self): rec_region = self.process_config.rec_region # without binning self.z_min = rec_region["start_z"] # In the user configuration, the upper bound is included: [start_z, end_z]. # In python syntax, the upper bound is not: [start_z, end_z[ self.z_max = rec_region["end_z"] + 1 self.delta_z = self.z_max - self.z_min # Cache some volume info self.n_angles = self.process_config.n_angles(subsampling=False) self.n_z, self.n_x = self.process_config.radio_shape(binning=False) def _get_resources(self): self.resources = {} self._get_gpu() self._get_memory() def _get_memory(self): vm = virtual_memory() self.resources["mem_avail_GB"] = vm.available / 1e9 # Account for other memory constraints. There might be a better way slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0)) if slurm_mem_constraint_MB > 0: self.resources["mem_avail_GB"] = slurm_mem_constraint_MB / 1e3 # self.cpu_mem = self.resources["mem_avail_GB"] * self.cpu_mem_fraction def _get_gpu(self): avail_gpus = collect_cuda_gpus() or {} self.resources["gpus"] = avail_gpus if len(avail_gpus) == 0: return # pick first GPU by default. TODO: handle user's nabu_config["resources"]["gpu_id"] self.resources["gpu_id"] = self._gpu_id = list(avail_gpus.keys())[0] def _get_backend(self, backend, cuda_options): self._pipeline_cls = ChunkedPipeline check_supported(backend, self._available_backends, "backend") if backend == "cuda": self.cuda_options = cuda_options if len(self.resources["gpus"]) == 0: # Not sure if an error should be raised in this case self.logger.error("Could not find any cuda device. Falling back on numpy/CPU backend.") backend = "numpy" else: self.gpu_mem = self.resources["gpus"][self._gpu_id]["memory_GB"] * self.gpu_mem_fraction if backend == "cuda": if not (__has_pycuda__): raise RuntimeError("pycuda not avilable") self._pipeline_cls = CudaChunkedPipeline # pylint: disable=E0606 self.backend = backend def _compute_max_chunk_size(self): """ Compute the maximum number of (partial) radios that can be processed in memory. Ideally, the processing is done by reading N lines of all the projections. This function estimates max_chunk_size = N_max, the maximum number of lines that can be read in all the projections while still fitting the memory. On the other hand, if a "processing margin" is needed (eg. phase retrieval), then we need to read at least N_min = 2 * margin_v + n_slices lines of each image. For large datasets, we have N_min > max_chunk_size, so we can't read lines from all the projections. """ user_chunk_size = self.extra_options["chunk_size"] if user_chunk_size is not None: self.chunk_size = user_chunk_size else: self.cpu_max_chunk_size = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="all", step=5 ) self.chunk_size = self.cpu_max_chunk_size if self.backend == "cuda": self.gpu_max_chunk_size = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="all", step=5, ) self.chunk_size = min(self.gpu_max_chunk_size, self.cpu_max_chunk_size) self.chunk_size = min(self.chunk_size, self.n_z) def _compute_max_group_size(self): """ Compute the maximum number of (partial) images that can be processed in memory """ # # "Group size" (i.e, how many radios can be processed in one pass for the first part of the pipeline) # user_group_size = self.extra_options.get("chunk_size", None) if user_group_size is not None: self.group_size = user_group_size else: self.cpu_max_group_size = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="radios", n_rows=min(2 * self._margin_v + self.delta_z, self.n_z), step=10, ) self.group_size = self.cpu_max_group_size if self.backend == "cuda": self.gpu_max_group_size = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="radios", n_rows=min(2 * self._margin_v + self.delta_z, self.n_z), step=10, ) self.group_size = min(self.gpu_max_group_size, self.cpu_max_group_size) self.group_size = min(self.group_size, self.n_angles) # # "sinos chunk size" (i.e, how many sinograms can be processed in one pass for the second part of the pipeline) # self.cpu_max_chunk_size_sinos = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="sinos", step=10, ) if self.backend == "cuda": self.gpu_max_chunk_size_sinos = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="sinos", step=5, ) self.chunk_size_sinos = min(self.gpu_max_chunk_size_sinos, self.cpu_max_chunk_size_sinos) self.chunk_size_sinos = min(self.chunk_size_sinos, self.delta_z) def _get_pipeline_mode(self): # "Pipeline mode" means we either process data chunks of type [:, delta_z, :] or [delta_theta, :, :]. # The first case is better in terms of performances and should be preferred. # However, in some cases, it's not possible to use it (eg. high "delta_z" because of margin) chunk_size_for_one_slice = 1 + 2 * self._margin_v # TODO ignore margin when resuming from sinogram ? chunk_is_too_small = False if chunk_size_for_one_slice > self.chunk_size: msg = str( "Margin is %d, so we need to process at least %d detector rows. However, the available memory enables to process only %d rows at once" % (self._margin_v, chunk_size_for_one_slice, self.chunk_size) ) chunk_is_too_small = True if self._margin_v > self.chunk_size // 3: n_slices = max(1, self.chunk_size - (2 * self._margin_v)) n_stages = ceil(self.delta_z / n_slices) if n_stages > 1: msg = str("Margin (%d) is too big for chunk size (%d)" % (self._margin_v, self.chunk_size)) chunk_is_too_small = True force_grouped_mode = self.extra_options.get("force_grouped_mode", False) if force_grouped_mode: msg = "Forcing grouped mode" if self.process_config.processing_options.get("phase", {}).get("method", None) == "CTF": force_grouped_mode = True msg = "CTF phase retrieval needs to process full radios" if (self.process_config.dataset_info.detector_tilt or 0) > 15: force_grouped_mode = True msg = "Radios rotation with a large angle needs to process full radios" if self.process_config.resume_from_step == "sinogram" and force_grouped_mode: self.logger.warning("Cannot use grouped-radios processing when resuming from sinogram") force_grouped_mode = False if not (chunk_is_too_small or force_grouped_mode): # Default case (preferred) self._pipeline_mode = "chunked" self.chunk_shape = (self.n_angles, self.chunk_size, self.n_x) else: # Fall-back mode (slower) self.logger.warning(msg) # pylint: disable=E0606 self._pipeline_mode = "grouped" self._compute_max_group_size() self.chunk_shape = (self.group_size, self.delta_z, self.n_x) self.logger.info("Using 'grouped' pipeline mode") # # "Margin" # def _compute_margin(self): user_margin = self.extra_options.get("margin", None) if self.process_config.resume_from_step == "sinogram": self.logger.debug("Margin not needed when resuming from sinogram") margin_v, margin_h = 0, 0 elif user_margin is not None and user_margin > 0: margin_v, margin_h = user_margin, user_margin self.logger.info("Using user-defined margin: %d" % user_margin) else: unsharp_margin = self._compute_unsharp_margin() phase_margin = self._compute_phase_margin() translations_margin = self._compute_translations_margin() cone_margin = self._compute_cone_overlap() rot_margin = self._compute_rotation_margin() # TODO radios rotation/movements margin_v = max(unsharp_margin[0], phase_margin[0], translations_margin[0], cone_margin[0], rot_margin[0]) margin_h = max(unsharp_margin[1], phase_margin[1], translations_margin[1], cone_margin[1], rot_margin[1]) if margin_v > 0: self.logger.info("Estimated margin: %d pixels" % margin_v) margin_v = min(margin_v, self.n_z) margin_h = min(margin_h, self.n_x) self._margin = margin_v, margin_h self._margin_v = margin_v def _compute_unsharp_margin(self): if "unsharp_mask" not in self.process_config.processing_steps: return (0, 0) opts = self.process_config.processing_options["unsharp_mask"] sigma = opts["unsharp_sigma"] # nabu uses cutoff = 4 cutoff = 4 gaussian_kernel_size = int(ceil(2 * cutoff * sigma + 1)) self.logger.debug("Unsharp mask margin: %d pixels" % gaussian_kernel_size) return (gaussian_kernel_size, gaussian_kernel_size) def _compute_phase_margin(self): phase_options = self.process_config.processing_options.get("phase", None) if phase_options is None: margin_v, margin_h = (0, 0) elif phase_options["method"] == "paganin": radio_shape = self.process_config.dataset_info.radio_dims[::-1] margin_v, margin_h = compute_paganin_margin( radio_shape, distance=phase_options["distance_m"], energy=phase_options["energy_kev"], delta_beta=phase_options["delta_beta"], pixel_size=phase_options["pixel_size_m"], padding=phase_options["padding_type"], ) elif phase_options["method"] == "CTF": # The whole projection has to be processed! margin_v = ceil( (self.n_z - self.delta_z) / 2 ) # not so elegant. Use a dedicated flag eg. _process_whole_image ? margin_h = 0 # unused for now else: margin_v, margin_h = (0, 0) return (margin_v, margin_h) def _compute_translations_margin(self): translations = self.process_config.dataset_info.translations if translations is None: return (0, 0) margin_h_v = [] for i in range(2): transl = translations[:, i] margin_h_v.append(ceil(max([transl.max(), (-transl).max()]))) self.logger.debug("Maximum vertical displacement: %d pixels" % margin_h_v[1]) return tuple(margin_h_v[::-1]) def _compute_cone_overlap(self): rec_cfg = self.process_config.processing_options.get("reconstruction", {}) rec_method = rec_cfg.get("method", None) if rec_method != "cone": return (0, 0) d1 = rec_cfg["source_sample_dist"] d2 = rec_cfg["sample_detector_dist"] n_z, _ = self.process_config.radio_shape(binning=True) delta_z = self.process_config.rec_delta_z # accounts_for_binning overlap = ceil(delta_z * d2 / (d1 + d2)) # sqrt(2) missing ? max_overlap = ceil(n_z * d2 / (d1 + d2)) # sqrt(2) missing ? return (max_overlap, 0) def _compute_rotation_margin(self): if "tilt_correction" in self.process_config.processing_steps: # Partial radios rotation yields too much error in single-slice mode # Forcing a big margin circumvents the problem # This is likely to trigger the 'grouped mode', but perhaps grouped mode should always be used when rotating radios nz, nx = self.process_config.radio_shape(binning=True) return nz // 3, nx // 3 else: return 0, 0 def _ensure_good_chunk_size_and_margin(self): """ Check that "chunk_size" and "margin" (if any) are a multiple of binning factor. See nabu:!208 After that, all subregion lengths of _build_tasks_chunked() should be multiple of the binning factor, because "delta_z" itself was checked to be a multiple in DatasetValidator._handle_binning() """ bin_z = self.process_config.binning_z if bin_z <= 1: return self.chunk_size -= self.chunk_size % bin_z if self._margin_v > 0 and (self._margin_v % bin_z) > 0: self._margin_v += bin_z - (self._margin_v % bin_z) # i.e margin = ((margin % bin_z) + 1) * bin_z # # Tasks management # def _modify_processconfig_stage_1(self): # Modify the "process_config" object to dump sinograms proc = self.process_config self._old_steps_to_save = proc.steps_to_save.copy() if "sinogram" in proc.steps_to_save: return proc._configure_save_steps(self._old_steps_to_save + ["sinogram"]) def _undo_modify_processconfig_stage_1(self): self.process_config.steps_to_save = self._old_steps_to_save if "sinogram" not in self._old_steps_to_save: self.process_config.dump_sinogram = False def _modify_processconfig_stage_2(self): # Modify the "process_config" object to resume from sinograms proc = self.process_config self._old_resume_from = proc.resume_from_step self._old_proc_steps = proc.processing_steps.copy() self._old_proc_options = proc.processing_options.copy() proc._configure_resume(resume_from="sinogram") def _undo_modify_processconfig_stage_2(self): self.process_config.resume_from_step = self._old_resume_from self.process_config.processing_steps = self._old_proc_steps self.process_config.processing_options = self._old_proc_options def _build_tasks_grouped(self): tasks = [] segments = subdivide_into_overlapping_segment(self.n_angles, self.group_size, 0) for segment in segments: _, start_angle, end_angle, _ = segment z_min = max(self.z_min - self._margin_v, 0) z_max = min(self.z_max + self._margin_v, self.n_z) sub_region = ((start_angle, end_angle), (z_min, z_max), (0, self.chunk_shape[-1])) tasks.append({"sub_region": sub_region, "margin": (self.z_min - z_min, z_max - self.z_max)}) self.tasks = tasks # Build tasks for stage 2 (sinogram processing + reconstruction) segments = subdivide_into_overlapping_segment(self.delta_z, self.chunk_size_sinos, 0) self._sino_tasks = [] for segment in segments: _, start_z, end_z, _ = segment z_min = self.z_min + start_z z_max = min(self.z_min + end_z, self.n_z) sub_region = ((0, self.n_angles), (z_min, z_max), (0, self.chunk_shape[-1])) self._sino_tasks.append({"sub_region": sub_region, "margin": None}) def _build_tasks_chunked(self): margin_v = self._margin_v if self.chunk_size >= self.delta_z and self.z_min == 0 and self.z_max == self.n_z: # In this case we can do everything in a single stage, without margin self.tasks = [ { "sub_region": ((0, self.n_angles), (self.z_min, self.z_max), (0, self.chunk_shape[-1])), "margin": None, } ] return if self.chunk_size - (2 * margin_v) >= self.delta_z: # In this case we can do everything in a single stage n_slices = self.delta_z (margin_up, margin_down) = (min(margin_v, self.z_min), min(margin_v, self.n_z - self.z_max)) self.tasks = [ { "sub_region": ( (0, self.n_angles), (self.z_min - margin_up, self.z_max + margin_down), (0, self.chunk_shape[-1]), ), "margin": ((margin_up, margin_down), (0, 0)), } ] return # In this case there are at least two stages n_slices = self.chunk_size - (2 * margin_v) n_stages = ceil(self.delta_z / n_slices) self.tasks = [] curr_z_min = self.z_min curr_z_max = self.z_min + n_slices for i in range(n_stages): margin_up = min(margin_v, curr_z_min) margin_down = min(margin_v, max(self.n_z - curr_z_max, 0)) if curr_z_max + margin_down >= self.z_max: curr_z_max -= curr_z_max - (self.z_max + 0) margin_down = min(margin_v, max(self.n_z - 1 - curr_z_max, 0)) self.tasks.append( { "sub_region": ( (0, self.n_angles), (curr_z_min - margin_up, curr_z_max + margin_down), (0, self.chunk_shape[-1]), ), "margin": ((margin_up, margin_down), (0, 0)), } ) if curr_z_max == self.z_max: # No need for further tasks break curr_z_min += n_slices curr_z_max += n_slices def _build_tasks(self): if self._pipeline_mode == "grouped": self._build_tasks_grouped() else: self._ensure_good_chunk_size_and_margin() self._build_tasks_chunked() def _print_tasks_chunked(self): for task in self.tasks: margin_up, margin_down = task["margin"][0] s_u, s_d = task["sub_region"][1] print( "Top Margin: [%04d, %04d[ | Slices: [%04d, %04d[ | Bottom Margin: [%04d, %04d[" % (s_u, s_u + margin_up, s_u + margin_up, s_d - margin_down, s_d - margin_down, s_d) ) def _print_tasks(self): for task in self.tasks: margin_up, margin_down = task["margin"][0] s_u, s_d = task["sub_region"][1] print( "Top Margin: [%04d, %04d[ | Slices: [%04d, %04d[ | Bottom Margin: [%04d, %04d[" # pylint: disable=E1307 % (s_u, s_u + margin_up, s_u + margin_up, s_d - margin_down, s_d - margin_down, s_d) ) def _get_chunk_length(self, task): if self._pipeline_mode == "helical": (start_z, end_z) = task["sub_region"] return end_z - start_z else: (start_angle, end_angle), (start_z, end_z), _ = task["sub_region"] if self._pipeline_mode == "chunked": return end_z - start_z else: return end_angle - start_angle def _give_progress_info(self, task): self.logger.info("Processing sub-volume %s" % (str(task["sub_region"][:-1]))) # # Reconstruction # def _instantiate_pipeline(self, task): self.logger.debug("Creating a new pipeline object") chunk_shape = tuple(s[1] - s[0] for s in task["sub_region"]) args = [self.process_config, chunk_shape] kwargs = {} if self.backend == "cuda": kwargs["cuda_options"] = self.cuda_options kwargs["use_grouped_mode"] = self._pipeline_mode == "grouped" pipeline = self._pipeline_cls(*args, logger=self.logger, margin=task["margin"], **kwargs) self.pipeline = pipeline def _instantiate_pipeline_if_necessary(self, current_task, other_task): """ Instantiate a pipeline only if current_task has a different "delta z" than other_task """ if self.pipeline is None: self._instantiate_pipeline(current_task) return length_cur = self._get_chunk_length(current_task) length_other = self._get_chunk_length(other_task) if length_cur != length_other: self.logger.debug("Destroying pipeline instance and releasing memory") self._destroy_pipeline() self._instantiate_pipeline(current_task) def _destroy_pipeline(self): self.pipeline = None # Not elegant, but for now the only way to release Cuda memory gc.collect() def _reconstruct_chunked(self, tasks=None): self.results = {} self._histograms = {} self._data_dumps = {} tasks = tasks or self.tasks prev_task = tasks[0] for task in tasks: self.logger.info("Processing sub-volume %s" % (str(task["sub_region"]))) self._instantiate_pipeline_if_necessary(task, prev_task) self.pipeline.process_chunk(task["sub_region"]) task_key = self.pipeline.sub_region task_result = self.pipeline.writer.fname self.results[task_key] = task_result if self.pipeline.writer.histogram: self._histograms[task_key] = self.pipeline.writer.histogram_writer.fname if len(self.pipeline.datadump_manager.data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline.datadump_manager.data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task def _reconstruct_grouped(self): self.results = {} # self._histograms = {} self._data_dumps = {} prev_task = self.tasks[0] # Stage 1: radios processing self._modify_processconfig_stage_1() for task in self.tasks: self.logger.info("Processing sub-volume %s" % (str(task["sub_region"]))) self._instantiate_pipeline_if_necessary(task, prev_task) self.pipeline.process_chunk(task["sub_region"]) task_key = self.pipeline.sub_region task_result = self.pipeline.datadump_manager.data_dump["sinogram"].fname self.results[task_key] = task_result if len(self.pipeline.datadump_manager.data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline.datadump_manager.data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task self.merge_data_dumps(axis=0) self._destroy_pipeline() self.logger.info("End of first stage of processing. Will now process sinograms saved on disk") self._undo_modify_processconfig_stage_1() # Stage 2: sinograms processing and reconstruction self._modify_processconfig_stage_2() self._pipeline_mode = "chunked" self._reconstruct_chunked(tasks=self._sino_tasks) self._pipeline_mode = "grouped" self._undo_modify_processconfig_stage_2() def reconstruct(self): if self._pipeline_mode == "chunked": self._reconstruct_chunked() else: self._reconstruct_grouped() # # Writing data # def get_relative_files(self, files=None): out_cfg = self.process_config.nabu_config["output"] if files is None: files = list(self.results.values()) try: files.sort(key=variable_idxlen_sort) except: self.logger.error( "Lexical ordering failed, falling back to default sort - it will fail for more than 10k projections" ) files.sort() local_files = [join(out_cfg["file_prefix"], basename(fname)) for fname in files] return local_files def _get_reconstruction_metadata(self, partial_volumes_files=None): metadata = { "nabu_config": self.process_config.nabu_config, "processing_options": self.process_config.processing_options, } if self._reconstruction_output_format_is_hdf5 and partial_volumes_files is not None: metadata[self._process_name + "_stages"] = { str(k): v for k, v in zip(self.results.keys(), partial_volumes_files) } if not (self._reconstruction_output_format_is_hdf5): metadata["process_info"] = { "process_name": "reconstruction", "processing_index": 0, "nabu_version": nabu_version, } return metadata def merge_hdf5_reconstructions( self, output_file=None, prefix=None, files=None, process_name=None, axis=0, merge_histograms=True, output_dir=None, ): """ Merge existing hdf5 files by creating a HDF5 virtual dataset. Parameters ---------- output_file: str, optional Output file name. If not given, the file prefix in section "output" of nabu config will be taken. """ out_cfg = self.process_config.nabu_config["output"] out_dir = output_dir or out_cfg["location"] prefix = prefix or "" # Prevent issue when out_dir is empty, which happens only if dataset/location is a relative path. # TODO this should be prevented earlier if out_dir is None or len(out_dir.strip()) == 0: out_dir = dirname(dirname(self.results[list(self.results.keys())[0]])) # if output_file is None: output_file = join(out_dir, prefix + out_cfg["file_prefix"]) + ".hdf5" if isfile(output_file): msg = str("File %s already exists" % output_file) if out_cfg["overwrite_results"]: msg += ". Overwriting as requested in configuration file" self.logger.warning(msg) else: msg += ". Set overwrite_results to True in [output] to overwrite existing files." self.logger.fatal(msg) raise ValueError(msg) local_files = files if local_files is None: local_files = self.get_relative_files() if local_files == []: self.logger.error("No files to merge") return entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") process_name = process_name or self._process_name h5_path = join(entry, *[process_name, "results", "data"]) # self.logger.info("Merging %ss to %s" % (process_name, output_file)) # # When dumping to disk an intermediate step taking place before cropping the radios, # 'start_z' and 'end_z' found in nabu config have to be augmented with margin_z. # (these values are checked in ProcessConfig._configure_resume()) # patched_start_end_z = False if ( self._margin_v > 0 and process_name != "reconstruction" and self.process_config.is_before_radios_cropping(process_name) and "reconstruction" in self.process_config.processing_steps ): user_rec_config = self.process_config.processing_options["reconstruction"] patched_start_end_z = True old_start_z = user_rec_config["start_z"] old_end_z = user_rec_config["end_z"] user_rec_config["start_z"] = max(0, old_start_z - self._margin_v) user_rec_config["end_z"] = min(self.n_z, old_end_z + self._margin_v) # merge_hdf5_files( local_files, h5_path, output_file, process_name, output_entry=entry, output_filemode="a", processing_index=0, config=self._get_reconstruction_metadata(local_files), base_dir=out_dir, axis=axis, overwrite=out_cfg["overwrite_results"], ) if merge_histograms: self.merge_histograms(output_file=output_file) if patched_start_end_z: user_rec_config["start_z"] = old_start_z user_rec_config["end_z"] = old_end_z return output_file def merge_histograms(self, output_file=None, force_merge=False): """ Merge the partial histograms """ if not (self._do_histograms): return if self._histogram_merged and not (force_merge): return self.logger.info("Merging histograms") masterfile_entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") masterfile_process_name = "histogram" # TODO don't hardcode process name output_entry = masterfile_entry out_cfg = self.process_config.nabu_config["output"] if output_file is None: output_file = ( join(dirname(list(self._histograms.values())[0]), out_cfg["file_prefix"] + "_histogram") + ".hdf5" ) local_files = self.get_relative_files(files=list(self._histograms.values())) # h5_path = join(masterfile_entry, *[masterfile_process_name, "results", "data"]) # try: files = sorted(self._histograms.values(), key=variable_idxlen_sort) except: self.logger.error( "Lexical ordering of histogram failed, falling back to default sort - it will fail for more than 10k projections" ) files = sorted(self._histograms.values()) data_urls = [] for fname in files: url = DataUrl(file_path=fname, data_path=h5_path, data_slice=None, scheme="silx") data_urls.append(url) histograms = [] for data_url in data_urls: h2D = get_data(data_url) histograms.append((h2D[0], add_last_bin(h2D[1]))) histograms_merger = PartialHistogram(method="fixed_bins_number", num_bins=histograms[0][0].size) merged_hist = histograms_merger.merge_histograms(histograms) rec_region = self.process_config.rec_region # Not sure if we should account for binning here. # (note that "rec_region" does not account for binning). # Anyway the histogram has little use when using binning volume_shape = ( rec_region["end_z"] - rec_region["start_z"] + 1, rec_region["end_y"] - rec_region["start_y"] + 1, rec_region["end_x"] - rec_region["start_x"] + 1, ) writer = NXProcessWriter(output_file, entry=output_entry, filemode="a", overwrite=True) writer.write( hist_as_2Darray(merged_hist), "histogram", processing_index=1, config={ "files": local_files, "bins": self.process_config.nabu_config["postproc"]["histogram_bins"], "volume_shape": volume_shape, }, is_frames_stack=False, direct_access=False, ) self._histogram_merged = True def merge_data_dumps(self, axis=1): # Collect in a dict where keys are step names (instead of task keys) dumps = {} for task_key, data_dumps in self._data_dumps.items(): for step_name, fname in data_dumps.items(): fname = join(basename(dirname(fname)), basename(fname)) if step_name not in dumps: dumps[step_name] = [fname] else: dumps[step_name].append(fname) # Merge HDF5 files for step_name, files in dumps.items(): dump_file = self.process_config.get_save_steps_file(step_name=step_name) self.merge_hdf5_reconstructions( output_file=dump_file, output_dir=dirname(dump_file), files=files, process_name=step_name, axis=axis, merge_histograms=False, ) def write_metadata_file(self): metadata = self._get_reconstruction_metadata() save_options = self.process_config.processing_options["save"] # Perhaps there is more elegant metadata_writer = VolumeSingleFrameBase( url=DataUrl(file_path=save_options["location"], data_path="/"), volume_basename=save_options["file_prefix"], overwrite=True, metadata=metadata, ) metadata_writer.save_metadata() def finalize_files_saving(self): """ Last step to save data. This will do several things: - Merge data dumps (which are always HDF5 files): create a master file for all data-dump sub-volumes - Merge HDF5 reconstruction (if output format is HDF5) - Create a "metadata file" (if output format is not HDF5) - Merge histograms (if output format is not HDF5) """ self.merge_data_dumps() if self._reconstruction_output_format_is_hdf5: self.merge_hdf5_reconstructions() else: self.merge_histograms() self.write_metadata_file() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/pipeline/helical/0000755000175000017500000000000014730277752016436 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/pipeline/helical/__init__.py0000644000175000017500000000000014402565210020516 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1699603354.0 nabu-2024.2.1/nabu/pipeline/helical/dataset_validator.py0000644000175000017500000000120014523361632022462 0ustar00pierrepierrefrom ..fullfield.dataset_validator import * from ...utils import copy_dict_items class HelicalDatasetValidator(FullFieldDatasetValidator): """Allows more freedom in the choice of the slice indices""" # this in the fullfield base class is instead True _check_also_z = False def _check_slice_indices(self): """Slice indices can be far beyond what fullfield pipeline accepts, no check here, but Nabu expects that rec_region is initialised here""" what = ["start_x", "end_x", "start_y", "end_y", "start_z", "end_z"] self.rec_region = copy_dict_items(self.rec_params, what) return ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708073564.0 nabu-2024.2.1/nabu/pipeline/helical/fbp.py0000644000175000017500000001332514563621134017552 0ustar00pierrepierrefrom ...reconstruction.fbp import * from .filtering import HelicalSinoFilter from ...utils import convert_index class BackprojectorHelical(Backprojector): """This is the Backprojector derived class for helical reconstruction. The modifications are detailed here : * the backprojection is decoupled from the filtering. This allows, in the pipeline, for doing first a filtering using backprojector_object.sino_filter subobject, then calling backprojector_object.backprojection only after reweigthing the result of the filter_sino method of sino_filter subobject. * the angles can be resetted on the fly, and the class can work with a variable number of projection. As a matter of fact, with the helical_chunked_regridded.py pipeline, the reconstruction is done each time with the same set of angles, this because of the regridding mechanism. The feature might return useful in the future if alternative approachs are developed again. * """ def __init__(self, *args, **kwargs): """This became needed after the _d_sino allocation was removed from the base class""" super().__init__(*args, **kwargs) self._d_sino = self._processing.allocate_array("d_sino", self.sino_shape, "f") def set_custom_angles_and_axis_corrections(self, angles_rad, x_per_proj): """To arbitrarily change angles Parameters ========== angles_rad: array of floats one angle per each projection in radians x_per_proj: array of floats each entry is the axis shift for a projection, in pixels. """ self.n_provided_angles = len(angles_rad) self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32) self._axis_correction[0, : self.n_provided_angles] = -x_per_proj self.angles[: self.n_provided_angles] = angles_rad self._compute_angles_again() self.kern_proj_args[1] = self.n_provided_angles self.kern_proj_args[6] = self.offsets["x"] self.kern_proj_args[7] = self.offsets["y"] def _compute_angles_again(self): """to update angles dependent auxiliary arrays""" self.h_cos[0] = np.cos(self.angles).astype("f") self.h_msin[0] = (-np.sin(self.angles)).astype("f") self._d_msin[:] = self.h_msin[0] self._d_cos[:] = self.h_cos[0] if self._axis_correction is not None: self._d_axcorr[:] = self._axis_correction def _init_filter(self, filter_name): """To use the HelicalSinoFilter which is derived from SinoFilter with a slightly different padding scheme """ self.filter_name = filter_name self.sino_filter = HelicalSinoFilter( self.sino_shape, filter_name=self.filter_name, padding_mode=self.padding_mode, cuda_options={"ctx": self.cuda_processing.ctx}, ) def backprojection(self, sino, output=None): """Redefined here to do backprojection only, compare to the base class method.""" self._d_sino[:] = sino res = self.backproj(self._d_sino, output=output) return res def _init_geometry(self, sino_shape, slice_shape, angles, rot_center, slice_roi): """this is identical to _init_geometry of the base class with the exception that self.extra_options["centered_axis"] is not considered and as a consequence self.offsets is not set here and the one of _set_slice_roi is kept. """ if slice_shape is not None and slice_roi is not None: raise ValueError("slice_shape and slice_roi cannot be used together") self.sino_shape = sino_shape if len(sino_shape) == 2: n_angles, dwidth = sino_shape else: raise ValueError("Expected 2D sinogram") self.dwidth = dwidth self.rot_center = rot_center or (self.dwidth - 1) / 2.0 self._set_slice_shape( slice_shape, ) self.axis_pos = self.rot_center self._set_angles(angles, n_angles) self._set_slice_roi(slice_roi) self._set_axis_corr() def _set_slice_shape(self, slice_shape): """this is identical to the _set_slice_shape ofthe base class with the exception that n_y,n_x default to the largest possible reconstructible slice """ n_y = self.dwidth + abs(self.dwidth - 1 - self.rot_center * 2) n_x = self.dwidth + abs(self.dwidth - 1 - self.rot_center * 2) if slice_shape is not None: if np.isscalar(slice_shape): slice_shape = (slice_shape, slice_shape) n_y, n_x = slice_shape self.n_x = n_x self.n_y = n_y self.slice_shape = (n_y, n_x) def _set_slice_roi(self, slice_roi): """Automatically tune the offset to in all cases.""" self.slice_roi = slice_roi if slice_roi is None: off = -(self.dwidth - 1 - self.rot_center * 2) if off < 0: self.offsets = {"x": off, "y": off} else: self.offsets = {"x": 0, "y": 0} else: start_x, end_x, start_y, end_y = slice_roi # convert negative indices slice_width, _ = self.slice_shape off = min(0, -(self.dwidth - 1 - self.rot_center * 2)) if end_x < start_x: start_x = off end_x = off + slice_width if end_y < start_y: start_y = off end_y = off + slice_width self.slice_shape = (end_y - start_y, end_x - start_x) self.n_x = self.slice_shape[-1] self.n_y = self.slice_shape[-2] self.offsets = {"x": start_x, "y": start_y} ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708073564.0 nabu-2024.2.1/nabu/pipeline/helical/filtering.py0000644000175000017500000002377214563621134020775 0ustar00pierrepierre# pylint: disable=too-many-arguments import numpy as np from ...utils import get_cuda_srcfile, updiv from ...reconstruction.filtering import get_next_power from ...reconstruction.filtering_cuda import CudaSinoFilter # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-function-args class HelicalSinoFilter(CudaSinoFilter): def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, cuda_options=None, ): """Derived from nabu.reconstruction.filtering.SinoFilter It is used by helical_chunked_regridded pipeline. The shape of the processed sino, as a matter of fact which is due to the helical_chunked_regridded.py module which is using the here present class, is always, but not necessarily [nangles, nslices, nhorizontal] with nslices = 1. This because helical_chunked_regridded.py after a first preprocessing phase, always processes slices one by one. In helical_chunked_regridded .py, the call to the filter_sino method here contained is followed by the weight redistribution ( done by another module), which solves the HA problem, and the backprojection. The latter is performed by fbp.py or hbp.py The reason for having this class, derived from nabu.reconstruction.filtering.SinoFilter, is that the padding mechanism here implemented incorporates the padding with the available theta+180 projection on the half-tomo side. """ super().__init__( sino_shape, filter_name=filter_name, padding_mode=padding_mode, extra_options=extra_options, cuda_options=cuda_options, ) self._init_pad_kernel() def _check_array(self, arr): """ This class may work with an arbitrary number of projections. This is a consequence of the first implementation of the helical pipeline. In the first implementation the slices were reconstructed by backprojecting several turns, and the number of useful projections was different from the beginning or end, to the center of the scan. Now in helical_chunked_regridded.py the number of projections is fixed. The only relic left is that the present class may work with an arbitrary number of projections. """ if arr.dtype != np.float32: raise ValueError("Expected data type = numpy.float32") if arr.shape[1:] != self.sino_shape[1:]: raise ValueError("Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape)) def _init_pad_kernel(self): """The four possible padding kernels. The first two, compared to nabu.reconstruction.filtering.SinoFilter can work with an arbitrary number of projection. The latter two implement the padding with the available information from theta+180. """ self.kern_args = (self.d_sino_f, self.d_filter_f) self.kern_args += self.d_sino_f.shape[::-1] self._pad_mirror_edges_kernel = self.cuda.kernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPfiiiii", options=[str("-DMIRROR_EDGES")], ) self._pad_mirror_constant_kernel = self.cuda.kernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPfiiiiiff", options=[str("-DMIRROR_CONSTANT")], ) self._pad_mirror_edges_variable_rot_pos_kernel = self.cuda.kernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPPiiiii", options=[str("-DMIRROR_EDGES_VARIABLE_ROT_POS")], ) self._pad_mirror_constant_variable_rot_pos_kernel = self.cuda.kernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPPiiiiiff", options=[str("-DMIRROR_CONSTANT_VARIABLE_ROT_POS")], ) self.d_mirror_indexes = self.cuda.allocate_array( "d_mirror_indexes", (self.sino_padded_shape[-2],), dtype=np.int32 ) self.d_variable_rot_pos = self.cuda.allocate_array( "d_variable_rot_pos", (self.sino_padded_shape[-2],), dtype=np.float32 ) self._pad_edges_kernel = self.cuda.kernel( "padding_edge", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiii" ) self._pad_block = (32, 32, 1) self._pad_grid = tuple([updiv(n, p) for n, p in zip(self.sino_padded_shape[::-1], self._pad_block)]) def _pad_sino(self, sino, mirror_indexes=None, rot_center=None): """redefined here to adapt the memory copy to the lenght of the sino argument which, in the general helical case may be varying """ if mirror_indexes is None: self._pad_sino_simple(sino) self.d_mirror_indexes[:] = np.zeros([len(self.d_mirror_indexes)], np.int32) self.d_mirror_indexes[: len(mirror_indexes)] = mirror_indexes.astype(np.int32) if np.isscalar(rot_center): argument_rot_center = rot_center tmp_pad_mirror_edges_kernel = self._pad_mirror_edges_kernel tmp_pad_mirror_constant_kernel = self._pad_mirror_constant_kernel else: self.d_variable_rot_pos[: len(rot_center)] = rot_center argument_rot_center = self.d_variable_rot_pos tmp_pad_mirror_edges_kernel = self._pad_mirror_edges_variable_rot_pos_kernel tmp_pad_mirror_constant_kernel = self._pad_mirror_constant_variable_rot_pos_kernel self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] if self.padding_mode == "edges": tmp_pad_mirror_edges_kernel( self.d_sino_padded, self.d_mirror_indexes, argument_rot_center, self.dwidth, self.n_angles, self.dwidth_padded, self.pad_left, self.pad_right, grid=self._pad_grid, block=self._pad_block, ) else: tmp_pad_mirror_constant_kernel( self.d_sino_padded, self.d_mirror_indexes, argument_rot_center, self.dwidth, self.n_angles, self.dwidth_padded, self.pad_left, self.pad_right, 0.0, 0.0, grid=self._pad_grid, block=self._pad_block, ) def _pad_sino_simple(self, sino): if self.padding_mode == "edges": self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] self._pad_edges_kernel( self.d_sino_padded, self.dwidth, self.n_angles, self.dwidth_padded, self.n_angles, self.pad_left, self.pad_right, 0, 0, grid=self._pad_grid, block=self._pad_block, ) else: # zeros self.d_sino_padded.fill(0) if self.ndim == 2: self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] else: self.d_sino_padded[: len(sino), :, : self.dwidth] = sino[:] def filter_sino(self, sino, mirror_indexes=None, rot_center=None, output=None, no_output=False): """ Perform the sinogram siltering. redefined here to use also mirror data Parameters ---------- sino: numpy.ndarray or pycuda.gpuarray.GPUArray Input sinogram (2D or 3D) output: numpy.ndarray or pycuda.gpuarray.GPUArray, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) # copy2d/copy3d self._pad_sino(sino, mirror_indexes=mirror_indexes, rot_center=rot_center) # FFT self.fft.fft(self.d_sino_padded, output=self.d_sino_f) # multiply padded sinogram with filter in the Fourier domain self.mult_kernel(*self.kern_args) # TODO tune block size ? # iFFT self.fft.ifft(self.d_sino_f, output=self.d_sino_padded) # return if no_output: return self.d_sino_padded if output is None: res = np.zeros(self.sino_shape, dtype=np.float32) # can't do memcpy2d D->H ? (self.d_sino_padded[:, w]) I have to get() sino_ref = self.d_sino_padded.get() else: res = output sino_ref = self.d_sino_padded if self.ndim == 2: res[:] = sino_ref[:, : self.dwidth] else: res[:] = sino_ref[:, :, : self.dwidth] return res def _calculate_shapes(self, sino_shape): self.ndim = len(sino_shape) if self.ndim == 2: n_angles, dwidth = sino_shape n_sinos = 1 elif self.ndim == 3: n_sinos, n_angles, dwidth = sino_shape else: raise ValueError("Invalid sinogram number of dimensions") self.sino_shape = sino_shape self.n_angles = n_angles self.dwidth = dwidth # int() is crucial here ! Otherwise some pycuda arguments (ex. memcpy2D) # will not work with numpy.int64 (as for 2018.X) ### the original get_next_power used in nabu gives a lower ram footprint self.dwidth_padded = 2 * int(get_next_power(self.dwidth)) self.sino_padded_shape = (n_angles, self.dwidth_padded) if self.ndim == 3: self.sino_padded_shape = (n_sinos,) + self.sino_padded_shape sino_f_shape = list(self.sino_padded_shape) sino_f_shape[-1] = sino_f_shape[-1] // 2 + 1 self.sino_f_shape = tuple(sino_f_shape) # self.pad_left = (self.dwidth_padded - self.dwidth) // 2 self.pad_right = self.dwidth_padded - self.dwidth - self.pad_left ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1707838209.0 nabu-2024.2.1/nabu/pipeline/helical/gridded_accumulator.py0000644000175000017500000006330414562705401023005 0ustar00pierrepierrefrom ...preproc.flatfield import FlatFieldArrays import numpy as np from scipy import ndimage as nd import math class GriddedAccumulator: nominal_current = 0.2 def __init__( self, gridded_radios, gridded_weights, diagnostic_radios, diagnostic_weights, diagnostic_angles, diagnostic_searched_angles_rad_clipped, diagnostic_zpix_transl, diag_zpro_run=0, dark=None, flat_indexes=None, flats=None, weights=None, double_flat=None, radios_srcurrent=None, flats_srcurrent=None, ): """ This class creates, for a selected volume slab, a standard set of radios from an helical dataset. Parameters ========== gridded_radios : 3D np.array this is the stack of new radios which will be resynthetised, by this class, for a selected slab. The object is initialised with this array, and this array will accumulate, during subsequent calls to method extract_preprocess_with_flats, the sum of the transformed contributions obtained from the arguments of the mentioned method (extract_preprocess_with_flats). gridded_weights : 3d np.array same shape as gridded_radios, but it will accumulate the weights, during calls to extract_preprocess_with_flats diag_zpro_run: int if > 0 then only the diagnostics are filled, and no accumulation is done diagnostic_searched_angles_rad_clipped: the angles between 0 and 2pi. The contributions to diagnostic will be searched for these angles plus for the same angles + 2pi ( following turn) diagnostic_radios : 3d np.array, a stack composed of each radio must have the same size as a radio of the gridded_radios argument. During the calls to extract_preprocess_with_flats methods, the radios will collect the transformed data for the angles given by diagnostic_searched_angles_rad_clipped and redundancy diagnostic_weights: 3d np.array a stack composed of two radios Same shape as diagnostic_radios. The weigths for diagnostic radios ( will be zero on pixel where no data is available, or where the weight is null) diagnostic_angles : 1D np.array Must have shape==(2*len(diagnostic_searched_angles_rad_clipped),). The entries will be filled with the angles at which the contributions to diagnostic_radios have been summed. diagnostic_zpix_transl: 1D np.array same as for diagnostic_angles, but for vertical translation in pixels dark: None or 2D np.array must have the shape of the detector ( generally larger that a radio of gridded_radios) If given, the dark will be subtracted from data and flats. radios_srcurrent: 1D np.array the machine current for every radio flats_srcurrent: 1D np.array the machine current for every flat flat_indexes: None or a list of integers the projection index corresponding to the flats flats : None or 3D np.array the stack of flats. Each flat must have the shape of the detector (generally larger that a radio of gridded_radios) The flats, if given, are subtracted of the dark, if given, and the result is used to normalise the data. weights : None or 2D np.array If not given each data pixel will be considered with unit weight. If given it must have the same shape as the detector. double_flat = None or 2D np.array If given, the double flat will be applied (division by double_flat) Must have the same shape as the detector. """ self.diag_zpro_run = diag_zpro_run self.gridded_radios = gridded_radios self.gridded_weights = gridded_weights self.diagnostic_radios = diagnostic_radios self.diagnostic_weights = diagnostic_weights self.diagnostic_angles = diagnostic_angles self.diagnostic_zpix_transl = diagnostic_zpix_transl self.diagnostic_searched_angles_rad_clipped = diagnostic_searched_angles_rad_clipped self.dark = dark self.radios_srcurrent = radios_srcurrent self.flats_srcurrent = flats_srcurrent self.flat_indexes = flat_indexes self.flat_indexes_reverse_map = dict( [(global_index, local_index) for (local_index, global_index) in enumerate(flat_indexes)] ) self.flats = flats self.weights = weights self.double_flat = double_flat def extract_preprocess_with_flats( self, subchunk_slice, subchunk_file_indexes, chunk_info, subr_start_end, dtasrc_start_end, data_raw, radios_angular_range_slicing, ): """ This functions is meant to be called providing, each time, a subset of the data which are needed to reconstruct a chunk (to reconstruct a slab). When all the necessary data have flown through the subsequent calls to this method, the accumulators are ready. Parameters ========== subchunk_slice: an object of the python class "slice" this slice slices the angular domain which corresponds to the useful projections which are useful for the chunk, and whose informations are contained in the companion argument "chunk_info" Such slicing correspond to the angular subset, for which we are providing data_raw subchunk_file_indexes: a sequence of integers. they correspond to the projection numbers from which the data in data_raw are coming. They are used to interpolate between falt fields chunk_info: an object returned by the get_chunk_info of the SpanStrategy class this object must have the following members, which relates to the wanted chunk angle_index_span: a pair of integers indicating the start and the end of useful angles in the array of all the scan angle self.projection_angles_deg span_v: a pair of two integers indicating the start and end of the span relatively to the lowest value of array self.total_view_heights integer_shift_v: an array, containing for each one of the useful projections of the span, the integer part of vertical shift to be used in cropping, fract_complement_to_integer_shift_v : the fractional remainder for cropping. z_pix_per_proj: an array, containing for each to be used projection of the span the vertical shift x_pix_per_proj: ....the horizontal shift angles_rad : an array, for each useful projection of the chunk the angle in radian subr_start_end: a pair of integers the start height, and the end height, of the slab for which we are collecting the data. The number are given with the same logic as for member span_v of the chunk_info. Without the phase margin, when the phase margin is zero, hey would correspond exactly to the start and end, vertically, of the reconstructed slices. dtasrc_start_end: a pair of integers This number are relative to the detector ( they are detector line indexes). They indicated, vertically, the detector portion the data_raw data correspond to data_raw: np.array 3D the data which correspond to a limited detector stripe and a limited angular subset radios_angular_range_slicing: my_subsampled_indexes is important in order to compare the radios positions with respect to the flat position, and these position are given as the sequential acquisition number which counts everything ( flats, darks, radios ) Insteqd, in order to access array which spans only the radios, we need to have an idea of where we are. this is provided by radios_angular_range_slicing which addresses the radios domain """ # the object below is going to containing some auxiliary variable that are use to reframe the data. # This object is used to pass in a compact way such informations to different methods. # The informations are extracted from chunk info reframing_infos = self._ReframingInfos( chunk_info, subchunk_slice, subr_start_end, dtasrc_start_end, subchunk_file_indexes ) # give the proper dimensioning to an auxiliary stack which will contain the reframed data extracted # from the treatement of the sub-chunk radios_subset = np.zeros( [data_raw.shape[0], subr_start_end[1] - subr_start_end[0], data_raw.shape[2]], np.float32 ) # ... and in the same way we dimension the container for the associated reframed weights. radios_weights_subset = np.zeros( [data_raw.shape[0], subr_start_end[1] - subr_start_end[0], data_raw.shape[2]], np.float32 ) # extraction of the data self._extract_preprocess_with_flats( data_raw, reframing_infos, chunk_info, radios_subset, radios_angular_range_slicing ) if self.weights is not None: # ... and, if required, extraction of the associated weights wdata_read = self.weights.data[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] self._extract_preprocess_with_flats( wdata_read, reframing_infos, chunk_info, radios_weights_subset, it_is_weight=True ) else: radios_weights_subset[:] = 1.0 # and the remaining part is a simple projection over the accumulators, for # the data and for the weights my_angles = chunk_info.angles_rad[subchunk_slice] n_gridded_angles = self.gridded_radios.shape[0] my_i_float = my_angles * (n_gridded_angles / (2 * math.pi)) tmp_i_rounded = np.floor(my_i_float).astype(np.int32) my_epsilon = my_i_float - tmp_i_rounded my_i0 = np.mod(tmp_i_rounded, n_gridded_angles) my_i1 = np.mod(my_i0 + 1, n_gridded_angles) if self.diag_zpro_run: # these are used only when collection the diagnostics # an estimation of the angular step my_angle_step_rad = abs(np.diff(chunk_info.angles_rad[subchunk_slice]).mean()) my_angles_02pi = np.mod(my_angles, 2 * np.pi) # bins are delimited by ticks ticks = np.empty(2 * len(self.diagnostic_searched_angles_rad_clipped), "f") ticks[::2] = self.diagnostic_searched_angles_rad_clipped - my_angle_step_rad / 2 ticks[1::2] = ticks[::2] + my_angle_step_rad for i0, epsilon, i1, data, weight, original_angle, original_zpix_transl in zip( my_i0, my_epsilon, my_i1, radios_subset, radios_weights_subset, chunk_info.angles_rad[subchunk_slice], chunk_info.z_pix_per_proj[subchunk_slice], ): data_token = data * weight if not self.diag_zpro_run: self.gridded_radios[i0] += data_token * (1 - epsilon) self.gridded_radios[i1] += data_token * epsilon self.gridded_weights[i0] += weight * (1 - epsilon) self.gridded_weights[i1] += weight * epsilon # building the intervals around the diagnostic angles if self.diag_zpro_run: my_i0 = np.searchsorted(ticks, my_angles_02pi) for i0, a02pi, data, weight, original_angle, original_zpix_transl in zip( my_i0, my_angles_02pi, radios_subset, radios_weights_subset, chunk_info.angles_rad[subchunk_slice], chunk_info.z_pix_per_proj[subchunk_slice], ): if i0 % 2 == 0: # not in an intervals continue # There is a contribution to the first regridded radio ( the one indexed by 0) # We build two diagnostics for the contributions to this radio. # The first for the first pass (i_diag=0) # The second for the second pass if any (i_diag=1) # To discriminate we introduce # An angular margin beyond which we know that a possible contribution # is coming from another turn safe_angular_margin = 3.14 / 40 my_dist = abs( a02pi - self.diagnostic_searched_angles_rad_clipped[(i0 - 1) // 2 : (i0 - 1) // 2 + 1].mean() ) # print(" i0 ", i0, " original_angle ", original_angle, " a02pi" , a02pi, " my_dist " , my_dist) # consider fist pass and second possible pass. There might be further passes which we dont consider here i_diag_list = [(i0 - 1) // 2, (i0 - 1) // 2 + len(self.diagnostic_searched_angles_rad_clipped)] for i_redundancy, i_diag in enumerate(i_diag_list): # print("IRED ", i_redundancy) if i_redundancy: # to avoid, in z_stages with >360 range for one single stage, to fill the second items which should instead be filled by another stage. if abs(original_zpix_transl - self.diagnostic_zpix_transl[i_diag_list[0]]) < 2.0: # print( " >>>>>> stesso z" , i_redundancy ) continue if np.isnan(self.diagnostic_angles[i_diag]) or ( abs(original_angle) < abs(self.diagnostic_angles[i_diag] + safe_angular_margin) ): # we are searching for the first contributions ( the one at the lowest angle) # for the two diagnostics. With the constraint that the second is at an higher angle # than the first. So if we are here this means that we have found an occurrence with # lower angle and we discard what we could have previously found. self.diagnostic_radios[i_diag][:] = 0 self.diagnostic_weights[i_diag][:] = 0 self.diagnostic_angles[i_diag] = original_angle self.diagnostic_zpix_transl[i_diag] = original_zpix_transl else: continue if abs(my_dist) <= my_angle_step_rad * 1.1: factor = 1 - abs(my_dist) / (my_angle_step_rad) self.diagnostic_radios[i_diag] += data_token * factor self.diagnostic_weights[i_diag] += weight * factor break else: pass class _ReframingInfos: def __init__(self, chunk_info, subchunk_slice, subr_start_end, dtasrc_start_end, subchunk_file_indexes): self.subchunk_file_indexes = subchunk_file_indexes my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] self.fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] self.x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] subr_start_z, subr_end_z = subr_start_end self.subr_start_z_list = subr_start_z - my_integer_shifts_v self.subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 self.dtasrc_start_z, self.dtasrc_end_z = dtasrc_start_end floating_start_z = self.subr_start_z_list.min() floating_end_z = self.subr_end_z_list.max() self.floating_subregion = None, None, floating_start_z, floating_end_z def _extract_preprocess_with_flats( self, data_raw, reframing_infos, chunk_info, output, radios_angular_range_slicing=None, it_is_weight=False ): if not it_is_weight: if self.dark is not None: data_raw = data_raw - self.dark[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] if self.flats is not None: for i, idx in enumerate(reframing_infos.subchunk_file_indexes): flat = self._get_flat(idx, slice(reframing_infos.dtasrc_start_z, reframing_infos.dtasrc_end_z)) if self.dark is not None: flat = flat - self.dark[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] if self.radios_srcurrent is not None: factor = self.nominal_current / self.radios_srcurrent[radios_angular_range_slicing.start + i] else: factor = 1 data_raw[i] = data_raw[i] * factor / flat if self.double_flat is not None: data_raw = data_raw / self.double_flat[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] if it_is_weight: # for the weight, the detector weights, depends on the detector portion, # the one corresponding to dtasrc_start_end, and is the same across # the subchunk first index ( the projection index) take_data_from_this = [data_raw] * len(reframing_infos.subr_start_z_list) else: take_data_from_this = data_raw for data_read, list_subr_start_z, list_subr_end_z, fract_shift, x_shift, data_target in zip( take_data_from_this, reframing_infos.subr_start_z_list, reframing_infos.subr_end_z_list, reframing_infos.fract_complement_shifts_v, reframing_infos.x_shifts_list, output, ): _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shift, list_subr_start_z, list_subr_end_z, reframing_infos.dtasrc_start_z, reframing_infos.dtasrc_end_z, x_shift=x_shift, extension_padding=(not it_is_weight), ) def _get_flat(self, idx, slice_y=slice(None, None), slice_x=slice(None, None), dtype=np.float32): prev_next = FlatFieldArrays.get_previous_next_indices(self.flat_indexes, idx) if len(prev_next) == 1: # current index corresponds to an acquired flat flat_data = self.flats[self.flat_indexes_reverse_map[prev_next[0]]][slice_y, slice_x] else: # interpolate prev_idx, next_idx = prev_next n_prev = self.flat_indexes_reverse_map[prev_idx] n_next = self.flat_indexes_reverse_map[next_idx] flat_data_prev = self.flats[n_prev][slice_y, slice_x] flat_data_next = self.flats[n_next][slice_y, slice_x] if self.flats_srcurrent is not None: prev_current_factor = self.nominal_current / self.flats_srcurrent[n_prev] next_current_factor = self.nominal_current / self.flats_srcurrent[n_next] else: prev_current_factor = 1 next_current_factor = 1 delta = next_idx - prev_idx w1 = 1 - (idx - prev_idx) / delta w2 = 1 - (next_idx - idx) / delta if self.dark is not None: dark = self.dark[slice_y, slice_x] else: dark = 0 flat_data = ( dark + w1 * (flat_data_prev - dark) * prev_current_factor + w2 * (flat_data_next - dark) * next_current_factor ) if flat_data.dtype != dtype: flat_data = np.ascontiguousarray(flat_data, dtype=dtype) return flat_data def _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shift, my_subr_start_z, my_subr_end_z, dtasrc_start_z, dtasrc_end_z, x_shift=0.0, extension_padding=True, ): data_read_precisely_shifted = nd.shift(data_read, (-fract_shift, x_shift), order=1, mode="nearest")[:-1] target_central_slicer, dtasrc_central_slicer = overlap_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if None not in [target_central_slicer, dtasrc_central_slicer]: data_target[target_central_slicer] = data_read_precisely_shifted[dtasrc_central_slicer] target_lower_slicer, target_upper_slicer = padding_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if extension_padding: if target_lower_slicer is not None: data_target[target_lower_slicer] = data_read_precisely_shifted[0] if target_upper_slicer is not None: data_target[target_upper_slicer] = data_read_precisely_shifted[-1] else: if target_lower_slicer is not None: data_target[target_lower_slicer] = 1.0e-6 if target_upper_slicer is not None: data_target[target_upper_slicer] = 1.0e-6 def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """determines the useful lines which can be transferred from the dtasrc_start_z:dtasrc_end_z range targeting the range subr_start_z: subr_end_z .................. """ t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z my_start = max(0, dtasrc_start_z - subr_start_z) my_end = min(t_h, dtasrc_end_z - subr_start_z) if my_start >= my_end: return None, None target_central_slicer = slice(my_start, my_end) my_start = max(0, subr_start_z - dtasrc_start_z) my_end = min(s_h, subr_end_z - dtasrc_start_z) assert my_start < my_end, "Overlap_logic logic error" dtasrc_central_slicer = slice(my_start, my_end) return target_central_slicer, dtasrc_central_slicer def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """.......... and the missing ranges which possibly could be obtained by extension padding""" t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z if dtasrc_start_z <= subr_start_z: target_lower_padding = None else: target_lower_padding = slice(0, dtasrc_start_z - subr_start_z) if dtasrc_end_z >= subr_end_z: target_upper_padding = None else: target_upper_padding = slice(dtasrc_end_z - subr_end_z, None) return target_lower_padding, target_upper_padding def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_margin_pix): """Utility function, so far used only by the unit test, which, given the span_info object, creates the auxiliary collection arrays and initialises the my_z_min, my_z_end variable keeping into account the scan direction and the min_scanwise_z, end_scanwise_z input arguments Parameters ========== span_info: SpanStrategy min_scanwise_z: int non negative number, where zero indicates the first feaseable slice doable scanwise. Indicates the first (scanwise) requested slice to be reconstructed end_scanwise_z: int non negative number, where zero indicates the first feaseable slice doable scanwise. Indicates the end (scanwise) slice which delimity the to be reconstructed requested slab. """ detector_z_start, detector_z_end = (span_info.get_doable_span()).view_heights_minmax if span_info.z_pix_per_proj[-1] > span_info.z_pix_per_proj[0]: my_z_min = detector_z_start + min_scanwise_z my_z_end = detector_z_start + end_scanwise_z else: my_z_min = detector_z_end - (end_scanwise_z - 1) my_z_end = detector_z_end - (min_scanwise_z + 1) # while the raw dataset may have non uniform angular step # the regridded dataset will have a constant step. # We evaluate here below the number of angles for the # regridded dataset, estimating a meaningul angular step representative # of the raw data my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean()) n_gridded_angles = int(round(360.0 / my_angle_step)) radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix # the accumulators gridded_radios = np.zeros([n_gridded_angles, radios_h, span_info.detector_shape_vh[1]], np.float32) gridded_cumulated_weights = np.zeros([n_gridded_angles, radios_h, span_info.detector_shape_vh[1]], np.float32) # this utility function is meant for testing the reconstruction only, not the diagnostic collection. # However we build diagnostic targets all the same to feed something through the API # which contemplates the diagnostics. So that the unit test runs correctly diagnostic_radios = np.zeros((4,) + gridded_radios.shape[1:], np.float32) diagnostic_weights = np.zeros((4,) + gridded_radios.shape[1:], np.float32) diagnostic_proj_angle = np.zeros([4], "f") diagnostic_searched_angles_rad_clipped = (0.5 + np.arange(2)) * (2 * np.pi / (2)) diagnostic_zpix_transl = np.zeros([4], "f") gridded_angles_rad = np.arange(n_gridded_angles) * 2 * np.pi / n_gridded_angles gridded_angles_deg = np.rad2deg(gridded_angles_rad) res = type( "on_the_fly_class_for_reconstruction_room_in_gridded_accumulator.py", (object,), { "my_z_min": my_z_min, "my_z_end": my_z_end, "gridded_radios": gridded_radios, "gridded_cumulated_weights": gridded_cumulated_weights, "diagnostic_radios": diagnostic_radios, "diagnostic_weights": diagnostic_weights, "diagnostic_proj_angle": diagnostic_proj_angle, "diagnostic_searched_angles_rad_clipped": diagnostic_searched_angles_rad_clipped, "diagnostic_zpix_transl": diagnostic_zpix_transl, "gridded_angles_rad": gridded_angles_rad, "gridded_angles_deg": gridded_angles_deg, }, ) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/pipeline/helical/helical_chunked_regridded.py0000644000175000017500000020554214654107202024116 0ustar00pierrepierre# pylint: skip-file from os import path import numpy as np import math from silx.image.tomography import get_next_power from scipy import ndimage as nd import h5py import silx.io import copy from silx.io.url import DataUrl from ...resources.logger import LoggerOrPrint from ...resources.utils import is_hdf5_extension from ...io.reader_helical import ChunkReaderHelical, get_hdf5_dataset_shape from ...preproc.flatfield_variable_region import FlatFieldDataVariableRegionUrls as FlatFieldDataHelicalUrls from ...preproc.distortion import DistortionCorrection from ...preproc.shift import VerticalShift from ...preproc.double_flatfield_variable_region import DoubleFlatFieldVariableRegion as DoubleFlatFieldHelical from ...preproc.phase import PaganinPhaseRetrieval from ...reconstruction.sinogram import SinoBuilder from ...processing.unsharp import UnsharpMask from ...processing.histogram import PartialHistogram, hist_as_2Darray from ..utils import use_options, pipeline_step from ..detector_distortion_provider import DetectorDistortionProvider from .utils import ( WriterConfiguratorHelical as WriterConfigurator, ) # .utils is the same as ..utils but internally we retouch the key associated to "tiffwriter" of Writers to # point to our class which can write tiff with names indexed by the z height above the sample stage in millimiters from numpy.lib.stride_tricks import sliding_window_view from ...misc.binning import get_binning_function from .helical_utils import find_mirror_indexes try: import nabuxx GriddedAccumulator = nabuxx.gridded_accumulator.GriddedAccumulator CCDFilter = nabuxx.ccd.CCDFilter Log = nabuxx.ccd.LogFilter cxx_paganin = nabuxx.paganin except: logger_tmp = LoggerOrPrint(None) logger_tmp.info( "Nabuxx not available. Loading python implementation for gridded_accumulator, Log, CCDFilter, paganin" ) from . import gridded_accumulator GriddedAccumulator = gridded_accumulator.GriddedAccumulator from ...preproc.ccd import Log, CCDFilter cxx_paganin = None # For now we don't have a plain python/numpy backend for reconstruction Backprojector = None class HelicalChunkedRegriddedPipeline: """ Pipeline for "helical" full or half field tomography. Data is processed by chunks. A chunk consists in K+-1 contiguous lines of all the radios which are read at variable height following the translations """ extra_marge_granularity = 4 """ This offers extra reading space to be able to read the redundant part which might be sligtly larger and or require extra border for interpolation """ FlatFieldClass = FlatFieldDataHelicalUrls DoubleFlatFieldClass = DoubleFlatFieldHelical CCDFilterClass = CCDFilter MLogClass = Log PaganinPhaseRetrievalClass = PaganinPhaseRetrieval UnsharpMaskClass = UnsharpMask VerticalShiftClass = VerticalShift SinoBuilderClass = SinoBuilder FBPClass = Backprojector HBPClass = None HistogramClass = PartialHistogram regular_accumulator = None def __init__( self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None, reading_granularity=10, span_info=None, diag_zpro_run=0, ): """ Initialize a "HelicalChunked" pipeline. Parameters ---------- process_config: `nabu.resources.processcinfig.ProcessConfig` Process configuration. sub_region: tuple Sub-region to process in the volume for this worker, in the format `(start_x, end_x, start_z, end_z)`. logger: `nabu.app.logger.Logger`, optional Logger class extra_options: dict, optional Advanced extra options. phase_margin: tuple, optional Margin to use when performing phase retrieval, in the form ((up, down), (left, right)). See also the documentation of PaganinPhaseRetrieval. If not provided, no margin is applied. reading_granularity: int The data angular span which needs to be read for a reconstruction is read step by step, reading each time a maximum of reading_granularity radios, and doing the preprocessing till phase retrieval for each of these angular groups Notes ------ Using a `phase_margin` results in a lesser number of reconstructed slices. More specifically, if `phase_margin = (V, H)`, then there will be `chunk_size - 2*V` reconstructed slices (if the sub-region is in the middle of the volume) or `chunk_size - V` reconstructed slices (if the sub-region is on top or bottom of the volume). """ self.span_info = span_info self.reading_granularity = reading_granularity self.logger = LoggerOrPrint(logger) self._set_params(process_config, sub_region, extra_options, phase_margin, diag_zpro_run) self._init_pipeline() def _set_params(self, process_config, sub_region, extra_options, phase_margin, diag_zpro_run): self.diag_zpro_run = diag_zpro_run self.process_config = process_config self.dataset_info = self.process_config.dataset_info self.processing_steps = self.process_config.processing_steps.copy() self.processing_options = self.process_config.processing_options sub_region = self._check_subregion(sub_region) self.chunk_size = sub_region[-1] - sub_region[-2] self.radios_buffer = None self._set_detector_distortion_correction() self.set_subregion(sub_region) self._set_phase_margin(phase_margin) self._set_extra_options(extra_options) self._callbacks = {} self._steps_name2component = {} self._steps_component2name = {} self._data_dump = {} self._resume_from_step = None @staticmethod def _check_subregion(sub_region): if len(sub_region) < 4: assert len(sub_region) == 2, " at least start_z and end_z are required in subregion" sub_region = (None, None) + sub_region if None in sub_region[-2:]: raise ValueError("Cannot set z_min or z_max to None") return sub_region def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} advanced_options = {} advanced_options.update(extra_options) self.extra_options = advanced_options def _set_phase_margin(self, phase_margin): if phase_margin is None: phase_margin = ((0, 0), (0, 0)) self._phase_margin_up = phase_margin[0][0] self._phase_margin_down = phase_margin[0][1] self._phase_margin_left = phase_margin[1][0] self._phase_margin_right = phase_margin[1][1] def set_subregion(self, sub_region): """ Set a sub-region to process. Parameters ---------- sub_region: tuple Sub-region to process in the volume, in the format `(start_x, end_x, start_z, end_z)` or `(start_z, end_z)`. """ sub_region = self._check_subregion(sub_region) dz = sub_region[-1] - sub_region[-2] if dz != self.chunk_size: raise ValueError( "Class was initialized for chunk_size = %d but provided sub_region has chunk_size = %d" % (self.chunk_size, dz) ) self.sub_region = sub_region self.z_min = sub_region[-2] self.z_max = sub_region[-1] def _compute_phase_kernel_margin(self): """ Get the "margin" to pass to classes like PaganinPhaseRetrieval. In order to have a good accuracy for filter-based phase retrieval methods, we need to load extra data around the edges of each image. Otherwise, a default padding type is applied. """ if not (self.use_radio_processing_margin): self._phase_margin = None return up_margin = self._phase_margin_up down_margin = self._phase_margin_down # Horizontal margin is not implemented left_margin, right_margin = (0, 0) self._phase_margin = ((up_margin, down_margin), (left_margin, right_margin)) @property def use_radio_processing_margin(self): return ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps) def _get_phase_margin(self): if not (self.use_radio_processing_margin): return ((0, 0), (0, 0)) return self._phase_margin @property def phase_margin(self): """ Return the margin for phase retrieval in the form ((up, down), (left, right)) """ return self._get_phase_margin() def _get_process_name(self, kind="reconstruction"): # In the future, might be something like "reconstruction-" if kind == "reconstruction": return "reconstruction" elif kind == "histogram": return "histogram" return kind def _configure_dump(self, step_name): if step_name not in self.processing_steps: if step_name == "sinogram" and self.process_config._dump_sinogram: fname_full = self.process_config._dump_sinogram_file else: return else: if not self.processing_options[step_name].get("save", False): return fname_full = self.processing_options[step_name]["save_steps_file"] fname, ext = path.splitext(fname_full) dirname, file_prefix = path.split(fname) output_dir = path.join(dirname, file_prefix) file_prefix += str("_%06d" % self._get_image_start_index()) self.logger.info("omitting config in data_dump because of too slow nexus writer ") self._data_dump[step_name] = WriterConfigurator( output_dir, file_prefix, file_format="hdf5", overwrite=True, logger=self.logger, nx_info={ "process_name": step_name, "processing_index": 0, # TODO # "config": {"processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config}, "config": None, "entry": getattr(self.dataset_info.dataset_scanner, "entry", None), }, ) def _configure_data_dumps(self): self.process_config._configure_save_steps() for step_name in self.processing_steps: self._configure_dump(step_name) # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation if self.process_config._dump_sinogram: self._configure_dump("sinogram") # # Callbacks # def register_callback(self, step_name, callback): """ Register a callback for a pipeline processing step. Parameters ---------- step_name: str processing step name callback: callable A function. It will be executed once the processing step `step_name` is finished. The function takes only one argument: the class instance. """ if step_name not in self.processing_steps: raise ValueError("'%s' is not in processing steps %s" % (step_name, self.processing_steps)) if step_name in self._callbacks: self._callbacks[step_name].append(callback) else: self._callbacks[step_name] = [callback] # # Overwritten in inheriting classes # def _get_shape(self, step_name): """ Get the shape to provide to the class corresponding to step_name. """ if step_name == "flatfield": shape = self.radios_subset.shape elif step_name == "double_flatfield": shape = self.radios_subset.shape elif step_name == "phase": shape = self.radios_subset.shape[1:] elif step_name == "ccd_correction": shape = self.gridded_radios.shape[1:] elif step_name == "unsharp_mask": shape = self.radios_subset.shape[1:] elif step_name == "take_log": shape = self.radios.shape elif step_name == "radios_movements": shape = self.radios.shape elif step_name == "sino_normalization": shape = self.radios.shape elif step_name == "sino_normalization_slim": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "one_sino_slim": shape = self.radios.shape[:1] + self.radios.shape[2:] elif step_name == "build_sino": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "reconstruction": shape = self.sino_builder.output_shape[1:] else: raise ValueError("Unknown processing step %s" % step_name) self.logger.debug("Data shape for %s is %s" % (step_name, str(shape))) return shape def _allocate_array(self, shape, dtype, name=None): """this function can be redefined in the derived class which is dedicated to gpu and will return gpu garrays """ return _cpu_allocate_array(shape, dtype, name=name) def _cpu_allocate_array(self, shape, dtype, name=None): """For objects used in the pre-gpu part. They will be always on CPU even in the derived class""" result = np.zeros(shape, dtype=dtype) return result def _allocate_sinobuilder_output(self): return self._cpu_allocate_array(self.sino_builder.output_shape, "f", name="sinos") def _allocate_recs(self, ny, nx): self.n_slices = self.gridded_radios.shape[1] if self.use_radio_processing_margin: self.n_slices -= sum(self.phase_margin[0]) self.recs = self._allocate_array((1, ny, nx), "f", name="recs") self.recs_stack = self._cpu_allocate_array((self.n_slices, ny, nx), "f", name="recs_stack") def _reset_memory(self): pass def _get_read_dump_subregion(self): read_opts = self.processing_options["read_chunk"] if read_opts.get("process_file", None) is None: return None dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"] relative_start_z = self.z_min - dump_start_z relative_end_z = relative_start_z + self.chunk_size # (n_angles, n_z, n_x) subregion = (None, None, relative_start_z, relative_end_z, None, None) return subregion def _check_resume_from_step(self): if self._resume_from_step is None: return read_opts = self.processing_options["read_chunk"] expected_radios_shape = get_hdf5_dataset_shape( read_opts["process_file"], read_opts["process_h5_path"], sub_region=self._get_read_dump_subregion(), ) # TODO check def _init_reader_finalize(self): """ Method called after _init_reader """ self._check_resume_from_step() self._compute_phase_kernel_margin() self._allocate_reduced_gridded_and_subset_radios() def _allocate_reduced_gridded_and_subset_radios(self): shp_h = self.chunk_reader.data.shape[-1] sliding_window_size = self.chunk_size if sliding_window_size % 2 == 0: sliding_window_size += 1 sliding_window_radius = (sliding_window_size - 1) // 2 if sliding_window_radius == 0: n_projs_max = (self.span_info.sunshine_ends - self.span_info.sunshine_starts).max() else: padded_starts = self.span_info.sunshine_starts padded_ends = self.span_info.sunshine_ends padded_starts = np.concatenate( [[padded_starts[0]] * sliding_window_radius, padded_starts, [padded_starts[-1]] * sliding_window_radius] ) starts = sliding_window_view(padded_starts, sliding_window_size).min(axis=-1) padded_ends = np.concatenate( [[padded_ends[0]] * sliding_window_radius, padded_ends, [padded_ends[-1]] * sliding_window_radius] ) ends = sliding_window_view(padded_ends, sliding_window_size).max(axis=-1) n_projs_max = (ends - starts).max() ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin (start_x, end_x, start_z, end_z) = self.sub_region ## and now the gridded ones my_angle_step = abs(np.diff(self.span_info.projection_angles_deg).mean()) self.n_gridded_angles = int(round(360.0 / my_angle_step)) self.my_angles_rad = np.arange(self.n_gridded_angles) * 2 * np.pi / self.n_gridded_angles my_angles_deg = np.rad2deg(self.my_angles_rad) self.mirror_angle_relative_indexes = find_mirror_indexes(my_angles_deg) if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") r_shp_v, r_shp_h = self.whole_radio_shape (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subradio_shape = subr_end_z - subr_start_z, r_shp_h ### these radios are for diagnostic of the translations ( they will be optionally written, for being further used ## by correlation techniques ). Two radios for the first two pass over the first gridded angles if self.diag_zpro_run: # 2 for the redundancy, 2 for +180 mirror ndiag = 2 * 2 * self.diag_zpro_run else: ndiag = 2 * 2 self.diagnostic_searched_angles_rad_clipped = ( (0.5 + np.arange(ndiag // 2)) * (2 * np.pi / (ndiag // 2)) ).astype("f") self.diagnostic_radios = np.zeros((ndiag,) + subradio_shape, np.float32) self.diagnostic_weights = np.zeros((ndiag,) + subradio_shape, np.float32) self.diagnostic_proj_angle = np.zeros([ndiag], "f") self.diagnostic_zpix_transl = np.zeros([ndiag], "f") self.diagnostic_zmm_transl = np.zeros([ndiag], "f") self.diagnostic = { "radios": self.diagnostic_radios, "weights": self.diagnostic_weights, "angles": self.diagnostic_proj_angle, "zpix_transl": self.diagnostic_zpix_transl, "zmm_trans": self.diagnostic_zmm_transl, "pixel_size_mm": self.span_info.pix_size_mm, "searched_rad": self.diagnostic_searched_angles_rad_clipped, } ## ------- if self.diag_zpro_run == 0: self.gridded_radios = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) self.gridded_cumulated_weights = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) else: # only diagnostic will be cumulated. No need to keep the full size for diagnostic runs. # The gridder is initialised passing also the two buffer below, # and the two first dimensions are used to allocate auxiliaries, # so we shorten only the last dimension, but this is already a good cut self.gridded_radios = np.zeros((self.n_gridded_angles,) + (subradio_shape[0], 2), np.float32) self.gridded_cumulated_weights = np.zeros((self.n_gridded_angles,) + (subradio_shape[0], 2), np.float32) self.radios_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) self.radios_weights_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) if not self.diag_zpro_run: self.radios = np.zeros( (self.n_gridded_angles,) + ((end_z - down_margin) - (start_z + up_margin), shp_h), np.float32 ) else: # place holder self.radios = np.zeros((self.n_gridded_angles,) + (1, 1), np.float32) self.radios_weights = np.zeros_like(self.radios) self.radios_slim = self._allocate_array(self._get_shape("one_sino_slim"), "f", name="radios_slim") def _process_finalize(self): """ Method called once the pipeline has been executed """ pass def _get_slice_start_index(self): return self.z_min + self._phase_margin_up _get_image_start_index = _get_slice_start_index # # Pipeline initialization # def _reset_diagnostics(self): self.diagnostic_radios[:] = 0 self.diagnostic_weights[:] = 0 self.diagnostic_zpix_transl[:] = 0 self.diagnostic_zmm_transl[:] = 0 self.diagnostic_proj_angle[:] = np.nan def _init_pipeline(self): self._get_size_of_a_raw_radio() self._init_reader() self._init_flatfield() self._init_double_flatfield() self._init_weights_field() self._init_ccd_corrections() self._init_phase() self._init_unsharp() self._init_mlog() self._init_sino_normalization() self._init_sino_builder() self._prepare_reconstruction() self._init_reconstruction() self._init_histogram() self._init_writer() self._configure_data_dumps() self._configure_regular_accumulator() def _set_detector_distortion_correction(self): if self.process_config.nabu_config["preproc"]["detector_distortion_correction"] is None: self.detector_corrector = None else: self.detector_corrector = DetectorDistortionProvider( detector_full_shape_vh=self.process_config.dataset_info.radio_dims[::-1], correction_type=self.process_config.nabu_config["preproc"]["detector_distortion_correction"], options=self.process_config.nabu_config["preproc"]["detector_distortion_correction_options"], ) def _configure_regular_accumulator(self): ## # keeping these freshly numpyed objects referenced by self # ensures that their buffer info, conserved by c++ implementation of GriddedAccumulator # will always point to existing data, which could otherwise be garbage collected by python # if self.process_config.nabu_config["preproc"]["normalize_srcurrent"]: self.radios_srcurrent = np.array(self.dataset_info.projections_srcurrent, "f") self.flats_srcurrent = np.array(self.dataset_info.flats_srcurrent, "f") else: self.radios_srcurrent = None self.flats_srcurrent = None self.regular_accumulator = GriddedAccumulator( gridded_radios=self.gridded_radios, gridded_weights=self.gridded_cumulated_weights, diagnostic_radios=self.diagnostic_radios, diagnostic_weights=self.diagnostic_weights, diagnostic_angles=self.diagnostic_proj_angle, diagnostic_zpix_transl=self.diagnostic_zpix_transl, diagnostic_searched_angles_rad_clipped=self.diagnostic_searched_angles_rad_clipped, dark=self.flatfield.get_dark(), flat_indexes=self.flatfield._sorted_flat_indices, flats=self.flatfield.flats_stack, weights=self.weights_field.data, double_flat=self.double_flatfield.data, diag_zpro_run=self.diag_zpro_run, radios_srcurrent=self.radios_srcurrent, flats_srcurrent=self.flats_srcurrent, ) return def _get_size_of_a_raw_radio(self): """Once for all we find the shape of a radio. This information will be used in other parts of the code when allocating bunch of data holders """ if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] here_a_file = next(iter(options["files"].values())) here_a_radio = silx.io.get_data(here_a_file) binning_x, binning_z = self._get_binning() if (binning_z, binning_x) != (1, 1): binning_function = get_binning_function((binning_z, binning_x)) here_a_radio = binning_function(here_a_radio) self.whole_radio_shape = here_a_radio.shape return self.whole_radio_shape @use_options("read_chunk", "chunk_reader") def _init_reader(self): if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] assert options.get("process_file", None) is None, "Resume not yet implemented in helical pipeline" # dummy initialisation, it will be _set_subregion'ed and set_data_buffer'ed in the loops self.chunk_reader = ChunkReaderHelical( options["files"], sub_region=None, # setting of subregion will be already done by calls to set_subregion detector_corrector=self.detector_corrector, convert_float=True, binning=options["binning"], dataset_subsampling=options["dataset_subsampling"], data_buffer=None, pre_allocate=True, ) self._init_reader_finalize() @use_options("flatfield", "flatfield") def _init_flatfield(self, shape=None): if shape is None: shape = self._get_shape("flatfield") options = self.processing_options["flatfield"] distortion_correction = None if options["do_flat_distortion"]: self.logger.info("Flats distortion correction will be applied") estimation_kwargs = {} estimation_kwargs.update(options["flat_distortion_params"]) estimation_kwargs["logger"] = self.logger distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn" ) self.flatfield = self.FlatFieldClass( shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=options["projs_indices"], interpolation="linear", distortion_correction=distortion_correction, radios_srcurrent=options["radios_srcurrent"], flats_srcurrent=options["flats_srcurrent"], detector_corrector=self.detector_corrector, ## every flat will be read at a different heigth ### sub_region=self.sub_region, binning=options["binning"], convert_float=True, ) def _get_binning(self): options = self.processing_options["read_chunk"] binning = options["binning"] if binning is None: return 1, 1 else: return binning def _init_double_flatfield(self): options = self.processing_options["double_flatfield"] binning_x, binning_z = self._get_binning() result_url = None self.double_flatfield = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/double_flatfield/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading double flatfield from %s" % result_url.file_path()) self.double_flatfield = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z, detector_corrector=self.detector_corrector, ) def _init_weights_field(self): options = self.processing_options["double_flatfield"] result_url = None binning_x, binning_z = self.chunk_reader.get_binning() self.weights_field = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/weights_field/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading weights_field from %s" % result_url.file_path()) self.weights_field = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z ) def _init_ccd_corrections(self): if "ccd_correction" not in self.processing_steps: return options = self.processing_options["ccd_correction"] median_clip_thresh = options["median_clip_thresh"] self.ccd_correction = self.CCDFilterClass( self._get_shape("ccd_correction"), median_clip_thresh=median_clip_thresh ) @use_options("phase", "phase_retrieval") def _init_phase(self): options = self.processing_options["phase"] # If unsharp mask follows phase retrieval, then it should be done # before cropping to the "inner part". # Otherwise, crop the data just after phase retrieval. if "unsharp_mask" in self.processing_steps: margin = None else: margin = self._phase_margin self.phase_retrieval = self.PaganinPhaseRetrievalClass( self._get_shape("phase"), distance=options["distance_m"], energy=options["energy_kev"], delta_beta=options["delta_beta"], pixel_size=options["pixel_size_m"], padding=options["padding_type"], margin=margin, fft_num_threads=True, # TODO tune in advanced params of nabu config file ) if self.phase_retrieval.use_fftw: self.logger.debug( "PaganinPhaseRetrieval using FFTW with %d threads" % self.phase_retrieval.fftw.num_threads ) ##@use_options("unsharp_mask", "unsharp_mask") def _init_unsharp(self): if "unsharp_mask" not in self.processing_steps: self.unsharp_mask = None self.unsharp_sigma = 0.0 self.unsharp_coeff = 0.0 self.unsharp_method = "log" else: options = self.processing_options["unsharp_mask"] self.unsharp_sigma = options["unsharp_sigma"] self.unsharp_coeff = options["unsharp_coeff"] self.unsharp_method = options["unsharp_method"] self.unsharp_mask = self.UnsharpMaskClass( self._get_shape("unsharp_mask"), options["unsharp_sigma"], options["unsharp_coeff"], mode="reflect", method=options["unsharp_method"], ) def _init_mlog(self): options = self.processing_options["take_log"] self.mlog = self.MLogClass( self._get_shape("take_log"), clip_min=options["log_min_clip"], clip_max=options["log_max_clip"] ) @use_options("sino_normalization", "sino_normalization") def _init_sino_normalization(self): options = self.processing_options["sino_normalization"] self.sino_normalization = self.SinoNormalizationClass( kind=options["method"], radios_shape=self._get_shape("sino_normalization_slim"), ) def _init_sino_builder(self): options = self.processing_options["reconstruction"] ## build_sino class disappeared disappeared self.sino_builder = self.SinoBuilderClass( radios_shape=self._get_shape("build_sino"), rot_center=options["rotation_axis_position"], halftomo=False, ) self._sinobuilder_copy = False self._sinobuilder_output = None self.sinos = None # this should be renamed, as it could be confused with _init_reconstruction. What about _get_reconstruction_array ? @use_options("reconstruction", "reconstruction") def _prepare_reconstruction(self): options = self.processing_options["reconstruction"] x_s, x_e = options["start_x"], options["end_x"] y_s, y_e = options["start_y"], options["end_y"] if not self.diag_zpro_run: self._rec_roi = (x_s, x_e + 1, y_s, y_e + 1) self._allocate_recs(y_e - y_s + 1, x_e - x_s + 1) else: ## Dummy 1x1 place holder self._rec_roi = (x_s, x_s + 1, y_s, y_s + 1) self._allocate_recs(y_s - y_s + 1, x_s - x_s + 1) @use_options("reconstruction", "reconstruction") def _init_reconstruction(self): options = self.processing_options["reconstruction"] if self.sino_builder is None: raise ValueError("Reconstruction cannot be done without build_sino") if self.FBPClass is None: raise ValueError("No usable FBP module was found") rot_center = options["rotation_axis_position"] start_y, end_y, start_x, end_x = self._rec_roi if self.HBPClass is not None and self.process_config.nabu_config["reconstruction"]["use_hbp"]: fan_source_distance_meters = self.process_config.nabu_config["reconstruction"]["fan_source_distance_meters"] self.reconstruction_hbp = self.HBPClass( self._get_shape("one_sino_slim"), slice_shape=(end_y - start_y, end_x - start_x), angles=self.my_angles_rad, rot_center=rot_center, extra_options={"axis_correction": np.zeros(self.radios.shape[0], "f")}, axis_source_meters=fan_source_distance_meters, voxel_size_microns=options["voxel_size_cm"][0] * 1.0e4, scale_factor=2.0 / options["voxel_size_cm"][0], clip_outer_circle=options["clip_outer_circle"], ) else: self.reconstruction_hbp = None self.reconstruction = self.FBPClass( self._get_shape("reconstruction"), angles=np.zeros(self.radios.shape[0], "f"), rot_center=rot_center, filter_name=options["fbp_filter_type"], slice_roi=self._rec_roi, # slice_shape = ( end_y-start_y, end_x- start_x ), scale_factor=2.0 / options["voxel_size_cm"][0], padding_mode=options["padding_type"], extra_options={ "scale_factor": 2.0 / options["voxel_size_cm"][0], "axis_correction": np.zeros(self.radios.shape[0], "f"), "clip_outer_circle": options["clip_outer_circle"], }, # "padding_mode": options["padding_type"], }, ) my_options = self.process_config.nabu_config["reconstruction"] if my_options["axis_to_the_center"]: x_s, x_ep1, y_s, y_ep1 = self._rec_roi off_x = -int(round((x_s + x_ep1 - 1) / 2.0 - rot_center)) off_y = -int(round((y_s + y_ep1 - 1) / 2.0 - rot_center)) self.reconstruction.offsets = {"x": off_x, "y": off_y} if options["fbp_filter_type"] is None: self.reconstruction.fbp = self.reconstruction.backproj @use_options("histogram", "histogram") def _init_histogram(self): options = self.processing_options["histogram"] self.histogram = self.HistogramClass(method="fixed_bins_number", num_bins=options["histogram_bins"]) self.histo_stack = [] @use_options("save", "writer") def _init_writer(self, chunk_info=None): options = self.processing_options["save"] file_prefix = options["file_prefix"] output_dir = path.join(options["location"], file_prefix) nx_info = None self._hdf5_output = is_hdf5_extension(options["file_format"]) if chunk_info is not None: d_v, d_h = self.process_config.dataset_info.radio_dims[::-1] h_rels = self._get_slice_start_index() + np.arange(chunk_info.span_v[1] - chunk_info.span_v[0]) fact_mm = self.process_config.dataset_info.pixel_size * 1.0e-3 heights_mm = ( fact_mm * (-self.span_info.z_pix_per_proj[0] + (d_v - 1) / 2 - h_rels) - self.span_info.z_offset_mm ) else: heights_mm = None if self._hdf5_output: fname_start_index = None file_prefix += str("_%06d" % self._get_slice_start_index()) entry = getattr(self.dataset_info.dataset_scanner, "entry", None) nx_info = { "process_name": self._get_process_name(), "processing_index": 0, "config": { "processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config, }, "entry": entry, } self._histogram_processing_index = nx_info["processing_index"] + 1 elif options["file_format"] in ["tif", "tiff", "edf"]: fname_start_index = self._get_slice_start_index() self._histogram_processing_index = 1 self._writer_configurator = WriterConfigurator( output_dir, file_prefix, file_format=options["file_format"], overwrite=options["overwrite"], start_index=fname_start_index, heights_above_stage_mm=heights_mm, logger=self.logger, nx_info=nx_info, write_histogram=("histogram" in self.processing_steps), histogram_entry=getattr(self.dataset_info.dataset_scanner, "entry", "entry"), ) self.writer = self._writer_configurator.writer self._writer_exec_args = self._writer_configurator._writer_exec_args self._writer_exec_kwargs = self._writer_configurator._writer_exec_kwargs self.histogram_writer = self._writer_configurator.get_histogram_writer() def _apply_expand_fact(self, t): if t is not None: t = t * self.chunk_reader.dataset_subsampling return t def _expand_slice(self, subchunk_slice): start, stop, step = subchunk_slice.start, subchunk_slice.stop, subchunk_slice.step if step is None: step = 1 start, stop, step = list(map(self._apply_expand_fact, [start, stop, step])) result_slice = slice(start, stop, step) return result_slice def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info): my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 self._reset_reader_subregion((None, None, subr_start_z_list.min(), subr_end_z_list.max())) dtasrc_start_x, dtasrc_end_x, dtasrc_start_z, dtasrc_end_z = self.trimmed_floating_subregion if self.diag_zpro_run: searched_angles = self.diagnostic_searched_angles_rad_clipped these_angles = chunk_info.angles_rad[subchunk_slice] if len(these_angles) > 1: # these_angles are the projection angles # if no diagnostic angle falls close to them we skip to the next angular subchunk # (here slice refers to angular slicing) # We like hdf5 but we that is not a reason to read them all the time, so we spare time a_step = abs(these_angles[1:] - these_angles[:-1]).mean() distance = abs(np.mod(these_angles, np.pi * 2) - searched_angles[:, None]).min() distance_l = abs(np.mod(these_angles, np.pi * 2) - searched_angles[:, None] - a_step).min() distance_h = abs(np.mod(these_angles, np.pi * 2) - searched_angles[:, None] + a_step).min() distance = np.array([distance, distance_h, distance_l]).min() if distance > 2 * a_step: return self.chunk_reader.load_data(overwrite=True, sub_total_prange_slice=sub_total_prange_slice) if self.chunk_reader.dataset_subsampling > 1: radios_angular_range_slicing = self._expand_slice(sub_total_prange_slice) else: radios_angular_range_slicing = sub_total_prange_slice my_subsampled_indexes = self.chunk_reader._sorted_files_indices[radios_angular_range_slicing] data_raw = self.chunk_reader.data[: len(my_subsampled_indexes)] self.regular_accumulator.extract_preprocess_with_flats( subchunk_slice, my_subsampled_indexes, # these are indexes pointing within the global domain sequence which is composed of darks flats radios chunk_info, np.array((subr_start_z, subr_end_z), "i"), np.array((dtasrc_start_z, dtasrc_end_z), "i"), data_raw, radios_angular_range_slicing, # my_subsampled_indexes is important in order to compare the # radios positions with respect to the flat position, and these position # are given as the sequential acquisition number which counts everything ( flats, darks, radios ) # Insteqd, in order to access array which spans only the radios, we need to have an idea of where we are. # this is provided by radios_angular_range_slicing which addresses the radios domain ) def binning_expanded(self, region): binning_x, binning_z = self.chunk_reader.get_binning() binnings = [binning_x] * 2 + [binning_z] * 2 res = [None if tok is None else tok * fact for tok, fact in zip(region, binnings)] return res def _reset_reader_subregion(self, floating_subregion): if self._resume_from_step is None: binning_x, binning_z = self.chunk_reader.get_binning() start_x, end_x, start_z, end_z = floating_subregion trimmed_start_z = max(0, start_z) trimmed_end_z = min(self.whole_radio_shape[0], end_z) my_buffer_height = trimmed_end_z - trimmed_start_z if self.radios_buffer is None or my_buffer_height > self.safe_buffer_height: self.safe_buffer_height = end_z - start_z assert ( self.safe_buffer_height >= my_buffer_height ), "This should always be true, if not contact the developer" self.radios_buffer = None self.radios_buffer = np.zeros( (self.reading_granularity + self.extra_marge_granularity,) + (self.safe_buffer_height, self.whole_radio_shape[1]), np.float32, ) self.trimmed_floating_subregion = start_x, end_x, trimmed_start_z, trimmed_end_z self.chunk_reader._set_subregion(self.binning_expanded(self.trimmed_floating_subregion)) self.chunk_reader._init_reader() self.chunk_reader._loaded = False self.chunk_reader.set_data_buffer(self.radios_buffer[:, :my_buffer_height, :], pre_allocate=False) else: message = "Resume not yet implemented in helical pipeline" raise RuntimeError(message) def _ccd_corrections(self, radios=None): if radios is None: radios = self.gridded_radios if hasattr(self.ccd_correction, "median_clip_correction_multiple_images"): self.ccd_correction.median_clip_correction_multiple_images(radios) else: _tmp_radio = self._cpu_allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio") for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] def _retrieve_phase(self): if "unsharp_mask" in self.processing_steps: for i in range(self.gridded_radios.shape[0]): self.gridded_radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) else: for i in range(self.gridded_radios.shape[0]): self.radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) def _nophase_put_to_radios(self, target, source): ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(target.shape[0]): target[i] = source[i][zslice, xslice] def _apply_unsharp(): ((up_margin, down_margin), (left_margin, right_margin)) = self._phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(self.radios.shape[0]): self.radios[i] = self.unsharp_mask.unsharp(self.gridded_radios[i])[zslice, xslice] def _take_log(self): self.mlog.take_logarithm(self.radios) @pipeline_step("sino_normalization", "Normalizing sinograms") def _normalize_sinos(self, radios=None): if radios is None: radios = self.radios sinos = radios.transpose((1, 0, 2)) self.sino_normalization.normalize(sinos) def _dump_sinogram(self, radios=None): if radios is None: radios = self.radios self._dump_data_to_file("sinogram", data=radios) @pipeline_step("sino_builder", "Building sinograms") def _build_sino(self): self.sinos = self.radios_slim def _filter(self): rot_center = self.processing_options["reconstruction"]["rotation_axis_position"] self.reconstruction.sino_filter.filter_sino( self.radios_slim, mirror_indexes=self.mirror_angle_relative_indexes, rot_center=rot_center, output=self.radios_slim, ) def _build_sino(self): self.sinos = self.radios_slim def _reconstruct(self, sinos=None, chunk_info=None, i_slice=0): if sinos is None: sinos = self.sinos use_hbp = self.process_config.nabu_config["reconstruction"]["use_hbp"] if not use_hbp: if i_slice == 0: self.reconstruction.set_custom_angles_and_axis_corrections( self.my_angles_rad, np.zeros_like(self.my_angles_rad) ) self.reconstruction.backprojection(sinos, output=self.recs[0]) self.recs[0].get(self.recs_stack[i_slice]) else: if self.reconstruction_hbp is None: raise ValueError("You requested the hierchical backprojector but the module could not be imported") self.reconstruction_hbp.backprojection(sinos, output=self.recs_stack[i_slice]) def _compute_histogram(self, data=None, i_slice=None, num_slices=None): if self.histogram is None: return if data is None: data = self.recs my_histo = self.histogram.compute_histogram(data.ravel()) self.histo_stack.append(my_histo) if i_slice == num_slices - 1: self.recs_histogram = self.histogram.merge_histograms(self.histo_stack) self.histo_stack.clear() def _write_data(self, data=None): if data is None: data = self.recs_stack my_kw_args = copy.copy(self._writer_exec_kwargs) if "config" in my_kw_args: self.logger.info( "omitting config in writer because of too slow nexus writer. Just writing the diagnostics, if any " ) # diagnostic are saved here, with the Nabu mechanism for config self.diagnostic_zpix_transl[:] = np.interp( self.diagnostic_proj_angle, np.deg2rad(self.span_info.projection_angles_deg_internally), self.span_info.z_pix_per_proj, ) self.diagnostic_zmm_transl[:] = self.diagnostic_zpix_transl * self.span_info.pix_size_mm my_kw_args["config"] = self.diagnostic self.writer.write(data, *self._writer_exec_args, **my_kw_args) self.logger.info("Wrote %s" % self.writer.get_filename()) self._write_histogram() def _write_histogram(self): if "histogram" not in self.processing_steps: return self.logger.info("Saving histogram") self.histogram_writer.write( hist_as_2Darray(self.recs_histogram), self._get_process_name(kind="histogram"), processing_index=self._histogram_processing_index, config={ "file": path.basename(self.writer.get_filename()), "bins": self.processing_options["histogram"]["histogram_bins"], }, ) def _dump_data_to_file(self, step_name, data=None): if step_name not in self._data_dump: return self.logger.info(f"DUMP step_name={step_name}") if data is None: data = self.radios writer = self._data_dump[step_name] self.logger.info("Dumping data to %s" % writer.fname) writer.write_data(data) def balance_weights(self): options = self.processing_options["reconstruction"] rot_center = options["rotation_axis_position"] self.radios_weights[:] = rebalance(self.radios_weights, self.my_angles_rad, rot_center) # When standard scans are incomplete, due to motors errors, some angular range # is missing short of 360 degrees. # The weight accounting correctly deal with it, but still the padding # procedure with theta+180 data may fall on empty data # and this may cause problems, coming from the ramp filter, # in half tomo. # To correct this we complete with what we have at hand from the nearest # non empty data # to_be_filled = [] for i in range(len(self.radios_weights) - 1, 0, -1): if self.radios_weights[i].sum(): break to_be_filled.append(i) for i in to_be_filled: self.radios[i] = self.radios[to_be_filled[-1] - 1] def _post_primary_data_reduction(self, i_slice): """This will be used in the derived class to transfer data to gpu""" self.radios_slim[:] = self.radios[:, i_slice, :] def process_chunk(self, sub_region=None): self._private_process_chunk(sub_region=sub_region) self._process_finalize() def _private_process_chunk(self, sub_region=None): assert sub_region is not None, "sub_region argument is mandatory in helical pipeline" # Every chunk has its diagnostic, that is good to follow the trends in helical scans # or zstages self._reset_diagnostics() self.set_subregion(sub_region) (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region span_v = subr_start_z + self._phase_margin_up, subr_end_z - self._phase_margin_down chunk_info = self.span_info.get_chunk_info(span_v) self._reset_memory() self._init_writer(chunk_info) self._configure_data_dumps() proj_num_start, proj_num_end = chunk_info.angle_index_span n_granularity = self.reading_granularity pnum_start_list = list(np.arange(proj_num_start, proj_num_end, n_granularity)) pnum_end_list = pnum_start_list[1:] + [proj_num_end] my_first_pnum = proj_num_start if self.diag_zpro_run == 0: # It may seem anodine, but setting a huge vector to zero # takes time. # In diagnostic collection mode we can spare it. On the other hand nothing has would be allocated for the data # in such case self.gridded_cumulated_weights[:] = 0 self.gridded_radios[:] = 0 for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list): start_in_chunk = pnum_start - my_first_pnum end_in_chunk = pnum_end - my_first_pnum self._read_data_and_apply_flats( slice(pnum_start, pnum_end), slice(start_in_chunk, end_in_chunk), chunk_info ) if not self.diag_zpro_run: # when we collect diagnostics we dont collect all the data # so there would be nothing to process here self.gridded_radios[:] /= self.gridded_cumulated_weights self.correct_for_missing_angles() linea = self.gridded_cumulated_weights.sum(axis=(1, 2)) i_zero_list = np.where(linea == 0)[0] for i_zero in i_zero_list: if i_zero > linea.shape[0] // 2: direction = -1 else: direction = 1 i = i_zero while ((i >= 0 and direction == -1) or ((i < linea.shape[0] - 1) and direction == 1)) and linea[i] == 0: i += direction if linea[i]: self.gridded_radios[i_zero] = self.gridded_radios[i] self.gridded_cumulated_weights[i_zero] = self.gridded_cumulated_weights[i] if "flatfield" in self._data_dump: paganin_margin = self._phase_margin_up if paganin_margin: data_to_dump = self.gridded_radios[:, paganin_margin:-paganin_margin, :] else: data_to_dump = self.gridded_radios self._dump_data_to_file("flatfield", data_to_dump) if self.process_config.nabu_config["pipeline"]["skip_after_flatfield_dump"]: return if "ccd_correction" in self.processing_steps: self._ccd_corrections() if cxx_paganin is None: if ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps): self._retrieve_phase() if "unsharp_mask" in self.processing_steps: self._apply_unsharp() else: self._nophase_put_to_radios(self.radios, self.gridded_radios) else: if "phase" in self.processing_steps: pr = self.phase_retrieval paganin_l_micron = math.sqrt(pr.wavelength_micron * pr.distance_micron * pr.delta_beta * math.pi) cxx_paganin.paganin_pyhst( data_raw=self.gridded_radios, output=self.radios, num_of_threads=-1, paganin_marge=self._phase_margin_up, paganin_l_micron=paganin_l_micron / pr.pixel_size_micron, image_pixel_size_y=1.0, image_pixel_size_x=1.0, unsharp_sigma=self.unsharp_sigma, unsharp_coeff=self.unsharp_coeff, unsharp_LoG=int((self.unsharp_method == "log")), ) else: self._nophase_put_to_radios(self.radios, self.gridded_radios) self.logger.info(" LOG ") self._nophase_put_to_radios(self.radios_weights, self.gridded_cumulated_weights) # print( " processing steps ", self.processing_steps ) # ['read_chunk', 'flatfield', 'double_flatfield', 'take_log', 'reconstruction', 'save'] if "take_log" in self.processing_steps: self._take_log() self.logger.info(" BALANCE ") self.balance_weights() num_slices = self.radios.shape[1] self.logger.info(" NORMALIZE") self._normalize_sinos() self._dump_sinogram() if "reconstruction" in self.processing_steps: if not self.diag_zpro_run: # otherwise, when collecting diagnostic, we are not interested in the remaining steps # on the other hand there would be nothing to process because only diagnostics have been collected for i_slice in range(num_slices): self._post_primary_data_reduction(i_slice) # charge on self.radios_slim self._filter() self.apply_weights(i_slice) self._build_sino() self._reconstruct(chunk_info=chunk_info, i_slice=i_slice) self._compute_histogram(i_slice=i_slice, num_slices=num_slices) self._write_data() def apply_weights(self, i_slice): """radios_slim is on gpu""" n_provided_angles = self.radios_slim.shape[0] for first_angle_index in range(0, n_provided_angles, self.num_weight_radios_per_app): end_angle_index = min(n_provided_angles, first_angle_index + self.num_weight_radios_per_app) self._d_radios_weights[: end_angle_index - first_angle_index].set( self.radios_weights[first_angle_index:end_angle_index, i_slice] ) self.radios_slim[first_angle_index:end_angle_index] *= self._d_radios_weights[ : end_angle_index - first_angle_index ] def correct_for_missing_angles(self): """For non helical scan, the rotation is often incomplete ( < 360) here we complement the missing angles """ linea = self.gridded_cumulated_weights.sum(axis=(1, 2)) i_zero_list = np.where(linea == 0)[0] for i_zero in i_zero_list: if i_zero > linea.shape[0] // 2: direction = -1 else: direction = 1 i = i_zero while ((i >= 0 and direction == -1) or ((i < linea.shape[0] - 1) and direction == 1)) and linea[i] == 0: i += direction if linea[i]: self.gridded_radios[i_zero] = self.gridded_radios[i] self.gridded_cumulated_weights[i_zero] = self.gridded_cumulated_weights[i] @classmethod def estimate_required_memory( cls, process_config, reading_granularity=None, chunk_size=None, margin_v=0, span_info=None, diag_zpro_run=0 ): """ Estimate the memory (RAM) needed for a reconstruction. Parameters ----------- process_config: `ProcessConfig` object Data structure with the processing configuration chunk_size: int, optional Size of a "radios chunk", i.e "delta z". A radios chunk is a 3D array of shape (n_angles, chunk_size, n_x) If set to None, then chunk_size = n_z Notes ----- It seems that Cuda does not allow allocating and/or transferring more than 16384 MiB (17.18 GB). If warn_from_GB is not None, then the result is in the form (estimated_memory_GB, warning) where warning is a boolean indicating wheher memory allocation/transfer might be problematic. """ dataset = process_config.dataset_info nabu_config = process_config.nabu_config processing_steps = process_config.processing_steps Nx, Ny = dataset.radio_dims total_memory_needed = 0 # Read data # ---------- # gridded part tmp_angles_deg = np.rad2deg(process_config.processing_options["reconstruction"]["angles"]) tmp_my_angle_step = abs(np.diff(tmp_angles_deg).mean()) my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean()) n_gridded_angles = int(round(360.0 / my_angle_step)) binning_z = nabu_config["dataset"]["binning_z"] projections_subsampling = nabu_config["dataset"]["projections_subsampling"] if not diag_zpro_run: # the gridded target total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the gridded weights total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the read grain total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) # the preprocessed radios, their weigth and the buffer used for balancing ( total of three buffer of the same size plus mask plus temporary) total_memory_needed += 5 * (Nx * (chunk_size) * n_gridded_angles) * 4 if "flatfield" in processing_steps: # Flat-field is done in-place, but still need to load darks/flats n_darks = len(dataset.darks) n_flats = len(dataset.flats) darks_size = n_darks * Nx * (2 * margin_v + chunk_size) * 2 # uint16 flats_size = n_flats * Nx * (2 * margin_v + chunk_size) * 4 # f32 total_memory_needed += darks_size + flats_size if "ccd_correction" in processing_steps: total_memory_needed += Nx * (2 * margin_v + chunk_size) * 4 # Phase retrieval # --------------- if "phase" in processing_steps: # Phase retrieval is done image-wise, so near in-place, but needs to # allocate some images, fft plans, and so on Nx_p = get_next_power(2 * Nx) Ny_p = get_next_power(2 * (2 * margin_v + chunk_size)) img_size_real = 2 * 4 * Nx_p * Ny_p img_size_cplx = 2 * 8 * ((Nx_p * Ny_p) // 2 + 1) total_memory_needed += 2 * img_size_real + 3 * img_size_cplx # Reconstruction # --------------- reconstructed_volume_size = 0 if "reconstruction" in processing_steps and not diag_zpro_run: ## radios_slim is used to process one slice at once, It will be on the gpu ## and cannot be reduced further, therefore no need to estimate it. ## Either it passes or it does not. #### if radios_and_sinos: #### togtal_memory_needed += data_volume_size # radios + sinos rec_config = process_config.processing_options["reconstruction"] Nx_rec = rec_config["end_x"] - rec_config["start_x"] + 1 Ny_rec = rec_config["end_y"] - rec_config["start_y"] + 1 Nz_rec = chunk_size // binning_z ## the volume is used to reconstruct for each chunk reconstructed_volume_size = Nx_rec * Ny_rec * Nz_rec * 4 # float32 total_memory_needed += reconstructed_volume_size return total_memory_needed # target_central_slicer, source_central_slicer = overlap_logic( subr_start_z, subr_end_z, dtasrc_start_z, dtasrcs_end_z ) def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """determines the useful lines which can be transferred from the dtasrc_start_z:dtasrc_end_z range targeting the range subr_start_z: subr_end_z .................. """ t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z my_start = max(0, dtasrc_start_z - subr_start_z) my_end = min(t_h, dtasrc_end_z - subr_start_z) if my_start >= my_end: return None, None target_central_slicer = slice(my_start, my_end) my_start = max(0, subr_start_z - dtasrc_start_z) my_end = min(s_h, subr_end_z - dtasrc_start_z) assert my_start < my_end, "Overlap_logic logic error" dtasrc_central_slicer = slice(my_start, my_end) return target_central_slicer, dtasrc_central_slicer def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """.......... and the missing ranges which possibly could be obtained by extension padding""" t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z if dtasrc_start_z <= subr_start_z: target_lower_padding = None else: target_lower_padding = slice(0, dtasrc_start_z - subr_start_z) if dtasrc_end_z >= subr_end_z: target_upper_padding = None else: target_upper_padding = slice(dtasrc_end_z - subr_end_z, None) return target_lower_padding, target_upper_padding def _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shit, my_subr_start_z, my_subr_end_z, dtasrc_start_z, dtasrc_end_z, x_shift=0.0, extension_padding=True, ): """given a freshly read cube of data, it dispatches every slice to its proper vertical position and proper radio by shifting, cropping, and extending if necessary""" data_read_precisely_shifted = nd.interpolation.shift(data_read, (-fract_shit, x_shift), order=1, mode="nearest")[ :-1 ] target_central_slicer, dtasrc_central_slicer = overlap_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if None not in [target_central_slicer, dtasrc_central_slicer]: data_target[target_central_slicer] = data_read_precisely_shifted[dtasrc_central_slicer] target_lower_slicer, target_upper_slicer = padding_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if extension_padding: if target_lower_slicer is not None: data_target[target_lower_slicer] = data_read_precisely_shifted[0] if target_upper_slicer is not None: data_target[target_upper_slicer] = data_read_precisely_shifted[-1] else: if target_lower_slicer is not None: data_target[target_lower_slicer] = 1.0e-6 if target_upper_slicer is not None: data_target[target_upper_slicer] = 1.0e-6 def shift(arr, shift, fill_value=0.0): """trivial horizontal shift. Contrarily to scipy.ndimage.interpolation.shift, this shift does not cut the tails abruptly, but by interpolation """ result = np.zeros_like(arr) num1 = int(math.floor(shift)) num2 = num1 + 1 partition = shift - num1 for num, factor in zip([num1, num2], [(1 - partition), partition]): if num > 0: result[:, :num] += fill_value * factor result[:, num:] += arr[:, :-num] * factor elif num < 0: result[:, num:] += fill_value * factor result[:, :num] += arr[:, -num:] * factor else: result[:] += arr * factor return result def rebalance(radios_weights, angles, ax_pos): """rebalance the weights, within groups of equivalent (up to multiple of 180), data pixels""" balanced = np.zeros_like(radios_weights) n_span = int(math.ceil(angles[-1] - angles[0]) / np.pi) center = (radios_weights.shape[-1] - 1) / 2 nloop = balanced.shape[0] for i in range(nloop): w_res = balanced[i] angle = angles[i] for i_half_turn in range(-n_span - 1, n_span + 2): if i_half_turn == 0: w_res[:] += radios_weights[i] continue shifted_angle = angle + i_half_turn * np.pi insertion_index = np.searchsorted(angles, shifted_angle) if insertion_index in [0, angles.shape[0]]: if insertion_index == 0: continue else: if shifted_angle > 2 * np.pi: continue myimage = radios_weights[-1] else: partition = (shifted_angle - angles[insertion_index - 1]) / ( angles[insertion_index] - angles[insertion_index - 1] ) myimage = (1.0 - partition) * radios_weights[insertion_index - 1] + partition * radios_weights[ insertion_index ] if i_half_turn % 2 == 0: w_res[:] += myimage else: myimage = np.fliplr(myimage) w_res[:] += shift(myimage, (2 * ax_pos - 2 * center)) mask = np.equal(0, radios_weights) balanced[:] = radios_weights / balanced balanced[mask] = 0 return balanced ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/pipeline/helical/helical_chunked_regridded_cuda.py0000644000175000017500000000745714654107202025117 0ustar00pierrepierreimport numpy as np from ...preproc.shift_cuda import CudaVerticalShift from ...reconstruction.sinogram_cuda import CudaSinoBuilder, CudaSinoNormalization from ...processing.histogram_cuda import CudaPartialHistogram from .fbp import BackprojectorHelical try: from ...reconstruction.hbp import HierarchicalBackprojector # pylint: disable=E0401,E0611 print("Successfully imported hbp") except: HierarchicalBackprojector = None from ...cuda.utils import get_cuda_context, __has_pycuda__, __pycuda_error_msg__ from .helical_chunked_regridded import HelicalChunkedRegriddedPipeline if __has_pycuda__: import pycuda.gpuarray as garray # pylint: disable=E0606 class CudaHelicalChunkedRegriddedPipeline(HelicalChunkedRegriddedPipeline): """ Cuda backend of HelicalChunkedPipeline """ VerticalShiftClass = CudaVerticalShift SinoBuilderClass = CudaSinoBuilder FBPClass = BackprojectorHelical HBPClass = HierarchicalBackprojector HistogramClass = CudaPartialHistogram SinoNormalizationClass = CudaSinoNormalization def __init__( self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None, cuda_options=None, reading_granularity=10, span_info=None, num_weight_radios_per_app=1000, diag_zpro_run=0, ): self._init_cuda(cuda_options) super().__init__( process_config, sub_region, logger=logger, extra_options=extra_options, phase_margin=phase_margin, reading_granularity=reading_granularity, span_info=span_info, diag_zpro_run=diag_zpro_run, ) self._register_callbacks() self.num_weight_radios_per_app = num_weight_radios_per_app def _init_cuda(self, cuda_options): if not (__has_pycuda__): raise ImportError(__pycuda_error_msg__) cuda_options = cuda_options or {} self.ctx = get_cuda_context(**cuda_options) self._d_radios = None self._d_radios_weights = None self._d_sinos = None self._d_recs = None def _allocate_array(self, shape, dtype, name=None): name = name or "tmp" # should be mandatory d_name = "_d_" + name d_arr = getattr(self, d_name, None) if d_arr is None: self.logger.debug("Allocating %s: %s" % (name, str(shape))) d_arr = garray.zeros(shape, dtype) setattr(self, d_name, d_arr) return d_arr def _process_finalize(self): pass def _post_primary_data_reduction(self, i_slice): self._allocate_array((self.num_weight_radios_per_app,) + self.radios_slim.shape[1:], "f", name="radios_weights") if self.process_config.nabu_config["reconstruction"]["angular_tolerance_steps"]: self.radios[:, i_slice, :][np.isnan(self.radios[:, i_slice, :])] = 0 self.radios_slim.set(self.radios[:, i_slice, :]) def _register_callbacks(self): pass # # Pipeline execution (class specialization) # def _compute_histogram(self, data=None, i_slice=None, num_slices=None): if self.histogram is None: return if data is None: data = self.recs my_histo = self.histogram.compute_histogram(data) self.histo_stack.append(my_histo) if i_slice == num_slices - 1: self.recs_histogram = self.histogram.merge_histograms(self.histo_stack) self.histo_stack.clear() def _dump_data_to_file(self, step_name, data=None): if data is None: data = self.radios if step_name not in self._data_dump: return if isinstance(data, garray.GPUArray): data = data.get() super()._dump_data_to_file(step_name, data=data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/pipeline/helical/helical_reconstruction.py0000644000175000017500000005677514654107202023561 0ustar00pierrepierrefrom os.path import join, isfile, dirname from math import ceil from time import time import numpy as np import copy from ...resources.logger import LoggerOrPrint from ...io.writer import merge_hdf5_files from ...cuda.utils import collect_cuda_gpus try: import nabuxx SpanStrategy = nabuxx.span_strategy.SpanStrategy except: logger_tmp = LoggerOrPrint(None) logger_tmp.info("Nabuxx not available. Loading python implementation for SpanStrategy") from .span_strategy import SpanStrategy from .helical_chunked_regridded_cuda import CudaHelicalChunkedRegriddedPipeline from ..fullfield.reconstruction import collect_cuda_gpus, FullFieldReconstructor avail_gpus = collect_cuda_gpus() or {} class HelicalReconstructorRegridded: """ A class for obtaining a full-volume reconstructions from helical-scan datasets. """ _pipeline_cls = CudaHelicalChunkedRegriddedPipeline _process_name = "reconstruction" _pipeline_mode = "helical" reading_granularity = 100 """ The data angular span which needs to be read for a reconstruction is read step by step, reading each time a maximum of reading_granularity radios, and doing the preprocessing till phase retrieval for each of these angular groups """ def __init__(self, process_config, logger=None, extra_options=None, cuda_options=None): """ Initialize a LocalReconstruction object. This class is used for managing pipelines Parameters ---------- process_config: ProcessConfig object Data structure with process configuration logger: Logger, optional logging object extra_options: dict, optional Dictionary with advanced options. Please see 'Other parameters' below cuda_options: dict, optional Dictionary with cuda options passed to `nabu.cuda.processing.CudaProcessing` Other parameters ----------------- Advanced options can be passed in the 'extra_options' dictionary. These can be: - "gpu_mem_fraction": 0.9, - "cpu_mem_fraction": 0.9, - "use_phase_margin": True, - "max_chunk_size": None, - "phase_margin": None, - "dry_run": 0, - "diag_zpro_run": 0, """ self.logger = LoggerOrPrint(logger) self.process_config = process_config self._set_extra_options(extra_options) self._get_resources() ### intrication problem: this is used in fullfield's compute_margin to clamp the margin but not used by the present pipeline ### Set it to a big number that will never clamp self.n_z = 10000000 # a big number self.n_x = 10000000 # a big number ### self._compute_margin() # self._margin_v, self._margin_h = self._compute_phase_margin() self._setup_span_info() self._compute_max_chunk_size() self._get_reconstruction_range() self._build_tasks() self.pipeline = None self.cuda_options = cuda_options def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} advanced_options = { "gpu_mem_fraction": 0.9, "cpu_mem_fraction": 0.9, "use_phase_margin": True, "max_chunk_size": None, "phase_margin": None, "dry_run": 0, "diag_zpro_run": 0, } advanced_options.update(extra_options) self.extra_options = advanced_options self.gpu_mem_fraction = self.extra_options["gpu_mem_fraction"] self.cpu_mem_fraction = self.extra_options["cpu_mem_fraction"] self.use_phase_margin = self.extra_options["use_phase_margin"] self.dry_run = self.extra_options["dry_run"] self.diag_zpro_run = self.extra_options["diag_zpro_run"] self._do_histograms = self.process_config.nabu_config["postproc"]["output_histogram"] if self.diag_zpro_run: self.process_config.processing_options.get("phase", None) self._do_histograms = False self.reading_granularity = 10 self._histogram_merged = False self._span_info = None def _get_reconstruction_range(self): rec_cfg = self.process_config.nabu_config["reconstruction"] self.z_min = rec_cfg["start_z"] self.z_max = rec_cfg["end_z"] + 1 z_fract_min = rec_cfg["start_z_fract"] z_fract_max = rec_cfg["end_z_fract"] z_min_mm = rec_cfg["start_z_mm"] z_max_mm = rec_cfg["end_z_mm"] if z_min_mm != 0.0 or z_max_mm != 0.0: z_min_mm += self.z_offset_mm z_max_mm += self.z_offset_mm d_v, d_h = self.process_config.dataset_info.radio_dims[::-1] z_start, z_end = (self._span_info.get_doable_span()).view_heights_minmax z_end += 1 h_s = np.arange(z_start, z_end) fact_mm = self.process_config.dataset_info.pixel_size * 1.0e-3 z_mm_s = fact_mm * (-self._span_info.z_pix_per_proj[0] + (d_v - 1) / 2 - h_s) self.z_min = 0 self.z_max = len(z_mm_s) if z_mm_s[-1] > z_mm_s[0]: for i in range(len(z_mm_s) - 1): if (z_min_mm - z_mm_s[i]) * (z_min_mm - z_mm_s[i + 1]) <= 0: self.z_min = i break for i in range(len(z_mm_s) - 1): if (z_max_mm - z_mm_s[i]) * (z_max_mm - z_mm_s[i]) <= 0: self.z_max = i + 1 break else: for i in range(len(z_mm_s) - 1): if (z_max_mm - z_mm_s[i]) * (z_max_mm - z_mm_s[i + 1]) <= 0: self.z_max = len(z_mm_s) - 2 - i break for i in range(len(z_mm_s) - 1): if (z_min_mm - z_mm_s[i]) * (z_min_mm - z_mm_s[i + 1]) <= 0: self.z_min = len(z_mm_s) - 1 - i break elif z_fract_min != 0.0 or z_fract_max != 0.0: z_start, z_max = (self._span_info.get_doable_span()).view_heights_minmax # the meaming of z_min and z_max is: position in slices units from the # first available slice and in the direction of the scan self.z_min = int(round(z_start * (0 - z_fract_min) + z_max * z_fract_min)) self.z_max = int(round(z_start * (0 - z_fract_max) + z_max * z_fract_max)) + 1 def _compute_translations_margin(self): return 0, 0 def _compute_cone_overlap(self): return 0, 0 _get_resources = FullFieldReconstructor._get_resources _get_memory = FullFieldReconstructor._get_memory _get_gpu = FullFieldReconstructor._get_gpu _compute_phase_margin = FullFieldReconstructor._compute_phase_margin _compute_margin = FullFieldReconstructor._compute_margin _compute_unsharp_margin = FullFieldReconstructor._compute_unsharp_margin _print_tasks = FullFieldReconstructor._print_tasks _instantiate_pipeline_if_necessary = FullFieldReconstructor._instantiate_pipeline_if_necessary _destroy_pipeline = FullFieldReconstructor._destroy_pipeline _give_progress_info = FullFieldReconstructor._give_progress_info get_relative_files = FullFieldReconstructor.get_relative_files merge_histograms = FullFieldReconstructor.merge_histograms merge_data_dumps = FullFieldReconstructor.merge_data_dumps _get_chunk_length = FullFieldReconstructor._get_chunk_length # redefined here, and with self, otherwise @static and inheritance gives "takes 1 positional argument but 2 were given" # when called from inside the inherite class def _get_delta_z(self, task): return task["sub_region"][1] - task["sub_region"][0] def _get_task_key(self): """ Get the 'key' (number) associated to the current task/pipeline """ return self.pipeline.sub_region[-2:] # Gpu required memory size does not depend on the number of slices def _compute_max_chunk_size(self): cpu_mem = self.resources["mem_avail_GB"] * self.cpu_mem_fraction user_max_chunk_size = self.extra_options["max_chunk_size"] if self.diag_zpro_run: if user_max_chunk_size is not None: user_max_chunk_size = min( user_max_chunk_size, max(self.process_config.dataset_info.radio_dims[1] // 4, 10) ) else: user_max_chunk_size = max(self.process_config.dataset_info.radio_dims[1] // 4, 10) self.cpu_max_chunk_size = self.estimate_chunk_size( cpu_mem, self.process_config, chunk_step=1, user_max_chunk_size=user_max_chunk_size ) if user_max_chunk_size is not None: self.cpu_max_chunk_size = min(self.cpu_max_chunk_size, user_max_chunk_size) self.user_slices_at_once = self.cpu_max_chunk_size # cannot use the estimate_chunk_size from computations.py beacause it has the estimate_required_memory hard-coded def estimate_chunk_size(self, available_memory_GB, process_config, chunk_step=1, user_max_chunk_size=None): """ Estimate the maximum chunk size given an avaiable amount of memory. Parameters ----------- available_memory_GB: float available memory in Giga Bytes (GB - not GiB !). process_config: ProcessConfig ProcessConfig object """ chunk_size = chunk_step radios_and_sinos = False if ( "reconstruction" in process_config.processing_steps and process_config.processing_options["reconstruction"]["enable_halftomo"] ): radios_and_sinos = True max_dz = process_config.dataset_info.radio_dims[1] chunk_size = chunk_step last_good_chunk_size = chunk_size while True: required_mem = self._pipeline_cls.estimate_required_memory( process_config, chunk_size=chunk_size, reading_granularity=self.reading_granularity, margin_v=self._margin_v, span_info=self._span_info, diag_zpro_run=self.diag_zpro_run, ) required_mem_GB = required_mem / 1e9 if required_mem_GB > available_memory_GB: break last_good_chunk_size = chunk_size if user_max_chunk_size is not None and chunk_size > user_max_chunk_size: last_good_chunk_size = user_max_chunk_size break chunk_size += chunk_step return last_good_chunk_size # different because of dry_run def _build_tasks(self): if self.dry_run: self.tasks = [] else: self._compute_volume_chunks() # this is very different def _compute_volume_chunks(self): margin_v = self._margin_v # self._margin_far_up = min(margin_v, self.z_min) # self._margin_far_down = min(margin_v, n_z - (self.z_max + 1)) ## It will be the reading process which pads self._margin_far_up = margin_v self._margin_far_down = margin_v # | margin_up | n_slices | margin_down | # |-----------|-----------------|--------------| # |----------------------------------------------------| # delta_z n_slices = self.user_slices_at_once z_start, z_end = (self._span_info.get_doable_span()).view_heights_minmax z_end += 1 if (self.z_min, self.z_max) == (0, 0): self.z_min, self.z_max = z_start, z_end my_z_min = z_start my_z_end = z_end else: if self.z_max <= self.z_min: message = f"""" The input file provide start_z end_z {self.z_min,self.z_max} but it is necessary that start_z < end_z """ raise ValueError(message) if self._span_info.z_pix_per_proj[-1] >= self._span_info.z_pix_per_proj[0]: my_z_min = z_start + self.z_min my_z_end = z_start + self.z_max else: my_z_min = z_end - self.z_max my_z_end = z_end - self.z_min my_z_min = max(z_start, my_z_min) my_z_end = min(z_end, my_z_end) print("my_z_min my_z_end ", my_z_min, my_z_end) if my_z_min >= my_z_end: message = f""" The requested vertical span, after translation to absolute doable heights would be {my_z_min, my_z_end} is not doable (start>=end). Scans are often shorter than expected ThereFore : CONSIDER TO INCREASE angular_tolerance_steps """ raise ValueError(message) # if my_z_min != self.z_min or my_z_end != self.z_max: # message = f""" The requested vertical span given by self.z_min, self.z_max+1 ={self.z_min, self.z_max} # is not withing the doable span which is {z_start, z_end} # """ # raise ValueError(message) tasks = [] n_stages = ceil((my_z_end - my_z_min) / n_slices) curr_z_min = my_z_min curr_z_max = my_z_min + n_slices for i in range(n_stages): if curr_z_max >= my_z_end: curr_z_max = my_z_end margin_down = margin_v margin_up = margin_v tasks.append( { "sub_region": (curr_z_min - margin_up, curr_z_max + margin_down), "phase_margin": ((margin_up, margin_down), (0, 0)), } ) if curr_z_max == my_z_end: # No need for further tasks break curr_z_min += n_slices curr_z_max += n_slices ## ## ########################################################################################### self.tasks = tasks self.n_slices = n_slices self._print_tasks() def _setup_span_info(self): """We create here an instance of SpanStrategy class for helical scans. This class do all the accounting for the doable slices, giving for each the useful angle, the shifts .. """ # projections_subsampling = self.process_config.dataset_info.projections_subsampling projections_subsampling = self.process_config.nabu_config["dataset"]["projections_subsampling"] radio_shape = self.process_config.dataset_info.radio_dims[::-1] dz_per_proj = self.process_config.nabu_config["reconstruction"]["dz_per_proj"] z_per_proj = self.process_config.dataset_info.z_per_proj dx_per_proj = self.process_config.nabu_config["reconstruction"]["dx_per_proj"] x_per_proj = self.process_config.dataset_info.x_per_proj tot_num_images = len(self.process_config.processing_options["read_chunk"]["files"]) // projections_subsampling if z_per_proj is not None: z_per_proj = np.array(z_per_proj) self.logger.info(" z_per_proj has been explicitely provided") if len(z_per_proj) != tot_num_images: message = f""" The provided array z_per_proj, which has length {len(z_per_proj)} must match in lenght the number of radios which is {tot_num_images} """ raise ValueError(message) else: z_per_proj = self.process_config.dataset_info.z_translation if dz_per_proj is not None: self.logger.info("correcting vertical displacement by provided screw rate dz_per_proj") z_per_proj += np.arange(tot_num_images) * dz_per_proj if x_per_proj is not None: x_per_proj = np.array(x_per_proj) self.logger.info(" x_per_proj has been explicitely provided") if len(x_per_proj) != tot_num_images: message = f""" The provided array x_per_proj, which has length {len(x_per_proj)} must match in lenght the number of radios which is {tot_num_images} """ raise ValueError(message) else: x_per_proj = self.process_config.dataset_info.x_translation if dx_per_proj is not None: self.logger.info("correcting vertical displacement by provided screw rate dx_per_proj") x_per_proj += np.arange(tot_num_images) * dx_per_proj x_per_proj = x_per_proj - x_per_proj[0] self.z_offset_mm = z_per_proj[0] * self.process_config.dataset_info.pixel_size * 1.0e-3 # micron to mm z_per_proj = z_per_proj - z_per_proj[0] binning = self.process_config.nabu_config["dataset"]["binning"] if binning is not None: if np.isscalar(binning): binning = (binning, binning) binning_x, binning_z = binning x_per_proj = x_per_proj / binning_x z_per_proj = z_per_proj / binning_z x_per_proj = projections_subsampling * x_per_proj z_per_proj = projections_subsampling * z_per_proj angles_rad = self.process_config.processing_options["reconstruction"]["angles"] angles_rad = np.unwrap(angles_rad) angles_deg = np.rad2deg(angles_rad) redundancy_angle_deg = self.process_config.nabu_config["reconstruction"]["redundancy_angle_deg"] do_helical_half_tomo = self.process_config.nabu_config["reconstruction"]["helical_halftomo"] self.logger.info("Creating SpanStrategy object for helical ") t0 = time() self._span_info = SpanStrategy( z_offset_mm=self.z_offset_mm, z_pix_per_proj=z_per_proj, x_pix_per_proj=x_per_proj, detector_shape_vh=radio_shape, phase_margin_pix=self._margin_v, projection_angles_deg=angles_deg, pixel_size_mm=self.process_config.dataset_info.pixel_size * 1.0e-3, # micron to mm require_redundancy=(redundancy_angle_deg > 0), angular_tolerance_steps=self.process_config.nabu_config["reconstruction"]["angular_tolerance_steps"], ) duration = time() - t0 self.logger.info(f"Creating SpanStrategy object for helical in {duration} seconds") if self.dry_run: info_string = self._span_info.get_informative_string() print(" Informations about the doable vertical span") print(info_string) return def _instantiate_pipeline(self, task): self.logger.debug("Creating a new pipeline object") args = [self.process_config, task["sub_region"]] dz = self._get_delta_z(task) pipeline = self._pipeline_cls( *args, logger=self.logger, phase_margin=task["phase_margin"], reading_granularity=self.reading_granularity, span_info=self._span_info, diag_zpro_run=self.diag_zpro_run, # cuda_options=self.cuda_options ) self.pipeline = pipeline # kept to save diagnostic def _process_task(self, task): self.pipeline.process_chunk(sub_region=task["sub_region"]) key = len(list(self.diagnostic_per_chunk.keys())) self.diagnostic_per_chunk[key] = copy.deepcopy(self.pipeline.diagnostic) # kept for diagnostic and dry run def reconstruct(self): self._print_tasks() self.diagnostic_per_chunk = {} tasks = self.tasks self.results = {} self._histograms = {} self._data_dumps = {} prev_task = None for task in tasks: if prev_task is None: prev_task = task self._give_progress_info(task) self._instantiate_pipeline_if_necessary(task, prev_task) if self.dry_run: info_string = self._span_info.get_informative_string() print(" SPAN_INFO informations ") print(info_string) return self._process_task(task) if self.pipeline.writer is not None: task_key = self._get_task_key() self.results[task_key] = self.pipeline.writer.fname if self.pipeline.histogram_writer is not None: # self._do_histograms self._histograms[task_key] = self.pipeline.histogram_writer.fname if len(self.pipeline._data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline._data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task ## kept in order to speed it up by omitting the super slow python writing ## of thousand of unused urls def merge_hdf5_reconstructions( self, output_file=None, prefix=None, files=None, process_name=None, axis=0, merge_histograms=True, output_dir=None, ): """ Merge existing hdf5 files by creating a HDF5 virtual dataset. Parameters ---------- output_file: str, optional Output file name. If not given, the file prefix in section "output" of nabu config will be taken. """ out_cfg = self.process_config.nabu_config["output"] out_dir = output_dir or out_cfg["location"] prefix = prefix or "" # Prevent issue when out_dir is empty, which happens only if dataset/location is a relative path. # TODO this should be prevented earlier if out_dir is None or len(out_dir.strip()) == 0: out_dir = dirname(dirname(self.results[list(self.results.keys())[0]])) # if output_file is None: output_file = join(out_dir, prefix + out_cfg["file_prefix"]) + ".hdf5" if isfile(output_file): msg = str("File %s already exists" % output_file) if out_cfg["overwrite_results"]: msg += ". Overwriting as requested in configuration file" self.logger.warning(msg) else: msg += ". Set overwrite_results to True in [output] to overwrite existing files." self.logger.fatal(msg) raise ValueError(msg) local_files = files if local_files is None: local_files = self.get_relative_files() if local_files == []: self.logger.error("No files to merge") return entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") process_name = process_name or self._process_name h5_path = join(entry, *[process_name, "results", "data"]) # self.logger.info("Merging %ss to %s" % (process_name, output_file)) print("omitting config in call to merge_hdf5_files because export2dict too slow") merge_hdf5_files( local_files, h5_path, output_file, process_name, output_entry=entry, output_filemode="a", processing_index=0, config={ self._process_name + "_stages": {str(k): v for k, v in zip(self.results.keys(), local_files)}, "diagnostics": self.diagnostic_per_chunk, }, # config={ # self._process_name + "_stages": {str(k): v for k, v in zip(self.results.keys(), local_files)}, # "nabu_config": self.process_config.nabu_config, # "processing_options": self.process_config.processing_options, # }, base_dir=out_dir, axis=axis, overwrite=out_cfg["overwrite_results"], ) if merge_histograms: self.merge_histograms(output_file=output_file) return output_file merge_hdf5_files = merge_hdf5_reconstructions ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/pipeline/helical/helical_utils.py0000644000175000017500000000263714402565210021622 0ustar00pierrepierrefrom ...resources.logger import LoggerOrPrint import numpy as np logger = LoggerOrPrint(None) def find_mirror_indexes(angles_deg, tolerance_factor=1.0): """return a list of indexes where the ith elememnt contains the index of the angles_deg array element which has the value the closest to angles_deg[i] + 180. It is used for padding in halftomo. Parameters: ----------- angles_deg: a nd.array of floats tolerance: float if the mirror positions are not within a distance less than tolerance fro; the ideal position a warning is raised """ av_step = abs(np.diff(angles_deg).mean()) tolerance = av_step * tolerance_factor tmp_mirror_angles_deg = angles_deg + 180 mirror_angle_relative_indexes = (abs(abs(np.mod(tmp_mirror_angles_deg[:, None] - angles_deg, 360) - 180))).argmax( axis=-1 ) mirror_values = angles_deg[mirror_angle_relative_indexes] differences = abs(np.mod(mirror_values - angles_deg, 360) - 180) if differences.max() > tolerance: logger.warning( f"""In function find_mirror_indexes the mirror position are far beyon tolerance from ideal position tolerance is {tolerance} given by average step {av_step} and tolerance_factor {tolerance_factor} and the maximum error is {differences.max()} """ ) return mirror_angle_relative_indexes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1708073564.0 nabu-2024.2.1/nabu/pipeline/helical/nabu_config.py0000644000175000017500000001755214563621134021263 0ustar00pierrepierrefrom ..fullfield.nabu_config import * import copy ## keep the text below for future inclusion in the documentation # start_z, end_z , start_z_mm, end_z_mm # ---------------------------------------------------------------- # By default, all the volume is reconstructed slice by slice, along the axis 'z'. # ** option 1) you can set # start_z_mm=0 # end_z_mm = 0 # # Now, concerning start_z and end_z, Use positive integers, with start_z < end_z # The reconstructed vertical region will be # slice start = first doable + start_z # slice end = first doable + end_z # or less if such range needs to be clipped to the doable one # # As an example start_z= 10, end_z = 20 # for reconstructing 10 slices close to scan start. # NOTE: we are proceeding in the direction of the scan so that, in millimiters, # the start may be above or below the end # To reconstruct the whole doable volume set # # start_z= 0 # end_z =-1 # # ** option 2) using start_z_mm, end_z_mm # Use positive floats, in millimiters. They indicate the height above the sample stage # The values of start_z and end_z are not used in this case help_start_end_z = """ If start_z_mm , end_z_mm are seto to zero, then start_z and end_z will be effective unless end_z_fract is different from zero. In this latter case the vertical range will be given in terms o the fractional position between the first doable and last doable slices. Otherwhise, if start_z_mm and end_z_mm are not zero, the slices whose height above the sample stage, in millimiters, between start_z_mm and end_z_mm are reconstructed """ # we need to deepcopy this in order not to mess the original nabu_config of the full-field pipeline nabu_config = copy.deepcopy(nabu_config) nabu_config["preproc"]["processes_file"] = { "default": "", "help": "Path tgo the file where some operations should be stored for later use. By default it is 'xxx_nabu_processes.h5'", "validator": optional_file_location_validator, "type": "required", } nabu_config["preproc"]["double_flatfield_enabled"]["default"] = 1 nabu_config["reconstruction"].update( { "dz_per_proj": { "default": 0, "help": " A positive DZPERPROJ means that the rotation axis is going up. Alternatively the vertical translations, can be given through an array using the variable z_per_proj_file", "validator": float_validator, "type": "optional", }, "z_per_proj_file": { "default": "", "help": "Alternative to dz_per_proj. A file where each line has one value: vertical displacements of the axis. There should be as many values as there are projection images.", "validator": optional_file_location_validator, "type": "optional", }, "dx_per_proj": { "default": 0, "help": " A positive value means that the rotation axis is going on the rigth. Alternatively the horizontal translations, can be given through an array using the variable x_per_proj_file", "validator": float_validator, "type": "optional", }, "x_per_proj_file": { "default": "", "help": "Alternative to dx_per_proj. A file where each line has one value: horizontal displacements of the axis. There should be as many values as there are projection images.", "validator": optional_file_location_validator, "type": "optional", }, "axis_to_the_center": { "default": "1", "help": "Whether to shift start_x and start_y so to have the axis at the center", "validator": boolean_validator, "type": "optional", }, "auto_size": { "default": "1", "help": "Wether to set automatically start_x end_x start_y end_y ", "validator": boolean_validator, "type": "optional", }, "use_hbp": { "default": "0", "help": "Wether to use hbp routine instead of the backprojector from fbp ", "validator": boolean_validator, "type": "optional", }, "fan_source_distance_meters": { "default": 1.0e9, "help": "For HBP, for the description of the fan geometry, the source to axis distance. Defaults to a large value which implies parallel geometry", "validator": float_validator, "type": "optional", }, "start_z_mm": { "default": "0", "help": help_start_end_z, "validator": float_validator, "type": "optional", }, "end_z_mm": { "default": "0", "help": " To determine the reconstructed vertical range: the height in millimiters above the stage below which slices are reconstructed ", "validator": float_validator, "type": "optional", }, "start_z_fract": { "default": "0", "help": help_start_end_z, "validator": float_validator, "type": "optional", }, "end_z_fract": { "default": "0", "help": " To determine the reconstructed vertical range: the height in fractional position between first doable slice and last doable slice above the stage below which slices are reconstructed ", "validator": float_validator, "type": "optional", }, "start_z": { "default": "0", "help": "the first slice of the reconstructed range. Numbered going in the direction of the scan and starting with number zero for the first doable slice", "validator": slice_num_validator, "type": "optional", }, "end_z": { "default": "-1", "help": "the " "end" " slice of the reconstructed range. Numbered going in the direction of the scan and starting with number zero for the first doable slice", "validator": slice_num_validator, "type": "optional", }, } ) nabu_config["pipeline"].update( { "skip_after_flatfield_dump": { "default": "0", "help": "When the writing of the flatfielded data is activated, if this option is set, then the phase and reconstruction steps are skipped", "validator": boolean_validator, "type": "optional", }, } ) nabu_config["reconstruction"].update( { "angular_tolerance_steps": { "default": "3.0", "help": "the angular tolerance, an angular width expressed in units of an angular step, which is tolerated in the criteria for deciding if a slice is reconstructable or not", "validator": float_validator, "type": "advanced", }, "redundancy_angle_deg": { "default": "0", "help": "Can be 0,180 or 360. If there are dead detector regions (notably scintillator junction (stripes) which need to be complemented at +-360 for local tomo or +- 180 for conventional tomo. This may have an impact on the doable vertical span (you can check it with the --dry-run 1 option)", "validator": float_validator, "type": "advanced", }, "enable_halftomo": { "default": "0", "help": "nabu-helical applies the same treatment for half-tomo as for full-tomo. Always let this key to zero", "validator": boolean_validator, "type": "advanced", }, "helical_halftomo": { "default": "1", "help": "Wether to consider doable slices those which are contributed by an angular span greater or equal to 360, instead of just 180 or more", "validator": boolean_validator, "type": "advanced", }, } ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1699603354.0 nabu-2024.2.1/nabu/pipeline/helical/processconfig.py0000644000175000017500000000532714523361632021652 0ustar00pierrepierrefrom .nabu_config import nabu_config, renamed_keys from .dataset_validator import HelicalDatasetValidator from ..fullfield import processconfig as ff_processconfig from ...resources import dataset_analyzer class ProcessConfig(ff_processconfig.ProcessConfig): default_nabu_config = nabu_config config_renamed_keys = renamed_keys _use_horizontal_translations = False def _configure_save_steps(self): self._dump_sinogram = False steps_to_save = self.nabu_config["pipeline"]["save_steps"] if steps_to_save in (None, ""): self.steps_to_save = [] return steps_to_save = [s.strip() for s in steps_to_save.split(",")] for step in self.processing_steps: step = step.strip() if step in steps_to_save: self.processing_options[step]["save"] = True self.processing_options[step]["save_steps_file"] = self.get_save_steps_file(step_name=step) # "sinogram" is a special keyword, not explicitly in the processing steps if "sinogram" in steps_to_save: self._dump_sinogram = True self._dump_sinogram_file = self.get_save_steps_file(step_name="sinogram") self.steps_to_save = steps_to_save def _update_dataset_info_with_user_config(self): super()._update_dataset_info_with_user_config() self._get_translation_file("reconstruction", "z_per_proj_file", "z_per_proj", last_dim=1) self._get_translation_file("reconstruction", "x_per_proj_file", "x_per_proj", last_dim=1) def _get_user_sino_normalization(self): """is called by the base class but it is not used in helical""" pass def _coupled_validation(self): self.logger.debug("Doing coupled validation") self._dataset_validator = HelicalDatasetValidator(self.nabu_config, self.dataset_info) for what in ["rec_params", "rec_region", "binning", "subsampling_factor"]: setattr(self, what, getattr(self._dataset_validator, what)) print(what, self._dataset_validator) def _browse_dataset(self, dataset_info): """ Browse a dataset and builds a data structure with the relevant information. """ self.logger.debug("Browsing dataset") if dataset_info is not None: self.dataset_info = dataset_info else: self.dataset_info = dataset_analyzer.analyze_dataset( self.nabu_config["dataset"]["location"], extra_options={ "exclude_projections": self.nabu_config["dataset"]["exclude_projections"], "hdf5_entry": self.nabu_config["dataset"]["hdf5_entry"], }, logger=self.logger, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/pipeline/helical/span_strategy.py0000644000175000017500000006072414550227307021673 0ustar00pierrepierreimport math import numpy as np from ...resources.logger import LoggerOrPrint from ...utils import DictToObj class SpanStrategy: def __init__( self, z_pix_per_proj, x_pix_per_proj, detector_shape_vh, phase_margin_pix, projection_angles_deg, require_redundancy=False, pixel_size_mm=0.1, z_offset_mm=0.0, logger=None, angular_tolerance_steps=0.0, ): """ This class does all the accounting for the reconstructible slices, giving for each one the list of the useful angles, of the vertical and horizontal shifts,and more ... Parameters ---------- z_pix_per_proj : array of floats an array of floats with one entry per projection, in pixel units. The values are the vertical displacements of the detector. An decreasing z means that the rotation axis is following the positive direction of the detector vertical axis, which is pointing toward the ground. In the experimental setup, the vertical detector axis is pointing toward the ground. Moreover the values are offsetted so that the first value is zero. The offset value, in millimiters is z_offset_mm and it is the vertical position of the sample stage relatively to the center of the detector. A negative z_offset_mm means that the sample stage is below the detector for the first projection, and this is almost always the case, because the sample is above the sample stage. A z_pix=0 value indicates that the translation-rotation stage "ground" is exactly at the beam height ( supposed to be near the central line of the detector) plus z_offset_mm. A positive z_pix means that the translation stage has been lowered, compared to the first projection, in order to scan higher parts of the sample. ( the sample is above the translation-rotation stage). x_pix_per_proj : array of floats one entry per projection. The horizontal displacement of the detector respect to the rotation axis. A positive x means that the sample shadow on the detector is moving toward the left of the detector. (the detector is going right) detector_shape_vh : a tuple of two ints the vertical and horizontal dimensions phase_margin_pix : int the maximum margin needed for the different kernels (phase, unsharp..) otherwhere in the pipeline projection_angles_deg : array of floats per each projection the rotation angle of the sample in degree. require_redundancy: bool, optional, defaults to False It can be set to True, when there are dead zones in the detector. In this case the minimal required angular span is increased from 360 to 2*360 in order to enforce the necessary redundancy, which allows the correction of the dead zones. The lines which do not satisfy this requirement are not doable. z_offset_mm: float the vertical position of the sample stage relatively to the center of the detector. A negative z_offset_mm means that the sample stage is below the detector for the first projection, and this is almost always the case, because the sample is above the sample stage. pixel_size_mm: float, the pixel size in millimiters this value is used to give results in units of " millimeters above the sample stage" Althougth internally all is calculated in pixel units, it is useful to incorporate such information in the spanstrategy object which will then be able to setup reconstruction informations according to several request formats: be they in sequential number of reconstructible slices, or millimeters above the stage. angular_tolerance_steps: float, defaults to zero the angular tolerance, an angular width expressed in units of an angular step, which is tolerated in the criteria for deciding if a slice is reconstructable or not logger : a logger, optional """ self.logger = LoggerOrPrint(logger) self.require_redundancy = require_redundancy self.z_pix_per_proj = z_pix_per_proj self.x_pix_per_proj = x_pix_per_proj self.detector_shape_vh = detector_shape_vh self.total_num_images = len(self.z_pix_per_proj) self.phase_margin_pix = phase_margin_pix self.pix_size_mm = pixel_size_mm self.z_offset_mm = z_offset_mm self.angular_tolerance_steps = angular_tolerance_steps # internally we use increasing angles, so that in all inequalities, that are used # to check the span, only such case has to be contemplated. # To do so, if needed, we change the sign of the angles. if projection_angles_deg[-1] > projection_angles_deg[0]: self.projection_angles_deg_internally = projection_angles_deg self.angles_are_inverted = False else: self.projection_angles_deg_internally = -projection_angles_deg self.angles_are_inverted = True self.projection_angles_deg = projection_angles_deg if ( len(self.x_pix_per_proj) != self.total_num_images or len(self.projection_angles_deg_internally) != self.total_num_images ): message = f""" all the arguments z_pix_per_proj, x_pix_per_proj and projection_angles_deg must have the same lenght but their lenght were {len(self.z_pix_per_proj) }, {len(self.x_pix_per_proj) }, {len(self.projection_angles_deg_internally) } respectively """ raise ValueError(message) ## informations to be built are initialised to None here below """ For a given slice, the procedure for obtaining the useful angles, is based on the "sunshine" image, The sunshine image has two dimensions, the second one is the projection number while the first runs over the heights of the slices. All the logic is based on this image: when a given pixel of this image is zero, this corresponds to a pair (height,projection) for which there is no contribution of that projection to that slice. """ self.sunshine_image = None """ This will be an array of integer heights. We use a one-to-one correspondance beween these integers and slices in the reconstructed volume. The value of a given item is the vertical coordinate of an horizontal line in in the dectector (or above or below (negative values) ). More precisely the line over which the corresponding slice gets projected at the beginning of the scan ( in other words for the first vertical translation entry of the z_pix_per_proj argument) Illustratively, a view height equal to zero corresponds to a slice which projects on row 0 when translation is given by the first value in self.z_pix_per_proj. Negative integers for the heights are possible too, according to the direction of translations, which may bring above or below the detector. ( Illustratively the values in z_pix_per_proj are alway positive for the simple fact that the sample is above the roto-translation stage and the stage must be lowered in height for the beam to hit the sample. A scan starts always from a positive z_pix translation but the z_pix value may either decrease or increase. In fact for practical reasons, after having done a scan in one direction it is convenient to scan also while coming back after a previous scan) """ self.total_view_heights = None """ A list wich will contain for every reconstructable heigth, listed in self.total_view_heights, and in the same order, a pair of two integer, the first is the first sequential number of the projection for which the height is projected inside the detector, the second integer is the last projection number for which the projection occurs inside the detector. """ self.on_detector_projection_intervals = None """ All like self.on_detector_projection_intervals, but considering also the margins (phase, unsharp, etc). so that the height is projected inside the detector but while keeping a safe phase_margin_pix distance from the upper and lower border of the detector. """ self.within_margin_projection_intervals = None """ This will contain projection number i_pro, the integer value given by ceil(z_pix_per_proj[i_pro]) This array, together with the here below array self.fract_complement_to_integer_shift_v will be used for cropping when the data are collected for a given to be reconstructed chunk. """ self.integer_shift_v = np.zeros([self.total_num_images], "i") """ The fractional vertical shifts are positive floats < 1.0 pixels. This is the fractional part which, added to self.integer_shift_v, gives back z_pix_per_proj. Together with integer_shift_v, this array is meant to be used, by other modules, for cropping when the data are collected for a given to be reconstructed chunk.""" self.fract_complement_to_integer_shift_v = np.zeros([self.total_num_images], "f") self._setup_ephemerides() self._setup_sunshine() def get_doable_span(self): """return an object with two properties: view_heights_minmax: containining minimum and maximum doable height ( detector reference at iproj=0) z_pix_minmax : containing minimum and maximum heights above the roto-translational sample stage """ vertical_profile = self.sunshine_image.sum(axis=1) doable_indexes = np.arange(len(vertical_profile))[vertical_profile > 0] vertical_span = doable_indexes.min(), doable_indexes.max() if not (vertical_profile[vertical_span[0] : vertical_span[1] + 1] > 0).all(): message = """ Something wrong occurred in the span preprocessing. It appears that some intermetiade slices are not doables. The doable span should instead be contiguous. Please signal the problem""" raise RuntimeError(message) view_heights_minmax = self.total_view_heights[list(vertical_span)] hmin, hmax = view_heights_minmax d_v, d_h = self.detector_shape_vh z_min, z_max = (-self.z_pix_per_proj[0] + (d_v - 1) / 2 - hmax, -self.z_pix_per_proj[0] + (d_v - 1) / 2 - hmin) res = { "view_heights_minmax": view_heights_minmax, "z_pix_minmax": (z_min, z_max), "z_mm_minmax": (z_min * self.pix_size_mm, z_max * self.pix_size_mm), } return DictToObj(res) def get_informative_string(self): doable_span_v = self.get_doable_span() if self.z_pix_per_proj[-1] >= self.z_pix_per_proj[-1]: direction = "ascending" else: direction = "descending" s = f""" Doable vertical span -------------------- The scan has been performed with an {direction} vertical translation of the rotation axis. The detector vertical axis is up side down. Detector reference system at iproj=0: from vertical view height ... {doable_span_v.view_heights_minmax[0]} up to (included) ... {doable_span_v.view_heights_minmax[1]} The slice that projects to the first line of the first projection corresponds to vertical heigth = 0 In voxels, the vertical doable span measures: {doable_span_v.z_pix_minmax[1] - doable_span_v.z_pix_minmax[0]} And in millimiters above the stage: from vertical height above stage ( mm units) ... {doable_span_v.z_mm_minmax[0] - self.z_offset_mm } up to (included) ... {doable_span_v.z_mm_minmax[1] - self.z_offset_mm } """ return s def get_chunk_info(self, span_v_absolute): """ This method returns an object containing all the informations that are needed to reconstruct the corresponding chunk angle_index_span: a pair of integers indicating the start and the end of useful angles in the array of all the scan angle self.projection_angles_deg span_v: a pair of two integers indicating the start and end of the span relatively to the lowest value of array self.total_view_heights integer_shift_v: an array, containing for each one of the useful projections of the span, the integer part of vertical shift to be used in cropping, fract_complement_to_integer_shift_v : the fractional remainder for cropping. z_pix_per_proj: an array, containing for each to be used projection of the span the vertical shift x_pix_per_proj: ....the horizontal shit angles_rad : an array, for each useful projection of the chunk the angle in radian Parameters: ----------- span_v_absolute: tuple of integers a pair of two integers the first view height ( referred to the detector y axis at iproj=0) the second view height with the first height smaller than the second. """ span_v = (span_v_absolute[0] - self.total_view_heights[0], span_v_absolute[1] - self.total_view_heights[0]) sunshine_subset = self.sunshine_image[span_v[0] : span_v[1]] angular_profile = sunshine_subset.sum(axis=0) angle_indexes = np.arange(len(self.projection_angles_deg_internally))[angular_profile > 0] angle_index_span = angle_indexes.min(), angle_indexes.max() + 1 if not (np.less(0, angular_profile[angle_index_span[0] : angle_index_span[1]]).all()): message = """ Something wrong occurred in the span preprocessing. It appears that some intermediate slices are not doables. The doable span should instead be contiguous. Please signal the problem""" raise RuntimeError(message) chunk_angles_deg = self.projection_angles_deg[angle_indexes] my_slicer = slice(angle_index_span[0], angle_index_span[1]) values = ( angle_index_span, span_v_absolute, self.integer_shift_v[my_slicer], self.fract_complement_to_integer_shift_v[my_slicer], self.z_pix_per_proj[my_slicer], self.x_pix_per_proj[my_slicer], np.deg2rad(chunk_angles_deg) * (1 - 2 * int(self.angles_are_inverted)), ) key_names = ( "angle_index_span", "span_v", "integer_shift_v", "fract_complement_to_integer_shift_v", "z_pix_per_proj", "x_pix_per_proj", "angles_rad", ) return DictToObj(dict(zip(key_names, values))) def _setup_ephemerides(self): """ A function which will set : * self.integer_shift_v * self.fract_complement_to_integer_shift_v * self.total_view_heights * self.on_detector_projection_intervals * self.within_margin_projection_intervals """ for i_pro in range(self.total_num_images): trans_v = self.z_pix_per_proj[i_pro] self.integer_shift_v[i_pro] = math.ceil(trans_v) self.fract_complement_to_integer_shift_v[i_pro] = math.ceil(trans_v) - trans_v ## The two following line initialize the view height, then considering the vertical translation # the filed of view will be expanded total_view_top = self.detector_shape_vh[0] total_view_bottom = 0 total_view_top = max(total_view_top, int(math.ceil(total_view_top + self.z_pix_per_proj.max()))) total_view_bottom = min(total_view_bottom, int(math.floor(total_view_bottom + self.z_pix_per_proj.min()))) self.total_view_heights = np.arange(total_view_bottom, total_view_top + 1) ## where possible only data from within safe phase margin will be considered. (within_margin) ## if it is enough for 360 or more degree. ## If it is not enough we'll complete wih data close to the border, provided that data comes from within detector ## This will be ## the case for the first and last doable slices. self.within_margin_projection_intervals = np.zeros(self.total_view_heights.shape + (2,), "i") self.on_detector_projection_intervals = np.zeros(self.total_view_heights.shape + (2,), "i") self.on_detector_projection_intervals[:, 1] = 0 # empty intervals self.within_margin_projection_intervals[:, 1] = 0 for i_h, height in enumerate(self.total_view_heights): previous_is_inside_detector = False previous_is_inside_margin = False pos_inside_filtered_v = height - self.integer_shift_v is_inside_detector = np.less_equal(0, pos_inside_filtered_v) is_inside_detector *= np.less( pos_inside_filtered_v, self.detector_shape_vh[0] - np.less(0, self.fract_complement_to_integer_shift_v) ) is_inside_margin = np.less_equal(self.phase_margin_pix, pos_inside_filtered_v) is_inside_margin *= np.less( pos_inside_filtered_v, self.detector_shape_vh[0] - self.phase_margin_pix - np.less(0, self.fract_complement_to_integer_shift_v), ) tmp = np.arange(self.total_num_images)[is_inside_detector] if len(tmp): self.on_detector_projection_intervals[i_h, :] = (tmp.min(), tmp.max()) tmp = np.arange(self.total_num_images)[is_inside_detector] if len(tmp): self.within_margin_projection_intervals[i_h, :] = (tmp.min(), tmp.max()) def _setup_sunshine(self): """It prepares * self.sunshine_image an image which for every height, contained in self.total_view_heights, and in the same order, contains the list of factors, one per every projection of the total list. Each factor is a backprojection weight. A non-doable height corresponds to a line full of zeros. A doable height must correspond to a line having one and only one segment of contiguous non zero elements. """ self.sunshine_image = np.zeros(self.total_view_heights.shape + (self.total_num_images,), "f") avg_angular_step_deg = np.diff(self.projection_angles_deg_internally).mean() projection_angles_for_interp = np.concatenate( [ [self.projection_angles_deg_internally[0] - avg_angular_step_deg], self.projection_angles_deg_internally, [self.projection_angles_deg_internally[-1] + avg_angular_step_deg], ] ) data_container_for_interp = np.zeros_like(projection_angles_for_interp) num_angular_periods = math.ceil( (self.projection_angles_deg_internally.max() - self.projection_angles_deg_internally.min()) / 360 ) for i_h, height in enumerate(self.total_view_heights): first_last_on_dect = self.on_detector_projection_intervals[i_h] first_last_within_margin = self.within_margin_projection_intervals[i_h] if first_last_on_dect[1] == 0: # this line never entered in the fov continue angle_on_dect_first_last = self.projection_angles_deg_internally[first_last_on_dect] angle_within_margin_first_last = self.projection_angles_deg_internally[first_last_within_margin] # a mask which is positive for angular positions for which the height i_h gets projected within the detector mask_on_dect = ( np.less_equal(angle_on_dect_first_last[0], self.projection_angles_deg_internally) * np.less_equal(self.projection_angles_deg_internally, angle_on_dect_first_last[1]) ).astype("f") # a mask which is positive for angular positions for which the height i_h gets projected within the margins mask_within_margin = ( np.less_equal(angle_within_margin_first_last[0], self.projection_angles_deg_internally) * np.less_equal(self.projection_angles_deg_internally, angle_within_margin_first_last[1]) ).astype("f") ## create a line which collects contributions from redundant angles detector_collector = np.zeros(self.projection_angles_deg_internally.shape, "f") margin_collector = np.zeros(self.projection_angles_deg_internally.shape, "f") # the following loop tracks, for each projection, the total weight available at the projection angle # The additional weight is coming from redundant angles. # In this sense the sunshine_image implements a first rudimentary reweighting, # which could be in principle used in the full pipeline, althought the regridded pipeline # implements a, better, reweighting o its own. for i_shift in range(-num_angular_periods, num_angular_periods + 1): signus = ( 1 if (self.projection_angles_deg_internally[-1] > self.projection_angles_deg_internally[0]) else -1 ) data_container_for_interp[1:-1] = mask_on_dect detector_collector = detector_collector + mask_on_dect * np.interp( (self.projection_angles_deg_internally + i_shift * 360) * signus, signus * projection_angles_for_interp, data_container_for_interp, left=0, right=0, ) data_container_for_interp[1:-1] = mask_within_margin margin_collector = margin_collector + mask_within_margin * np.interp( (self.projection_angles_deg_internally + i_shift * 360) * signus, signus * projection_angles_for_interp, data_container_for_interp, left=0, right=0, ) detector_shined_angles = self.projection_angles_deg_internally[detector_collector > 0.99] margin_shined_angles = self.projection_angles_deg_internally[margin_collector > 0.99] if not len(detector_shined_angles) > 1: continue avg_step_deg = abs(avg_angular_step_deg) if len(margin_shined_angles): angular_span_safe_margin = ( margin_shined_angles.max() - margin_shined_angles.min() + avg_step_deg * (1.01 + self.angular_tolerance_steps) ) else: angular_span_safe_margin = 0 angular_span_bare_border = ( detector_shined_angles.max() - detector_shined_angles.min() + avg_step_deg * (1.01 + self.angular_tolerance_steps) ) if not self.require_redundancy: if angular_span_safe_margin >= 360: self.sunshine_image[i_h] = margin_collector elif angular_span_bare_border >= 360: self.sunshine_image[i_h] = detector_collector else: redundancy_angle_deg = 360 if angular_span_safe_margin >= 360 and angular_span_safe_margin > 2 * ( redundancy_angle_deg + avg_step_deg ): self.sunshine_image[i_h] = margin_collector elif angular_span_bare_border >= 360 and angular_span_bare_border > 2 * ( redundancy_angle_deg + avg_step_deg ): self.sunshine_image[i_h] = detector_collector sunshine_mask = np.less(0.99, self.sunshine_image) self.sunshine_image[np.array([True]) ^ sunshine_mask] = 1.0 self.sunshine_image[:] = 1 / self.sunshine_image self.sunshine_image[np.array([True]) ^ sunshine_mask] = 0.0 shp = self.sunshine_image.shape X, Y = np.meshgrid(np.arange(shp[1]), np.arange(shp[0])) condition = self.sunshine_image > 0 self.sunshine_starts = X.min(axis=1, initial=shp[1], where=condition) self.sunshine_ends = X.max(axis=1, initial=0, where=condition) self.sunshine_ends[self.sunshine_ends > 0] += 1 ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/pipeline/helical/tests/0000755000175000017500000000000014730277752017600 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/pipeline/helical/tests/__init__.py0000644000175000017500000000000014402565210021660 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/pipeline/helical/weight_balancer.py0000644000175000017500000000761514402565210022120 0ustar00pierrepierreimport numpy as np import math class WeightBalancer: def __init__(self, rot_center, angles_rad): """This class contains the method for rebalancing the weight prior to backprojection. The weights of halfomo redundacy data ( the central part) are rebalanced. In a pipeline, the weights rebalanced by the method balance_weight, have to be applied to the ramp-filtered data prior to backprojection. As a matter of fact the method balance_weights could be called as a function, but in order to be conformant to Nabu, we create this class and follow the scheme initialisation + application. Parameters ========== rot_center : float the center of rotation in pixel units angles_rad : the angles corresponding to the to be rebalanced projections """ self.rot_center = rot_center self.my_angles_rad = angles_rad def balance_weights(self, radios_weights): """ The parameter radios_weights is a stack having having the same weight as the stack of projections. It is modified in place, correcting the value of overlapping data, so that the sum is always one """ radios_weights[:] = self._rebalance(radios_weights) def _rebalance(self, radios_weights): """rebalance the weights, within groups of equivalent (up to multiple of 180), data pixels""" balanced = np.zeros_like(radios_weights) n_span = int(math.ceil(self.my_angles_rad[-1] - self.my_angles_rad[0]) / np.pi) center = (radios_weights.shape[-1] - 1) / 2 nloop = balanced.shape[0] for i in range(nloop): w_res = balanced[i] angle = self.my_angles_rad[i] for i_half_turn in range(-n_span - 1, n_span + 2): if i_half_turn == 0: w_res[:] += radios_weights[i] continue shifted_angle = angle + i_half_turn * np.pi insertion_index = np.searchsorted(self.my_angles_rad, shifted_angle) if insertion_index in [0, self.my_angles_rad.shape[0]]: if insertion_index == 0: if abs(self.my_angles_rad[0] - shifted_angle) > np.pi / 100: continue myimage = radios_weights[0] else: if abs(self.my_angles_rad[-1] - shifted_angle) > np.pi / 100: continue myimage = radios_weights[-1] else: partition = shifted_angle - self.my_angles_rad[insertion_index - 1] myimage = (1.0 - partition) * radios_weights[insertion_index - 1] + partition * radios_weights[ insertion_index ] if i_half_turn % 2 == 0: w_res[:] += myimage else: myimage = np.fliplr(myimage) w_res[:] += shift(myimage, (2 * self.rot_center - 2 * center)) mask = np.equal(0, radios_weights) balanced[:] = radios_weights / balanced balanced[mask] = 0 return balanced def shift(arr, shift, fill_value=0.0): """trivial horizontal shift. Contrarily to scipy.ndimage.interpolation.shift, this shift does not cut the tails abruptly, but by interpolation """ result = np.zeros_like(arr) num1 = int(math.floor(shift)) num2 = num1 + 1 partition = shift - num1 for num, factor in zip([num1, num2], [(1 - partition), partition]): if num > 0: result[:, :num] += fill_value * factor result[:, num:] += arr[:, :-num] * factor elif num < 0: result[:, num:] += fill_value * factor result[:, :num] += arr[:, -num:] * factor else: result[:] += arr * factor return result ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/pipeline/params.py0000644000175000017500000000724214720040151016653 0ustar00pierrepierreflatfield_modes = { "true": True, "1": True, "false": False, "0": False, "forced": "force-load", "force-load": "force-load", "force-compute": "force-compute", } phase_retrieval_methods = { "": None, "none": None, "paganin": "paganin", "tie": "paganin", "ctf": "CTF", } unsharp_methods = { "gaussian": "gaussian", "log": "log", "laplacian": "log", "imagej": "imagej", "none": None, "": None, } padding_modes = { "edges": "edge", "edge": "edge", "mirror": "mirror", "zeros": "zeros", "zero": "zeros", } reconstruction_methods = { "fbp": "FBP", "cone": "cone", "conic": "cone", "none": None, "": None, "mlem": "mlem", "fluo": "mlem", "em": "mlem", "hbp": "HBP", "ghbp": "HBP", } fbp_filters = { "ramlak": "ramlak", "ram-lak": "ramlak", "none": None, "": None, "shepp-logan": "shepp-logan", "cosine": "cosine", "hamming": "hamming", "hann": "hann", "tukey": "tukey", "lanczos": "lanczos", "hilbert": "hilbert", } iterative_methods = { "tv": "TV", "wavelets": "wavelets", "l2": "L2", "ls": "L2", "sirt": "SIRT", "em": "EM", } optim_algorithms = { "chambolle": "chambolle-pock", "chambollepock": "chambolle-pock", "fista": "fista", } reco_implementations = { "astra": "astra", "corrct": "corrct", "corr-ct": "corrct", "nabu": "nabu", "": None, } files_formats = { "h5": "hdf5", "hdf5": "hdf5", "nexus": "hdf5", "nx": "hdf5", "npy": "npy", "npz": "npz", "tif": "tiff", "tiff": "tiff", "jp2": "jp2", "jp2k": "jp2", "j2k": "jp2", "jpeg2000": "jp2", "edf": "edf", "vol": "vol", } distribution_methods = { "local": "local", "slurm": "slurm", "": "local", "preview": "preview", } log_levels = { "0": "error", "1": "warning", "2": "info", "3": "debug", } sino_normalizations = { "none": None, "": None, "chebyshev": "chebyshev", "subtraction": "subtraction", "division": "division", } cor_methods = { "auto": "centered", "centered": "centered", "global": "global", "sino sliding window": "sino-sliding-window", "sino-sliding-window": "sino-sliding-window", "sliding window": "sliding-window", "sliding-window": "sliding-window", "sino growing window": "sino-growing-window", "sino-growing-window": "sino-growing-window", "growing window": "growing-window", "growing-window": "growing-window", "sino-coarse-to-fine": "sino-coarse-to-fine", "composite-coarse-to-fine": "composite-coarse-to-fine", "near": "composite-coarse-to-fine", "fourier-angles": "fourier-angles", "fourier angles": "fourier-angles", "fourier-angle": "fourier-angles", "fourier angle": "fourier-angles", "octave-accurate": "octave-accurate", "vo": "vo", } tilt_methods = { "1d-correlation": "1d-correlation", "1dcorrelation": "1d-correlation", "polarfft": "fft-polar", "polar-fft": "fft-polar", "fft-polar": "fft-polar", } rings_methods = { "none": None, "": None, "munch": "munch", "mean-subtraction": "mean-subtraction", "mean_subtraction": "mean-subtraction", "mean-division": "mean-division", "mean_division": "mean-division", "vo": "vo", } detector_distortion_correction_methods = {"none": None, "": None, "identity": "identity", "map_xz": "map_xz"} radios_rotation_mode = { "none": None, "": None, "chunk": "chunk", "chunks": "chunk", "full": "full", } exclude_projections_type = { "indices": "indices", "angular_range": "angular_range", "angles": "angles", } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556963.0 nabu-2024.2.1/nabu/pipeline/processconfig.py0000644000175000017500000001704714656662143020263 0ustar00pierrepierreimport os from .config import parse_nabu_config_file from ..utils import is_writeable from ..resources.logger import Logger, PrinterLogger from .config import validate_config from ..resources.dataset_analyzer import analyze_dataset from .estimators import DetectorTiltEstimator class ProcessConfigBase: """ A class for describing the Nabu process configuration. """ # Must be overriden by inheriting class default_nabu_config = None config_renamed_keys = None def __init__( self, conf_fname=None, conf_dict=None, dataset_info=None, create_logger=False, ): """ Initialize a ProcessConfig class. Parameters ---------- conf_fname: str Path to the nabu configuration file. If provided, the parameters `conf_dict` is ignored. conf_dict: dict A dictionary describing the nabu processing steps. If provided, the parameter `conf_fname` is ignored. dataset_info: DatasetAnalyzer A `DatasetAnalyzer` class instance. checks: bool, optional, default is True Whether to perform checks on configuration and datasets (recommended !) remove_unused_radios: bool, optional, default is True Whether to remove unused radios, i.e radios present in the dataset, but not explicitly listed in the scan metadata. create_logger: str or bool, optional Whether to create a Logger object. Default is False, meaning that the logger object creation is left to the user. If set to True, a Logger object is created, and logs will be written to the file "nabu_dataset_name.log". If set to a string, a Logger object is created, and the logs will be written to the file specified by this string. """ # Step (1a): create 'nabu_config' self._parse_configuration(conf_fname, conf_dict) self._create_logger(create_logger) # Step (1b): create 'dataset_info' self._browse_dataset(dataset_info) # Step (2) self._update_dataset_info_with_user_config() # Step (3): estimate tilt, CoR, ... self._dataset_estimations() # Step (4) self._coupled_validation() # Step (5) self._build_processing_steps() # Step (6) self._configure_save_steps() self._configure_resume() def _create_logger(self, create_logger): if create_logger is False: self.logger = PrinterLogger() return elif create_logger is True: dataset_loc = self.nabu_config["dataset"]["location"] dataset_fname_rel = os.path.basename(dataset_loc) if os.path.isfile(dataset_loc): logger_filename = os.path.join( os.path.abspath(os.getcwd()), os.path.splitext(dataset_fname_rel)[0] + "_nabu.log" ) else: logger_filename = os.path.join(os.path.abspath(os.getcwd()), dataset_fname_rel + "_nabu.log") elif isinstance(create_logger, str): logger_filename = create_logger else: raise ValueError("Expected bool or str for create_logger") if not is_writeable(os.path.dirname(logger_filename)): self.logger = PrinterLogger() self.logger.error("Cannot create logger file %s: no permission to write therein" % logger_filename) else: self.logger = Logger("nabu", level=self.nabu_config["pipeline"]["verbosity"], logfile=logger_filename) def _parse_configuration(self, conf_fname, conf_dict): """ Parse the user configuration and builds a dictionary. Parameters ---------- conf_fname: str Path to the .conf file. Mutually exclusive with 'conf_dict' conf_dict: dict Dictionary with the configuration. Mutually exclusive with 'conf_fname' """ if not ((conf_fname is None) ^ (conf_dict is None)): raise ValueError("You must either provide 'conf_fname' or 'conf_dict'") if conf_fname is not None: if not os.path.isfile(conf_fname): raise ValueError("No such file: %s" % conf_fname) self.conf_fname = conf_fname self.conf_dict = parse_nabu_config_file(conf_fname) else: self.conf_dict = conf_dict if self.default_nabu_config is None or self.config_renamed_keys is None: raise ValueError( "'default_nabu_config' and 'config_renamed_keys' must be specified by classes inheriting from ProcessConfig" ) self.nabu_config = validate_config( self.conf_dict, self.default_nabu_config, self.config_renamed_keys, ) def _browse_dataset(self, dataset_info): """ Browse a dataset and builds a data structure with the relevant information. """ self.logger.debug("Browsing dataset") if dataset_info is not None: self.dataset_info = dataset_info else: extra_options = { "exclude_projections": self.nabu_config["dataset"]["exclude_projections"], "hdf5_entry": self.nabu_config["dataset"]["hdf5_entry"], "nx_version": self.nabu_config["dataset"]["nexus_version"], } self.dataset_info = analyze_dataset( self.nabu_config["dataset"]["location"], extra_options=extra_options, logger=self.logger ) def _update_dataset_info_with_user_config(self): """ Update the 'dataset_info' (DatasetAnalyzer class instance) data structure with options from user configuration. """ raise ValueError("Base class") def _get_rotation_axis_position(self): self.dataset_info.axis_position = self.nabu_config["reconstruction"]["rotation_axis_position"] def _update_rotation_angles(self): raise ValueError("Base class") def _dataset_estimations(self): """ Perform estimation of several parameters like center of rotation and detector tilt angle. """ self.logger.debug("Doing dataset estimations") self._get_tilt() self._get_cor() def _get_cor(self): raise ValueError("Base class") def _get_tilt(self): tilt = self.nabu_config["preproc"]["tilt_correction"] if isinstance(tilt, str): # auto-tilt... self.tilt_estimator = DetectorTiltEstimator( self.dataset_info, do_flatfield=self.nabu_config["preproc"]["flatfield"], logger=self.logger, autotilt_options=self.nabu_config["preproc"]["autotilt_options"], ) tilt = self.tilt_estimator.find_tilt(tilt_method=tilt) self.dataset_info.detector_tilt = tilt def _coupled_validation(self): """ Validate together the dataset information and user configuration. Update 'dataset_info' and 'nabu_config' """ raise ValueError("Base class") def _build_processing_steps(self): """ Build the processing steps, i.e a tuple (steps, options) where - steps is a list of str (list of processing steps names) - options is a dict with processing options """ raise ValueError("Base class") build_processing_steps = _build_processing_steps # COMPAT. def _configure_save_steps(self): raise ValueError("Base class") def _configure_resume(self): raise ValueError("Base class") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/pipeline/reader.py0000644000175000017500000001127014656662150016647 0ustar00pierrepierrefrom multiprocessing.pool import ThreadPool import numpy as np from nabu.utils import get_num_threads from ..misc.binning import binning as image_binning from ..io.reader import NXTomoReader, EDFStackReader # # NXTomoReader with binning # def bin_image_stack(src_stack, dst_stack, binning_factor=(2, 2), num_threads=8): def _apply_binning(img_res_tuple): img, res = img_res_tuple res[:] = image_binning(img, binning_factor) if dst_stack is None: dst_stack = np.zeros((src_stack.shape[0],) + image_binning(src_stack[0], binning_factor).shape, dtype="f") with ThreadPool(num_threads) as tp: tp.map(_apply_binning, zip(src_stack, dst_stack)) return dst_stack def NXTomoReaderBinning(binning_factor, *nxtomoreader_args, num_threads=None, **nxtomoreader_kwargs): [ nxtomoreader_kwargs.pop(kwarg, None) for kwarg in ["processing_func", "processing_func_args", "processing_func_kwargs"] ] nxtomoreader_kwargs["processing_func"] = bin_image_stack nxtomoreader_kwargs["processing_func_kwargs"] = { "binning_factor": binning_factor, "num_threads": num_threads or get_num_threads(), } return NXTomoReader( *nxtomoreader_args, **nxtomoreader_kwargs, ) # # NXTomoReader with distortion correction # def apply_distortion_correction_on_images_stack(src_stack, dst_stack, distortion_corrector, num_threads=8): _, subregion = distortion_corrector.get_actual_shapes_source_target() src_x_start, src_x_end, src_z_start, src_z_end = subregion if dst_stack is None: dst_stack = np.zeros([src_stack.shape[0], src_z_end - src_z_start, src_x_end - src_x_start], "f") def apply_corrector(i_img_tuple): i, img = i_img_tuple dst_stack[i] = distortion_corrector.transform(img) with ThreadPool(num_threads) as tp: tp.map(apply_corrector, enumerate(src_stack)) return dst_stack def NXTomoReaderDistortionCorrection(distortion_corrector, *nxtomoreader_args, num_threads=None, **nxtomoreader_kwargs): [ nxtomoreader_kwargs.pop(kwarg, None) for kwarg in ["processing_func", "processing_func_args", "processing_func_kwargs"] ] nxtomoreader_kwargs["processing_func"] = apply_distortion_correction_on_images_stack nxtomoreader_kwargs["processing_func_args"] = [distortion_corrector] nxtomoreader_kwargs["processing_func_kwargs"] = {"num_threads": num_threads or get_num_threads()} return NXTomoReader( *nxtomoreader_args, **nxtomoreader_kwargs, ) # # EDF Reader with binning # def EDFStackReaderBinning(binning_factor, *edfstackreader_args, **edfstackreader_kwargs): [ edfstackreader_kwargs.pop(kwarg, None) for kwarg in ["processing_func", "processing_func_args", "processing_func_kwargs"] ] edfstackreader_kwargs["processing_func"] = image_binning edfstackreader_kwargs["processing_func_args"] = [binning_factor] return EDFStackReader( *edfstackreader_args, **edfstackreader_kwargs, ) # # EDF Reader with distortion correction # def apply_distortion_correction_on_image(image, distortion_corrector): return distortion_corrector.transform(image) def EDFStackReaderDistortionCorrection(distortion_corrector, *edfstackreader_args, **edfstackreader_kwargs): [ edfstackreader_kwargs.pop(kwarg, None) for kwarg in ["processing_func", "processing_func_args", "processing_func_kwargs"] ] edfstackreader_kwargs["processing_func"] = apply_distortion_correction_on_image edfstackreader_kwargs["processing_func_args"] = [distortion_corrector] return EDFStackReader( *edfstackreader_args, **edfstackreader_kwargs, ) def load_darks_flats( dataset_info, sub_region, processing_func=None, processing_func_args=None, processing_func_kwargs=None ): """ Load the (reduced) darks and flats and crop them to the sub-region currently used. At this stage, dataset_info.flats should be a dict in the form {num: array} Parameters ---------- sub_region: 2-tuple of 3-tuples of int Tuple in the form ((start_y, end_y), (start_x, end_x)) """ (start_y, end_y), (start_x, end_x) = sub_region processing_func_args = processing_func_args or [] processing_func_kwargs = processing_func_kwargs or {} def proc(img): if processing_func is None: return img return processing_func(img, *processing_func_args, **processing_func_kwargs) res = { "flats": {k: proc(flat_k)[start_y:end_y, start_x:end_x] for k, flat_k in dataset_info.flats.items()}, "darks": {k: proc(dark_k)[start_y:end_y, start_x:end_x] for k, dark_k in dataset_info.darks.items()}, } return res ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/pipeline/tests/0000755000175000017500000000000014730277752016177 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/pipeline/tests/test_estimators.py0000644000175000017500000001400214726604214021767 0ustar00pierrepierreimport os import pytest import numpy as np from nabu.testutils import utilstest, __do_long_tests__ from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.utils import extract_parameters from nabu.pipeline.estimators import CompositeCOREstimator from nabu.pipeline.config import parse_nabu_config_file from nabu.pipeline.estimators import SinoCORFinder, CORFinder # # Test CoR estimation with "composite-coarse-to-fine" (aka "near" in the legacy system vocable) # @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls dataset_downloaded_path = utilstest.getfile("test_composite_cor_finder_data.h5") cls.theta_interval = 4.5 * 1 # this is given. Radios in the middle of steps 4.5 degree long # are set to zero for compression # You can still change it to a multiple of 4.5 cls.cor_pix = 1321.625 cls.abs_tol = 0.0001 cls.dataset_info = HDF5DatasetAnalyzer(dataset_downloaded_path) update_dataset_info_flats_darks(cls.dataset_info, True) cls.cor_options = extract_parameters("side=300.0; near_width = 20.0", sep=";") @pytest.mark.skipif(not (__do_long_tests__), reason="Need NABU_LONG_TESTS=1 for this test") @pytest.mark.usefixtures("bootstrap") class TestCompositeCorFinder: def test(self): cor_finder = CompositeCOREstimator( self.dataset_info, theta_interval=self.theta_interval, cor_options=self.cor_options ) cor_position = cor_finder.find_cor() message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_pix assert np.isclose(self.cor_pix, cor_position, atol=self.abs_tol), message @pytest.fixture(scope="class") def bootstrap_bamboo_reduced(request): cls = request.cls cls.abs_tol = 0.2 # Dataset without estimated_cor_frm_motor (non regression test) dataset_relpath = os.path.join("bamboo_reduced.nx") dataset_downloaded_path = utilstest.getfile(dataset_relpath) conf_relpath = os.path.join("bamboo_reduced.conf") conf_downloaded_path = utilstest.getfile(conf_relpath) cls.ds_std = analyze_dataset(dataset_downloaded_path) update_dataset_info_flats_darks(cls.ds_std, True) cls.conf_std = parse_nabu_config_file(conf_downloaded_path) # Dataset with estimated_cor_frm_motor dataset_relpath = os.path.join("bamboo_reduced_bliss.nx") dataset_downloaded_path = utilstest.getfile(dataset_relpath) conf_relpath = os.path.join("bamboo_reduced_bliss.conf") conf_downloaded_path = utilstest.getfile(conf_relpath) cls.ds_bliss = analyze_dataset(dataset_downloaded_path) update_dataset_info_flats_darks(cls.ds_bliss, True) cls.conf_bliss = parse_nabu_config_file(conf_downloaded_path) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") @pytest.mark.usefixtures("bootstrap_bamboo_reduced") class TestCorNearPos: # TODO adapt test file true_cor = 339.486 - 0.5 def test_cor_sliding_standard(self): cor_options = extract_parameters(self.conf_std["reconstruction"].get("cor_options", None), sep=";") for side in [None, "from_file", "center"]: if side is not None: cor_options.update({"side": side}) finder = CORFinder("sliding-window", self.ds_std, do_flatfield=True, cor_options=cor_options) cor = finder.find_cor() message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}." assert np.isclose(self.true_cor, cor, atol=self.abs_tol + 0.5), message # FIXME def test_cor_fourier_angles_standard(self): cor_options = extract_parameters(self.conf_std["reconstruction"].get("cor_options", None), sep=";") # TODO modify test files if "near_pos" in cor_options and "near" in cor_options.get("side", "") == "near": cor_options["side"] = cor_options["near_pos"] # for side in [None, "from_file", "center"]: if side is not None: cor_options.update({"side": side}) finder = SinoCORFinder("fourier-angles", self.ds_std, do_flatfield=True, cor_options=cor_options) cor = finder.find_cor() message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}." assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message def test_cor_sliding_bliss(self): cor_options = extract_parameters(self.conf_bliss["reconstruction"].get("cor_options", None), sep=";") # TODO modify test files if "near_pos" in cor_options and "near" in cor_options.get("side", "") == "near": cor_options["side"] = cor_options["near_pos"] # for side in [None, "from_file", "center"]: if side is not None: cor_options.update({"side": side}) finder = CORFinder("sliding-window", self.ds_bliss, do_flatfield=True, cor_options=cor_options) cor = finder.find_cor() message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}." assert np.isclose(self.true_cor, cor, atol=self.abs_tol), message def test_cor_fourier_angles_bliss(self): cor_options = extract_parameters(self.conf_bliss["reconstruction"].get("cor_options", None), sep=";") for side in [None, "from_file", "center"]: if side is not None: cor_options.update({"side": side}) finder = SinoCORFinder("fourier-angles", self.ds_bliss, do_flatfield=True, cor_options=cor_options) cor = finder.find_cor() message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}." assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/pipeline/utils.py0000644000175000017500000000672714715655362016563 0ustar00pierrepierrefrom ..utils import deprecated_class from .config_validators import str2bool from dataclasses import dataclass import os # # Decorators and callback mechanism # def use_options(step_name, step_attr): def decorator(func): def wrapper(*args, **kwargs): self = args[0] if step_name not in self.processing_steps: self.__setattr__(step_attr, None) return self._steps_name2component[step_name] = step_attr self._steps_component2name[step_attr] = step_name return func(*args, **kwargs) return wrapper return decorator def pipeline_step(step_attr, step_desc): def decorator(func): def wrapper(*args, **kwargs): self = args[0] if getattr(self, step_attr, None) is None: return self.logger.info(step_desc) res = func(*args, **kwargs) step_name = self._steps_component2name[step_attr] callbacks = self._callbacks.get(step_name, None) if callbacks is not None: for callback in callbacks: callback(self) if self.datadump_manager is not None and step_name in self.datadump_manager.data_dump: self.datadump_manager.dump_data_to_file( step_name, self.radios, crop_margin=not (self._radios_were_cropped) ) return res return wrapper return decorator # # sub-region, shapes, etc # def get_subregion(sub_region, ndim=3): """ Return a "normalized" sub-region in the form ((start_z, end_z), (start_y, end_y), (start_x, end_x)). Parameters ---------- sub_region: tuple A tuple of tuples or tuple of integers. Notes ----- The parameter "sub_region" is normally a tuple of tuples of integers. However it can be more convenient to use tuple of integers. This function will attempt at catching the different cases, but will fail if 'sub_region' contains heterogeneous types (ex. tuples along with int) """ if sub_region is None: res = ((None, None),) elif hasattr(sub_region[0], "__iter__"): if set(map(len, sub_region)) != set([2]): raise ValueError("Expected each tuple to be in the form (start, end)") res = sub_region else: if len(sub_region) % 2: raise ValueError("Expected even number of elements") starts, ends = sub_region[::2], sub_region[1::2] res = tuple([(s, e) for s, e in zip(starts, ends)]) if len(res) != ndim: res += ((None, None),) * (ndim - len(res)) return res # # Writer - moved to pipeline.writer # from .writer import WriterManager WriterConfigurator = deprecated_class("WriterConfigurator moved to nabu.pipeline.writer.WriterManager", do_print=True)( WriterManager ) @dataclass class EnvSettings: """This class centralises the definitions, possibly documentation, and access to environnt variable driven settings. It is meant to be used in the following way: from nabu.utils import nabu_env_settings if not nabu_env_settings.skip_tomoscan_checks: do something """ skip_tomoscan_checks: bool = False def _get_nabu_environment_variables(): nabu_env_settings = EnvSettings() nabu_env_settings.skip_tomoscan_checks = str2bool(os.getenv("SKIP_TOMOSCAN_CHECK", "0")) return nabu_env_settings nabu_env_settings = _get_nabu_environment_variables() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/pipeline/writer.py0000644000175000017500000001775414726604214016731 0ustar00pierrepierrefrom os import path from tomoscan.esrf import TIFFVolume, MultiTIFFVolume, EDFVolume, JP2KVolume from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase from ..utils import check_supported, get_num_threads from ..resources.logger import LoggerOrPrint from ..io.writer import NXProcessWriter, HSTVolVolume, NXVolVolume from ..io.utils import convert_dict_values from .params import files_formats class WriterManager: """ This class is a wrapper on top of all "writers". It will create the right "writer" with all the necessary options, and the histogram writer. The layout is the following. * Single-big-file volume formats (big-tiff, .vol): - no start index - everything is increasingly written in one file * Multiple-frames per file (HDF5 + master-file): - needs a start index (change file_prefix) for each partial file - needs a subdirectory for partial files * One-file-per-frame (tiff, edf, jp2) - start_index When saving intermediate steps (eg. sinogram): HDF5 format is always used. """ _overwrite_warned = False _writer_classes = { "hdf5": NXVolVolume, "tiff": TIFFVolume, "bigtiff": MultiTIFFVolume, "jp2": JP2KVolume, "edf": EDFVolume, "vol": HSTVolVolume, } def __init__( self, output_dir, file_prefix, file_format="hdf5", overwrite=False, start_index=0, logger=None, metadata=None, histogram=False, extra_options=None, ): """ Create a Writer from a set of parameters. Parameters ---------- output_dir: str Directory where the file(s) will be written. file_prefix: str File prefix (without leading path) start_index: int, optional Index to start the files numbering (filename_0123.ext). Default is 0. Ignored for HDF5 extension. logger: nabu.resources.logger.Logger, optional Logger object metadata: dict, optional Metadata, eg. information on various processing steps. For HDF5, it will go to "configuration" histogram: bool, optional Whether to also write a histogram of data. If set to True, it will configure an additional "writer". extra_options: dict, optional Other advanced options to pass to Writer class. """ self.extra_options = extra_options or {} self._set_file_format(file_format) self.overwrite = overwrite self.start_index = start_index self.logger = LoggerOrPrint(logger) self.histogram = histogram self.output_dir = output_dir self.file_prefix = file_prefix self.metadata = convert_dict_values(metadata or {}, {None: "None"}) self._init_writer() self._init_histogram_writer() def _set_file_format(self, file_format): check_supported(file_format, files_formats, "file format") self.file_format = files_formats[file_format] self._is_bigtiff = file_format in ["tiff", "tif"] and any( [self.extra_options.get(opt, False) for opt in ["tiff_single_file", "use_bigtiff"]] ) if self._is_bigtiff: self.file_format = "bigtiff" @staticmethod def get_first_fname(vol_writer): if hasattr(vol_writer, "file_path"): return path.dirname(vol_writer.file_path) dirname = vol_writer.data_url.file_path() fname = vol_writer.data_url.data_path().format( volume_basename=vol_writer._volume_basename, index_zfill6=vol_writer.start_index, data_extension=vol_writer.extension or vol_writer.DEFAULT_DATA_EXTENSION, ) return path.join(dirname, fname) @staticmethod def get_fname(vol_writer): if hasattr(vol_writer, "file_path"): # several frames per file - return the file itself return vol_writer.file_path # one file per frame - return the directory return vol_writer.data_url.file_path() def _init_writer(self): self._writer_was_already_initialized = self.extra_options.get("writer_initialized", False) if self.file_format in ["tiff", "edf", "jp2", "hdf5"]: writer_kwargs = { "folder": self.output_dir, "volume_basename": self.file_prefix, "start_index": self.start_index, "overwrite": self.overwrite, } if self.file_format == "hdf5": writer_kwargs["data_path"] = self.metadata.get("entry", "entry") writer_kwargs["process_name"] = self.metadata.get("process_name", "reconstruction") writer_kwargs["create_subfolder"] = self.extra_options.get("create_subfolder", True) elif self.file_format == "jp2": writer_kwargs["cratios"] = self.metadata.get("jpeg2000_compression_ratio", None) writer_kwargs["clip_values"] = self.metadata.get("float_clip_values", None) writer_kwargs["n_threads"] = get_num_threads() elif self.file_format in ["vol", "bigtiff"]: writer_kwargs = { "file_path": path.join( self.output_dir, self.file_prefix + "." + self.file_format.replace("bigtiff", "tiff") ), "overwrite": self.overwrite, "append": self.extra_options.get("single_output_file_initialized", False), } if self.file_format == "vol": writer_kwargs["hst_metadata"] = self.extra_options.get("raw_vol_metadata", {}) else: raise ValueError("Unsupported file format: %s" % self.file_format) self._h5_entry = self.metadata.get("entry", "entry") self.writer = self._writer_classes[self.file_format](**writer_kwargs) self.fname = self.get_fname(self.writer) # In certain cases, tomoscan needs to remove any previous existing volume filess # and avoid calling 'clean_output_data' when writing downstream (for chunk processing) if isinstance(self.writer, VolumeSingleFrameBase): self.writer.skip_existing_data_files_removal = self._writer_was_already_initialized # --- if path.exists(self.fname): err = "File already exists: %s" % self.fname if self.overwrite: if not (self.__class__._overwrite_warned): self.logger.warning(err + ". It will be overwritten as requested in configuration") self.__class__._overwrite_warned = True else: self.logger.fatal(err) raise ValueError(err) def _init_histogram_writer(self): if not self.histogram: return separate_histogram_file = not (self.file_format == "hdf5") if separate_histogram_file: fmode = "w" hist_fname = path.join(self.output_dir, "histogram_%05d.hdf5" % self.start_index) else: fmode = "a" hist_fname = self.fname # Nabu's original NXProcessWriter has to be used here, as histogram is not 3D self.histogram_writer = NXProcessWriter( hist_fname, entry=self._h5_entry, filemode=fmode, overwrite=True, ) def write_histogram(self, data, config=None, processing_index=1): if not (self.histogram): return self.histogram_writer.write( data, "histogram", processing_index=processing_index, config=config, is_frames_stack=False, direct_access=False, ) def _write_metadata(self): self.writer.metadata = self.metadata self.writer.save_metadata() def write_data(self, data, metadata=None): self.writer.data = data if metadata is not None: self.writer.metadata = metadata self.writer.save() # self._write_metadata() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/pipeline/xrdct/0000755000175000017500000000000014730277752016161 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/pipeline/xrdct/__init__.py0000644000175000017500000000000014315516747020257 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5127568 nabu-2024.2.1/nabu/preproc/0000755000175000017500000000000014730277752014702 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/__init__.py0000644000175000017500000000043414402565210016775 0ustar00pierrepierrefrom .ccd import CCDFilter, Log from .ctf import CTFPhaseRetrieval from .distortion import DistortionCorrection from .double_flatfield import DoubleFlatField from .flatfield import FlatField, FlatFieldDataUrls from .phase import PaganinPhaseRetrieval from .shift import VerticalShift ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/alignment.py0000644000175000017500000000057314402565210017220 0ustar00pierrepierre# Backward compat. from ..estimation.alignment import AlignmentBase from ..estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationGrowingWindow, CenterOfRotationSlidingWindow, ) from ..estimation.translation import DetectorTranslationAlongBeam from ..estimation.focus import CameraFocus from ..estimation.tilt import CameraTilt ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/preproc/ccd.py0000644000175000017500000001203114550227307015771 0ustar00pierrepierreimport numpy as np from ..utils import check_supported from silx.math.medianfilter import medfilt2d class CCDFilter: """ Filtering applied on radios. """ _supported_ccd_corrections = ["median_clip"] def __init__( self, radios_shape: tuple, correction_type: str = "median_clip", median_clip_thresh: float = 0.1, abs_diff=False, preserve_borders=False, ): """ Initialize a CCDCorrection instance. Parameters ----------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. correction_type: str Correction type for radios ("median_clip", "sigma_clip", ...) median_clip_thresh: float, optional Threshold for the median clipping method. abs_diff: boolean by default False: the correction is triggered when img - median > threshold. If equals True: correction is triggered for abs(img-media) > threshold preserve borders: boolean by default False: If equals True: the borders (width=1) are not modified. Notes ------ A CCD correction is a process (usually filtering) taking place in the radios space. Available filters: - median_clip: if the value of the current pixel exceeds the median of adjacent pixels (a 3x3 neighborhood) more than a threshold, then this pixel value is set to the median value. """ self._set_radios_shape(radios_shape) check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode") self.correction_type = correction_type self.median_clip_thresh = median_clip_thresh self.abs_diff = abs_diff self.preserve_borders = preserve_borders def _set_radios_shape(self, radios_shape): if len(radios_shape) == 2: self.radios_shape = (1,) + radios_shape elif len(radios_shape) == 3: self.radios_shape = radios_shape else: raise ValueError("Expected radios to have 2 or 3 dimensions") n_radios, n_z, n_x = self.radios_shape self.n_radios = n_radios self.n_angles = n_radios self.shape = (n_z, n_x) @staticmethod def median_filter(img): """ Perform a median filtering on an image. """ return medfilt2d(img, (3, 3), mode="reflect") def median_clip_mask(self, img, return_medians=False): """ Compute a mask indicating whether a pixel is valid or not, according to the median-clip method. Parameters ---------- img: numpy.ndarray Input image return_medians: bool, optional Whether to return the median values additionally to the mask. """ median_values = self.median_filter(img) if not self.abs_diff: invalid_mask = img >= median_values + self.median_clip_thresh else: invalid_mask = abs(img - median_values) > self.median_clip_thresh if return_medians: return invalid_mask, median_values else: return invalid_mask def median_clip_correction(self, radio, output=None): """ Compute the median clip correction on one image. Parameters ---------- radios: numpy.ndarray, optional A radio image. output: numpy.ndarray, optional Output array """ assert radio.shape == self.shape if output is None: output = np.copy(radio) else: output[:] = radio[:] invalid_mask, medians = self.median_clip_mask(radio, return_medians=True) if self.preserve_borders: fixed_border = np.array(radio[[0, 0, -1, -1], [0, -1, 0, -1]]) output[invalid_mask] = medians[invalid_mask] if self.preserve_borders: output[[0, 0, -1, -1], [0, -1, 0, -1]] = fixed_border return output class Log: """ Helper class to take -log(radios) Parameters ----------- clip_min: float, optional Before taking the logarithm, the values are clipped to this minimum. clip_max: float, optional Before taking the logarithm, the values are clipped to this maximum. """ def __init__(self, radios_shape, clip_min=None, clip_max=None): self.radios_shape = radios_shape self.clip_min = clip_min self.clip_max = clip_max def take_logarithm(self, radios): """ Take the negative logarithm of a radios chunk. Processing is done in-place ! Parameters ----------- radios: array Radios chunk. """ if (self.clip_min is not None) or (self.clip_max is not None): np.clip(radios, self.clip_min, self.clip_max, out=radios) np.log(radios, out=radios) else: np.log(radios, out=radios) radios[:] *= -1 return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/preproc/ccd_cuda.py0000644000175000017500000001371514654107202016773 0ustar00pierrepierreimport numpy as np from ..preproc.ccd import CCDFilter, Log from ..processing.medfilt_cuda import MedianFilter from ..utils import get_cuda_srcfile, updiv, deprecated_class from ..cuda.utils import __has_pycuda__ if __has_pycuda__: from ..cuda.kernel import CudaKernel # COMPAT. from .flatfield_cuda import ( CudaFlatField as CudaFlatfield_, CudaFlatFieldArrays as CudaFlatFieldArrays_, CudaFlatFieldDataUrls as CudaFlatFieldDataUrls_, ) FlatField = deprecated_class( "preproc.ccd_cuda.CudaFlatField was moved to preproc.flatfield_cuda.CudaFlatField", do_print=True )(CudaFlatfield_) FlatFieldArrays = deprecated_class( "preproc.ccd_cuda.CudaFlatFieldArrays was moved to preproc.flatfield_cuda.CudaFlatFieldArrays", do_print=True )(CudaFlatFieldArrays_) FlatFieldDataUrls = deprecated_class( "preproc.ccd_cuda.CudaFlatFieldDataUrls was moved to preproc.flatfield_cuda.CudaFlatFieldDataUrls", do_print=True )(CudaFlatFieldDataUrls_) # class CudaCCDFilter(CCDFilter): def __init__( self, radios_shape, correction_type="median_clip", median_clip_thresh=0.1, abs_diff=False, cuda_options=None, ): """ Initialize a CudaCCDCorrection instance. Please refer to the documentation of CCDCorrection. """ super().__init__( radios_shape, correction_type=correction_type, median_clip_thresh=median_clip_thresh, ) self._set_cuda_options(cuda_options) self.cuda_median_filter = None if correction_type == "median_clip": self.cuda_median_filter = MedianFilter( self.shape, footprint=(3, 3), mode="reflect", threshold=median_clip_thresh, abs_diff=abs_diff, cuda_options={ "device_id": self.cuda_options["device_id"], "ctx": self.cuda_options["ctx"], "cleanup_at_exit": self.cuda_options["cleanup_at_exit"], }, ) def _set_cuda_options(self, user_cuda_options): self.cuda_options = {"device_id": None, "ctx": None, "cleanup_at_exit": None} if user_cuda_options is None: user_cuda_options = {} self.cuda_options.update(user_cuda_options) def median_clip_correction(self, radio, output=None): """ Compute the median clip correction on one image. Parameters ---------- radio: pycuda.gpuarray A radio image output: pycuda.gpuarray, optional Output data. """ assert radio.shape == self.shape return self.cuda_median_filter.medfilt2(radio, output=output) CudaCCDCorrection = deprecated_class("CudaCCDCorrection is replaced with CudaCCDFilter", do_print=True)(CudaCCDFilter) class CudaLog(Log): """ Helper class to take -log(radios) """ def __init__(self, radios_shape, clip_min=None, clip_max=None): """ Initialize a Log processing. Parameters ----------- radios_shape: tuple The shape of 3D radios stack. clip_min: float, optional Data smaller than this value is replaced by this value. clip_max: float, optional. Data bigger than this value is replaced by this value. """ super().__init__(radios_shape, clip_min=clip_min, clip_max=clip_max) self._init_kernels() def _init_kernels(self): self._do_clip_min = int(self.clip_min is not None) self._do_clip_max = int(self.clip_max is not None) self.clip_min = np.float32(self.clip_min or 0) self.clip_max = np.float32(self.clip_max or 1) self._nlog_srcfile = get_cuda_srcfile("ElementOp.cu") nz, ny, nx = self.radios_shape self._nx = np.int32(nx) self._ny = np.int32(ny) self._nz = np.int32(nz) self._nthreadsperblock = (16, 16, 4) # TODO tune ? self._nblocks = tuple([updiv(n, p) for n, p in zip([nx, ny, nz], self._nthreadsperblock)]) self.nlog_kernel = CudaKernel( # pylint: disable=E0606 "nlog", filename=self._nlog_srcfile, signature="Piiiff", options=[ "-DDO_CLIP_MIN=%d" % self._do_clip_min, "-DDO_CLIP_MAX=%d" % self._do_clip_max, ], ) def take_logarithm(self, radios, clip_min=None, clip_max=None): """ Take the negative logarithm of a radios chunk. Parameters ----------- radios: `pycuda.gpuarray.GPUArray` Radios chunk If not provided, a new GPU array is created. clip_min: float, optional Before taking the logarithm, the values are clipped to this minimum. clip_max: float, optional Before taking the logarithm, the values are clipped to this maximum. """ clip_min = clip_min or self.clip_min clip_max = clip_max or self.clip_max if radios.flags.c_contiguous: self.nlog_kernel( radios, self._nx, self._ny, self._nz, clip_min, clip_max, grid=self._nblocks, block=self._nthreadsperblock, ) else: # map-like operations cannot be directly applied on 3D arrays # that are not C-contiguous. We have to process image per image. # TODO it's even worse when each single frame is not C-contiguous. For now this case is not handled nz = np.int32(1) nthreadsperblock = (32, 32, 1) nblocks = tuple([updiv(n, p) for n, p in zip([int(self._nx), int(self._ny), int(nz)], nthreadsperblock)]) for i in range(radios.shape[0]): self.nlog_kernel( radios[i], self._nx, self._ny, nz, clip_min, clip_max, grid=nblocks, block=nthreadsperblock ) return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/preproc/ctf.py0000644000175000017500000003530014654107202016014 0ustar00pierrepierreimport math import numpy as np from scipy.fft import rfft2, irfft2, fft2, ifft2 from ..resources.logger import LoggerOrPrint from ..misc import fourier_filters from ..misc.padding import pad_interpolate, recut from ..utils import get_num_threads, deprecation_warning class GeoPars: """ A class to describe the geometry of a phase contrast radiography with a source obtained by a focussing system, possibly astigmatic, which is at distance z1_vh from the sample. The detector is at z2 from the sample """ def __init__( self, z1_vh=None, z2=None, pix_size_det=1e-6, wavelength=None, magnification=True, length_scale=10.0e-6, logger=None, ): """ Parameters ---------- z1_vh : None, a float, or a sequence of two floats the source sample distance (meters), if None the parallel beam is assumed. If two floats are given then they are taken as the distance of the vertically focused source (horizontal line) and the horizontaly focused source (vertical line) for KB mirrors. z2 : float the sample detector distance (meters). pix_size_det: float or tuple pixel size in meters. If a tuple is passed, it is interpreted as (horizontal_size, vertical_size) wavelength: float beam wave length (meters). magnification: boolean defaults to True if false no magnification is considered length_scale: float rescaling length scale, meant to avoid having too big or too small numbers. defaults to 10.0e-6 logger: Logger, optional A logging object """ self.logger = LoggerOrPrint(logger) if z1_vh is None: self.z1_vh = None else: if hasattr(type(z1_vh), "__iter__"): self.z1_vh = np.array(z1_vh) else: self.z1_vh = np.array([z1_vh, z1_vh]) self.z2 = z2 self.magnification = magnification if np.isscalar(pix_size_det): self.pix_size_det_xy = (pix_size_det, pix_size_det) else: self.pix_size_det_xy = pix_size_det self.pix_size_det = self.pix_size_det_xy[0] # COMPAT if self.magnification and self.z1_vh is not None: self.M_vh = (self.z1_vh + self.z2) / self.z1_vh else: self.M_vh = np.array([1, 1]) self.logger.debug("Magnification : h ({}) ; v ({}) ".format(self.M_vh[1], self.M_vh[0])) self.length_scale = length_scale self.wavelength = wavelength self.maxM = self.M_vh.max() # we bring everything to highest magnification self.pix_size_rec_xy = [p / self.maxM for p in self.pix_size_det_xy] self.pix_size_rec = self.pix_size_rec_xy[0] # COMPAT which_unit = int(np.sum(np.array([self.pix_size_rec > small for small in [1.0e-6, 1.0e-7]]).astype(np.int32))) self.pixelsize_string = [ "{:.1f} nm".format(self.pix_size_rec * 1e9), "{:.3f} um".format(self.pix_size_rec * 1e6), "{:.1f} um".format(self.pix_size_rec * 1e6), ][which_unit] if self.magnification: self.logger.debug( "All images are resampled to smallest pixelsize: {}".format(self.pixelsize_string), ) else: self.logger.debug("Pixelsize images: {}".format(self.pixelsize_string)) class CTFPhaseRetrieval: """ This class implements the CTF formula of [1] in its regularised form which avoids the zeros of unregularised_filter_denominator (unreg_filter_denom is the so here named denominator/delta_beta of equation 8). References ----------- [1] B. Yu, L. Weber, A. Pacureanu, M. Langer, C. Olivier, P. Cloetens, and F. Peyrin, "Evaluation of phase retrieval approaches in magnified X-ray phase nano computerized tomography applied to bone tissue", Optics Express, Vol 26, No 9, 11110-11124 (2018) """ def __init__( self, shape, geo_pars, delta_beta, padded_shape="auto", padding_mode="reflect", translation_vh=None, normalize_by_mean=False, lim1=1.0e-5, lim2=0.2, use_rfft=False, fftw_num_threads=None, fft_num_threads=None, logger=None, ): """ Initialize a Contrast Transfer Function phase retrieval. Parameters ---------- geo_pars: GeoPars the geometry description delta_beta : float the delta/beta ratio padded_shape: str or tuple, optional Padded image shape, in the form (num_rows, num_columns) i.e (vertical, horizontal). By default, it is twice the image shape. padding_mode: str Padding mode. It must be valid for the numpy.pad function translation_vh: array, optional Shift in the form (y, x). It is used to perform a translation of the image before applying the CTF filter. normalize_by_mean: bool Whether to divide the (padded) image with its mean before applying the CTF filter. lim1: float >0 the regulariser strenght at low frequencies lim2: float >0 the regulariser strenght at high frequencies use_rfft: bool, optional Whether to use real-to-complex (R2C) FFT instead of usual complex-to-complex (C2C). fftw_num_threads: bool or None or int, optional DEPRECATED - please use fft_num_threads instead. fft_num_threads: bool or None or int, optional Number of threads to use for FFT. If a number is provided: number of threads to use for FFT. You can pass a negative number to use N - fft_num_threads cores. logger: optional a logger object """ self.logger = LoggerOrPrint(logger) if not isinstance(geo_pars, GeoPars): raise ValueError("Expected GeoPars instance for 'geo_pars' parameter") self.geo_pars = geo_pars self._calc_shape(shape, padded_shape, padding_mode) self.delta_beta = delta_beta # COMPAT. if fftw_num_threads is not None: deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="ctf_fftw") fft_num_threads = fftw_num_threads # --- self.lim = None self.lim1 = lim1 self.lim2 = lim2 self.normalize_by_mean = normalize_by_mean self.translation_vh = translation_vh self._setup_fft(use_rfft, fft_num_threads) self._get_ctf_filter() def _calc_shape(self, shape, padded_shape, padding_mode): if np.isscalar(shape): shape = (shape, shape) else: assert len(shape) == 2 self.shape = shape if padded_shape is None or padded_shape is False: padded_shape = self.shape # no padding elif isinstance(padded_shape, (tuple, list, np.ndarray)): pass elif padded_shape == "auto": padded_shape = (2 * self.shape[0], 2 * self.shape[1]) self.shape_padded = tuple(padded_shape) self.padding_mode = padding_mode def _setup_fft(self, use_rfft, fft_num_threads): self.use_rfft = use_rfft self._fft_func = rfft2 if use_rfft else fft2 self._ifft_func = irfft2 if use_rfft else ifft2 self.fft_num_threads = get_num_threads(fft_num_threads) def _get_ctf_filter(self): """ The parameter "length_scale" was mentioned, in the octave code, as a rescaling length scale, which is meant to avoid having too big or too small numbers. From the mathematical point of view, it is infact completely trasparent: its action is on fsamplex, fsampley and betash, betasv. But these latters ( beta's) are multiplied by the formers (fsample's) so that "length_scale" mathematically disappears, however in case of simple precision float the exponent of a float number ranges from -38 to 38, and one could approach it as an example by taking the square of a very small number ( 1.0e-19), and losing significant bits in he mantissa or getting zero, or the square of a big number, thus generting inf Althought the values involved in our x-ray regimes seems safe, with respect to these problems, this length_scale parameters does not hurt. """ padded_img_shape = self.shape_padded fsample_vh = np.array( [ self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[1], self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[0], ] ) if not self.use_rfft: ff_index_vh = list(map(np.fft.fftfreq, padded_img_shape)) else: ff_index_vh = [np.fft.fftfreq(padded_img_shape[0]), np.fft.rfftfreq(padded_img_shape[1])] # if padded_img_shape[1]%2 == 0 : # change to holotomo_slave indexing (by a transparent 2pi shift) # ff_index_x[ ff_index_x == -0.5 ] = +0.5 # if padded_img_shape[0]%2 == 0 : # change to holotomo_slave indexing (by a transparent 2pi shift) # ff_index_y[ ff_index_y == -0.5 ] = +0.5 frequencies_vh = np.array( np.meshgrid(ff_index_vh[0] * fsample_vh[0], ff_index_vh[1] * fsample_vh[1], indexing="ij") ) frequencies_squared_vh = frequencies_vh * frequencies_vh """ --------------- fresnelnumbers and forward propagators ------------------- In the limit of parallel beam, z1_h and z1_v would be infinite so that the here below distances becomes z2 which is sample-detector distance. """ if self.geo_pars.z1_vh is not None: distances_vh = (self.geo_pars.z1_vh * self.geo_pars.z2) / (self.geo_pars.z1_vh + self.geo_pars.z2) else: distances_vh = np.array([self.geo_pars.z2, self.geo_pars.z2]) """ Citing David Paganin (2002) : The intensity I_{R_1}(r⊥,z) at a distance z of a weakly refracting object illuminated by a point source at distance R1 behind the said object, is related to the intensity I∞(r⊥,z), which would result from normally incident collimated illumination of the same object, by (Pogany et al., 1997): I_{R_1}(r⊥,z) = 1/{M^2} I∞(r⊥/M , z/M) where M is the magnification. This explains the effective distance formula expressed by distancesh, distancesv above. ------------------------------------------------------------------------------ """ lambda_dist_vh = self.geo_pars.wavelength * distances_vh / (self.geo_pars.length_scale**2) """ --------------------------------------------------------------------------------------- -> cut_v at first maximum of ctf largest distance In the paraxial expansion of the Fresnel propagator, the phase is equal to 1/2 K_{parallel}^2 wavelength * Distance /2 / pi When this is equal to a multiple of 2 pi, the effect of propagation disappears and we have a singularity in equation 8. The sampling in the reciprocal space is done with a step length of 2*pi*fsamplex,y (note: fsamples are the plain inverse of pixel size) The first singularity occurs at a frequence number K_{parallel}/2/pi/ fsamplex for K_{parallel}^2 = 2 * (2pi)^2 /( wavelength * distance ) which would correspond to K_{parallel}/2/pi/ fsamplex = sqrt( 2/wavelength*distance) / fsample Question:( why the factor 2 appear at the denominator in the square root below?) Answer : the below defined cut corresponds to the first maximum of the denominator, before arriving to the first zero. In this way the regularisation is already at ( almost) full strenght on the first pole.ation ## is already at ( almost) full strenght on the first pole. """ self.cut_v = math.sqrt(1.0 / 2 / lambda_dist_vh[0]) / fsample_vh[0] self.cut_v = min(self.cut_v, 0.5) self.logger.debug("Normalized cut-off = {:5.3f}".format(self.cut_v)) self.r = fourier_filters.get_lowpass_filter( padded_img_shape, cutoff_par=( 0.5 / (self.cut_v + 1.0 / padded_img_shape[0]), 0.01 * padded_img_shape[0] / (1 + self.cut_v * padded_img_shape[0]), ), use_rfft=self.use_rfft, ) self.r /= self.r[0, 0] self.lim = self.lim1 * self.r + self.lim2 * (1 - self.r) # more methods exist in the original code, and they are initialized starting from here # (ht_app1, ht_app2... ht_app7) fresnel_phase = ( np.pi * lambda_dist_vh[1] * frequencies_squared_vh[1] + np.pi * lambda_dist_vh[0] * frequencies_squared_vh[0] ) if self.delta_beta: unreg_filter_denom = np.sin(fresnel_phase) + (1.0 / self.delta_beta) * np.cos(fresnel_phase) else: unreg_filter_denom = np.sin(fresnel_phase) self.unreg_filter_denom = unreg_filter_denom.astype(np.float32) self._ctf_filter_denom = (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim).astype(np.complex64) def _apply_filter(self, img): img_f = self._fft_func(img, workers=self.fft_num_threads) img_f *= self.unreg_filter_denom unreg_filter_denom_0_mean = self.unreg_filter_denom[0, 0] nf, mf = img.shape # here it is assumed that the average of img is 1 and the DC component is removed img_f[0, 0] -= nf * mf * unreg_filter_denom_0_mean ## formula 8, with regularisation to stay at a safe distance from the poles img_f /= self._ctf_filter_denom ph = self._ifft_func(img_f, workers=self.fft_num_threads).real return ph def retrieve_phase(self, img, output=None): """ Apply the CTF filter to retrieve the phase. Parameters ---------- img: np.ndarray Projection image. It must have been already flat-fielded. Returns -------- ph: numpy.ndarray Phase image """ padded_img = pad_interpolate( img, self.shape_padded, translation_vh=self.translation_vh, padding_mode=self.padding_mode ) if self.normalize_by_mean: padded_img /= padded_img.mean() phase_img = self._apply_filter(padded_img) res = recut(phase_img, img.shape) if output is not None: output[:, :] = res[:, :] return output return res __call__ = retrieve_phase CtfFilter = CTFPhaseRetrieval ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/preproc/ctf_cuda.py0000644000175000017500000001222414712705065017016 0ustar00pierrepierreimport numpy as np from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile, docstring from ..cuda.processing import CudaProcessing from ..cuda.utils import __has_pycuda__ from ..processing.padding_cuda import CudaPadding from ..processing.fft_cuda import get_fft_class from .phase_cuda import CudaPaganinPhaseRetrieval from .ctf import CTFPhaseRetrieval if __has_pycuda__: from pycuda import gpuarray as garray # TODO: # - better padding scheme (for now 2*shape) # - rework inheritance scheme ? (base class SingleDistancePhaseRetrieval and its cuda counterpart) class CudaCTFPhaseRetrieval(CTFPhaseRetrieval): """ Cuda back-end of CTFPhaseRetrieval """ @docstring(CTFPhaseRetrieval) def __init__( self, shape, geo_pars, delta_beta, padded_shape="auto", padding_mode="reflect", translation_vh=None, normalize_by_mean=False, lim1=1.0e-5, lim2=0.2, use_rfft=True, fftw_num_threads=None, # COMPAT. fft_num_threads=None, logger=None, cuda_options=None, fft_backend="vkfft", ): """ Initialize a CudaCTFPhaseRetrieval. Parameters ---------- shape: tuple Shape of the images to process padding_mode: str Padding mode. Default is "reflect". Other parameters ----------------- Please refer to CTFPhaseRetrieval documentation. """ if not use_rfft: raise ValueError("Only use_rfft=True is supported") self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, geo_pars, delta_beta, padded_shape=padded_shape, padding_mode=padding_mode, translation_vh=translation_vh, normalize_by_mean=normalize_by_mean, lim1=lim1, lim2=lim2, logger=logger, use_rfft=True, fft_num_threads=False, ) self._init_ctf_filter() self._init_cuda_padding() self._init_fft(fft_backend) self._init_mult_kernel() def _init_ctf_filter(self): self._mean_scale_factor = self.unreg_filter_denom[0, 0] * np.prod(self.shape_padded) self._d_filter_num = self.cuda_processing.to_device("_d_filter_num", self.unreg_filter_denom).astype("f") self._d_filter_denom = self.cuda_processing.to_device( "_d_filter_denom", (1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f") ) def _init_cuda_padding(self): pad_width = calc_padding_lengths(self.shape, self.shape_padded) # Custom coordinate transform to get directly FFT layout R, C = np.indices(self.shape, dtype=np.int32, sparse=True) coords_R = np.roll(np.pad(R.ravel(), pad_width[0], mode=self.padding_mode), -pad_width[0][0]) coords_C = np.roll(np.pad(C.ravel(), pad_width[1], mode=self.padding_mode), -pad_width[1][0]) self.cuda_padding = CudaPadding( self.shape, (coords_R, coords_C), mode=self.padding_mode, # propagate cuda options ? ) def _init_fft(self, fft_backend): fft_cls = get_fft_class(backend=fft_backend) self.cufft = fft_cls(shape=self.shape_padded, dtype=np.float32, r2c=True) self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.shape_padded, "f") self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64) def _init_mult_kernel(self): self.cpxmult_kernel = self.cuda_processing.kernel( "CTF_kernel", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPPfii", ) Nx = np.int32(self.shape_padded[1] // 2 + 1) Ny = np.int32(self.shape_padded[0]) self._cpxmult_kernel_args = [ self.d_radio_f, self._d_filter_num, self._d_filter_denom, np.float32(self._mean_scale_factor), Nx, Ny, ] blk = (32, 32, 1) grd = (updiv(Nx, blk[0]), updiv(Ny, blk[1])) self._cpxmult_kernel_kwargs = {"grid": grd, "block": blk} set_input = CudaPaganinPhaseRetrieval.set_input def retrieve_phase(self, image, output=None): """ Perform padding on an image. Please see the documentation of CTFPhaseRetrieval.retrieve_phase(). """ self.set_input(image) self.cuda_padding.pad(image, output=self.d_radio_padded) if self.normalize_by_mean: m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded) # pylint: disable=E0606 self.d_radio_padded /= m self.cufft.fft(self.d_radio_padded, output=self.d_radio_f) self.cpxmult_kernel(*self._cpxmult_kernel_args, **self._cpxmult_kernel_kwargs) self.cufft.ifft(self.d_radio_f, output=self.d_radio_padded) if output is None: output = self.cuda_processing.allocate_array("d_output", self.shape) output[:, :] = self.d_radio_padded[: self.shape[0], : self.shape[1]] return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/distortion.py0000644000175000017500000000635514402565210017444 0ustar00pierrepierreimport numpy as np from scipy.interpolate import RegularGridInterpolator from ..utils import check_supported from ..estimation.distortion import estimate_flat_distortion def correct_distortion_interpn(image, coords, bounds_error=False, fill_value=None): """ Correct image distortion with scipy.interpolate.interpn. Parameters ---------- image: array Distorted image coords: array Coordinates of the distortion correction to apply, with the shape (Ny, Nx, 2) """ foo = RegularGridInterpolator( (np.arange(image.shape[0]), np.arange(image.shape[1])), image, bounds_error=bounds_error, method="linear", fill_value=fill_value, ) return foo(coords) class DistortionCorrection: """ A class for estimating and correcting image distortion. """ estimation_methods = { "fft-correlation": estimate_flat_distortion, } correction_methods = { "interpn": correct_distortion_interpn, } def __init__( self, estimation_method="fft-correlation", estimation_kwargs=None, correction_method="interpn", correction_kwargs=None, ): """ Initialize a DistortionCorrection object. Parameters ----------- estimation_method: str Name of the method to use for estimating the distortion estimation_kwargs: dict, optional Named arguments to pass to the estimation method, in the form of a dictionary. correction_method: str Name of the method to use for correcting the distortion correction_kwargs: dict, optional Named arguments to pass to the correction method, in the form of a dictionary. """ self._set_estimator(estimation_method, estimation_kwargs) self._set_corrector(correction_method, correction_kwargs) def _set_estimator(self, estimation_method, estimation_kwargs): check_supported(estimation_method, self.estimation_methods.keys(), "estimation method") self.estimator = self.estimation_methods[estimation_method] self._estimator_kwargs = estimation_kwargs or {} def _set_corrector(self, correction_method, correction_kwargs): check_supported(correction_method, self.correction_methods.keys(), "correction method") self.corrector = self.correction_methods[correction_method] self._corrector_kwargs = correction_kwargs or {} def estimate_distortion(self, image, reference_image): return self.estimator(image, reference_image, **self._estimator_kwargs) estimate = estimate_distortion def correct_distortion(self, image, coords): image_corrected = self.corrector(image, coords, **self._corrector_kwargs) fill_value = self._corrector_kwargs.get("fill_value", None) if fill_value is not None and np.isnan(fill_value): mask = np.isnan(image_corrected) image_corrected[mask] = image[mask] return image_corrected correct = correct_distortion def estimate_and_correct(self, image, reference_image): coords = self.estimate_distortion(image, reference_image) image_corrected = self.correct_distortion(image, coords) return image_corrected ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/preproc/double_flatfield.py0000644000175000017500000001732714654107202020535 0ustar00pierrepierrefrom os import path import numpy as np from scipy.ndimage import gaussian_filter from silx.io.url import DataUrl from ..utils import check_shape, get_2D_3D_shape from ..io.reader import HDF5Reader from ..io.writer import NXProcessWriter from .ccd import Log class DoubleFlatField: _default_h5_path = "/entry/double_flatfield/results" _small = 1e-7 def __init__( self, shape, result_url=None, sub_region=None, detector_corrector=None, input_is_mlog=True, output_is_mlog=False, average_is_on_log=False, sigma_filter=None, filter_mode="reflect", log_clip_min=None, log_clip_max=None, ): """ Init double flat field by summing a series of urls and considering the same subregion of them. Parameters ---------- shape: tuple Expected shape of radios chunk to process result_url: url, optional where the double-flatfield is stored after being computed, and possibly read (instead of re-computed) before processing the images. sub_region: tuple, optional If provided, this must be a tuple in the form (start_x, end_x, start_y, end_y). Each image will be cropped to this region. This is used to specify a chunk of files. Each of the parameters can be None, in this case the default start and end are taken in each dimension. input_is_mlog: boolean, default True the input is considred as minus logarithm of normalised radios output_is_mlog: boolean, default True the output is considred as minus logarithm of normalised radios average_is_on_log : boolean, False the minus logarithm of the data is averaged the clipping value that is applied prior to the logarithm sigma_filter: optional if given a high pass filter is applied by signal -gaussian_filter(signal,sigma,filter_mode) filter_mode: optional, default 'reflect' the padding scheme applied a the borders ( same as scipy.ndimage.filtrs.gaussian_filter) """ self.radios_shape = get_2D_3D_shape(shape) self.n_angles = self.radios_shape[0] self.shape = self.radios_shape[1:] self._log_clip_min = log_clip_min self._log_clip_max = log_clip_max self._init_filedump(result_url, sub_region, detector_corrector) self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode) self._computed = False def _load_dff_dump(self): res = self.reader.get_data(self.result_url) if self.detector_corrector is not None: if res.ndim == 2: res = self.detector_corrector.transform(res) else: for i in range(res.shape[0]): res[i] = self.detector_corrector.transform(res[i]) if res.ndim == 3 and res.shape[0] == 1: res = res.reshape(res.shape[1], res.shape[2]) if res.shape != self.shape: raise ValueError( "Data in %s has shape %s, but expected %s" % (self.result_url.file_path(), str(res.shape), str(self.shape)) ) return res def _init_filedump(self, result_url, sub_region, detector_corrector=None): if isinstance(result_url, str): result_url = DataUrl(file_path=result_url, data_path=self._default_h5_path) self.sub_region = sub_region self.detector_corrector = detector_corrector self.result_url = result_url self.writer = None self.reader = None if self.result_url is None: return if path.exists(result_url.file_path()): if detector_corrector is None: adapted_subregion = sub_region else: adapted_subregion = self.detector_corrector.get_adapted_subregion(sub_region) self.reader = HDF5Reader(sub_region=adapted_subregion) else: self.writer = NXProcessWriter(self.result_url.file_path()) def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode): self.input_is_mlog = input_is_mlog self.output_is_mlog = output_is_mlog self.average_is_on_log = average_is_on_log self.sigma_filter = sigma_filter if self.sigma_filter is not None and abs(float(self.sigma_filter)) < 1e-4: self.sigma_filter = None self.filter_mode = filter_mode proc = lambda x, o: np.copyto(o, x) self._mlog = Log((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max) if self.input_is_mlog: if not self.average_is_on_log: proc = lambda x, o: np.exp(-x, out=o) else: if self.average_is_on_log: proc = self._proc_mlog postproc = lambda x: x if self.output_is_mlog: if not self.average_is_on_log: postproc = self._proc_mlog else: if self.average_is_on_log: postproc = lambda x: np.exp(-x) self.proc = proc self.postproc = postproc def _proc_mlog(self, x, o): o[:] = x[:] self._mlog.take_logarithm(o) return o def compute_double_flatfield(self, radios, recompute=False): """ Read the radios and generate the "double flat field" by averaging and possibly other processing. Parameters ---------- radios: array Input radios chunk. recompute: bool, optional Whether to recompute the double flatfield if already computed. """ if self._computed and not (recompute): return self.doubleflatfield # pylint: disable=E0203 acc = np.zeros(radios[0].shape, "f") tmpdat = np.zeros(radios[0].shape, "f") for ima in radios: self.proc(ima, tmpdat) acc += tmpdat acc /= radios.shape[0] if self.sigma_filter is not None: acc = acc - gaussian_filter(acc, self.sigma_filter, mode=self.filter_mode) self.doubleflatfield = self.postproc(acc) # Handle small values to avoid issues when dividing self.doubleflatfield[np.abs(self.doubleflatfield) < self._small] = 1.0 self.doubleflatfield = self.doubleflatfield.astype("f") if self.writer is not None: self.writer.write(self.doubleflatfield, "double_flatfield") self._computed = True return self.doubleflatfield def get_double_flatfield(self, radios=None, compute=False): """ Get the double flat field or a subregion of it. Parameters ---------- radios: array, optional Input radios chunk compute: bool, optional Whether to compute the double flatfield anyway even if a dump file exists. """ if self.reader is None: if radios is None: raise ValueError("result_url was not provided. Please provide 'radios' to this function") return self.compute_double_flatfield(radios) if radios is not None and compute: res = self.compute_double_flatfield(radios) else: res = self._load_dff_dump() self._computed = True return res def apply_double_flatfield(self, radios): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ check_shape(radios.shape, self.radios_shape, "radios") dff = self.get_double_flatfield(radios=radios) for i in range(self.n_angles): radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/preproc/double_flatfield_cuda.py0000644000175000017500000001407314654107202021524 0ustar00pierrepierrefrom .double_flatfield import DoubleFlatField from ..utils import check_shape from ..cuda.utils import __has_pycuda__ from ..cuda.processing import CudaProcessing from ..processing.unsharp_cuda import CudaUnsharpMask from .ccd_cuda import CudaLog if __has_pycuda__: import pycuda.gpuarray as garray import pycuda.cumath as cumath class CudaDoubleFlatField(DoubleFlatField): def __init__( self, shape, result_url=None, sub_region=None, detector_corrector=None, input_is_mlog=True, output_is_mlog=False, average_is_on_log=False, sigma_filter=None, filter_mode="reflect", log_clip_min=None, log_clip_max=None, cuda_options=None, ): """ Init double flat field with Cuda backend. """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, result_url=result_url, sub_region=sub_region, detector_corrector=detector_corrector, input_is_mlog=input_is_mlog, output_is_mlog=output_is_mlog, average_is_on_log=average_is_on_log, sigma_filter=sigma_filter, filter_mode=filter_mode, log_clip_min=log_clip_min, log_clip_max=log_clip_max, ) self._init_gaussian_filter() def _init_gaussian_filter(self): if self.sigma_filter is None: return self._unsharp_mask = CudaUnsharpMask(self.shape, self.sigma_filter, -1.0, mode=self.filter_mode, method="log") @staticmethod def _proc_copy(x, o): o[:] = x[:] return o @staticmethod def _proc_expm(x, o): o[:] = x[:] o[:] *= -1 cumath.exp(o, out=o) # pylint: disable=E0606 return o def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode): self.input_is_mlog = input_is_mlog self.output_is_mlog = output_is_mlog self.average_is_on_log = average_is_on_log self.sigma_filter = sigma_filter if self.sigma_filter is not None and abs(float(self.sigma_filter)) < 1e-4: self.sigma_filter = None self.filter_mode = filter_mode # proc = lambda x,o: np.copyto(o, x) proc = self._proc_copy self._mlog = CudaLog((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max) if self.input_is_mlog: if not self.average_is_on_log: # proc = lambda x,o: np.exp(-x, out=o) proc = self._proc_expm else: if self.average_is_on_log: # proc = lambda x,o: -np.log(x, out=o) proc = self._proc_mlog # postproc = lambda x: x postproc = self._proc_copy if self.output_is_mlog: if not self.average_is_on_log: # postproc = lambda x: -np.log(x) postproc = self._proc_mlog else: if self.average_is_on_log: # postproc = lambda x: np.exp(-x) postproc = self._proc_expm self.proc = proc self.postproc = postproc def compute_double_flatfield(self, radios, recompute=False): """ Read the radios and generate the "double flat field" by averaging and possibly other processing. Parameters ---------- radios: array Input radios chunk. recompute: bool, optional Whether to recompute the double flatfield if already computed. """ if not (isinstance(radios, garray.GPUArray)): # pylint: disable=E0606 raise ValueError("Expected pycuda.gpuarray.GPUArray for radios") if self._computed and not (recompute): return self.doubleflatfield acc = garray.zeros(radios[0].shape, "f") tmpdat = garray.zeros(radios[0].shape, "f") for i in range(radios.shape[0]): self.proc(radios[i], tmpdat) acc += tmpdat acc /= radios.shape[0] if self.sigma_filter is not None: # acc = acc - gaussian_filter(acc, self.sigma_filter, mode=self.filter_mode) self._unsharp_mask.unsharp(acc, tmpdat) acc[:] = tmpdat[:] self.postproc(acc, tmpdat) self.doubleflatfield = tmpdat # Handle small values to avoid issues when dividing # self.doubleflatfield[np.abs(self.doubleflatfield) < self._small] = 1. cumath.fabs(self.doubleflatfield, out=acc) acc -= self._small # acc = abs(doubleflatfield) - _small garray.if_positive(acc, self.doubleflatfield, garray.zeros_like(acc) + self._small, out=self.doubleflatfield) if self.writer is not None: self.writer.write(self.doubleflatfield.get(), "double_flatfield") self._computed = True return self.doubleflatfield def get_double_flatfield(self, radios=None, compute=False): """ Get the double flat field or a subregion of it. Parameters ---------- radios: array, optional Input radios chunk compute: bool, optional Whether to compute the double flatfield anyway even if a dump file exists. """ if self.reader is None: if radios is None: raise ValueError("result_url was not provided. Please provide 'radios' to this function") return self.compute_double_flatfield(radios) if radios is not None and compute: res = self.compute_double_flatfield(radios) else: res = self._load_dff_dump() res = garray.to_gpu(res) self._computed = True return res def apply_double_flatfield(self, radios): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ check_shape(radios.shape, self.radios_shape, "radios") dff = self.get_double_flatfield(radios=radios) for i in range(self.n_angles): radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/double_flatfield_variable_region.py0000644000175000017500000000434714402565210023741 0ustar00pierrepierrefrom .double_flatfield import ( DoubleFlatField, DoubleFlatField, check_shape, get_2D_3D_shape, ) from ..misc.binning import get_binning_function class DoubleFlatFieldVariableRegion(DoubleFlatField): def __init__( self, shape, result_url=None, binning_x=None, binning_z=None, detector_corrector=None, ): """This class provides the division by the double flat field. At variance with the standard class, it store as member the whole field, and performs the division by the proper region according to the positionings of the processed radios which is passed by the argument array sub_regions_per_radio to the method apply_double_flatfield_for_sub_regions """ self.radios_shape = get_2D_3D_shape(shape) self.n_angles = self.radios_shape[0] self.shape = self.radios_shape[1:] self._init_filedump(result_url, None, detector_corrector) data = self._load_dff_full_dump() if (binning_z, binning_x) != (1, 1): print(" (binning_z, binning_x) ", (binning_z, binning_x)) binning_function = get_binning_function((binning_z, binning_x)) if binning_function is None: raise NotImplementedError(f"Binning factor for {(binning_z, binning_x)} is not implemented yet") self.data = binning_function(data) else: self.data = data def _load_dff_full_dump(self): res = self.reader.get_data(self.result_url) if self.detector_corrector is not None: self.detector_corrector.set_full_transformation() res = self.detector_corrector.transform(res, do_full=True) return res def apply_double_flatfield_for_sub_regions(self, radios, sub_regions_per_radio): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ my_double_ff = self.data for i in range(radios.shape[0]): s_x, e_x, s_y, e_y = sub_regions_per_radio[i] dff = my_double_ff[s_y:e_y, s_x:e_x] check_shape(radios[i].shape, dff.shape, "radios") radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/preproc/flatfield.py0000644000175000017500000004567114656662150017220 0ustar00pierrepierrefrom multiprocessing.pool import ThreadPool from bisect import bisect_left import numpy as np from ..io.reader import load_images_from_dataurl_dict from ..utils import check_supported, deprecated_class, get_num_threads class FlatFieldArrays: """ A class for flat-field normalization """ # the variable below will be True for the derived class # which is taylored for to helical case _full_shape = False _supported_interpolations = ["linear", "nearest"] def __init__( self, radios_shape: tuple, flats, darks, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, n_threads=None, ): """ Initialize a flat-field normalization process. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where each key is the flat index, and the value is a numpy.ndarray of the flat image. darks: dict Dictionary where each key is the dark index, and the value is a numpy.ndarray of the dark image. radios_indices: array of int, optional Array containing the radios indices in the scan. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. radios_srcurrent: array, optional Array with the same shape as radios_indices. Each item contains the synchrotron electric current. If not None, normalization with current is applied. Please refer to "Notes" for more information on this normalization. flats_srcurrent: array, optional Array with the same length as "flats". Each item is a measurement of the synchrotron electric current for the corresponding flat. The items must be ordered in the same order as the flats indices (`flats.keys()`). This parameter must be used along with 'radios_srcurrent'. Please refer to "Notes" for more information on this normalization. n_threads: int or None, optional Number of threads to use for flat-field correction. Default is to use half the threads. Important ---------- `flats` and `darks` are expected to be a dictionary with integer keys (the flats/darks indices) and numpy array values. You can use the following helper functions: `nabu.io.reader.load_images_from_dataurl_dict` and `nabu.io.utils.create_dict_of_indices` Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. The normalization with synchrotron electric current is done as follows. Let s = sr/sr_max denote the ratio between current and maximum current, D be the dark-current frame, and X' be the normalized frame. Then: srcurrent_normalization(X) = X' = (X - D)/s + D flatfield_normalization(X') = (X' - D)/(F' - D) = (X - D) / (F - D) * sF/sX So current normalization boils down to a scalar multiplication after flat-field. """ if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape if radios_indices is not None: radios_shape = (len(radios_indices),) + radios_shape[1:] self._set_parameters(radios_shape, radios_indices, interpolation, nan_value) self._set_flats_and_darks(flats, darks) self._precompute_flats_indices_weights() self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent) self.distortion_correction = distortion_correction self.n_threads = min(1, get_num_threads(n_threads) // 2) def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value): self._set_radios_shape(radios_shape) if radios_indices is None: radios_indices = np.arange(0, self.n_radios, dtype=np.int32) else: radios_indices = np.array(radios_indices, dtype=np.int32) self._check_radios_and_indices_congruence(radios_indices) self.radios_indices = radios_indices self.interpolation = interpolation check_supported(interpolation, self._supported_interpolations, "Interpolation mode") self.nan_value = nan_value self._radios_idx_to_pos = dict(zip(self.radios_indices, np.arange(self.radios_indices.size))) def _set_radios_shape(self, radios_shape): if len(radios_shape) == 2: self.radios_shape = (1,) + radios_shape elif len(radios_shape) == 3: self.radios_shape = radios_shape else: raise ValueError("Expected radios to have 2 or 3 dimensions") n_radios, n_z, n_x = self.radios_shape self.n_radios = n_radios self.n_angles = n_radios self.shape = (n_z, n_x) def _set_flats_and_darks(self, flats, darks): self._check_frames(flats, "flats", 1, 9999) self.n_flats = len(flats) self.flats = flats self._sorted_flat_indices = sorted(self.flats.keys()) if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape self.shape = flats[self._sorted_flat_indices[0]].shape self._flat2arrayidx = dict(zip(self._sorted_flat_indices, np.arange(self.n_flats))) self.flats_arr = np.zeros((self.n_flats,) + self.shape, "f") for i, idx in enumerate(self._sorted_flat_indices): self.flats_arr[i] = self.flats[idx] self._check_frames(darks, "darks", 1, 1) self.darks = darks self.n_darks = len(darks) self._sorted_dark_indices = sorted(self.darks.keys()) self._dark = None def _check_frames(self, frames, frames_type, min_frames_required, max_frames_supported): n_frames = len(frames) if n_frames < min_frames_required: raise ValueError("Need at least %d %s" % (min_frames_required, frames_type)) if n_frames > max_frames_supported: raise ValueError( "Flat-fielding with more than %d %s is not supported" % (max_frames_supported, frames_type) ) self._check_frame_shape(frames, frames_type) def _check_frame_shape(self, frames, frames_type): for frame_idx, frame in frames.items(): if frame.shape != self.shape: raise ValueError( "Invalid shape for %s %s: expected %s, but got %s" % (frames_type, frame_idx, str(self.shape), str(frame.shape)) ) def _check_radios_and_indices_congruence(self, radios_indices): if radios_indices.size != self.n_radios: raise ValueError( "Expected radios_indices to have length %s = n_radios, but got length %d" % (self.n_radios, radios_indices.size) ) def _precompute_flats_indices_weights(self): """ Build two arrays: "indices" and "weights". These arrays contain pre-computed information so that the interpolated flat is obtained with flat_interpolated = weight_prev * flat_prev + weight_next * flat_next where weight_prev, weight_next = weights[2*i], weights[2*i+1] idx_prev, idx_next = indices[2*i], indices[2*i+1] flat_prev, flat_next = flats[idx_prev], flats[idx_next] In words: - If a projection has an index between two flats, the equivalent flat is a linear interpolation between "previous flat" and "next flat". - If a projection has the same index as a flat, only this flat is used for normalization (this case normally never occurs, but it's handled in the code) """ def _interp_linear(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next delta = next_idx - prev_idx w1 = 1 - (idx - prev_idx) / delta w2 = 1 - (next_idx - idx) / delta weights = (w1, w2) f_idx = (self._flat2arrayidx[prev_idx], self._flat2arrayidx[next_idx]) return f_idx, weights def _interp_nearest(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next idx_to_take = prev_idx if abs(idx - prev_idx) < abs(idx - next_idx) else next_idx weights = (1, 0) f_idx = (self._flat2arrayidx[idx_to_take], -1) return f_idx, weights self.flats_idx = np.zeros((self.n_radios, 2), dtype=np.int32) self.flats_weights = np.zeros((self.n_radios, 2), dtype=np.float32) for i, idx in enumerate(self.radios_indices): prev_next = self.get_previous_next_indices(self._sorted_flat_indices, idx) if self.interpolation == "nearest": f_idx, weights = _interp_nearest(idx, prev_next) elif self.interpolation == "linear": f_idx, weights = _interp_linear(idx, prev_next) # pylint: disable=E0606 self.flats_idx[i] = f_idx self.flats_weights[i] = weights # pylint: disable=E1307 def _configure_srcurrent_normalization(self, radios_srcurrent, flats_srcurrent): self.normalize_srcurrent = False if radios_srcurrent is None or flats_srcurrent is None: return radios_srcurrent = np.array(radios_srcurrent) if radios_srcurrent.size != self.n_radios: raise ValueError( "Expected 'radios_srcurrent' to have %d elements but got %d" % (self.n_radios, radios_srcurrent.size) ) flats_srcurrent = np.array(flats_srcurrent) if flats_srcurrent.size != self.n_flats: raise ValueError( "Expected 'flats_srcurrent' to have %d elements but got %d" % (self.n_flats, flats_srcurrent.size) ) self.normalize_srcurrent = True self.radios_srcurrent = radios_srcurrent self.flats_srcurrent = flats_srcurrent self.srcurrent_ratios = np.zeros(self.n_radios, "f") # Flats SRCurrent is obtained with "nearest" interp, to emulate an already-done flats SR current normalization for i, radio_idx in enumerate(self.radios_indices): flat_idx = self.get_nearest_index(self._sorted_flat_indices, radio_idx) flat_srcurrent = self.flats_srcurrent[self._flat2arrayidx[flat_idx]] self.srcurrent_ratios[i] = flat_srcurrent / self.radios_srcurrent[i] @staticmethod def get_previous_next_indices(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr): # outside range return (arr[-1],) if arr[pos] == idx: return (idx,) if pos == 0: return (arr[0],) return arr[pos - 1], arr[pos] @staticmethod def get_nearest_index(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr) or arr[pos] == idx: return arr[-1] return arr[pos - 1] if idx - arr[pos - 1] < arr[pos] - idx else arr[pos] @staticmethod def interp(pos, indices, weights, array, slice_y=slice(None, None), slice_x=slice(None, None)): """ Interpolate between two values. The interpolator consists in pre-computed arrays such that prev, next = indices[pos] w1, w2 = weights[pos] interpolated_value = w1 * array[prev] + w2 * array[next] """ prev_idx = indices[pos, 0] next_idx = indices[pos, 1] if slice_y != slice(None, None) or slice_x != slice(None, None): w1 = weights[pos, 0][slice_y, slice_x] w2 = weights[pos, 1][slice_y, slice_x] else: w1 = weights[pos, 0] w2 = weights[pos, 1] if next_idx == -1: val = array[prev_idx] else: val = w1 * array[prev_idx] + w2 * array[next_idx] return val def get_flat(self, pos, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None)): flat = self.interp(pos, self.flats_idx, self.flats_weights, self.flats_arr, slice_y=slice_y, slice_x=slice_x) if flat.dtype != dtype: flat = np.ascontiguousarray(flat, dtype=dtype) return flat def get_dark(self): if self._dark is None: first_dark_idx = self._sorted_dark_indices[0] dark = np.ascontiguousarray(self.darks[first_dark_idx], dtype=np.float32) self._dark = dark return self._dark def remove_invalid_values(self, img): if self.nan_value is None: return invalid_mask = np.logical_not(np.isfinite(img)) img[invalid_mask] = self.nan_value def normalize_radios(self, radios): """ Apply a flat-field normalization, with the current parameters, to a stack of radios. The processing is done in-place, meaning that the radios content is overwritten. Parameters ----------- radios: numpy.ndarray Radios chunk """ do_flats_distortion_correction = self.distortion_correction is not None dark = self.get_dark() def apply_flatfield(i): radio_data = radios[i] radio_data -= dark flat = self.get_flat(i) flat = flat - dark if do_flats_distortion_correction: flat = self.distortion_correction.estimate_and_correct(flat, radio_data) np.divide(radio_data, flat, out=radio_data) self.remove_invalid_values(radio_data) if self.n_threads > 2: with ThreadPool(self.n_threads) as tp: tp.map(apply_flatfield, range(self.n_radios)) else: for i in range(self.n_radios): apply_flatfield(i) if self.normalize_srcurrent: radios *= self.srcurrent_ratios[:, np.newaxis, np.newaxis] return radios def normalize_single_radio( self, radio, radio_idx, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None) ): """ Apply a flat-field normalization to a single projection image. """ dark = self.get_dark()[slice_y, slice_x] radio -= dark radio_pos = self._radios_idx_to_pos[radio_idx] flat = self.get_flat(radio_pos, dtype=dtype, slice_y=slice_y, slice_x=slice_x) flat = flat - dark if self.distortion_correction is not None: flat = self.distortion_correction.estimate_and_correct(flat, radio) radio /= flat if self.normalize_srcurrent: radio *= self.srcurrent_ratios[radio_pos] self.remove_invalid_values(radio) return radio FlatField = FlatFieldArrays @deprecated_class( "FlatFieldDataUrls is deprecated since 2024.2.0 and will be removed in a future version", do_print=True ) class FlatFieldDataUrls(FlatField): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, **chunk_reader_kwargs, ): """ Initialize a flat-field normalization process with DataUrls. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where the key is the flat index, and the value is a silx.io.DataUrl pointing to the flat. darks: dict Dictionary where the key is the dark index, and the value is a silx.io.DataUrl pointing to the dark. radios_indices: array, optional Array containing the radios indices. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. Other Parameters ---------------- The other named parameters are passed to ChunkReader(). Please read its documentation for more information. Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. """ flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/preproc/flatfield_cuda.py0000644000175000017500000001262314656662150020203 0ustar00pierrepierreimport numpy as np from nabu.cuda.processing import CudaProcessing from ..preproc.flatfield import FlatFieldArrays from ..utils import deprecated_class, get_cuda_srcfile from ..io.reader import load_images_from_dataurl_dict from ..cuda.utils import __has_pycuda__ class CudaFlatFieldArrays(FlatFieldArrays): def __init__( self, radios_shape, flats, darks, radios_indices=None, interpolation="linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options=None, ): """ Initialize a flat-field normalization CUDA process. Please read the documentation of nabu.preproc.flatfield.FlatField for help on the parameters. """ # if distortion_correction is not None: raise NotImplementedError("Flats distortion correction is not implemented with the Cuda backend") # super().__init__( radios_shape, flats, darks, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, nan_value=nan_value, ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_kernels() self._load_flats_and_darks_on_gpu() def _init_cuda_kernels(self): # TODO if self.interpolation != "linear": raise ValueError("Interpolation other than linar is not yet implemented in the cuda back-end") # self._cuda_fname = get_cuda_srcfile("flatfield.cu") options = [ "-DN_FLATS=%d" % self.n_flats, "-DN_DARKS=%d" % self.n_darks, ] if self.nan_value is not None: options.append("-DNAN_VALUE=%f" % self.nan_value) self.cuda_kernel = self.cuda_processing.kernel( "flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options ) self._nx = np.int32(self.shape[1]) self._ny = np.int32(self.shape[0]) def _load_flats_and_darks_on_gpu(self): # Flats self.d_flats = self.cuda_processing.allocate_array("d_flats", (self.n_flats,) + self.shape, np.float32) for i, flat_idx in enumerate(self._sorted_flat_indices): self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32)) # Darks self.d_darks = self.cuda_processing.allocate_array("d_darks", (self.n_darks,) + self.shape, np.float32) for i, dark_idx in enumerate(self._sorted_dark_indices): self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32)) self.d_darks_indices = self.cuda_processing.to_device( "d_darks_indices", np.array(self._sorted_dark_indices, dtype=np.int32) ) # Indices self.d_flats_indices = self.cuda_processing.to_device("d_flats_indices", self.flats_idx) self.d_flats_weights = self.cuda_processing.to_device("d_flats_weights", self.flats_weights) def normalize_radios(self, radios): """ Apply a flat-field correction, with the current parameters, to a stack of radios. Parameters ----------- radios_shape: `pycuda.gpuarray.GPUArray` Radios chunk. """ if not (isinstance(radios, self.cuda_processing.array_class)): raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios))) if radios.dtype != np.float32: raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype)) if radios.shape != self.radios_shape: raise ValueError("Expected radios shape = %s but got %s" % (str(self.radios_shape), str(radios.shape))) self.cuda_kernel( radios, self.d_flats, self.d_darks, self._nx, self._ny, np.int32(self.n_radios), self.d_flats_indices, self.d_flats_weights, ) if self.normalize_srcurrent: for i in range(self.n_radios): radios[i] *= self.srcurrent_ratios[i] return radios CudaFlatField = CudaFlatFieldArrays @deprecated_class( "CudaFlatFieldDataUrls is deprecated since version 2024.2.0 and will be removed in a future version", do_print=True ) class CudaFlatFieldDataUrls(CudaFlatField): def __init__( self, radios_shape, flats, darks, radios_indices=None, interpolation="linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options=None, **chunk_reader_kwargs, ): flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, cuda_options=cuda_options, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/flatfield_variable_region.py0000644000175000017500000000612314402565210022401 0ustar00pierrepierreimport numpy as np from .flatfield import FlatFieldArrays, load_images_from_dataurl_dict, check_supported class FlatFieldArraysVariableRegion(FlatFieldArrays): _full_shape = True def _check_frame_shape(self, frames, frames_type): # in helical the flat is the whole one and its shape does not necesseraly match the smaller frames. # Therefore no check is done to allow this. pass def _check_radios_and_indices_congruence(self, radios_indices): """At variance with parent class, preprocesing is done with on a fraction of the radios, whose length may vary. So we dont enforce here that the lenght is always the same """ pass def _normalize_radios(self, radios, sub_indexes, sub_regions_per_radio): """ Apply a flat-field normalization, with the current parameters, to a stack of radios. The processing is done in-place, meaning that the radios content is overwritten. """ if len(sub_regions_per_radio) != len(sub_indexes): message = f""" The length of sub_regions_per_radio,which is {len(sub_regions_per_radio)} , does not correspond to the length of sub_indexes which is {len(sub_indexes)} """ raise ValueError(message) do_flats_distortion_correction = self.distortion_correction is not None whole_dark = self.get_dark() for i, (idx, sub_r) in enumerate(zip(sub_indexes, sub_regions_per_radio)): start_x, end_x, start_y, end_y = sub_r slice_x = slice(start_x, end_x) slice_y = slice(start_y, end_y) self.normalize_single_radio(radios[i], idx, dtype=np.float32, slice_y=slice_y, slice_x=slice_x) return radios class FlatFieldDataVariableRegionUrls(FlatFieldArraysVariableRegion): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, **chunk_reader_kwargs, ): flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) _flats_indexes = list(flats_arrays_dict.keys()) _flats_indexes.sort() self.flats_indexes = np.array(_flats_indexes) self.flats_stack = np.array([flats_arrays_dict[i] for i in self.flats_indexes], "f") flats_arrays_dict = dict([[indx, flat] for indx, flat in zip(self.flats_indexes, self.flats_stack)]) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, ) self._sorted_flat_indices = np.array(self._sorted_flat_indices, "i") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/preproc/phase.py0000644000175000017500000003310314654107202016337 0ustar00pierrepierrefrom math import pi from bisect import bisect import numpy as np from scipy.fft import rfft2, irfft2, fft2, ifft2 from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning # COMPAT. from .ctf import CTFPhaseRetrieval # def lmicron_to_db(Lmicron, energy, distance): """ Utility to convert the "Lmicron" parameter of PyHST to a value of delta/beta. Parameters ----------- Lmicron: float Length in microns, values of the parameter "PAGANIN_Lmicron" in PyHST2 parameter file. energy: float Energy in keV. distance: float Sample-detector distance in microns Notes -------- The conversion is done using the formula .. math:: L^2 = \\pi \\lambda D \\frac{\\delta}{\\beta} """ L2 = Lmicron**2 wavelength = 1.23984199e-3 / energy return L2 / (pi * wavelength * distance) class PaganinPhaseRetrieval: available_padding_modes = ["zeros", "mean", "edge", "symmetric", "reflect"] powers = generate_powers() def __init__( self, shape, distance=0.5, energy=20, delta_beta=250.0, pixel_size=1e-6, padding="edge", use_rfft=True, use_R2C=None, fftw_num_threads=None, fft_num_threads=None, ): """ Paganin Phase Retrieval for an infinitely distant point source. Formula (10) in [1]. Parameters ---------- shape: int or tuple Shape of each radio, in the format (num_rows, num_columns), i.e (size_vertical, size_horizontal). If an integer is provided, the shape is assumed to be square. distance : float, optional Propagation distance in meters. energy : float, optional Energy in keV. delta_beta: float, optional delta/beta ratio, where n = (1 - delta) + i*beta is the complex refractive index of the sample. pixel_size : float or tuple, optional Detector pixel size in meters. Default is 1e-6 (one micron) If a tuple is passed, the pixel size is set as (horizontal_size, vertical_size). padding : str, optional Padding method. Available are "zeros", "mean", "edge", "sym", "reflect". Default is "edge". Please refer to the "Padding" section below for more details. use_rfft: bool, optional Whether to use Real-to-Complex (R2C) transform instead of standard Complex-to-Complex transform, providing better performances use_R2C: bool, optional DEPRECATED, use use_rfft instead fftw_num_threads: bool or None or int, optional DEPRECATED - please use fft_num_threads fft_num_threads: bool or None or int, optional Number of threads for FFT. Default is to use all available threads. You can pass a negative number to use N - fft_num_threads cores. Important ---------- Mind the units! Distance and pixel size are in meters, and energy is in keV. Notes ------ **Padding methods** The phase retrieval is a convolution done in Fourier domain using FFT, so the Fourier transform size has to be at least twice the size of the original data. Mathematically, the data should be padded with zeros before being Fourier transformed. However, in practice, this can lead to artefacts at the edges (Gibbs effect) if the data does not go to zero at the edges. Apart from applying an apodization (Hamming, Blackman, etc), a common strategy to avoid these artefacts is to pad the data. In tomography reconstruction, this is usually done by replicating the last(s) value(s) of the edges ; but one can think of other methods: - "zeros": the data is simply padded with zeros. - "mean": the upper side of extended data is padded with the mean of the first row, the lower side with the mean of the last row, etc. - "edge": the data is padded by replicating the edges. This is the default mode. - "sym": the data is padded by mirroring the data with respect to its edges. See ``numpy.pad()``. - "reflect": the data is padded by reflecting the data with respect to its edges, including the edges. See ``numpy.pad()``. **Formulas** The radio is divided, in the Fourier domain, by the original "Paganin filter" `[1]`. .. math:: F = 1 + \\frac{\\delta}{\\beta} \\lambda D \\pi |k|^2 where k is the wave vector. References ----------- [1] D. Paganin Et Al, "Simultaneous phase and amplitude extraction from a single defocused image of a homogeneous object", Journal of Microscopy, Vol 206, Part 1, 2002 """ self._init_parameters(distance, energy, pixel_size, delta_beta, padding) self._calc_shape(shape) # COMPAT. if use_R2C is not None: deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c") if fftw_num_threads is not None: deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="pag_fftw") fft_num_threads = fftw_num_threads # --- self._get_fft(use_rfft, fft_num_threads) self.compute_filter() def _init_parameters(self, distance, energy, pixel_size, delta_beta, padding): self.distance_cm = distance * 1e2 self.distance_micron = distance * 1e6 self.energy_kev = energy if np.isscalar(pixel_size): self.pixel_size_xy_micron = (pixel_size * 1e6, pixel_size * 1e6) else: self.pixel_size_xy_micron = pixel_size * 1e6 # COMPAT. self.pixel_size_micron = self.pixel_size_xy_micron[0] # self.delta_beta = delta_beta self.wavelength_micron = 1.23984199e-3 / self.energy_kev self.padding = padding self.padding_methods = { "zeros": self._pad_zeros, "mean": self._pad_mean, "edge": self._pad_edge, "symmetric": self._pad_sym, "reflect": self._pad_reflect, } def _get_fft(self, use_rfft, fft_num_threads): self.use_rfft = use_rfft self.use_R2C = use_rfft # Compat. self.fft_num_threads = get_num_threads(fft_num_threads) if self.use_rfft: self.fft_func = rfft2 self.ifft_func = irfft2 else: self.fft_func = fft2 self.ifft_func = ifft2 def _calc_shape(self, shape): if np.isscalar(shape): shape = (shape, shape) else: assert len(shape) == 2 self.shape = shape self._calc_padded_shape() def _calc_padded_shape(self): """ Compute the padded shape. If margin = 0, length_padded = next_power(2*length). Otherwise : length_padded = next_power(2*(length - margins)) Principle ---------- <--------------------- nx_p ---------------------> | | original data | | < -- Pl - ><-- L -->< -- nx --><-- R --><-- Pr --> <----------- nx0 -----------> Pl, Pr : left/right padding length L, R : left/right margin nx : length of inner data (and length of final result) nx0 : length of original data nx_p : total length of padded data """ n_y, n_x = self.shape n_y_p = self._get_next_power(2 * n_y) n_x_p = self._get_next_power(2 * n_x) self.shape_padded = (n_y_p, n_x_p) self.data_padded = np.zeros((n_y_p, n_x_p), dtype=np.float64) self.pad_top_len = (n_y_p - n_y) // 2 self.pad_bottom_len = n_y_p - n_y - self.pad_top_len self.pad_left_len = (n_x_p - n_x) // 2 self.pad_right_len = n_x_p - n_x - self.pad_left_len def _get_next_power(self, n): """ Given a number, get the closest (upper) number p such that p is a power of 2, 3, 5 and 7. """ idx = bisect(self.powers, n) if self.powers[idx - 1] == n: return n return self.powers[idx] def compute_filter(self): nyp, nxp = self.shape_padded fftfreq = np.fft.rfftfreq if self.use_rfft else np.fft.fftfreq fy = np.fft.fftfreq(nyp, d=self.pixel_size_xy_micron[1]) fx = fftfreq(nxp, d=self.pixel_size_xy_micron[0]) self._coords_grid = np.add.outer(fy**2, fx**2) # k2 = self._coords_grid D = self.distance_micron L = self.wavelength_micron db = self.delta_beta self.paganin_filter = 1.0 / (1 + db * L * D * pi * k2) def pad_with_values(self, data, top_val=0, bottom_val=0, left_val=0, right_val=0): """ Pad the data into `self.padded_data` with values. Parameters ---------- data: numpy.ndarray data (radio) top_val: float or numpy.ndarray, optional Value(s) to fill the top of the padded data with. bottom_val: float or numpy.ndarray, optional Value(s) to fill the bottom of the padded data with. left_val: float or numpy.ndarray, optional Value(s) to fill the left of the padded data with. right_val: float or numpy.ndarray, optional Value(s) to fill the right of the padded data with. """ self.data_padded.fill(0) Pu, Pd = self.pad_top_len, self.pad_bottom_len Pl, Pr = self.pad_left_len, self.pad_right_len self.data_padded[:Pu, :] = top_val self.data_padded[-Pd:, :] = bottom_val self.data_padded[:, :Pl] = left_val self.data_padded[:, -Pr:] = right_val self.data_padded[Pu:-Pd, Pl:-Pr] = data # Transform the data to the FFT layout self.data_padded = np.roll(self.data_padded, (-Pu, -Pl), axis=(0, 1)) def _pad_zeros(self, data): return self.pad_with_values(data, top_val=0, bottom_val=0, left_val=0, right_val=0) def _pad_mean(self, data): """ Pad the data at each border with a different constant value. The value depends on the padding size: - On the left, value = mean(first data column) - On the right, value = mean(last data column) - On the top, value = mean(first data row) - On the bottom, value = mean(last data row) """ return self.pad_with_values( data, top_val=np.mean(data[0, :]), bottom_val=np.mean(data[-1, :]), left_val=np.mean(data[:, 0]), right_val=np.mean(data[:, -1]), ) def _pad_numpy(self, data, mode): data_padded = np.pad( data, ((self.pad_top_len, self.pad_bottom_len), (self.pad_left_len, self.pad_right_len)), mode=mode ) # Transform the data to the FFT layout Pu, Pl = self.pad_top_len, self.pad_left_len return np.roll(data_padded, (-Pu, -Pl), axis=(0, 1)) def _pad_edge(self, data): self.data_padded = self._pad_numpy(data, mode="edge") def _pad_sym(self, data): self.data_padded = self._pad_numpy(data, mode="symmetric") def _pad_reflect(self, data): self.data_padded = self._pad_numpy(data, mode="reflect") def pad_data(self, data, padding_method=None): padding_method = padding_method or self.padding check_supported(padding_method, self.available_padding_modes, "padding mode") if padding_method not in self.padding_methods: raise ValueError( "Unknown padding method %s. Available are: %s" % (padding_method, str(list(self.padding_methods.keys()))) ) pad_func = self.padding_methods[padding_method] pad_func(data) return self.data_padded def apply_filter(self, radio, padding_method=None, output=None): self.pad_data(radio, padding_method=padding_method) radio_f = self.fft_func(self.data_padded, workers=self.fft_num_threads) radio_f *= self.paganin_filter radio_filtered = self.ifft_func(radio_f, workers=self.fft_num_threads).real s0, s1 = self.shape if output is None: return radio_filtered[:s0, :s1] else: output[:, :] = radio_filtered[:s0, :s1] return output def lmicron_to_db(self, Lmicron): """ Utility to convert the "Lmicron" parameter of PyHST to a value of delta/beta. Please see the doc of nabu.preproc.phase.lmicron_to_db() """ return lmicron_to_db(Lmicron, self.energy_kev, self.distance_micron) __call__ = apply_filter retrieve_phase = apply_filter def compute_paganin_margin(shape, cutoff=1e3, **pag_kwargs): """ Compute the convolution margin to use when calling PaganinPhaseRetrieval class. Parameters ----------- shape: tuple Detector shape in the form (n_z, n_x) """ P = PaganinPhaseRetrieval(shape, **pag_kwargs) ifft_func = np.fft.irfft2 if P.use_rfft else np.fft.ifft2 conv_kernel = ifft_func(P.paganin_filter) vmax = conv_kernel[0, 0] v_margin = get_decay(conv_kernel[:, 0], cutoff=cutoff, vmax=vmax) h_margin = get_decay(conv_kernel[0, :], cutoff=cutoff, vmax=vmax) # If the Paganin filter is very narrow, then the corresponding convolution # kernel is constant, and np.argmax() gives 0 (when it should give the max value) if v_margin == 0: v_margin = shape[0] if h_margin == 0: h_margin = shape[1] return v_margin, h_margin ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/preproc/phase_cuda.py0000644000175000017500000001150114712705065017337 0ustar00pierrepierreimport numpy as np import pycuda.driver as cuda from ..utils import get_cuda_srcfile, check_supported, docstring from ..cuda.processing import CudaProcessing from ..processing.fft_cuda import get_fft_class from .phase import PaganinPhaseRetrieval class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval): supported_paddings = ["zeros", "constant", "edge"] @docstring(PaganinPhaseRetrieval) def __init__( self, shape, distance=0.5, energy=20, delta_beta=250.0, pixel_size=1e-6, padding="edge", cuda_options=None, fftw_num_threads=None, # COMPAT. fft_num_threads=None, fft_backend="vkfft", ): """ Please refer to the documentation of nabu.preproc.phase.PaganinPhaseRetrieval """ padding = self._check_padding(padding) self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, distance=distance, energy=energy, delta_beta=delta_beta, pixel_size=pixel_size, padding=padding, use_rfft=True, fft_num_threads=False, ) self._init_gpu_arrays() self._init_fft(fft_backend) self._init_padding_kernel() self._init_mult_kernel() def _check_padding(self, padding): check_supported(padding, self.supported_paddings, "padding") if padding == "zeros": padding = "constant" return padding def _init_gpu_arrays(self): self.d_paganin_filter = self.cuda_processing.to_device( "d_paganin_filter", np.ascontiguousarray(self.paganin_filter, dtype=np.float32) ) # overwrite parent method, don't initialize any FFT plan def _get_fft(self, use_rfft, fft_num_threads): self.use_rfft = use_rfft def _init_fft(self, fft_backend): fft_cls = get_fft_class(backend=fft_backend) self.cufft = fft_cls(shape=self.data_padded.shape, dtype=np.float32, r2c=True) self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.cufft.shape, "f") self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64) def _init_padding_kernel(self): kern_signature = {"constant": "Piiiiiiiiffff", "edge": "Piiiiiiii"} self.padding_kernel = self.cuda_processing.kernel( "padding_%s" % self.padding, filename=get_cuda_srcfile("padding.cu"), signature=kern_signature[self.padding], ) Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.padding_kernel_args = [ self.d_radio_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, ] # TODO configurable constant values if self.padding == "constant": self.padding_kernel_args.extend([0, 0, 0, 0]) def _init_mult_kernel(self): self.cpxmult_kernel = self.cuda_processing.kernel( "inplace_complexreal_mul_2Dby2D", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPii", ) self.cpxmult_kernel_args = [ self.d_radio_f, self.d_paganin_filter, self.shape_padded[1] // 2 + 1, self.shape_padded[0], ] def set_input(self, data): assert data.shape == self.shape assert data.dtype == np.float32 # Rectangular memcopy # TODO profile, and if needed include this copy in the padding kernel if isinstance(data, np.ndarray) or isinstance(data, self.cuda_processing.array_class): self.d_radio_padded[: self.shape[0], : self.shape[1]] = data[:, :] elif isinstance(data, cuda.DeviceAllocation): # TODO manual memcpy2D raise NotImplementedError("pycuda buffers are not supported yet") else: raise ValueError("Expected either numpy array, pycuda array or pycuda buffer") def get_output(self, output): s0, s1 = self.shape if output is None: # copy D2H return self.d_radio_padded[:s0, :s1].get() assert output.shape == self.shape assert output.dtype == np.float32 output[:, :] = self.d_radio_padded[:s0, :s1] return output def apply_filter(self, radio, output=None): self.set_input(radio) self.padding_kernel(*self.padding_kernel_args) self.cufft.fft(self.d_radio_padded, output=self.d_radio_f) self.cpxmult_kernel(*self.cpxmult_kernel_args) self.cufft.ifft(self.d_radio_f, output=self.d_radio_padded) return self.get_output(output) __call__ = apply_filter retrieve_phase = apply_filter ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730363900.0 nabu-2024.2.1/nabu/preproc/shift.py0000644000175000017500000000654014710640774016372 0ustar00pierrepierrefrom math import floor import numpy as np class VerticalShift: def __init__(self, radios_shape, shifts): """ This class is used when a vertical translation (along the tomography rotation axis) occurred. These translations are meant "per projection" and can be due either to mechanical errors, or can be applied purposefully with known motor movements to smear rings artefacts. The object is initialised with an array of shifts: one shift for each projection. A positive shifts means that the axis has moved in the positive Z direction. The interpolation is done taking for a pixel (y,x) the pixel found at (y+shft,x) in the recorded images. The method apply_vertical_shifts performs the correctionson the radios. Parameters ---------- radios_shape: tuple Shape of the radios chunk, in the form (n_radios, n_y, n_x) shifts: sequence of floats one shift for each projection Notes ------ During the acquisition, there might be other translations, each of them orthogonal to the rotation axis. - A "horizontal" translation in the detector plane: this is handled directly in the Backprojection operation. - A translation along the beam direction: this one is of no concern for parallel-beam geometry """ self.radios_shape = radios_shape self.shifts = shifts self._init_interp_coefficients() def _init_interp_coefficients(self): self.interp_infos = [] for s in self.shifts: s0 = int(floor(s)) f = s - s0 self.interp_infos.append([s0, f]) def _check(self, radios, iangles): assert np.min(iangles) >= 0 assert np.max(iangles) < len(self.interp_infos) assert len(iangles) == radios.shape[0] def apply_vertical_shifts(self, radios, iangles, output=None): """ Parameters ---------- radios: a sequence of np.array The input radios. If the optional parameter is not given, they are modified in-place iangles: a sequence of integers Must have the same lenght as radios. It contains the index at which the shift is found in `self.shifts` given by `shifts` argument in the initialisation of the object. output: a sequence of np.array, optional If given, it will be modified to contain the shifted radios. Must be of the same shape of `radios`. """ self._check(radios, iangles) newradio = np.zeros_like(radios[0]) for radio, ia in zip(radios, iangles): newradio[:] = 0 S0, f = self.interp_infos[ia] s0 = S0 if s0 > 0: newradio[:-s0] = radio[s0:] * (1 - f) elif s0 == 0: newradio[:] = radio[s0:] * (1 - f) else: newradio[-s0:] = radio[:s0] * (1 - f) s0 = S0 + 1 if s0 > 0: newradio[:-s0] += radio[s0:] * f elif s0 == 0: newradio[:] += radio[s0:] * f else: newradio[-s0:] += radio[:s0] * f if output is None: radios[ia] = newradio else: output[ia] = newradio ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1706619687.0 nabu-2024.2.1/nabu/preproc/shift_cuda.py0000644000175000017500000001010514556171447017362 0ustar00pierrepierreimport numpy as np from ..cuda.utils import __has_pycuda__ from ..cuda.processing import CudaProcessing from ..processing.muladd_cuda import CudaMulAdd from .shift import VerticalShift class CudaVerticalShift(VerticalShift): def __init__(self, radios_shape, shifts, **cuda_options): """ Vertical Shifter, Cuda backend. """ super().__init__(radios_shape, shifts) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_arrays() def _init_cuda_arrays(self): interp_infos_arr = np.zeros((len(self.interp_infos), 2), "f") self._d_interp_infos = self.cuda_processing.to_device("_d_interp_infos", interp_infos_arr) self._d_radio_new = self.cuda_processing.allocate_array("_d_radio_new", self.radios_shape[1:], "f") self._d_radio = self.cuda_processing.allocate_array("_d_radio", self.radios_shape[1:], "f") self.muladd_kernel = CudaMulAdd(ctx=self.cuda_processing.ctx) def apply_vertical_shifts(self, radios, iangles, output=None): """ Parameters ---------- radios: 3D pycuda.gpuarray.GPUArray The input radios. If the optional parameter is not given, they are modified in-place iangles: a sequence of integers Must have the same lenght as radios. It contains the index at which the shift is found in `self.shifts` given by `shifts` argument in the initialisation of the object. output: 3D pycuda.gpuarray.GPUArray, optional If given, it will be modified to contain the shifted radios. Must be of the same shape of `radios`. """ self._check(radios, iangles) n_a, n_z, n_x = radios.shape assert n_z == self.radios_shape[1] x_slice = slice(0, n_x) # slice(None, None) def nonempty_subregion(region): if region is None: return True z_slice = region[0] return z_slice.stop - z_slice.start > 0 d_radio_new = self._d_radio_new d_radio = self._d_radio for ia in iangles: d_radio_new.fill(0) d_radio[:] = radios[ia, :, :] # mul-add kernel won't work with pycuda view S0, f = self.interp_infos[ia] f = np.float32(f) s0 = S0 if s0 > 0: # newradio[:-s0] = radio[s0:] * (1 - f) dst_region = (slice(0, n_z - s0), x_slice) other_region = (slice(s0, n_z), x_slice) elif s0 == 0: # newradio[:] = radio[s0:] * (1 - f) dst_region = None other_region = (slice(s0, n_z), x_slice) else: # newradio[-s0:] = radio[:s0] * (1 - f) dst_region = (slice(-s0, n_z), x_slice) other_region = (slice(0, n_z + s0), x_slice) if all([nonempty_subregion(reg) for reg in [dst_region, other_region]]): self.muladd_kernel( d_radio_new, d_radio, 1, 1 - f, dst_region=dst_region, other_region=other_region, ) s0 = S0 + 1 if s0 > 0: # newradio[:-s0] += radio[s0:] * f dst_region = (slice(0, n_z - s0), x_slice) other_region = (slice(s0, n_z), x_slice) elif s0 == 0: # newradio[:] += radio[s0:] * f dst_region = None other_region = (slice(s0, n_z), x_slice) else: # newradio[-s0:] += radio[:s0] * f dst_region = (slice(-s0, n_z), x_slice) other_region = (slice(0, n_z + s0), x_slice) if all([nonempty_subregion(reg) for reg in [dst_region, other_region]]): self.muladd_kernel(d_radio_new, d_radio, 1, f, dst_region=dst_region, other_region=other_region) if output is None: radios[ia, :, :] = d_radio_new[:, :] else: output[ia, :, :] = d_radio_new[:, :] ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5167568 nabu-2024.2.1/nabu/preproc/tests/0000755000175000017500000000000014730277752016044 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/preproc/tests/__init__.py0000644000175000017500000000000114315516747020143 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/preproc/tests/test_ccd_corr.py0000644000175000017500000000431114402565210021213 0ustar00pierrepierreimport pytest import numpy as np from nabu.utils import median2 as nabu_median_filter from nabu.testutils import get_data from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.preproc.ccd import CCDFilter if __has_pycuda__: import pycuda.gpuarray as garray from nabu.preproc.ccd_cuda import CudaCCDFilter @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("mri_proj_astra.npz")["data"] cls.data /= cls.data.max() cls.put_hotspots_in_data() if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.usefixtures("bootstrap") class TestCCDFilter: @classmethod def put_hotspots_in_data(cls): # Put 5 hot spots in the data # (row, column, deviation from median) cls.hotspots = [(50, 51, 0.04), (151, 150, 0.08), (202, 303, 0.12), (322, 203, 0.14)] cls.threshold = 0.1 # parameterize ? data_medfilt = nabu_median_filter(cls.data) for r, c, deviation_from_median in cls.hotspots: cls.data[r, c] = data_medfilt[r, c] + deviation_from_median def check_detected_hotspots_locations(self, res): diff = self.data - res rows, cols = np.where(diff > 0) hotspots_arr = np.array(self.hotspots) M = hotspots_arr[:, -1] > self.threshold hotspots_rows = hotspots_arr[M, 0] hotspots_cols = hotspots_arr[M, 1] assert np.allclose(hotspots_rows, rows) assert np.allclose(hotspots_cols, cols) def test_median_clip(self): ccd_filter = CCDFilter(self.data.shape, median_clip_thresh=self.threshold) res = np.zeros_like(self.data) res = ccd_filter.median_clip_correction(self.data, output=res) self.check_detected_hotspots_locations(res) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_median_clip(self): d_radios = garray.to_gpu(self.data) cuda_ccd_correction = CudaCCDFilter(d_radios.shape, median_clip_thresh=self.threshold) d_out = garray.zeros_like(d_radios) cuda_ccd_correction.median_clip_correction(d_radios, output=d_out) res = d_out.get() self.check_detected_hotspots_locations(res) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/preproc/tests/test_ctf.py0000644000175000017500000002342614712705065020231 0ustar00pierrepierreimport pytest import numpy as np import scipy.interpolate from nabu.processing.fft_cuda import get_available_fft_implems from nabu.testutils import get_data as nabu_get_data from nabu.testutils import __do_long_tests__ from nabu.preproc.flatfield import FlatFieldArrays from nabu.preproc.ccd import CCDFilter from nabu.preproc import ctf from nabu.estimation.distortion import estimate_flat_distortion from nabu.misc.filters import correct_spikes from nabu.preproc.distortion import DistortionCorrection from nabu.cuda.utils import __has_pycuda__, get_cuda_context __has_cufft__ = False if __has_pycuda__: from nabu.preproc.ctf_cuda import CudaCTFPhaseRetrieval avail_fft = get_available_fft_implems() __has_cufft__ = len(avail_fft) > 0 @pytest.fixture(scope="class") def bootstrap_TestCtf(request): cls = request.cls cls.abs_tol = 1.0e-4 test_data = nabu_get_data("ctf_tests_data_all_pars.npz") cls.rand_disp_vh = test_data["rh"] ## the dimension number 1 is over holotomo distances, so far our filter is for one distance only cls.rand_disp_vh.shape = [cls.rand_disp_vh.shape[0], cls.rand_disp_vh.shape[2]] cls.dark = test_data["dark"] cls.flats = [test_data["ref0"], test_data["ref1"]] cls.im = test_data["im"] cls.ipro = int(test_data["ipro"]) cls.expected_result = test_data["result"] cls.ref_plain = test_data["ref_plain_float_flat"] cls.flats_n = test_data["refns"] cls.img_shape_vh = test_data["img_shape_vh"] cls.padded_img_shape_vh = test_data["padded_img_shape_vh"] cls.z1_vh = test_data["z1_vh"] cls.z2 = test_data["z2"] cls.pix_size_det = test_data["pix_size_det"][()] cls.length_scale = test_data["length_scale"] cls.wavelength = test_data["wave_length"] cls.remove_spikes_threshold = test_data["remove_spikes_threshold"] cls.delta_beta = 27 @pytest.mark.usefixtures("bootstrap_TestCtf") class TestCtf: def check_result(self, res, ref, error_message): diff = np.abs(res - ref) diff[diff > np.percentile(diff, 99)] = 0 assert diff.max() < self.abs_tol * (np.abs(ref).mean()), error_message def test_ctf_id16_way(self): """test the ctf phase retrieval. The cft filter, of the CtfFilter class is iniitalised with the geomety informations contained in geo_pars object of the GeoPars class. The geometry encompass the case of astigmatic wavefront with a vertical and horisontal sources which are at distance z1_vh[0] and z1_vh[1] from the object. In the case of parllel geometry put z1_vh[0]= z1_vh[1] = R where R is a large value ( meters). SI unit system is used. But the same results shudl be obtained with any homogenuous choice of the distance units. The img_shape is the shape of the images which will be processed. padded_img_shape is an intermediate shape which needs to be larger that the img_shape to avoid border effect due to convolutions. Length scale is an internal parameters which should not affect in anyway the result unless there are serious numerical problems involving very small lenghts. You can safely let the default value. """ geo_pars = ctf.GeoPars( z1_vh=self.z1_vh, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) flats = FlatFieldArrays( [1200] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark} ) my_flat = flats.get_flat(self.ipro) my_img = self.im - self.dark my_flat = my_flat - self.dark new_coordinates = estimate_flat_distortion( my_flat, my_img, tile_size=100, interpolation_kind="cubic", padding_mode="edge", correction_spike_threshold=3, ) interpolator = scipy.interpolate.RegularGridInterpolator( (np.arange(my_flat.shape[0]), np.arange(my_flat.shape[1])), my_flat, bounds_error=False, method="linear", fill_value=None, ) my_flat = interpolator(new_coordinates) my_img = my_img / my_flat my_img = correct_spikes(my_img, self.remove_spikes_threshold) my_shift = self.rand_disp_vh[:, self.ipro] ctf_filter = ctf.CtfFilter( self.dark.shape, geo_pars, self.delta_beta, padded_shape=self.padded_img_shape_vh, translation_vh=my_shift, normalize_by_mean=True, lim1=1.0e-5, lim2=0.2, ) phase = ctf_filter.retrieve_phase(my_img) self.check_result( phase, self.expected_result, "retrieved phase and reference result differ beyond the accepted tolerance" ) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_ctf_id16_class(self): geo_pars = ctf.GeoPars( z1_vh=self.z1_vh, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs={ "tile_size": 100, "interpolation_kind": "cubic", "padding_mode": "edge", "correction_spike_threshold": 3.0, }, correction_method="interpn", correction_kwargs={"fill_value": None}, ) flats = FlatFieldArrays( [1200] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark}, distortion_correction=distortion_correction, ) # The "correct_spikes" function is numerically unstable (comparison with a float threshold). # If float32 is used for the image, one spike is detected while it is not in the previous test # (although the max difference between the inputs is about 1e-8). # We use float64 data type for the image to make tests pass. img = self.im.astype(np.float64) flats.normalize_single_radio(img, self.ipro) img = correct_spikes(img, self.remove_spikes_threshold) shift = self.rand_disp_vh[:, self.ipro] ctf_filter = ctf.CtfFilter( img.shape, geo_pars, self.delta_beta, padded_shape=self.padded_img_shape_vh, translation_vh=shift, normalize_by_mean=True, lim1=1.0e-5, lim2=0.2, ) phase = ctf_filter.retrieve_phase(img) message = "retrieved phase and reference result differ beyond the accepted tolerance" assert np.abs(phase - self.expected_result).max() < 10 * self.abs_tol * ( np.abs(self.expected_result).mean() ), message def test_ctf_plain_way(self): geo_pars = ctf.GeoPars( z1_vh=None, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) flatfielder = FlatFieldArrays( [1] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark}, radios_indices=[self.ipro], ) spikes_corrector = CCDFilter( self.dark.shape, median_clip_thresh=self.remove_spikes_threshold, abs_diff=True, preserve_borders=True ) img = self.im.astype("f") img = flatfielder.normalize_radios(np.array([img]))[0] img = spikes_corrector.median_clip_correction(img) ctf_args = [img.shape, geo_pars, self.delta_beta] ctf_kwargs = {"padded_shape": self.padded_img_shape_vh, "normalize_by_mean": True, "lim1": 1.0e-5, "lim2": 0.2} ctf_filter = ctf.CtfFilter(*ctf_args, **ctf_kwargs) phase = ctf_filter.retrieve_phase(img) self.check_result(phase, self.ref_plain, "Something wrong with CtfFilter") # Test R2C ctf_numpy = ctf.CtfFilter(*ctf_args, **ctf_kwargs, use_rfft=True) phase_r2c = ctf_numpy.retrieve_phase(img) self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-R2C") # Test multi-core FFT ctf_fft = ctf.CtfFilter(*ctf_args, **ctf_kwargs, use_rfft=True, fft_num_threads=0) if ctf_fft.use_rfft: phase_fft = ctf_fft.retrieve_phase(img) self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFT") @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (scikit-cuda or vkfft)") def test_cuda_ctf(self): data = nabu_get_data("brain_phantom.npz")["data"] delta_beta = 50.0 energy_kev = 22.0 distance_m = 1.0 pix_size_m = 0.1e-6 geo_pars = ctf.GeoPars(z2=distance_m, pix_size_det=pix_size_m, wavelength=1.23984199e-9 / energy_kev) ctx = get_cuda_context() for normalize in [True, False]: ctf_filter = ctf.CTFPhaseRetrieval( data.shape, geo_pars, delta_beta=delta_beta, normalize_by_mean=normalize, use_rfft=True ) cuda_ctf_filter = CudaCTFPhaseRetrieval( data.shape, geo_pars, delta_beta=delta_beta, use_rfft=True, normalize_by_mean=normalize, ) ref = ctf_filter.retrieve_phase(data) d_data = cuda_ctf_filter.cuda_processing.to_device("_d_data", data) res = cuda_ctf_filter.retrieve_phase(d_data).get() err_max = np.max(np.abs(res - ref)) assert err_max < 1e-2, "Something wrong with retrieve_phase(normalize_by_mean=%s)" % (str(normalize)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/preproc/tests/test_double_flatfield.py0000644000175000017500000000547114550227307022737 0ustar00pierrepierreimport os.path as path from math import exp import tempfile import numpy as np import pytest from silx.io.url import DataUrl from tomoscan.esrf.mock import MockNXtomo from nabu.io.reader import HDF5Reader from nabu.preproc.double_flatfield import DoubleFlatField from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: import pycuda.gpuarray as garray from nabu.preproc.double_flatfield_cuda import CudaDoubleFlatField, __has_pycuda__ @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tmpdir = tempfile.TemporaryDirectory() dname = cls.tmpdir.name cls.dname = dname radios = MockNXtomo( path.join(dname, "tmp"), 10, n_ini_proj=10, dim=100, n_refs=1, scene="increasing value", ).scan.projections reader = HDF5Reader() cls.radios = [] Rkeys = list(radios.keys()) for k in Rkeys: dataurl = radios[k] data = reader.get_data(dataurl) cls.radios.append(data) cls.radios = np.array(cls.radios) cls.ff_dump_url = DataUrl( file_path=path.join(cls.dname, "dff.h5"), data_path="/entry/double_flatfield/results/data" ) cls.ff_cuda_dump_url = DataUrl( file_path=path.join(cls.dname, "dff_cuda.h5"), data_path="/entry/double_flatfield/results/data" ) golden = 0 for i in range(10): golden += exp(-i) cls.golden = golden / 10 cls.tol = 1e-4 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestDoubleFlatField: def test_dff_numpy(self): dff = DoubleFlatField(self.radios.shape, result_url=self.ff_dump_url) mydf = dff.get_double_flatfield(radios=self.radios) assert path.isfile(dff.result_url.file_path()) dff2 = DoubleFlatField(self.radios.shape, result_url=self.ff_dump_url) mydf2 = dff2.get_double_flatfield(radios=self.radios) assert np.max(np.abs(mydf2 - mydf)) < self.tol assert np.max(np.abs(mydf - self.golden)) < self.tol @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for double flatfield with cuda backend") def test_dff_cuda(self): dff = CudaDoubleFlatField(self.radios.shape, result_url=self.ff_cuda_dump_url, cuda_options={"ctx": self.ctx}) d_radios = garray.to_gpu(self.radios) mydf = dff.get_double_flatfield(radios=d_radios).get() assert path.isfile(dff.result_url.file_path()) dff2 = CudaDoubleFlatField(self.radios.shape, result_url=self.ff_cuda_dump_url, cuda_options={"ctx": self.ctx}) mydf2 = dff2.get_double_flatfield(radios=d_radios).get() assert np.max(np.abs(mydf2 - mydf)) < self.tol assert np.max(np.abs(mydf - self.golden)) < self.tol ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/preproc/tests/test_flatfield.py0000644000175000017500000005110514656662150021406 0ustar00pierrepierreimport os import numpy as np import pytest from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.preproc.flatfield import FlatField if __has_pycuda__: from nabu.preproc.flatfield_cuda import CudaFlatField # Flats values should be O(k) so that linear interpolation between flats gives exact results flatfield_tests_cases = { "simple_nearest_interp": { "image_shape": (100, 512), "radios_values": np.arange(10) + 1, "radios_indices": None, "flats_values": [0.5], "flats_indices": [1], "darks_values": [1], "darks_indices": [0], "expected_result": np.arange(0, -2 * 10, -2), }, "two_flats_no_radios_indices": { "image_shape": (100, 512), "radios_values": np.arange(10) + 1, "radios_indices": None, "flats_values": [2, 11], "flats_indices": [0, 9], "darks_values": [1], "darks_indices": [0], "expected_result": np.arange(10) / (np.arange(10) + 1), }, "two_flats_with_radios_indices": { # Type D F R R R F # IDX 0 1 2 3 4 5 # Value 1 4 9 16 25 8 # F_interp 5 6 7 "image_shape": (16, 17), # R_k = (k + 1)**2 "radios_values": [9, 16, 25], "radios_indices": [2, 3, 4], # F_k = k+3 "flats_values": [4, 8], "flats_indices": [1, 5], # D_k = 1 "darks_values": [1], "darks_indices": [0], # Expected normalization result: N_k = k "expected_result": [2, 3, 4], }, "three_flats_srcurrent": { # Type D F R R | R F R R F # IDX 0 1 2 3 | 4 5 6 7 8 # Value 1 4 9 16 | 25 8 46 67 14 # F_interp 5 6 | 7 10 12 # srCurrent 20 10 10 | 16 32 16 8 16 "image_shape": (16, 17), "radios_values": [9, 16, 25, 46, 67], "radios_indices": [2, 3, 4, 6, 7], "flats_values": [4, 8, 14], "flats_indices": [1, 5, 8], "darks_values": [1], "darks_indices": [0], # "expected_result": [2, 3, 4, 5, 6], # without SR normalization "expected_result": [4, 6, 8, 10, 12], # sr_flat/sr_radio = 2 "radios_srcurrent": [10, 16, 16, 16, 8], # "flats_srcurrent": [20, 18, 6, 4, 2], "flats_srcurrent": [20, 32, 16], }, } def generate_test_flatfield_generalized( image_shape, radios_indices, radios_values, flats_indices, flats_values, darks_indices, darks_values, dtype=np.uint16, ): """ Parameters ----------- image_shape: tuple of int shape of each image. radios_indices: array of int Indices where radios are found in the dataset. radios_values: array of scalars Value for each radio image. Length must be equal to `radios_shape[0]`. flats_indices: array of int Indices where flats are found in the dataset. flats_values: array of scalars Values of flat images. Length must be equal to `len(flats_indices)` darks_indices: array of int Indices where darks are found in the dataset. darks_values: array of scalars Values of dark images. Length must be equal to `len(darks_indices)` Returns ------- radios: numpy.ndarray 3D array with raw radios darks: dict of arrays Dictionary where each key is the dark indice, and value is an array flats: dict of arrays Dictionary where each key is the flat indice, and value is an array """ # Radios radios = np.zeros((len(radios_values),) + image_shape, dtype="f") n_radios = radios.shape[0] for i in range(n_radios): radios[i].fill(radios_values[i]) img_shape = radios.shape[1:] # Flats flats = {} for i, flat_idx in enumerate(flats_indices): flats[flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i] # Darks darks = {} for i, dark_idx in enumerate(darks_indices): darks[dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i] return radios, flats, darks @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tmp_files = [] cls.tmp_dirs = [] cls.n_radios = 10 cls.n_z = 100 cls.n_x = 512 if __has_pycuda__: cls.ctx = get_cuda_context() yield # Tear-down for fname in cls.tmp_files: os.remove(fname) for dname in cls.tmp_dirs: os.rmdir(dname) @pytest.mark.usefixtures("bootstrap") class TestFlatField: def get_test_elements(self, case_name): config = flatfield_tests_cases[case_name] radios_stack, flats, darks = generate_test_flatfield_generalized( config["image_shape"], config["radios_indices"], config["radios_values"], config["flats_indices"], config["flats_values"], config["darks_indices"], config["darks_values"], ) # fname = flats_url[list(flats_url.keys())[0]].file_path() # self.tmp_files.append(fname) # self.tmp_dirs.append(os.path.dirname(fname)) return radios_stack, flats, darks, config @staticmethod def check_normalized_radios(radios_corr, expected_values): # must be the same value everywhere in the radio std = np.std(np.std(radios_corr, axis=-1), axis=-1) assert np.max(np.abs(std)) < 1e-7 # radios values must be 0, -2, -4, ... assert np.allclose(radios_corr[:, 0, 0], expected_values) def test_flatfield_simple(self): """ Test the flat-field normalization on a radios stack with 1 dark and 1 flat. (I - D)/(F - D) where I = (1, 2, ...), D = 1, F = 0.5 = (0, -2, -4, -6, ...) """ radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp") flatfield = FlatField(radios_stack.shape, flats, darks) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios(radios_corr, config["expected_result"]) def test_flatfield_simple_subregion(self): """ Same as test_flatfield_simple, but in a vertical subregion of the radios. """ radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp") end_z = 51 flats = {k: arr[:end_z, :] for k, arr in flats.items()} darks = {k: arr[:end_z, :] for k, arr in darks.items()} radios_chunk = np.copy(radios_stack[:, :end_z, :]) # we only have a chunk in memory. Instantiate the class with the # corresponding subregion to only load the relevant part of dark/flat flatfield = FlatField( radios_chunk.shape, flats, darks, ) radios_corr = flatfield.normalize_radios(radios_chunk) self.check_normalized_radios(radios_corr, config["expected_result"]) def test_flatfield_linear_interp(self): """ Test flat-field normalization with 1 dark and 2 flats, with linear interpolation between flats. I = 1 2 3 4 5 6 7 8 9 10 D = 1 (one dark) F = 2 11 (two flats) F_i = 2 3 4 5 6 7 8 9 10 11 (interpolated flats) R = 0 .5 .66 .75 .8 .83 .86 = (I-D)/(F-D) = (I-1)/I """ radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices") flatfield = FlatField(radios_stack.shape, flats, darks) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios(radios_corr, config["expected_result"]) # Test 2: one of the flats is not at the beginning/end # I = 1 2 3 4 5 6 7 8 9 10 # F = 2 11 # F_i = 2 3.8 5.6 7.4 9.2 11 11 11 11 11 # R = 0 .357 .435 .469 .488 .5 .6 .7 .8 .9 flats = {k: v.copy() for k, v in flats.items()} flats[5] = flats[9] flats.pop(9) flatfield = FlatField(radios_stack.shape, flats, darks) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios( radios_corr, [0.0, 0.35714286, 0.43478261, 0.46875, 0.48780488, 0.5, 0.6, 0.7, 0.8, 0.9] ) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): """ Test the flat-field with cuda back-end. """ radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices") cuda_flatfield = CudaFlatField( radios_stack.shape, flats, darks, ) d_radios = cuda_flatfield.cuda_processing.to_device("d_radios", radios_stack.astype("f")) cuda_flatfield.normalize_radios(d_radios) radios_corr = d_radios.get() self.check_normalized_radios(radios_corr, config["expected_result"]) # Linear interpolation, two flats, one dark def test_twoflats_simple(self): radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices") FF = FlatField(radios.shape, flats, darks, radios_indices=config["radios_indices"]) FF.normalize_radios(radios) self.check_normalized_radios(radios, config["expected_result"]) def _setup_numerical_issue(self): radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices") flats_copy = {} darks_copy = {} for flat_idx, flat in flats.items(): flats_copy[flat_idx] = flat.copy() flats_copy[flat_idx][0, 0] = 99 for dark_idx, dark in darks.items(): darks_copy[dark_idx] = dark.copy() darks_copy[dark_idx][0, 0] = 99 radios[:, 0, 0] = 99 return radios, flats_copy, darks_copy, config def _check_numerical_issue(self, radios, expected_result, nan_value=None): if nan_value is None: assert np.all(np.logical_not(np.isfinite(radios[:, 0, 0]))), "First pixel should be nan or inf" radios[:, 0, 0] = radios[:, 1, 1] self.check_normalized_radios(radios, expected_result) else: assert np.all(np.isfinite(radios)), "No inf/nan value should be there" assert np.allclose(radios[:, 0, 0], nan_value, atol=1e-7), ( "Handled NaN should have nan_value=%f" % nan_value ) radios[:, 0, 0] = radios[:, 1, 1] self.check_normalized_radios(radios, expected_result) def test_twoflats_numerical_issue(self): """ Same as above, but for the first radio: I==Dark and Flat==Dark For this radio, nan is replaced with 1. """ radios, flats, darks, config = self._setup_numerical_issue() radios0 = radios.copy() # FlatField without NaN handling yields NaN and raises RuntimeWarning FF_no_nan_handling = FlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None ) with pytest.warns(RuntimeWarning): FF_no_nan_handling.normalize_radios(radios) self._check_numerical_issue(radios, config["expected_result"], None) # FlatField with NaN handling nan_value = 50 radios = radios0.copy() FF_with_nan_handling = FlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=nan_value ) with pytest.warns(RuntimeWarning): FF_with_nan_handling.normalize_radios(radios) self._check_numerical_issue(radios, config["expected_result"], nan_value) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_twoflats_numerical_issue(self): """ Same as above, with the Cuda backend """ radios, flats, darks, config = self._setup_numerical_issue() radios0 = radios.copy() FF_no_nan_handling = CudaFlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None ) d_radios = FF_no_nan_handling.cuda_processing.to_device("radios", radios) # In a cuda kernel, no one can hear you scream FF_no_nan_handling.normalize_radios(d_radios) radios = d_radios.get() self._check_numerical_issue(radios, config["expected_result"], None) # FlatField with NaN handling nan_value = 50 d_radios.set(radios0) FF_with_nan_handling = CudaFlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=nan_value ) FF_with_nan_handling.normalize_radios(d_radios) radios = d_radios.get() self._check_numerical_issue(radios, config["expected_result"], nan_value) def test_srcurrent(self): radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent") FF = FlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], radios_srcurrent=config["radios_srcurrent"], flats_srcurrent=config["flats_srcurrent"], ) radios_corr = FF.normalize_radios(np.copy(radios)) self.check_normalized_radios(radios_corr, config["expected_result"]) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_srcurrent_cuda(self): radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent") FF = CudaFlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], radios_srcurrent=config["radios_srcurrent"], flats_srcurrent=config["flats_srcurrent"], ) d_radios = FF.cuda_processing.to_device("radios", radios) FF.normalize_radios(d_radios) radios_corr = d_radios.get() self.check_normalized_radios(radios_corr, config["expected_result"]) # This test should be closer to the ESRF standard setting. # There are 2 flats, one dark, 4000 radios. # dark : indice=0 value=10 # flat1 : indice=1 value=4202 # flat2 : indice=2102 value=2101 # # The projections have the following indices: # j: 0 1 1998 1999 2000 2001 3999 # idx: [102, 103, ..., 2100, 2101, 2203, 2204, ..., 4200, 4201, 4202] # Notice the gap in the middle. # # The linear interpolation is # flat_i = (n2 - i)/(n2 - n1)*flat_1 + (i - n1)/(n2 - n1)*flat_2 # where n1 and n2 are the indices of flat_1 and flat_2 respectively. # With the above values, we have flat_i = 4203 - i. # # The projections values are dark + i*(flat_i - dark), # so that the normalization norm_i = (proj_i - dark)/(flat_i - dark) gives # # idx 102 103 104 ... # flat 4101 4102 4103 ... # norm 102 103 104 ... # class FlatFieldTestDataset: # Parameters shp = (27, 32) n1 = 1 # flat indice 1 n2 = 2102 # flat indice 2 dark_val = 10 darks = {0: np.zeros(shp, "f") + dark_val} flats = {n1: np.zeros(shp, "f") + (n2 - 1) * 2, n2: np.zeros(shp, "f") + n2 - 1} projs_idx = list(range(102, 2102)) + list(range(2203, 4203)) # gap in the middle def __init__(self): self._generate_projections() def get_flat_idx(self, proj_idx): flats_idx = sorted(list(self.flats.keys())) if proj_idx <= flats_idx[0]: return (flats_idx[0],) elif proj_idx > flats_idx[0] and proj_idx < flats_idx[1]: return flats_idx else: return (flats_idx[1],) def get_flat(self, idx): flatidx = self.get_flat_idx(idx) if len(flatidx) == 1: flat = self.flats[flatidx[0]] else: nf1, nf2 = flatidx w1 = (nf2 - idx) / (nf2 - nf1) flat = w1 * self.flats[nf1] + (1 - w1) * self.flats[nf2] return flat def _generate_projections(self): self.projs_data = np.zeros((len(self.projs_idx),) + self.shp, "f") self.projs = {} for i, proj_idx in enumerate(self.projs_idx): flat = self.get_flat(proj_idx) proj_val = self.dark_val + proj_idx * (flat[0, 0] - self.dark_val) self.projs[str(proj_idx)] = np.zeros(self.shp, "f") + proj_val self.projs_data[i] = self.projs[str(proj_idx)] @pytest.fixture(scope="class") def bootstraph5(request): cls = request.cls cls.dataset = FlatFieldTestDataset() n1, n2 = cls.dataset.n1, cls.dataset.n2 # Interpolation function cls._weight1 = lambda i: (n2 - i) / (n2 - n1) cls.tol = 5e-4 cls.tol_std = 1e-3 yield @pytest.mark.usefixtures("bootstraph5") class TestFlatFieldH5: def check_normalization(self, projs): # Check that each projection is filled with the same values std_projs = np.std(projs, axis=(-2, -1)) assert np.max(np.abs(std_projs)) < self.tol_std # Check that the normalized radios are equal to 102, 103, 104, ... errs = projs[:, 0, 0] - self.dataset.projs_idx assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization" def test_flatfield(self): flatfield = FlatField( self.dataset.projs_data.shape, self.dataset.flats, self.dataset.darks, radios_indices=self.dataset.projs_idx, interpolation="linear", ) projs = np.copy(self.dataset.projs_data) flatfield.normalize_radios(projs) self.check_normalization(projs) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): cuda_flatfield = CudaFlatField( self.dataset.projs_data.shape, self.dataset.flats, self.dataset.darks, radios_indices=self.dataset.projs_idx, ) d_projs = cuda_flatfield.cuda_processing.to_device("d_projs", self.dataset.projs_data) cuda_flatfield.normalize_radios(d_projs) projs = d_projs.get() self.check_normalization(projs) # # Another test with more than two flats. # # Here we have # # F_i = i + 2 # R_i = i*(F_i - 1) + 1 # N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i # def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname): radios = np.zeros((n_radios,) + radio_shape, "f") dark_data = np.ones(radios.shape[1:], "f") flats = {} # F_i = i + 2 # R_i = i*(F_i - 1) + 1 # N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i for i in range(n_radios): f_i = i + 2 if (i % flat_interval) == 0: flats[i] = np.zeros(radio_shape, "f") + f_i radios[i] = i * (f_i - 1) + 1 darks = {0: dark_data} return radios, flats, darks @pytest.fixture(scope="class") def bootstrap_multiflats(request): cls = request.cls n_radios = 50 radio_shape = (20, 21) cls.flat_interval = 11 h5_fname = "testff.h5" radios, flats, dark = generate_test_flatfield(n_radios, radio_shape, cls.flat_interval, h5_fname) cls.radios = radios cls.flats = flats cls.darks = dark cls.expected_results = np.arange(n_radios) cls.tol = 5e-4 cls.tol_std = 1e-4 yield @pytest.mark.usefixtures("bootstrap_multiflats") class TestFlatFieldMultiFlat: def check_normalization(self, projs): # Check that each projection is filled with the same values std_projs = np.std(projs, axis=(-2, -1)) assert np.max(np.abs(std_projs)) < self.tol_std # Check that the normalized radios are equal to 0, 1, 2, ... stop = (projs.shape[0] // self.flat_interval) * self.flat_interval errs = projs[:stop, 0, 0] - self.expected_results[:stop] assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization" def test_flatfield(self): flatfield = FlatField(self.radios.shape, self.flats, self.darks, interpolation="linear") projs = np.copy(self.radios) flatfield.normalize_radios(projs) print(projs[:, 0, 0]) self.check_normalization(projs) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): cuda_flatfield = CudaFlatField( self.radios.shape, self.flats, self.darks, ) d_projs = cuda_flatfield.cuda_processing.to_device("radios", self.radios) cuda_flatfield.normalize_radios(d_projs) projs = d_projs.get() self.check_normalization(projs) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/preproc/tests/test_paganin.py0000644000175000017500000000566614712705065021100 0ustar00pierrepierreimport pytest import numpy as np from nabu.preproc.phase import PaganinPhaseRetrieval from nabu.processing.fft_cuda import get_available_fft_implems from nabu.testutils import generate_tests_scenarios, get_data from nabu.thirdparty.tomopy_phase import retrieve_phase from nabu.cuda.utils import __has_pycuda__ __has_cufft__ = False if __has_pycuda__: from nabu.preproc.phase_cuda import CudaPaganinPhaseRetrieval avail_fft = get_available_fft_implems() __has_cufft__ = len(avail_fft) > 0 scenarios = { "distance": [1], "energy": [35], "delta_beta": [1e1], "margin": [((50, 50), (0, 0)), None], } scenarios = generate_tests_scenarios(scenarios) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("mri_proj_astra.npz")["data"] cls.rtol = 1.1e-6 cls.rtol_pag = 5e-3 @pytest.mark.usefixtures("bootstrap") class TestPaganin: """ Test the Paganin phase retrieval. The reference implementation is tomopy. """ @staticmethod def get_paganin_instance_and_data(cfg, data): pag_kwargs = cfg.copy() margin = pag_kwargs.pop("margin") if margin is not None: data = np.pad(data, margin, mode="edge") paganin = PaganinPhaseRetrieval(data.shape, **pag_kwargs) return paganin, data, pag_kwargs @staticmethod def crop_to_margin(data, margin): if margin is None: return data ((U, D), (L, R)) = margin D = None if D == 0 else -D R = None if R == 0 else -R return data[U:D, L:R] @pytest.mark.parametrize("config", scenarios) def test_paganin(self, config): paganin, data, _ = self.get_paganin_instance_and_data(config, self.data) res = paganin.apply_filter(data) data_tomopy = np.atleast_3d(np.copy(data)).T res_tomopy = retrieve_phase( data_tomopy, pixel_size=paganin.pixel_size_xy_micron[0] * 1e-4, dist=paganin.distance_cm, energy=paganin.energy_kev, alpha=1.0 / (4 * 3.141592**2 * paganin.delta_beta), ) res_tomopy = self.crop_to_margin(res_tomopy[0].T, config["margin"]) res = self.crop_to_margin(res, config["margin"]) errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy)) assert errmax < self.rtol_pag, "Max error is too high" @pytest.mark.skipif( not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (scikit-cuda or vkfft) for this test" ) @pytest.mark.parametrize("config", scenarios) def test_gpu_paganin(self, config): paganin, data, pag_kwargs = self.get_paganin_instance_and_data(config, self.data) gpu_paganin = CudaPaganinPhaseRetrieval(data.shape, **pag_kwargs) ref = paganin.apply_filter(data) res = gpu_paganin.apply_filter(data) errmax = np.max(np.abs((res - ref) / np.max(ref))) assert errmax < self.rtol, "Max error is too high" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1706619687.0 nabu-2024.2.1/nabu/preproc/tests/test_vshift.py0000644000175000017500000000546314556171447020770 0ustar00pierrepierreimport pytest import numpy as np from scipy.ndimage import shift as ndshift from nabu.preproc.shift import VerticalShift from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: import pycuda.gpuarray as garray from nabu.preproc.shift_cuda import CudaVerticalShift @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls data = np.zeros([13, 11], "f") slope = 100 + np.arange(13) data[:] = slope[:, None] cls.radios = np.array([data] * 17) cls.shifts = 0.3 + np.arange(17) cls.indexes = range(17) # given the shifts and the radios we build the golden reference golden = [] for iradio in range(17): projection_number = cls.indexes[iradio] my_shift = cls.shifts[projection_number] padded_radio = np.concatenate( [cls.radios[iradio], np.zeros([1, 11], "f")], axis=0 ) # needs padding because ndshifs does not work as expected shifted_padded_radio = ndshift(padded_radio, [-my_shift, 0], mode="constant", cval=0.0, order=1).astype("f") shifted_radio = shifted_padded_radio[:-1] golden.append(shifted_radio) cls.golden = np.array(golden) cls.tol = 1e-5 if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.usefixtures("bootstrap") class TestVerticalShift: def test_vshift(self): radios = self.radios.copy() new_radios = np.zeros_like(radios) Shifter = VerticalShift(radios.shape, self.shifts) Shifter.apply_vertical_shifts(radios, self.indexes, output=new_radios) assert abs(new_radios - self.golden).max() < self.tol Shifter.apply_vertical_shifts(radios, self.indexes) assert abs(radios - self.golden).max() < self.tol @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_vshift(self): d_radios = garray.to_gpu(self.radios) d_radios2 = d_radios.copy() d_out = garray.zeros_like(d_radios) Shifter = CudaVerticalShift(d_radios.shape, self.shifts) Shifter.apply_vertical_shifts(d_radios, self.indexes, output=d_out) assert abs(d_out.get() - self.golden).max() < self.tol Shifter.apply_vertical_shifts(d_radios, self.indexes) assert abs(d_radios.get() - self.golden).max() < self.tol # Test with negative shifts radios2 = self.radios.copy() Shifter_neg = VerticalShift(self.radios.shape, -self.shifts) Shifter_neg.apply_vertical_shifts(radios2, self.indexes) Shifter_neg_cuda = CudaVerticalShift(d_radios.shape, -self.shifts) Shifter_neg_cuda.apply_vertical_shifts(d_radios2, self.indexes) err_max = np.max(np.abs(d_radios2.get() - radios2)) assert err_max < 1e-6, "Something wrong for negative translations: max error = %.2e" % err_max ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5167568 nabu-2024.2.1/nabu/processing/0000755000175000017500000000000014730277752015404 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/__init__.py0000644000175000017500000000000014550227307017472 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731941746.0 nabu-2024.2.1/nabu/processing/azim.py0000644000175000017500000001564414716652562016730 0ustar00pierrepierrefrom multiprocessing.pool import ThreadPool import numpy as np try: from skimage.transform import warp_polar __have_skimage__ = True except ImportError: __have_skimage__ = False def azimuthal_integration(img, axes=(-2, -1), domain="direct"): """ Computes azimuthal integration of an image or a stack of images. Parameters ---------- img : `numpy.array_like` The image or stack of images. axes : tuple(int, int), optional Axes of that need to be azimuthally integrated. The default is (-2, -1). domain : string, optional Domain of the integration. Options are: "direct" | "fourier". Default is "direct". Raises ------ ValueError Error returned when not passing images or wrong axes. NotImplementedError In case of tack of images for the moment. Returns ------- `numpy.array_like` The azimuthally integrated profile. """ if not len(img.shape) >= 2: raise ValueError("Input image should be at least 2-dimensional.") if not len(axes) == 2: raise ValueError("Input axes should be 2.") img_axes_dims = np.array((img.shape[axes[0]], img.shape[axes[1]])) if domain.lower() == "direct": half_dims = (img_axes_dims - 1) / 2 xx = np.linspace(-half_dims[0], half_dims[0], img_axes_dims[0]) yy = np.linspace(-half_dims[1], half_dims[1], img_axes_dims[1]) else: xx = np.fft.fftfreq(img_axes_dims[0], 1 / img_axes_dims[0]) yy = np.fft.fftfreq(img_axes_dims[1], 1 / img_axes_dims[1]) xy = np.stack(np.meshgrid(xx, yy, indexing="ij")) r = np.sqrt(np.sum(xy**2, axis=0)) img_tr_op = [*range(len(img.shape))] for a in axes: img_tr_op.append(img_tr_op.pop(a)) img = np.transpose(img, img_tr_op) if len(img.shape) > 2: img_old_shape = img.shape[:-2] img = np.reshape(img, [-1, *img_axes_dims]) r_l = np.floor(r) r_u = r_l + 1 w_l = (r_u - r) * img w_u = (r - r_l) * img r_all = np.concatenate((r_l.flatten(), r_u.flatten())).astype(np.int64) if len(img.shape) == 2: w_all = np.concatenate((w_l.flatten(), w_u.flatten())) return np.bincount(r_all, weights=w_all) else: num_imgs = img.shape[0] az_img = [None] * num_imgs for ii in range(num_imgs): w_all = np.concatenate((w_l[ii, :].flatten(), w_u[ii, :].flatten())) az_img[ii] = np.bincount(r_all, weights=w_all) az_img = np.array(az_img) return np.reshape(az_img, (*img_old_shape, az_img.shape[-1])) def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, cal=None, return_radii=False): """ Translates the Java method `doRadialDistribution` (from imagej) into Python using NumPy. Done by chatgpt-4o on 2024-11-08 Args: - ip: A 2D numpy array representing the image. - X0, Y0: Coordinates of the center. - mR: Maximum radius. - nBins: Number of bins (optional, defaults to 3*mR/4). - use_calibration: Boolean indicating if calibration should be applied. - cal: Calibration object with attributes `pixel_width` and `units` (optional). """ if nBins is None: nBins = int(3 * mR / 4) Accumulator = np.zeros((2, nBins)) # Define the bounding box xmin, xmax = X0 - mR, X0 + mR ymin, ymax = Y0 - mR, Y0 + mR # Create grid of coordinates x = np.arange(xmin, xmax) y = np.arange(ymin, ymax) xv, yv = np.meshgrid(x, y, indexing="ij") # Calculate the radius for each point R = np.sqrt((xv - X0) ** 2 + (yv - Y0) ** 2) # Bin calculation bins = np.floor((R / mR) * nBins).astype(int) bins = np.clip(bins - 1, 0, nBins - 1) # Adjust bins to be in range [0, nBins-1] # Accumulate values for b in range(nBins): mask = bins == b Accumulator[0, b] = np.sum(mask) Accumulator[1, b] = np.sum(ip[mask]) # Normalize integrated intensity Accumulator[1] /= Accumulator[0] if use_calibration and cal is not None: # Apply calibration if units are provided radii = cal.pixel_width * mR * (np.arange(1, nBins + 1) / nBins) units = cal.units else: # Use pixel units radii = mR * (np.arange(1, nBins + 1) / nBins) units = "pixels" if return_radii: return radii, Accumulator[1] else: return Accumulator[1] # OK-ish, but small discrepancy with do_radial_distribution. # 20-40X faster than above methods for (2048, 2048) images # Also it assumes a uniform sampling # No idea why there is this "offset=1", to be investigated - perhaps radius=0 is also calculated ? def azimuthal_integration_skimage(img, center=None, offset=1): shape2 = [int(s // 2 * 1.4142) for s in img.shape] s = min(img.shape) // 2 img_polar = warp_polar(img, output_shape=shape2, center=center) return img_polar.mean(axis=0)[offset : offset + s] def _apply_on_images_stack(func, images_stack, n_threads=4, func_args=None, func_kwargs=None): func_args = func_args or [] func_kwargs = func_kwargs or {} def _process_image(img): return func(img, *func_args, **func_kwargs) with ThreadPool(n_threads) as tp: res = tp.map(_process_image, images_stack) return np.array(res) def _apply_on_patches_stack(func, images_stack, n_threads=4, func_args=None, func_kwargs=None): (n_images, n_patchs_y, img_shape_y, n_patchs_x, img_shape_x) = images_stack.shape func_args = func_args or [] func_kwargs = func_kwargs or {} out_sample = func(images_stack[0, 0, :, 0, :], *func_args, **func_kwargs) out_shape = out_sample.shape out_dtype = out_sample.dtype def _process_image(img): res = np.zeros((n_patchs_y, n_patchs_x) + out_shape, dtype=out_dtype) for i in range(n_patchs_y): for j in range(n_patchs_x): res[i, j] = func(img[i, :, j, :], *func_args, **func_kwargs) return res with ThreadPool(n_threads) as tp: res = tp.map(_process_image, images_stack) return np.array(res) def azimuthal_integration_imagej_stack(images_stack, n_threads=4): if images_stack.ndim == 3: img_shape = images_stack.shape[-2:] _apply = _apply_on_images_stack elif images_stack.ndim == 5: img_shape = np.array(images_stack.shape)[[-3, -1]] _apply = _apply_on_patches_stack else: raise ValueError s = min(img_shape) return _apply( do_radial_distribution, images_stack, n_threads=n_threads, func_args=[s // 2, s // 2, s // 2], func_kwargs={"nBins": s // 2, "return_radii": False}, ) def azimuthal_integration_skimage_stack(images_stack, n_threads=4): if images_stack.ndim == 3: return _apply_on_images_stack(azimuthal_integration_skimage, images_stack, n_threads=n_threads) elif images_stack.ndim == 5: return _apply_on_patches_stack(azimuthal_integration_skimage, images_stack, n_threads=n_threads) else: raise ValueError ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/convolution_cuda.py0000644000175000017500000003575214654107202021330 0ustar00pierrepierrefrom os.path import dirname import numpy as np from ..utils import updiv, get_cuda_srcfile from ..cuda.utils import __has_pycuda__ from ..misc.utils import ConvolutionInfos from ..cuda.processing import CudaProcessing if __has_pycuda__: from pycuda.compiler import SourceModule class Convolution: """ A class for performing convolution on GPU with CUDA, but without using textures (unlike for example in ``silx.opencl.convolution``) """ def __init__(self, shape, kernel, axes=None, mode=None, extra_options=None, cuda_options=None): """ Constructor of Cuda Convolution. Parameters ----------- shape: tuple Shape of the array. kernel: array-like Convolution kernel (1D, 2D or 3D). axes: tuple, optional Axes along which the convolution is performed, for batched convolutions. mode: str, optional Boundary handling mode. Available modes are: - "reflect": cba|abcd|dcb - "nearest": aaa|abcd|ddd - "wrap": bcd|abcd|abc - "constant": 000|abcd|000 Default is "reflect". extra_options: dict, optional Advanced options (dict). Current options are: - "allocate_input_array": True - "allocate_output_array": True - "allocate_tmp_array": True - "sourcemodule_kwargs": {} - "batch_along_flat_dims": True """ self.cuda = CudaProcessing(**(cuda_options or {})) self._configure_extra_options(extra_options) self._determine_use_case(shape, kernel, axes) self._allocate_memory(mode) self._init_kernels() def _configure_extra_options(self, extra_options): self.extra_options = { "allocate_input_array": True, "allocate_output_array": True, "allocate_tmp_array": True, "sourcemodule_kwargs": {}, "batch_along_flat_dims": True, } extra_opts = extra_options or {} self.extra_options.update(extra_opts) self.sourcemodule_kwargs = self.extra_options["sourcemodule_kwargs"] def _get_dimensions(self, shape, kernel): self.shape = shape self.data_ndim = self._check_dimensions(shape=shape, name="Data") self.kernel_ndim = self._check_dimensions(arr=kernel, name="Kernel") Nx = shape[-1] if self.data_ndim >= 2: Ny = shape[-2] else: Ny = 1 if self.data_ndim >= 3: Nz = shape[-3] else: Nz = 1 self.Nx = np.int32(Nx) self.Ny = np.int32(Ny) self.Nz = np.int32(Nz) def _determine_use_case(self, shape, kernel, axes): """ Determine the convolution use case from the input/kernel shape, and axes. """ self._get_dimensions(shape, kernel) if self.kernel_ndim > self.data_ndim: raise ValueError("Kernel dimensions cannot exceed data dimensions") data_ndim = self.data_ndim kernel_ndim = self.kernel_ndim self.kernel = kernel.astype("f") convol_infos = ConvolutionInfos() k = (data_ndim, kernel_ndim) if k not in convol_infos.use_cases: raise ValueError( "Cannot find a use case for data ndim = %d and kernel ndim = %d" % (data_ndim, kernel_ndim) ) possible_use_cases = convol_infos.use_cases[k] # If some dimensions are "flat", make a batched convolution along them # Ex. data_dim = (1, Nx) -> batched 1D convolution if self.extra_options["batch_along_flat_dims"] and (1 in self.shape): axes = tuple([curr_dim for numels, curr_dim in zip(self.shape, range(len(self.shape))) if numels != 1]) # self.use_case_name = None for uc_name, uc_params in possible_use_cases.items(): if axes in convol_infos.allowed_axes[uc_name]: self.use_case_name = uc_name self.use_case_desc = uc_params["name"] self.use_case_kernels = uc_params["kernels"].copy() if self.use_case_name is None: raise ValueError( "Cannot find a use case for data ndim = %d, kernel ndim = %d and axes=%s" % (data_ndim, kernel_ndim, str(axes)) ) # TODO implement this use case if self.use_case_name == "batched_separable_2D_1D_3D": raise NotImplementedError("The use case %s is not implemented" % self.use_case_name) # self.axes = axes # Replace "axes=None" with an actual value (except for ND-ND) allowed_axes = convol_infos.allowed_axes[self.use_case_name] if len(allowed_axes) > 1: # The default choice might impact perfs self.axes = allowed_axes[0] or allowed_axes[1] self.separable = self.use_case_name.startswith("separable") self.batched = self.use_case_name.startswith("batched") def _allocate_memory(self, mode): self.mode = mode or "reflect" # The current implementation does not support kernel size bigger than data size, # except for mode="nearest" for i, dim_size in enumerate(self.shape): if min(self.kernel.shape) > dim_size and i in self.axes: print( "Warning: kernel support is too large for data dimension %d (%d). Forcing convolution mode to 'nearest'" % (i, dim_size) ) self.mode = "nearest" # option_array_names = { "allocate_input_array": "data_in", "allocate_output_array": "data_out", "allocate_tmp_array": "data_tmp", } # Nonseparable transforms do not need tmp array if not (self.separable): self.extra_options["allocate_tmp_array"] = False # Allocate arrays for option_name, array_name in option_array_names.items(): if self.extra_options[option_name]: value = self.cuda.allocate_array("value", self.shape, np.float32) else: value = None setattr(self, array_name, value) if isinstance(self.kernel, np.ndarray): self.d_kernel = self.cuda.to_device("d_kernel", self.kernel) else: if not (isinstance(self.kernel, self.cuda.array_class)): raise ValueError("kernel must be either numpy array or pycuda array") self.d_kernel = self.kernel self._old_input_ref = None self._old_output_ref = None self._c_modes_mapping = { "periodic": 2, "wrap": 2, "nearest": 1, "replicate": 1, "reflect": 0, "constant": 3, } mp = self._c_modes_mapping if self.mode.lower() not in mp: raise ValueError( """ Mode %s is not available. Available modes are: %s """ % (self.mode, str(mp.keys())) ) if self.mode.lower() == "constant": raise NotImplementedError("mode='constant' is not implemented yet") self._c_conv_mode = mp[self.mode] def _init_kernels(self): if self.kernel_ndim > 1: if np.abs(np.diff(self.kernel.shape)).max() > 0: raise NotImplementedError("Non-separable convolution with non-square kernels is not implemented yet") # Compile source module compile_options = [str("-DUSED_CONV_MODE=%d" % self._c_conv_mode)] fname = get_cuda_srcfile("convolution.cu") nabu_cuda_dir = dirname(fname) include_dirs = [nabu_cuda_dir] self.sourcemodule_kwargs["options"] = compile_options self.sourcemodule_kwargs["include_dirs"] = include_dirs with open(fname) as fid: cuda_src = fid.read() self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs) # pylint: disable=E0606 # Blocks, grid self._block_size = {1: (32, 1, 1), 2: (32, 32, 1), 3: (16, 8, 8)}[self.data_ndim] # TODO tune self._n_blocks = tuple([int(updiv(a, b)) for a, b in zip(self.shape[::-1], self._block_size)]) # Prepare cuda kernel calls self._cudakernel_signature = { 1: "PPPiiii", 2: "PPPiiiii", 3: "PPPiiiiii", }[self.kernel_ndim] self.cuda_kernels = {} for axis, kern_name in enumerate(self.use_case_kernels): self.cuda_kernels[axis] = self._module.get_function(kern_name) self.cuda_kernels[axis].prepare(self._cudakernel_signature) # Cuda kernel arguments kernel_args = [ self._n_blocks, self._block_size, None, None, self.d_kernel.gpudata, np.int32(self.kernel.shape[0]), self.Nx, self.Ny, self.Nz, ] if self.kernel_ndim == 2: kernel_args.insert(5, np.int32(self.kernel.shape[1])) if self.kernel_ndim == 3: kernel_args.insert(5, np.int32(self.kernel.shape[2])) kernel_args.insert(6, np.int32(self.kernel.shape[1])) self.kernel_args = tuple(kernel_args) # If self.data_tmp is allocated, separable transforms can be performed # by a series of batched transforms, without any copy, by swapping refs. self.swap_pattern = None if self.separable: if self.data_tmp is not None: self.swap_pattern = { 2: [("data_in", "data_tmp"), ("data_tmp", "data_out")], 3: [ ("data_in", "data_out"), ("data_out", "data_tmp"), ("data_tmp", "data_out"), ], } else: raise NotImplementedError("For now, data_tmp has to be allocated") def _get_swapped_arrays(self, i): """ Get the input and output arrays to use when using a "swap pattern". Swapping refs enables to avoid copies between temp. array and output. For example, a separable 2D->1D convolution on 2D data reads: data_tmp = convol(data_input, kernel, axis=1) # step i=0 data_out = convol(data_tmp, kernel, axis=0) # step i=1 :param i: current step number of the separable convolution """ n_batchs = len(self.axes) in_ref, out_ref = self.swap_pattern[n_batchs][i] d_in = getattr(self, in_ref) d_out = getattr(self, out_ref) return d_in, d_out def _configure_kernel_args(self, cuda_kernel_args, input_ref, output_ref): # TODO more elegant if isinstance(input_ref, self.cuda.array_class): input_ref = input_ref.gpudata if isinstance(output_ref, self.cuda.array_class): output_ref = output_ref.gpudata if input_ref is not None or output_ref is not None: cuda_kernel_args = list(cuda_kernel_args) if input_ref is not None: cuda_kernel_args[2] = input_ref if output_ref is not None: cuda_kernel_args[3] = output_ref cuda_kernel_args = tuple(cuda_kernel_args) return cuda_kernel_args @staticmethod def _check_dimensions(arr=None, shape=None, name="", dim_min=1, dim_max=3): if shape is not None: ndim = len(shape) elif arr is not None: ndim = arr.ndim else: raise ValueError("Please provide either arr= or shape=") if ndim < dim_min or ndim > dim_max: raise ValueError("%s dimensions should be between %d and %d" % (name, dim_min, dim_max)) return ndim def _check_array(self, arr): if not (isinstance(arr, self.cuda.array_class) or isinstance(arr, np.ndarray)): raise TypeError("Expected either pycuda.gpuarray or numpy.ndarray") if arr.dtype != np.float32: raise TypeError("Data must be float32") if arr.shape != self.shape: raise ValueError("Expected data shape = %s" % str(self.shape)) def _set_arrays(self, array, output=None): # Either copy H->D or update references. if isinstance(array, np.ndarray): self.data_in[:] = array[:] else: self._old_input_ref = self.data_in self.data_in = array data_in_ref = self.data_in if output is not None: if not (isinstance(output, np.ndarray)): self._old_output_ref = self.data_out self.data_out = output # Update Cuda kernel arguments with new array references self.kernel_args = self._configure_kernel_args(self.kernel_args, data_in_ref, self.data_out) def _separable_convolution(self): assert len(self.axes) == len(self.use_case_kernels) # Separable: one kernel call per data dimension for i, axis in enumerate(self.axes): in_ref, out_ref = self._get_swapped_arrays(i) self._batched_convolution(axis, input_ref=in_ref, output_ref=out_ref) def _batched_convolution(self, axis, input_ref=None, output_ref=None): # Batched: one kernel call in total cuda_kernel = self.cuda_kernels[axis] cuda_kernel_args = self._configure_kernel_args(self.kernel_args, input_ref, output_ref) ev = cuda_kernel.prepared_call(*cuda_kernel_args) def _nd_convolution(self): assert len(self.use_case_kernels) == 1 cuda_kernel = self._module.get_function(self.use_case_kernels[0]) ev = cuda_kernel.prepared_call(*self.kernel_args) def _recover_arrays_references(self): if self._old_input_ref is not None: self.data_in = self._old_input_ref self._old_input_ref = None if self._old_output_ref is not None: self.data_out = self._old_output_ref self._old_output_ref = None self.kernel_args = self._configure_kernel_args(self.kernel_args, self.data_in, self.data_out) def _get_output(self, output): if output is None: res = self.data_out.get() else: res = output if isinstance(output, np.ndarray): output[:] = self.data_out[:] self._recover_arrays_references() return res def convolve(self, array, output=None): """ Convolve an array with the class kernel. :param array: Input array. Can be numpy.ndarray or pycuda.gpuarray.GPUArray. :param output: Output array. Can be numpy.ndarray or pycuda.gpuarray.GPUArray. """ self._check_array(array) self._set_arrays(array, output=output) if self.axes is not None: if self.separable: self._separable_convolution() elif self.batched: assert len(self.axes) == 1 self._batched_convolution(self.axes[0]) # else: ND-ND convol else: # ND-ND convol self._nd_convolution() res = self._get_output(output) return res __call__ = convolve ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/fft_base.py0000644000175000017500000001303014550227307017513 0ustar00pierrepierreimport numpy as np from ..utils import BaseClassError class _BaseFFT: """ A base class for FFTs. """ implem = "none" ProcessingCls = BaseClassError def __init__(self, shape, dtype, r2c=True, axes=None, normalize="rescale", **backend_options): """ Base class for Fast Fourier Transform (FFT). Parameters ---------- shape: list of int Shape of the input data dtype: str or numpy.dtype Data type of the input data r2c: bool, optional Whether to use real-to-complex transform for real-valued input. Default is True. axes: list of int, optional Axes along which FFT is computed. * For 2D transform: axes=(1,0) * For batched 1D transform of 2D image: axes=(-1,) normalize: str, optional Whether to normalize FFT and IFFT. Possible values are: * "rescale": in this case, Fourier data is divided by "N" before IFFT, so that IFFT(FFT(data)) = data. This corresponds to numpy norm=None i.e norm="backward". * "ortho": in this case, FFT and IFFT are adjoint of eachother, the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N). * "none": no normalizatio is done : IFFT(FFT(data)) = data*N Other parameters ----------------- backend_options: dict, optional Parameters to pass to CudaProcessing or OpenCLProcessing class. """ self._init_backend(backend_options) self._set_dtypes(dtype, r2c) self._set_shape_and_axes(shape, axes) self._configure_batched_transform() self._configure_normalization(normalize) self._compute_fft_plans() def _init_backend(self, backend_options): self.processing = self.ProcessingCls(**backend_options) def _set_dtypes(self, dtype, r2c): self.dtype = np.dtype(dtype) dtypes_mapping = { np.dtype("float32"): np.complex64, np.dtype("float64"): np.complex128, np.dtype("complex64"): np.complex64, np.dtype("complex128"): np.complex128, } if self.dtype not in dtypes_mapping: raise ValueError("Invalid input data type: got %s" % self.dtype) self.dtype_out = dtypes_mapping[self.dtype] self.r2c = r2c def _set_shape_and_axes(self, shape, axes): # Input shape if np.isscalar(shape): shape = (shape,) self.shape = shape # Axes default_axes = tuple(range(len(self.shape))) if axes is None: self.axes = default_axes else: self.axes = tuple(np.array(default_axes)[np.array(axes)]) # Output shape shape_out = self.shape if self.r2c: reduced_dim = self.axes[-1] if self.axes is not None else -1 shape_out = list(shape_out) shape_out[reduced_dim] = shape_out[reduced_dim] // 2 + 1 shape_out = tuple(shape_out) self.shape_out = shape_out def _configure_batched_transform(self): pass def _configure_normalization(self, normalize): pass def _compute_fft_plans(self): pass class _BaseVKFFT(_BaseFFT): """ FFT using VKFFT backend """ implem = "vkfft" backend = "none" ProcessingCls = BaseClassError vkffs_cls = BaseClassError def _configure_batched_transform(self): if self.axes is not None and len(self.shape) == len(self.axes): self.axes = None return if self.r2c: # batched Real-to-complex transforms are supported only along fast axes if not (is_fast_axes(len(self.shape), self.axes)): raise ValueError("For %dD R2C, only batched transforms along fast axes are allowed" % (len(self.shape))) self._vkfft_ndim = len(self.axes) self.axes = None # vkfft still can do a batched transform by providing dim=XX, axes=None def _configure_normalization(self, normalize): self.normalize = normalize self._vkfft_norm = { "rescale": 1, "backward": 1, "ortho": "ortho", "none": 0, }.get(self.normalize, 1) def _set_shape_and_axes(self, shape, axes): super()._set_shape_and_axes(shape, axes) self._vkfft_ndim = None def _compute_fft_plans(self): self._vkfft_plan = self.vkffs_cls( self.shape, self.dtype, ndim=self._vkfft_ndim, inplace=False, norm=self._vkfft_norm, r2c=self.r2c, dct=False, axes=self.axes, strides=None, **self._vkfft_other_init_kwargs, ) def fft(self, array, output=None): if output is None: output = self.output_fft = self.processing.allocate_array( "output_fft", self.shape_out, dtype=self.dtype_out ) return self._vkfft_plan.fft(array, dest=output) def ifft(self, array, output=None): if output is None: output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype) return self._vkfft_plan.ifft(array, dest=output) def is_fast_axes(ndim, axes): """ Return true if "axes" are the fast dimensions """ all_axes = list(range(ndim)) axes = sorted([ax + ndim if ax < 0 else ax for ax in axes]) # transform "-1" to an actual axis index (1 for 2D) return all_axes[-len(axes) :] == axes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/processing/fft_cuda.py0000644000175000017500000002133414712705065017525 0ustar00pierrepierreimport os import warnings from multiprocessing import get_context from multiprocessing.pool import Pool import numpy as np from ..utils import check_supported from .fft_base import _BaseFFT, _BaseVKFFT try: from pyvkfft.cuda import VkFFTApp as vk_cufft __has_vkfft__ = True except (ImportError, OSError): __has_vkfft__ = False vk_cufft = None from ..cuda.processing import CudaProcessing Plan = None cu_fft = None cu_ifft = None __has_skcuda__ = None def init_skcuda(): # This needs to be done here, because scikit-cuda creates a Cuda context at import, # which can mess things up in some cases. # Ugly solution to an ugly problem. global __has_skcuda__, Plan, cu_fft, cu_ifft try: from skcuda.fft import Plan from skcuda.fft import fft as cu_fft from skcuda.fft import ifft as cu_ifft __has_skcuda__ = True except ImportError: __has_skcuda__ = False class SKCUFFT(_BaseFFT): implem = "skcuda" backend = "cuda" ProcessingCls = CudaProcessing def _configure_batched_transform(self): if __has_skcuda__ is None: init_skcuda() if not (__has_skcuda__): raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end") self.cufft_batch_size = 1 self.cufft_shape = self.shape self._cufft_plan_kwargs = {} if (self.axes is not None) and (len(self.axes) < len(self.shape)): # In the easiest case, the transform is computed along the fastest dimensions: # - 1D transforms of lines of 2D data # - 2D transforms of images of 3D data (stacked along slow dim) # - 1D transforms of 3D data along fastest dim # Otherwise, we have to configure cuda "advanced memory layout". data_ndims = len(self.shape) if data_ndims == 2: n_y, n_x = self.shape along_fast_dim = self.axes[0] == 1 self.cufft_shape = n_x if along_fast_dim else n_y self.cufft_batch_size = n_y if along_fast_dim else n_x if not (along_fast_dim): # Batched vertical 1D FFT on 2D data need advanced data layout # http://docs.nvidia.com/cuda/cufft/#advanced-data-layout self._cufft_plan_kwargs = { "inembed": np.int32([0]), "istride": n_x, "idist": 1, "onembed": np.int32([0]), "ostride": n_x, "odist": 1, } if data_ndims == 3: # TODO/FIXME - the following work for C2C but not R2C ?! # fast_axes = [(1, 2), (2, 1), (2,)] fast_axes = [(2,)] if self.axes not in fast_axes: raise NotImplementedError( "With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions" ) self.cufft_batch_size = self.shape[0] self.cufft_shape = self.shape[1:] if len(self.axes) == 1: # 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny self.cufft_batch_size = np.prod(self.shape[:2]) self.cufft_shape = (self.shape[-1],) if len(self.cufft_shape) == 1: self.cufft_shape = self.cufft_shape[0] def _configure_normalization(self, normalize): self.normalize = normalize if self.normalize == "ortho": # TODO raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.") self.cufft_scale_inverse = self.normalize == "rescale" def _compute_fft_plans(self): self.plan_forward = Plan( # pylint: disable = E1102 self.cufft_shape, self.dtype, self.dtype_out, batch=self.cufft_batch_size, stream=self.processing.stream, **self._cufft_plan_kwargs, # cufft extensible plan API is only supported after 0.5.1 # (commit 65288d28ca0b93e1234133f8d460dc6becb65121) # but there is still no official 0.5.2 # ~ auto_allocate=True # cufft extensible plan API ) self.plan_inverse = Plan( # pylint: disable = E1102 self.cufft_shape, # not shape_out self.dtype_out, self.dtype, batch=self.cufft_batch_size, stream=self.processing.stream, **self._cufft_plan_kwargs, # cufft extensible plan API is only supported after 0.5.1 # (commit 65288d28ca0b93e1234133f8d460dc6becb65121) # but there is still no official 0.5.2 # ~ auto_allocate=True ) def fft(self, array, output=None): if output is None: output = self.output_fft = self.processing.allocate_array( "output_fft", self.shape_out, dtype=self.dtype_out ) cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102 return output def ifft(self, array, output=None): if output is None: output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype) cu_ifft( # pylint: disable = E1102 array, output, self.plan_inverse, scale=self.cufft_scale_inverse, ) return output class VKCUFFT(_BaseVKFFT): """ Cuda FFT, using VKFFT backend """ implem = "vkfft" backend = "cuda" ProcessingCls = CudaProcessing vkffs_cls = vk_cufft def _init_backend(self, backend_options): super()._init_backend(backend_options) self._vkfft_other_init_kwargs = {"stream": self.processing.stream} def _has_vkfft(x): # should be run from within a Process try: from nabu.processing.fft_cuda import VKCUFFT, __has_vkfft__ if not __has_vkfft__: return False vk = VKCUFFT((16,), "f") avail = True except (ImportError, RuntimeError, OSError, NameError): avail = False return avail def has_vkfft(safe=True): """ Determine whether pyvkfft is available. For Cuda GPUs, vkfft relies on nvrtc which supports a narrow range of Cuda devices. Unfortunately, it's not possible to determine whether vkfft is available before creating a Cuda context. So we create a process (from scratch, i.e no fork), do the test within, and exit. This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__"). """ if not safe: return _has_vkfft(None) ctx = get_context("spawn") with Pool(1, context=ctx) as p: v = p.map(_has_vkfft, [1])[0] return v def _has_skfft(x): # should be run from within a Process try: from nabu.processing.fft_cuda import SKCUFFT sk = SKCUFFT((16,), "f") avail = True except (ImportError, RuntimeError, OSError, NameError): avail = False return avail def has_skcuda(safe=True): """ Determine whether scikit-cuda/CUFFT is available. Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution. Do it in a separate thread. """ if not safe: return _has_skfft(None) ctx = get_context("spawn") with Pool(1, context=ctx) as p: v = p.map(_has_skfft, [1])[0] return v def get_fft_class(backend="vkfft"): backends = { "scikit-cuda": SKCUFFT, "skcuda": SKCUFFT, "cufft": SKCUFFT, "scikit": SKCUFFT, "vkfft": VKCUFFT, "pyvkfft": VKCUFFT, } def get_fft_cls(asked_fft_backend): asked_fft_backend = asked_fft_backend.lower() check_supported(asked_fft_backend, list(backends.keys()), "Cuda FFT backend name") return backends[asked_fft_backend] asked_fft_backend_env = os.environ.get("NABU_FFT_BACKEND", "") if asked_fft_backend_env != "": return get_fft_cls(asked_fft_backend_env) avail_fft_implems = get_available_fft_implems() if len(avail_fft_implems) == 0: raise RuntimeError("Could not any Cuda FFT implementation. Please install either scikit-cuda or pyvkfft") if backend not in avail_fft_implems: warnings.warn("Could not get FFT backend '%s'" % backend, RuntimeWarning) backend = avail_fft_implems[0] return get_fft_cls(backend) def get_available_fft_implems(): avail_implems = [] if has_vkfft(safe=True): avail_implems.append("vkfft") if has_skcuda(safe=True): avail_implems.append("skcuda") return avail_implems ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/fft_opencl.py0000644000175000017500000000255214550227307020070 0ustar00pierrepierrefrom multiprocessing import get_context from multiprocessing.pool import Pool from .fft_base import _BaseVKFFT from ..opencl.processing import OpenCLProcessing try: from pyvkfft.opencl import VkFFTApp as vk_clfft __has_vkfft__ = True except (ImportError, OSError): __has_vkfft__ = False vk_clfft = None class VKCLFFT(_BaseVKFFT): """ OpenCL FFT, using VKFFT backend """ implem = "vkfft" backend = "opencl" ProcessingCls = OpenCLProcessing vkffs_cls = vk_clfft def _init_backend(self, backend_options): super()._init_backend(backend_options) self._vkfft_other_init_kwargs = {"queue": self.processing.queue} def _has_vkfft(x): # should be run from within a Process try: from nabu.processing.fft_opencl import VKCLFFT, __has_vkfft__ if not __has_vkfft__: return False vk = VKCLFFT((16,), "f") avail = True except (RuntimeError, OSError): avail = False return avail def has_vkfft(safe=True): """ Determine whether pyvkfft is available. This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__"). """ if not safe: return _has_vkfft(None) ctx = get_context("spawn") with Pool(1, context=ctx) as p: v = p.map(_has_vkfft, [1])[0] return v ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/fftshift.py0000644000175000017500000001142014550227307017560 0ustar00pierrepierreimport numpy as np from ..utils import BaseClassError, get_opencl_srcfile, updiv from ..opencl.kernel import OpenCLKernel from ..opencl.processing import OpenCLProcessing from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype class FFTshiftBase: KernelCls = BaseClassError ProcessingCls = BaseClassError dtype_to_ctype = BaseClassError backend = "none" def __init__(self, shape, dtype, dst_dtype=None, axes=None, **backend_options): """ Parameters ---------- shape: tuple Array shape - can be 1D or 2D. 3D is not supported. dtype: str or numpy.dtype Data type, eg. "f", numpy.complex64, ... dst_dtype: str or numpy.dtype Output data type. If not provided (default), the shift is done in-place. axes: tuple, optional Axes over which to shift. Default is None, which shifts all axes. Other parameters ---------------- backend_options: named arguments to pass to CudaProcessing or OpenCLProcessing """ # if axes not in [1, (1,), (-1,)]: raise NotImplementedError # self.processing = self.ProcessingCls(**backend_options) self.shape = shape if len(self.shape) not in [1, 2]: raise ValueError("Expected 1D or 2D array") self.dtype = np.dtype(dtype) self.dst_dtype = dst_dtype if dst_dtype is None: self._configure_inplace_shift() else: self._configure_out_of_place_shift() self._configure_kenel_initialization() self._fftshift_kernel = self.KernelCls(*self._kernel_init_args, **self._kernel_init_kwargs) self._configure_kernel_call() def _configure_inplace_shift(self): self.inplace = True # in-place on odd-sized array is more difficult - see fftshift.cl if self.shape[-1] & 1: raise NotImplementedError # self._kernel_init_args = [ "fftshift_x_inplace", ] self._kernel_init_kwargs = { "options": [ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype), ], } def _configure_out_of_place_shift(self): self.inplace = False self._kernel_init_args = [ "fftshift_x", ] self._kernel_init_kwargs = { "options": [ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype), "-DDTYPE_OUT=%s" % self.dtype_to_ctype(np.dtype(self.dst_dtype)), ], } additional_flag = None input_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dtype)) output_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dst_dtype)) if not (input_is_complex) and output_is_complex: additional_flag = "-DCAST_TO_COMPLEX" if input_is_complex and not (output_is_complex): additional_flag = "-DCAST_TO_REAL" if additional_flag is not None: self._kernel_init_kwargs["options"].append(additional_flag) def _call_fftshift_inplace(self, arr, direction): self._fftshift_kernel( # pylint: disable=E1102 arr, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs ) return arr def _call_fftshift_out_of_place(self, arr, dst, direction): if dst is None: dst = self.processing.allocate_array("dst", arr.shape, dtype=self.dst_dtype) self._fftshift_kernel( # pylint: disable=E1102 arr, dst, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs ) return dst def fftshift(self, arr, dst=None): if self.inplace: return self._call_fftshift_inplace(arr, 1) else: return self._call_fftshift_out_of_place(arr, dst, 1) def ifftshift(self, arr, dst=None): if self.inplace: return self._call_fftshift_inplace(arr, -1) else: return self._call_fftshift_out_of_place(arr, dst, -1) class OpenCLFFTshift(FFTshiftBase): KernelCls = OpenCLKernel ProcessingCls = OpenCLProcessing dtype_to_ctype = cl_dtype_to_ctype backend = "opencl" def _configure_kenel_initialization(self): self._kernel_init_args.append(self.processing.ctx) self._kernel_init_kwargs.update( { "filename": get_opencl_srcfile("fftshift.cl"), "queue": self.processing.queue, } ) def _configure_kernel_call(self): # TODO in-place fftshift needs to launch only arr.size//2 threads block = (16, 16, 1) grid = [updiv(a, b) * b for a, b in zip(self.shape[::-1], block)] self._kernel_kwargs = {"global_size": grid, "local_size": block} ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/histogram.py0000644000175000017500000002616214654107202017745 0ustar00pierrepierrefrom math import log2, ceil import numpy as np from silx.math import Histogramnd from tomoscan.io import HDF5File from ..utils import check_supported from ..resources.logger import LoggerOrPrint class PartialHistogram: """ A class for computing histogram progressively. In certain cases, it is cumbersome to compute a histogram directly on a big chunk of data (ex. data not fitting in memory, disk access too slow) while some parts of the data are readily available in-memory. """ histogram_methods = ["fixed_bins_width", "fixed_bins_number"] bin_width_policies = ["uint16"] backends = ["numpy", "silx"] def __init__(self, method="fixed_bins_width", bin_width="uint16", num_bins=None, min_bins=None, backend="silx"): """ Initialize a PartialHistogram class. Parameters ---------- method: str, optional Partial histogram computing method. Available are: - `fixed_bins_width`: all the histograms are computed with the same bin width. The class adapts to the data range and computes the number of bins accordingly. - `fixed_bins_number`: all the histograms are computed with the same number of bins. The class adapts to the data range and computes the bin width accordingly. Default is "fixed_bins_width" bin_width: str or float, optional Policy for histogram bins when method="fixed_bins_width". Available are: - "uint16": The bin width is computed so that floating-point elements `f1` and `f2` satisfying `|f1 - f2| < bin_width` implies `f1_converted - f2_converted < 1` once cast to uint16. - A number: all the bins have this fixed width. Default is "uint16" num_bins: int, optional Number of bins when method = 'fixed_bins_number'. min_bins: int, optional Minimum number of bins when method = 'fixed_bins_width'. backend: str, optional Which histogram backend to use for computations. Available are "silx", "numpy". Fastest is "silx". """ check_supported(method, self.histogram_methods, "histogram computing method") self.method = method check_supported(backend, self.backends, "histogram backend") self.backend = backend self._set_bin_width(bin_width) self._set_num_bins(num_bins) self.min_bins = min_bins self._set_histogram_methods() def _set_bin_width(self, bin_width): if self.method == "fixed_bins_number": self.bin_width = None return if isinstance(bin_width, str): check_supported(bin_width, self.bin_width_policies, "bin width policy") self._fixed_bw = False else: bin_width = float(bin_width) self._fixed_bw = True self.bin_width = bin_width def _set_num_bins(self, num_bins): if self.method == "fixed_bins_width": self.num_bins = None return if self.method == "fixed_bins_number" and num_bins is None: raise ValueError("Need to specify num_bins for method='fixed_bins_number'") self.num_bins = int(num_bins) def _set_histogram_methods(self): self._histogram_methods = { "fixed_bins_number": { "compute": self._compute_histogram_fixed_nbins, "merge": self._merge_histograms_fixed_nbins, }, "fixed_bins_width": { "compute": self._compute_histogram_fixed_bw, "merge": self._merge_histograms_fixed_bw, }, } assert set(self._histogram_methods.keys()) == set(self.histogram_methods) @staticmethod def _get_histograms_and_bins(histograms, center=False, dont_truncate_bins=False): histos = [h[0] for h in histograms] if dont_truncate_bins: bins = [h[1] for h in histograms] else: if center: bins = [0.5 * (h[1][1:] + h[1][:-1]) for h in histograms] else: bins = [h[1][:-1] for h in histograms] return histos, bins # # Histogram with fixed number of bins # def _compute_histogram_fixed_nbins(self, data, data_range=None): if data.ndim > 1: data = data.ravel() dmin, dmax = data.min(), data.max() if data_range is None else data_range if self.backend == "numpy": res = np.histogram(data, bins=self.num_bins) elif self.backend == "silx": histogrammer = Histogramnd(data, n_bins=self.num_bins, histo_range=(dmin, dmax), last_bin_closed=True) res = histogrammer.histo, histogrammer.edges[0] # pylint: disable=E1136 else: raise ValueError("Unknown backend") return res def _merge_histograms_fixed_nbins(self, histograms, dont_truncate_bins=False): histos, bins = self._get_histograms_and_bins(histograms, dont_truncate_bins=dont_truncate_bins) res = np.histogram( np.hstack(bins), weights=np.hstack(histos), bins=self.num_bins, ) return res # # Histogram with fixed bin width # def _bin_width_u16(self, dmin, dmax): return (dmax - dmin) / 65535.0 def _bin_width_fixed(self, dmin, dmax): return self.bin_width def get_bin_width(self, dmin, dmax): if self._fixed_bw: return self._bin_width_fixed(dmin, dmax) elif self.bin_width == "uint16": return self._bin_width_u16(dmin, dmax) else: raise ValueError() def _compute_histogram_fixed_bw(self, data, data_range=None): dmin, dmax = data.min(), data.max() if data_range is None else data_range min_bins = self.min_bins or 1 bw_max = self.get_bin_width(dmin, dmax) nbins = 0 bw_factor = 1 while nbins < min_bins: bw = 2 ** round(log2(bw_max)) / bw_factor nbins = int((dmax - dmin) / bw) bw_factor *= 2 res = np.histogram(data, bins=nbins) return res def _merge_histograms_fixed_bw(self, histograms, **kwargs): histos, bins = self._get_histograms_and_bins(histograms, center=False) dmax = max([b[-1] for b in bins]) dmin = min([b[0] for b in bins]) bw_max = max([b[1] - b[0] for b in bins]) res = np.histogram(np.hstack(bins), weights=np.hstack(histos), bins=int((dmax - dmin) / bw_max)) return res # # Dispatch methods # def compute_histogram(self, data, data_range=None): compute_hist_func = self._histogram_methods[self.method]["compute"] return compute_hist_func(data, data_range=data_range) def merge_histograms(self, histograms, **kwargs): merge_hist_func = self._histogram_methods[self.method]["merge"] return merge_hist_func(histograms, **kwargs) class VolumeHistogram: """ A class for computing the histogram of an entire volume. Unless explicitly specified, histogram is computed in several passes so that not all the volume is loaded in memory. """ def __init__(self, data_url, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None): """ Initialize a VolumeHistogram object. Parameters ---------- fname: DataUrl DataUrl to the HDF5 file. chunk_size_slices: int, optional Compute partial histograms of groups of slices. This is the default behavior, where the groups size is 100 slices. This parameter is mutually exclusive with 'chunk_size_GB'. chunk_size_GB: float, optional Maximum memory (in GB) to use when computing the histogram by group of slices. This parameter is mutually exclusive with 'chunk_size_slices'. nbins: int, optional Histogram number of bins. Default is 1e6. """ self.data_url = data_url self.logger = LoggerOrPrint(logger) self._get_data_info() self._set_chunk_size(chunk_size_slices, chunk_size_GB) self.nbins = int(nbins) self._init_histogrammer() def _get_data_info(self): self.fname = self.data_url.file_path() self.data_path = self.data_url.data_path() with HDF5File(self.fname, "r") as fid: try: data_ptr = fid[self.data_path] except KeyError: msg = str( "Could not access HDF5 path %s in file %s. Please check that this file \ actually contains a reconstruction and that the HDF5 path is correct" % (self.data_path, self.fname) ) self.logger.fatal(msg) raise ValueError(msg) if data_ptr.ndim != 3: msg = "Expected data to have 3 dimensions, got %d" % data_ptr.ndim raise ValueError(msg) self.data_shape = data_ptr.shape self.data_dtype = data_ptr.dtype self.data_nbytes_GB = np.prod(data_ptr.shape) * data_ptr.dtype.itemsize / 1e9 def _set_chunk_size(self, chunk_size_slices, chunk_size_GB): if not ((chunk_size_slices is not None) ^ (chunk_size_GB is not None)): raise ValueError("Please specify either chunk_size_slices or chunk_size_GB") if chunk_size_slices is None: chunk_size_slices = int(chunk_size_GB / (np.prod(self.data_shape[1:]) * self.data_dtype.itemsize / 1e9)) self.chunk_size = chunk_size_slices self.logger.debug("Computing histograms by groups of %d slices" % self.chunk_size) def _init_histogrammer(self): self.histogrammer = PartialHistogram(method="fixed_bins_number", num_bins=self.nbins) def _compute_histogram(self, data): return self.histogrammer.compute_histogram(data.ravel()) # 1D def compute_volume_histogram(self): n_z = self.data_shape[0] histograms = [] n_steps = ceil(n_z / self.chunk_size) with HDF5File(self.fname, "r") as fid: for chunk_id in range(n_steps): self.logger.debug("Computing histogram %d/%d" % (chunk_id + 1, n_steps)) z_slice = slice(chunk_id * self.chunk_size, (chunk_id + 1) * self.chunk_size) images_stack = fid[self.data_path][z_slice, :, :] hist = self._compute_histogram(images_stack) histograms.append(hist) res = self.histogrammer.merge_histograms(histograms) return res def hist_as_2Darray(hist, center=True, dtype="f"): hist, bins = hist if bins.size != hist.size: # assert bins.size == hist.size +1 if center: bins = 0.5 * (bins[1:] + bins[:-1]) else: bins = bins[:-1] res = np.zeros((2, hist.size), dtype=dtype) res[0] = hist res[1] = bins.astype(dtype) return res def add_last_bin(histo_bins): """ Add the last bin (max value) to a list of bin edges. """ res = np.zeros(histo_bins.size + 1, dtype=histo_bins.dtype) res[:-1] = histo_bins[:] res[-1] = res[-2] + (res[1] - res[0]) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/histogram_cuda.py0000644000175000017500000000641214654107202020735 0ustar00pierrepierreimport numpy as np from ..utils import get_cuda_srcfile, updiv from ..cuda.utils import __has_pycuda__ from .histogram import PartialHistogram, VolumeHistogram if __has_pycuda__: import pycuda.gpuarray as garray from ..cuda.processing import CudaProcessing class CudaPartialHistogram(PartialHistogram): def __init__( self, method="fixed_bins_number", bin_width="uint16", num_bins=None, min_bins=None, cuda_options=None, ): if method == "fixed_bins_width": raise NotImplementedError("Histogram with fixed bins width is not implemented with the Cuda backend") super().__init__( method=method, bin_width=bin_width, num_bins=num_bins, min_bins=min_bins, ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) # pylint: disable=E0606 self._init_cuda_histogram() def _init_cuda_histogram(self): self.cuda_hist = self.cuda_processing.kernel( "histogram", filename=get_cuda_srcfile("histogram.cu"), signature="PiiiffPi", ) self.d_hist = self.cuda_processing.allocate_array("d_hist", self.num_bins, dtype=np.uint32) def _compute_histogram_fixed_nbins(self, data, data_range=None): if isinstance(data, np.ndarray): data = self.cuda_processing.to_device("data", data) if data_range is None: # Should be possible to do both in one single pass with ReductionKernel # and garray.vec.float2, but the last step in volatile shared memory # still gives errors. To be investigated... # pylint: disable=E0606 data_min = garray.min(data).get()[()] data_max = garray.max(data).get()[()] else: data_min, data_max = data_range Nz, Ny, Nx = data.shape block = (16, 16, 4) grid = ( updiv(Nx, block[0]), updiv(Ny, block[1]), updiv(Nz, block[2]), ) self.d_hist.fill(0) self.cuda_hist( data, Nx, Ny, Nz, data_min, data_max, self.d_hist, self.num_bins, grid=grid, block=block, ) # Return a result in the same format as numpy.histogram res_hist = self.d_hist.get() res_bins = np.linspace(data_min, data_max, num=self.num_bins + 1, endpoint=True) return res_hist, res_bins class CudaVolumeHistogram(VolumeHistogram): def __init__( self, data_url, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None, cuda_options=None, ): self.cuda_options = cuda_options super().__init__( data_url, chunk_size_slices=chunk_size_slices, chunk_size_GB=chunk_size_GB, nbins=nbins, logger=logger, ) def _init_histogrammer(self): self.histogrammer = CudaPartialHistogram( method="fixed_bins_number", num_bins=self.nbins, cuda_options=self.cuda_options, ) def _compute_histogram(self, data): return self.histogrammer.compute_histogram(data) # 3D ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/kernel_base.py0000644000175000017500000001041614654107202020215 0ustar00pierrepierre""" Base class for CudaKernel and OpenCLKernel Should not be used directly """ from ..utils import updiv class KernelBase: """ A base class for OpenCL and Cuda kernels. Parameters ----------- kernel_name: str Name of the CUDA kernel. filename: str, optional Path to the file name containing kernels definitions src: str, optional Source code of kernels definitions automation_params: dict, optional Automation parameters, see below Automation parameters ---------------------- automation_params is a dictionary with the following keys and default values. guess_block: bool (True) If block is not specified during calls, choose a block size based on the size/dimensions of the first array. Mind that it is unlikely to be the optimal choice. guess_grid: bool (True): If the grid size is not specified during calls, choose a grid size based on the size of the first array. follow_device_ptr: bool (True) specify gpuarray.gpudata for all cuda GPUArrays (and pyopencl.array.data for pyopencl arrays). Otherwise, raise an error. """ _default_automation_params = { "guess_block": True, "guess_grid": True, "follow_device_ptr": True, } def __init__( self, kernel_name, filename=None, src=None, automation_params=None, silent_compilation_warnings=False, ): self.check_filename_src(filename, src) self.set_automation_params(automation_params) self.silent_compilation_warnings = silent_compilation_warnings def check_filename_src(self, filename, src): err_msg = "Please provide either filename or src" if filename is None and src is None: raise ValueError(err_msg) if filename is not None and src is not None: raise ValueError(err_msg) if filename is not None: with open(filename) as fid: src = fid.read() self.filename = filename self.src = src def set_automation_params(self, automation_params): self.automation_params = self._default_automation_params.copy() self.automation_params.update(automation_params or {}) @staticmethod def guess_grid_size(shape, block_size): # python: (z, y, x) -> cuda: (x, y, z) res = tuple(map(lambda x: updiv(x[0], x[1]), zip(shape[::-1], block_size))) if len(res) == 2: res += (1,) return res @staticmethod def guess_block_size(shape): """ Guess a block size based on the shape of an array. """ ndim = len(shape) if ndim == 1: return (128, 1, 1) if ndim == 2: return (32, 32, 1) else: return (16, 8, 8) def get_block_grid(self, *args, **kwargs): block = None grid = None if ("block" not in kwargs) or (kwargs["block"] is None): if self.automation_params["guess_block"]: block = self.guess_block_size(args[0].shape) else: raise ValueError("Please provide block size") else: block = kwargs["block"] if ("grid" not in kwargs) or (kwargs["grid"] is None): if self.automation_params["guess_grid"]: grid = self.guess_grid_size(args[0].shape, block) else: raise ValueError("Please provide block grid") else: grid = kwargs["grid"] self.last_block_size = block self.last_grid_size = grid return block, grid def follow_device_arr(self, args): raise ValueError("Base class") def _prepare_call(self, *args, **kwargs): block, grid = self.get_block_grid(*args, **kwargs) # pycuda crashes when any element of block/grid is not a python int (ex. numpy.int64). # A weird behavior once observed is "data.shape" returning (np.int64, int, int) (!). # Ensure that everything is a python integer. grid = tuple(int(x) for x in grid) if block is not None: block = tuple(int(x) for x in block) # args = self.follow_device_arr(args) return grid, block, args, kwargs ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/medfilt_cuda.py0000644000175000017500000001307314550227307020371 0ustar00pierrepierrefrom os.path import dirname import numpy as np from pycuda.compiler import SourceModule from ..utils import updiv, get_cuda_srcfile from ..cuda.processing import CudaProcessing class MedianFilter: """ A class for performing median filter on GPU with CUDA """ def __init__( self, shape, footprint=(3, 3), mode="reflect", threshold=None, cuda_options=None, abs_diff=False, ): """Constructor of Cuda Median Filter. Parameters ---------- shape: tuple Shape of the array, in the format (n_rows, n_columns) footprint: tuple Size of the median filter, in the format (y, x). mode: str Boundary handling mode. Available modes are: - "reflect": cba|abcd|dcb - "nearest": aaa|abcd|ddd - "wrap": bcd|abcd|abc - "constant": 000|abcd|000 Default is "reflect". threshold: float, optional Threshold for the "thresholded median filter". A thresholded median filter only replaces a pixel value by the median if this pixel value is greater or equal than median + threshold. abs_diff: bool, optional Whether to perform conditional threshold as abs(value - median) Notes ------ Please refer to the documentation of the CudaProcessing class for the other parameters. """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._set_params(shape, footprint, mode, threshold, abs_diff) self.cuda_processing.init_arrays_to_none(["d_input", "d_output"]) self._init_kernels() def _set_params(self, shape, footprint, mode, threshold, abs_diff): self.data_ndim = len(shape) if self.data_ndim == 2: ny, nx = shape nz = 1 elif self.data_ndim == 3: nz, ny, nx = shape else: raise ValueError("Expected 2D or 3D data") self.shape = shape self.Nx = np.int32(nx) self.Ny = np.int32(ny) self.Nz = np.int32(nz) if len(footprint) != 2: raise ValueError("3D median filter is not implemented yet") if not ((footprint[0] & 1) and (footprint[1] & 1)): raise ValueError("Must have odd-sized footprint") self.footprint = footprint self._set_boundary_mode(mode) self.do_threshold = False self.abs_diff = abs_diff if threshold is not None: self.threshold = np.float32(threshold) self.do_threshold = True else: self.threshold = np.float32(0) def _set_boundary_mode(self, mode): self.mode = mode # Some code duplication from convolution self._c_modes_mapping = { "periodic": 2, "wrap": 2, "nearest": 1, "replicate": 1, "reflect": 0, "constant": 3, } mp = self._c_modes_mapping if self.mode.lower() not in mp: raise ValueError( """ Mode %s is not available. Available modes are: %s """ % (self.mode, str(mp.keys())) ) if self.mode.lower() == "constant": raise NotImplementedError("mode='constant' is not implemented yet") self._c_conv_mode = mp[self.mode] def _init_kernels(self): # Compile source module compile_options = [ "-DUSED_CONV_MODE=%d" % self._c_conv_mode, "-DMEDFILT_X=%d" % self.footprint[1], "-DMEDFILT_Y=%d" % self.footprint[0], "-DDO_THRESHOLD=%d" % (int(self.do_threshold) + int(self.abs_diff)), ] fname = get_cuda_srcfile("medfilt.cu") nabu_cuda_dir = dirname(fname) include_dirs = [nabu_cuda_dir] self.sourcemodule_kwargs = {} self.sourcemodule_kwargs["options"] = compile_options self.sourcemodule_kwargs["include_dirs"] = include_dirs with open(fname) as fid: cuda_src = fid.read() self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs) self.cuda_kernel_2d = self._module.get_function("medfilt2d") # Blocks, grid self._block_size = {2: (32, 32, 1), 3: (16, 8, 8)}[self.data_ndim] # TODO tune self._n_blocks = tuple([updiv(a, b) for a, b in zip(self.shape[::-1], self._block_size)]) def medfilt2(self, image, output=None): """ Perform a median filter on an image (or batch of images). Parameters ----------- images: numpy.ndarray or pycuda.gpuarray 2D image or 3D stack of 2D images output: numpy.ndarray or pycuda.gpuarray, optional Output of filtering. If provided, it must have the same shape as the input array. """ self.cuda_processing.set_array("d_input", image) if output is not None: self.cuda_processing.set_array("d_output", output) else: self.cuda_processing.allocate_array("d_output", self.shape) self.cuda_kernel_2d( self.cuda_processing.d_input, self.cuda_processing.d_output, self.Nx, self.Ny, self.Nz, self.threshold, grid=self._n_blocks, block=self._block_size, ) self.cuda_processing.recover_arrays_references(["d_input", "d_output"]) if output is None: return self.cuda_processing.d_output.get() else: return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/muladd.py0000644000175000017500000000157714550227307017225 0ustar00pierrepierreimport numpy as np from .processing_base import ProcessingBase class MulAdd: processing_cls = ProcessingBase def __init__(self, **backend_options): self.processing = self.processing_cls(**(backend_options or {})) self._init_finalize() def _init_finalize(self): pass def mul_add(self, dst, other, fac_dst, fac_other, dst_region=None, other_region=None): if dst_region is None: dst_slice_y = dst_slice_x = slice(None, None) else: dst_slice_y, dst_slice_x = dst_region if other_region is None: other_slice_y = other_slice_x = slice(None, None) else: other_slice_y, other_slice_x = other_region dst[dst_slice_y, dst_slice_x] = ( fac_dst * dst[dst_slice_y, dst_slice_x] + fac_other * other[other_slice_y, other_slice_x] ) __call__ = mul_add ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/muladd_cuda.py0000644000175000017500000000465514654107202020215 0ustar00pierrepierreimport numpy as np from nabu.utils import get_cuda_srcfile, updiv from .muladd import MulAdd from ..cuda.utils import __has_pycuda__ from ..cuda.processing import CudaProcessing if __has_pycuda__: import pycuda.gpuarray as garray class CudaMulAdd(MulAdd): processing_cls = CudaProcessing def _init_finalize(self): self._init_kernel() def _init_kernel(self): self.muladd_kernel = self.processing.kernel( "mul_add", filename=get_cuda_srcfile("ElementOp.cu"), # signature="PPiiffiiii" ) def mul_add(self, dst, other, fac_dst, fac_other, dst_region=None, other_region=None): """ 'region' should be a tuple (slice(y_start, y_end), slice(x_start, x_end)) """ if dst_region is None: dst_coords = (0, dst.shape[1], 0, dst.shape[0]) else: dst_coords = (dst_region[1].start, dst_region[1].stop, dst_region[0].start, dst_region[0].stop) if other_region is None: other_coords = (0, other.shape[1], 0, other.shape[0]) else: other_coords = (other_region[1].start, other_region[1].stop, other_region[0].start, other_region[0].stop) delta_x = np.diff(dst_coords[:2]) delta_y = np.diff(dst_coords[2:]) if delta_x != np.diff(other_coords[:2]) or delta_y != np.diff(other_coords[2:]): raise ValueError("Invalid dst_region and other_region provided. Regions must have the same size") if delta_x == 0 or delta_y == 0: raise ValueError("delta_x or delta_y is 0") # can't use "int4" in pycuda ? int2 seems fine. Go figure # pylint: disable=E0606 dst_x_range = np.array(dst_coords[:2], dtype=garray.vec.int2) dst_y_range = np.array(dst_coords[2:], dtype=garray.vec.int2) other_x_range = np.array(other_coords[:2], dtype=garray.vec.int2) other_y_range = np.array(other_coords[2:], dtype=garray.vec.int2) block = (32, 32, 1) grid = [updiv(length, b) for (length, b) in zip((delta_x[0], delta_y[0]), block)] self.muladd_kernel( dst, other, np.int32(dst.shape[1]), np.int32(other.shape[1]), np.float32(fac_dst), np.float32(fac_other), dst_x_range, dst_y_range, other_x_range, other_y_range, grid=grid, block=block, ) __call__ = mul_add ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/padding_base.py0000644000175000017500000000547214550227307020355 0ustar00pierrepierreimport numpy as np from ..utils import check_supported class PaddingBase: """ A class for performing padding based on coordinate transform. The Cuda and OpenCL backends will subclass this class. """ supported_modes = ["constant", "edge", "reflect", "symmetric", "wrap"] def __init__(self, shape, pad_width, mode="constant", **kwargs): """ Initialize a Padding object. Parameters ---------- shape: tuple Image shape pad_width: tuple Padding width for each axis. Please see the documentation of numpy.pad(). mode: str Padding mode Other parameters ---------------- constant_values: tuple Tuple containing the values to fill when mode="constant" (as in numpy.pad) """ if len(shape) != 2: raise ValueError("This class only works on images") self.shape = shape self._set_mode(mode, **kwargs) self._get_padding_arrays(pad_width) def _set_mode(self, mode, **kwargs): # COMPAT. if mode == "edges": mode = "edge" # check_supported(mode, self.supported_modes, "padding mode") self.mode = mode self._kwargs = kwargs def _get_padding_arrays(self, pad_width): self.pad_width = pad_width if isinstance(pad_width, tuple) and isinstance(pad_width[0], np.ndarray): # user-defined coordinate transform err_msg = "pad_width must be either a scalar, a tuple in the form ((a, b), (c, d)), or a tuple of two one-dimensional numpy arrays (eg. use numpy.indices(..., sparse=True))" if len(pad_width) != 2: raise ValueError(err_msg) if any([np.squeeze(pw).ndim > 1 for pw in pad_width]): raise ValueError(err_msg) if self.mode == "constant": raise ValueError("Custom coordinate transform does not work with mode='constant'") self.mode = "custom" self.coords_rows, self.coords_cols = pad_width else: if self.mode == "constant": # no need for coordinate transform here constant_values = self._kwargs.get("constant_values", 0) self.padded_array_constant = np.pad( np.zeros(self.shape, dtype="f"), self.pad_width, mode="constant", constant_values=constant_values ) self.padded_shape = self.padded_array_constant.shape return R, C = np.indices(self.shape, dtype=np.int32, sparse=True) self.coords_rows = np.pad(R.ravel(), self.pad_width[0], mode=self.mode) self.coords_cols = np.pad(C.ravel(), self.pad_width[1], mode=self.mode) self.padded_shape = (self.coords_rows.size, self.coords_cols.size) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/padding_cuda.py0000644000175000017500000000520114550227307020345 0ustar00pierrepierreimport numpy as np from ..utils import get_cuda_srcfile, updiv from ..cuda.processing import CudaProcessing from ..cuda.utils import __has_pycuda__ from .padding_base import PaddingBase class CudaPadding(PaddingBase): """ A class for performing padding on GPU using Cuda """ backend = "cuda" # TODO docstring from base class def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs): super().__init__(shape, pad_width, mode=mode, **kwargs) self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_coordinate_transform() def _init_cuda_coordinate_transform(self): if self.mode == "constant": self.d_padded_array_constant = self.processing.to_device( "d_padded_array_constant", self.padded_array_constant ) return self._coords_transform_kernel = self.processing.kernel( "coordinate_transform", filename=get_cuda_srcfile("padding.cu"), signature="PPPPiii", ) self._coords_transform_block = (32, 32, 1) self._coords_transform_grid = [ updiv(a, b) for a, b in zip(self.padded_shape[::-1], self._coords_transform_block) ] self.d_coords_rows = self.processing.to_device("d_coords_rows", self.coords_rows) self.d_coords_cols = self.processing.to_device("d_coords_cols", self.coords_cols) def _pad_constant(self, image, output): pad_y, pad_x = self.pad_width self.d_padded_array_constant[pad_y[0] : pad_y[0] + self.shape[0], pad_x[0] : pad_x[0] + self.shape[1]] = image[ : ] output[:] = self.d_padded_array_constant[:] return output def pad(self, image, output=None): """ Pad an array. Parameters ---------- image: pycuda.gpuarray.GPUArray Image to pad output: pycuda.gpuarray.GPUArray, optional Output image. If provided, must be in the expected shape. """ if output is None: output = self.processing.allocate_array("d_output", self.padded_shape) if self.mode == "constant": return self._pad_constant(image, output) self._coords_transform_kernel( image, output, self.d_coords_cols, self.d_coords_rows, np.int32(self.shape[1]), np.int32(self.padded_shape[1]), np.int32(self.padded_shape[0]), grid=self._coords_transform_grid, block=self._coords_transform_block, ) return output __call__ = pad ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/padding_opencl.py0000644000175000017500000000600214654107202020705 0ustar00pierrepierreimport numpy as np from ..utils import get_opencl_srcfile from ..opencl.processing import OpenCLProcessing from .padding_base import PaddingBase from ..opencl.utils import __has_pyopencl__ if __has_pyopencl__: from ..opencl.memcpy import OpenCLMemcpy2D class OpenCLPadding(PaddingBase): """ A class for performing padding on GPU using OpenCL """ backend = "opencl" # TODO docstring from base class def __init__(self, shape, pad_width, mode="constant", opencl_options=None, **kwargs): super().__init__(shape, pad_width, mode=mode, **kwargs) self.opencl_processing = self.processing = OpenCLProcessing(**(opencl_options or {})) self.queue = self.opencl_processing.queue self._init_opencl_coordinate_transform() def _init_opencl_coordinate_transform(self): if self.mode == "constant": self.d_padded_array_constant = self.processing.to_device( "d_padded_array_constant", self.padded_array_constant ) self.memcpy2D = OpenCLMemcpy2D(ctx=self.processing.ctx, queue=self.queue) # pylint: disable=E0606 return self._coords_transform_kernel = self.processing.kernel( "coordinate_transform", filename=get_opencl_srcfile("padding.cl"), ) self._coords_transform_global_size = self.padded_shape[::-1] self.d_coords_rows = self.processing.to_device("d_coords_rows", self.coords_rows) self.d_coords_cols = self.processing.to_device("d_coords_cols", self.coords_cols) def _pad_constant(self, image, output): pad_y, pad_x = self.pad_width # the following line is not implemented in pyopencl # self.d_padded_array_constant[pad_y[0] : pad_y[0] + self.shape[0], pad_x[0] : pad_x[0] + self.shape[1]] = image[:] # cl.enqueue_copy is too cumbersome to use for Buffer <-> Buffer. # Use a dedicated kernel instead. # This is not optimal (two copies) - TODO write a constant padding kernel self.memcpy2D(self.d_padded_array_constant, image, image.shape[::-1], dst_offset_xy=(pad_x[0], pad_y[0])) output[:] = self.d_padded_array_constant[:] return output def pad(self, image, output=None): """ Pad an array. Parameters ---------- image: pyopencl array Image to pad output: pyopencl array Output image. If provided, must be in the expected shape. """ if output is None: output = self.processing.allocate_array("d_output", self.padded_shape) if self.mode == "constant": return self._pad_constant(image, output) self._coords_transform_kernel( self.queue, image, output, self.d_coords_cols, self.d_coords_rows, np.int32(self.shape[1]), np.int32(self.padded_shape[1]), np.int32(self.padded_shape[0]), global_size=self._coords_transform_global_size, ) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723477091.0 nabu-2024.2.1/nabu/processing/processing_base.py0000644000175000017500000001047414656426143021127 0ustar00pierrepierreimport numpy as np from ..utils import BaseClassError """ Base class for OpenCLProcessing and CudaProcessing Should not be used directly """ class ProcessingBase: array_class = None dtype_to_ctype = BaseClassError def __init__(self): self._allocated = {} def init_arrays_to_none(self, arrays_names): """ Initialize arrays to None. After calling this method, the current instance will have self.array_name = None, and self._old_array_name = None. Parameters ---------- arrays_names: list of str List of arrays names. """ for array_name in arrays_names: setattr(self, array_name, None) setattr(self, "_old_" + array_name, None) self._allocated[array_name] = False def recover_arrays_references(self, arrays_names): """ Performs self._array_name = self._old_array_name, for each array_name in arrays_names. Parameters ---------- arrays_names: list of str List of array names """ for array_name in arrays_names: old_arr = getattr(self, "_old_" + array_name, None) if old_arr is not None: setattr(self, array_name, old_arr) def _allocate_array_mem(self, shape, dtype): raise ValueError("Base class") def allocate_array(self, array_name, shape, dtype=np.float32): """ Allocate a GPU array on the current context/stream/device, and set 'self.array_name' to this array. Parameters ---------- array_name: str Name of the array (for book-keeping) shape: tuple of int Array shape dtype: numpy.dtype, optional Data type. Default is float32. """ if not self._allocated.get(array_name, False): new_device_arr = self._allocate_array_mem(shape, dtype) setattr(self, array_name, new_device_arr) self._allocated[array_name] = True return getattr(self, array_name) def set_array(self, array_name, array_ref, dtype=np.float32): """ Set the content of a device array. Parameters ---------- array_name: str Array name. This method will look for self.array_name. array_ref: array (numpy or GPU array) Array containing the data to copy to 'array_name'. dtype: numpy.dtype, optional Data type. Default is float32. """ if isinstance(array_ref, self.array_class): current_arr = getattr(self, array_name, None) setattr(self, "_old_" + array_name, current_arr) setattr(self, array_name, array_ref) elif isinstance(array_ref, np.ndarray): self.allocate_array(array_name, array_ref.shape, dtype=dtype) getattr(self, array_name).set(array_ref) else: raise ValueError("Expected numpy array or pycuda array") return getattr(self, array_name) def get_array(self, array_name): return getattr(self, array_name, None) # COMPAT. _init_arrays_to_none = init_arrays_to_none _recover_arrays_references = recover_arrays_references _allocate_array = allocate_array _set_array = set_array def check_array(self, arr, expected_shape, expected_dtype="f", check_contiguous=True): """ Check whether a given array is suitable for being processed (shape, dtype, contiguous) """ if arr.shape != expected_shape: raise ValueError("Expected shape %s but got %s" % (str(expected_shape), str(arr.shape))) if arr.dtype != np.dtype(expected_dtype): raise ValueError("Expected data type %s but got %s" % (str(expected_dtype), str(arr.dtype))) if check_contiguous: if isinstance(arr, np.ndarray) and not (arr.flags["C_CONTIGUOUS"]): raise ValueError("Expected C-contiguous array") if isinstance(arr, self.array_class) and not arr.flags.c_contiguous: raise ValueError("Expected C-contiguous array") def kernel(self, *args, **kwargs): raise ValueError("Base class") def to_device(self, array_name, array): arr_ref = self.allocate_array(array_name, array.shape, dtype=array.dtype) arr_ref.set(array) return arr_ref ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/roll_opencl.py0000644000175000017500000000435314654107202020256 0ustar00pierrepierre# # WIP ! # # pylint: skip-file import numpy as np from ..opencl.utils import __has_pyopencl__ from ..utils import get_opencl_srcfile if __has_pyopencl__: import pyopencl as cl from ..opencl.processing import OpenCLProcessing from ..opencl.kernel import OpenCLKernel from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype class OpenCLRoll: def __init__(self, dtype, direction=1, offset=None, **processing_kwargs): self.processing = OpenCLProcessing(queue=processing_kwargs.get("queue", None)) self.dtype = np.dtype(dtype) compile_options = ["-DDTYPE=%s" % cl_dtype_to_ctype(self.dtype)] self.offset = offset or 0 self.roll_kernel = OpenCLKernel( "roll_forward_x", None, queue=self.processing.queue, filename=get_opencl_srcfile("roll.cl"), options=compile_options, ) self.shmem = cl.LocalMemory(self.dtype.itemsize) self.direction = direction if self.direction < 0: self.revert_kernel = OpenCLKernel( "revert_array_x", None, queue=self.processing.queue, filename=get_opencl_srcfile("roll.cl"), options=compile_options, ) def __call__(self, arr): ny, nx = arr.shape # Launch one big horizontal workgroup wg_x = min((nx - self.offset) // 2, self.processing.queue.device.max_work_group_size) local_size = (wg_x, 1, 1) global_size = [wg_x, ny] if self.direction < 0: local_size2 = None global_size2 = [nx - self.offset, ny] self.revert_kernel( arr, np.int32(nx), np.int32(ny), np.int32(self.offset), local_size=local_size2, global_size=global_size2 ) self.roll_kernel( arr, np.int32(nx), np.int32(ny), np.int32(self.offset), self.shmem, local_size=local_size, global_size=global_size, ) if self.direction < 0: self.revert_kernel( arr, np.int32(nx), np.int32(ny), np.int32(self.offset), local_size=local_size2, global_size=global_size2 ) return arr ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/rotation.py0000644000175000017500000000345414550227307017612 0ustar00pierrepierretry: from skimage.transform import rotate __have__skimage__ = True except ImportError: __have__skimage__ = False class Rotation: supported_modes = { "constant": "constant", "zeros": "constant", "edge": "edge", "edges": "edge", "symmetric": "symmetric", "sym": "symmetric", "reflect": "reflect", "wrap": "wrap", "periodic": "wrap", } def __init__(self, shape, angle, center=None, mode="edge", reshape=False, **sk_kwargs): """ Initiate a Rotation object. Parameters ---------- shape: tuple of int Shape of the images to process angle: float Rotation angle in DEGREES center: tuple of float, optional Coordinates of the center of rotation, in the format (X, Y) (mind the non-python convention !). Default is ((Nx - 1)/2.0, (Ny - 1)/2.0) mode: str, optional Padding mode. Default is "edge". reshape: bool, optional Other Parameters ----------------- All the other parameters are passed directly to scikit image 'rotate' function: order, cval, clip, preserve_range. """ self.shape = shape self.angle = angle self.center = center self.mode = mode self.reshape = reshape self.sk_kwargs = sk_kwargs def rotate(self, img, output=None): if not __have__skimage__: raise ValueError("scikit-image is needed for using rotate()") res = rotate(img, self.angle, resize=self.reshape, center=self.center, mode=self.mode, **self.sk_kwargs) if output is not None: output[:] = res[:] return output else: return res __call__ = rotate ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/rotation_cuda.py0000644000175000017500000000535414654107202020603 0ustar00pierrepierreimport numpy as np from .rotation import Rotation from ..utils import get_cuda_srcfile, updiv from ..cuda.utils import __has_pycuda__, copy_array from ..cuda.processing import CudaProcessing if __has_pycuda__: from ..cuda.kernel import CudaKernel import pycuda.driver as cuda class CudaRotation(Rotation): def __init__(self, shape, angle, center=None, mode="edge", reshape=False, cuda_options=None, **sk_kwargs): if center is None: center = ((shape[1] - 1) / 2.0, (shape[0] - 1) / 2.0) super().__init__(shape, angle, center=center, mode=mode, reshape=reshape, **sk_kwargs) self._init_cuda_rotation(cuda_options) def _init_cuda_rotation(self, cuda_options): cuda_options = cuda_options or {} self.cuda_processing = CudaProcessing(**cuda_options) self._allocate_arrays() self._init_rotation_kernel() def _allocate_arrays(self): self._d_image_cua = cuda.np_to_array(np.zeros(self.shape, "f"), "C") # pylint: disable=E0606 self.cuda_processing.init_arrays_to_none(["d_output"]) def _init_rotation_kernel(self): self.cuda_rotation_kernel = CudaKernel("rotate", get_cuda_srcfile("rotation.cu")) # pylint: disable=E0606 self.texref_image = self.cuda_rotation_kernel.module.get_texref("tex_image") self.texref_image.set_filter_mode(cuda.filter_mode.LINEAR) # bilinear self.texref_image.set_address_mode(0, cuda.address_mode.CLAMP) # TODO tune self.texref_image.set_address_mode(1, cuda.address_mode.CLAMP) # TODO tune self.cuda_rotation_kernel.prepare("Piiffff", [self.texref_image]) self.texref_image.set_array(self._d_image_cua) self._cos_theta = np.cos(np.deg2rad(self.angle)) self._sin_theta = np.sin(np.deg2rad(self.angle)) self._Nx = np.int32(self.shape[1]) self._Ny = np.int32(self.shape[0]) self._center_x = np.float32(self.center[0]) self._center_y = np.float32(self.center[1]) self._block = (32, 32, 1) # tune ? self._grid = (updiv(self.shape[1], self._block[1]), updiv(self.shape[0], self._block[0]), 1) def rotate(self, img, output=None, do_checks=True): copy_array(self._d_image_cua, img, check=do_checks) if output is not None: d_out = output else: self.cuda_processing.allocate_array("d_output", self.shape, np.float32) d_out = self.cuda_processing.d_output self.cuda_rotation_kernel( d_out, self._Nx, self._Ny, self._cos_theta, self._sin_theta, self._center_x, self._center_y, grid=self._grid, block=self._block, ) return d_out __call__ = rotate ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.5167568 nabu-2024.2.1/nabu/processing/tests/0000755000175000017500000000000014730277752016546 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/__init__.py0000644000175000017500000000000014550227307020634 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/processing/tests/test_fft.py0000644000175000017500000002400414726604214020726 0ustar00pierrepierrefrom itertools import permutations import pytest import numpy as np from scipy.fft import fftn, ifftn, rfftn, irfftn from nabu.testutils import generate_tests_scenarios, get_data, get_array_of_given_shape, __do_long_tests__ from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.processing.fft_cuda import SKCUFFT, VKCUFFT, get_available_fft_implems from nabu.opencl.utils import __has_pyopencl__, get_opencl_context from nabu.processing.fft_opencl import VKCLFFT, has_vkfft as has_cl_vkfft from nabu.processing.fft_base import is_fast_axes available_cuda_fft = get_available_fft_implems() __has_vkfft__ = "vkfft" in available_cuda_fft __has_skcuda__ = "skcuda" in available_cuda_fft scenarios = { "shape": [(256,), (300,), (300, 301), (300, 302)], "r2c": [True, False], "precision": ["simple"], "backend": ["cuda", "opencl"], } if __do_long_tests__: scenarios["shape"].extend([(307,), (125, 126, 260)]) scenarios["precision"].append("double") scenarios = generate_tests_scenarios(scenarios) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("chelsea.npz")["data"] cls.abs_tol = { "simple": { 1: 5e-3, 2: 1.0e0, 3: 5e2, # ! }, "double": { 1: 1e-10, 2: 1e-9, 3: 1e-7, }, } if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = get_opencl_context("all") yield if __has_pycuda__: cls.cu_ctx.pop() def _get_fft_cls(backend): fft_cls = None if backend == "cuda": if not (__has_vkfft__ and __has_pycuda__): pytest.skip("Need vkfft and pycuda to use VKCUFFT") fft_cls = VKCUFFT if backend == "opencl": if not (has_cl_vkfft() and __has_pyopencl__): pytest.skip("Need vkfft and pyopencl to use VKCLFFT") fft_cls = VKCLFFT return fft_cls @pytest.mark.usefixtures("bootstrap") class TestFFT: def _get_data_array(self, config): r2c = config["r2c"] shape = config["shape"] precision = config["precision"] dtype = { True: {"simple": np.float32, "double": np.float64}, False: {"simple": np.complex64, "double": np.complex128}, }[r2c][precision] data = get_array_of_given_shape(self.data, shape, dtype) return data @staticmethod def check_result(res, ref, config, tol, name=""): err_max = np.max(np.abs(res - ref)) err_msg = "%s FFT(%s, r2c=%s): tol=%.2e, but max error = %.2e" % ( name, str(config["shape"]), str(config["r2c"]), tol, err_max, ) assert np.allclose(res, ref, atol=tol), err_msg def _do_fft(self, data, r2c, axes=None, return_fft_obj=False, backend_cls=None): ctx = self.cu_ctx if backend_cls.backend == "cuda" else self.cl_ctx fft = backend_cls(data.shape, data.dtype, r2c=r2c, axes=axes, ctx=ctx) d_data = fft.processing.allocate_array("_data", data.shape, dtype=data.dtype) d_data.set(data) d_out = fft.fft(d_data) res = d_out.get() return (res, fft) if return_fft_obj else res @staticmethod def _do_reference_fft(data, r2c, axes=None): ref_fft_func = rfftn if r2c else fftn ref = ref_fft_func(data, axes=axes) return ref @staticmethod def _do_reference_ifft(data, r2c, axes=None): ref_ifft_func = irfftn if r2c else ifftn ref = ref_ifft_func(data, axes=axes) return ref @pytest.mark.skipif( not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test" ) @pytest.mark.parametrize("config", scenarios) def test_sckcuda(self, config): r2c = config["r2c"] shape = config["shape"] precision = config["precision"] ndim = len(shape) if ndim == 3 and not (__do_long_tests__): pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1") data = self._get_data_array(config) res, cufft = self._do_fft(data, r2c, return_fft_obj=True, backend_cls=SKCUFFT) ref = self._do_reference_fft(data, r2c) tol = self.abs_tol[precision][ndim] self.check_result(res, ref, config, tol, name="skcuda") # Complex-to-complex can also be performed on real data (as in numpy.fft.fft(real_data)) if not (r2c): res = self._do_fft(data, False, backend_cls=SKCUFFT) ref = self._do_reference_fft(data, False) self.check_result(res, ref, config, tol, name="skcuda") # IFFT res = cufft.ifft(cufft.output_fft).get() self.check_result(res, data, config, tol, name="skcuda") # Perhaps we should also check against numpy/scipy ifft, # but it does not yield the good shape for R2C on odd-sized data @pytest.mark.skipif( not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test" ) @pytest.mark.parametrize("config", scenarios) def test_skcuda_batched(self, config): shape = config["shape"] if len(shape) == 1: return elif len(shape) == 3 and not (__do_long_tests__): pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1") r2c = config["r2c"] tol = self.abs_tol[config["precision"]][len(shape)] data = self._get_data_array(config) if data.ndim == 2: axes_to_test = [(0,), (1,)] elif data.ndim == 3: # axes_to_test = [(1, 2), (2, 1), (2,)] # See fft.py: works for C2C but not R2C ? axes_to_test = [(2,)] for axes in axes_to_test: res, cufft = self._do_fft(data, r2c, axes=axes, return_fft_obj=True, backend_cls=SKCUFFT) ref = self._do_reference_fft(data, r2c, axes=axes) self.check_result(res, ref, config, tol, name="skcuda batched axes=%s" % (str(axes))) # IFFT res = cufft.ifft(cufft.output_fft).get() self.check_result(res, data, config, tol, name="skcuda") @pytest.mark.parametrize("config", scenarios) def test_vkfft(self, config): backend = config["backend"] fft_cls = _get_fft_cls(backend) r2c = config["r2c"] shape = config["shape"] precision = config["precision"] ndim = len(shape) if ndim == 3 and not (__do_long_tests__): pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1") if ndim >= 2 and r2c and shape[-1] & 1: pytest.skip("R2C with odd-sized fast dimension is not supported in VKFFT") # FIXME - vkfft + POCL fail for R2C in one dimension if config["backend"] == "opencl" and r2c and ndim == 1: if self.cl_ctx.devices[0].platform.name.strip().lower() == "portable computing language": pytest.skip("Something wrong with vkfft + pocl for R2C 1D") # --- data = self._get_data_array(config) res, fft_obj = self._do_fft(data, r2c, return_fft_obj=True, backend_cls=fft_cls) ref = self._do_reference_fft(data, r2c) tol = self.abs_tol[precision][ndim] self.check_result(res, ref, config, tol, name="vkfft_%s" % backend) # Complex-to-complex can also be performed on real data (as in numpy.fft.fft(real_data)) if not (r2c): res = self._do_fft(data, False, backend_cls=fft_cls) ref = self._do_reference_fft(data, False) self.check_result(res, ref, config, tol, name="vkfft_%s" % backend) # IFFT res = fft_obj.ifft(fft_obj.output_fft).get() self.check_result(res, data, config, tol, name="vkfft_%s" % backend) @pytest.mark.parametrize("config", scenarios) def test_vkfft_batched(self, config): backend = config["backend"] fft_cls = _get_fft_cls(backend) shape = config["shape"] if len(shape) == 1: return elif len(shape) == 3 and not (__do_long_tests__): pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1") r2c = config["r2c"] tol = self.abs_tol[config["precision"]][len(shape)] data = self._get_data_array(config) if data.ndim >= 2 and r2c and shape[-1] & 1: pytest.skip("R2C with odd-sized fast dimension is not supported in VKFFT") # For R2C, only fastest axes are supported by vkfft if data.ndim == 2: axes_to_test = [(1,)] elif data.ndim == 3: axes_to_test = [ (1, 2), (2,), ] for axes in axes_to_test: res, cufft = self._do_fft(data, r2c, axes=axes, return_fft_obj=True, backend_cls=fft_cls) ref = self._do_reference_fft(data, r2c, axes=axes) self.check_result(res, ref, config, tol, name="vkfft_%s batched axes=%s" % (backend, str(axes))) # IFFT res = cufft.ifft(cufft.output_fft).get() self.check_result(res, data, config, tol, name="vkfft_%s" % backend) @pytest.mark.skipif(not (__do_long_tests__), reason="Use NABU_LONG_TESTS=1 for this test") def test_fast_axes_utility_function(self): axes_to_test = { 2: { (0, 1): True, (1,): True, (-1,): True, (-2,): False, (0,): False, }, 3: { (0, 1, 2): True, (0, 1): False, (1, 2): True, (2, 1): True, (-2, -1): True, (2,): True, (-1,): True, }, } for ndim, axes_ in axes_to_test.items(): for axes, is_fast in axes_.items(): possible_axes = [axes] if len(axes) > 1: possible_axes = list(permutations(axes, len(axes))) for ax in possible_axes: assert is_fast_axes(ndim, ax) is is_fast ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1719842437.0 nabu-2024.2.1/nabu/processing/tests/test_fftshift.py0000644000175000017500000000507214640533205021764 0ustar00pierrepierreimport numpy as np import pytest from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import __has_pyopencl__, get_opencl_context from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__ if __has_pyopencl__: from nabu.processing.fftshift import OpenCLFFTshift configs = { "shape": [(300, 451), (300, 300), (255, 300)], "axes": [(1,)], "dtype_in_out": [(np.float32, np.complex64), (np.complex64, np.float32)], "inplace": [True, False], } scenarios = generate_tests_scenarios(configs) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("chelsea.npz")["data"] cls.tol = 1e-7 if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = get_opencl_context(device_type="all") yield if __has_pycuda__: cls.cu_ctx.pop() @pytest.mark.skip(reason="OpenCL fftshift is a prototype") @pytest.mark.usefixtures("bootstrap") class TestFFTshift: def _do_test_fftshift(self, config, fftshift_cls): shape = config["shape"] dtype = config["dtype_in_out"][0] dst_dtype = config["dtype_in_out"][1] axes = config["axes"] inplace = config["inplace"] if inplace and shape[-1] & 1: pytest.skip("Not Implemented") data = np.ascontiguousarray(self.data[: shape[0], : shape[1]], dtype=dtype) backend = fftshift_cls.backend ctx = self.cu_ctx if backend == "cuda" else self.cl_ctx backend_options = {"ctx": ctx} if not (inplace): fftshift = fftshift_cls(data.shape, dtype, dst_dtype=dst_dtype, axes=axes, **backend_options) else: fftshift = fftshift_cls(data.shape, dtype, axes=axes, **backend_options) d_data = fftshift.processing.allocate_array("data", shape, dtype) d_data.set(data) d_res = fftshift.fftshift(d_data) assert ( np.max(np.abs(d_res.get() - np.fft.fftshift(data, axes=axes))) == 0 ), "something wrong with fftshift_%s(%s)" % (backend, str(config)) # @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") # @pytest.mark.parametrize("config", scenarios) # def test_cuda_transpose(self, config): # self._do_test_transpose(config, CudaTranspose) @pytest.mark.skipif(not (__has_pyopencl__), reason="Need pyopencl for this test") @pytest.mark.parametrize("config", scenarios) def test_opencl_fftshift(self, config): self._do_test_fftshift(config, OpenCLFFTshift) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1706619687.0 nabu-2024.2.1/nabu/processing/tests/test_histogram.py0000644000175000017500000000435314556171447022161 0ustar00pierrepierreimport pytest import numpy as np from nabu.testutils import get_data from nabu.processing.histogram import PartialHistogram from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: from nabu.processing.histogram_cuda import CudaPartialHistogram import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("mri_rec_astra.npz")["data"] cls.data /= 10 cls.data[:100] *= 10 cls.data_part_1 = cls.data[:100] cls.data_part_2 = cls.data[100:] cls.data0 = cls.data.ravel() cls.bin_tol = 1e-5 * (cls.data0.max() - cls.data0.min()) cls.hist_rtol = 1.5e-3 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestPartialHistogram: def compare_histograms(self, hist1, hist2): errmax_bins = np.max(np.abs(hist1[1] - hist2[1])) assert errmax_bins < self.bin_tol errmax_hist = np.max(np.abs(hist1[0] - hist2[0]) / hist2[0].max()) assert errmax_hist / hist2[0].max() < self.hist_rtol def test_fixed_nbins(self): partial_hist = PartialHistogram(method="fixed_bins_number", num_bins=1e6) hist1 = partial_hist.compute_histogram(self.data_part_1.ravel()) hist2 = partial_hist.compute_histogram(self.data_part_2.ravel()) hist = partial_hist.merge_histograms([hist1, hist2]) ref = np.histogram(self.data0, bins=partial_hist.num_bins) self.compare_histograms(hist, ref) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_fixed_nbins_cuda(self): partial_hist = CudaPartialHistogram(method="fixed_bins_number", num_bins=1e6, cuda_options={"ctx": self.ctx}) data_part_1 = garray.to_gpu(np.tile(self.data_part_1, (1, 1, 1))) data_part_2 = garray.to_gpu(np.tile(self.data_part_2, (1, 1, 1))) hist1 = partial_hist.compute_histogram(data_part_1) hist2 = partial_hist.compute_histogram(data_part_2) hist = partial_hist.merge_histograms([hist1, hist2]) ref = np.histogram(self.data0, bins=partial_hist.num_bins) self.compare_histograms(hist, ref) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_medfilt.py0000644000175000017500000000524214550227307021575 0ustar00pierrepierreimport pytest import numpy as np from silx.math.medianfilter import medfilt2d from nabu.testutils import generate_tests_scenarios, get_data from nabu.cuda.utils import get_cuda_context, __has_pycuda__ if __has_pycuda__: from nabu.processing.medfilt_cuda import MedianFilter import pycuda.gpuarray as garray scenarios = generate_tests_scenarios( { "input_on_gpu": [False, True], "output_on_gpu": [False, True], "footprint": [(3, 3), (5, 5)], "mode": ["reflect", "nearest"], "batched_2d": [False, True], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = np.ascontiguousarray(get_data("brain_phantom.npz")["data"][::2, ::2][:-1, :]) cls.tol = 1e-7 cls.ctx = get_cuda_context(cleanup_at_exit=False) cls.allocate_numpy_arrays() cls.allocate_cuda_arrays() yield cls.ctx.pop() @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda/pycuda for this test") @pytest.mark.usefixtures("bootstrap") class TestMedianFilter(object): @classmethod def allocate_numpy_arrays(cls): shape = cls.data.shape cls.input = cls.data cls.input3d = np.tile(cls.input, (2, 1, 1)) @classmethod def allocate_cuda_arrays(cls): shape = cls.data.shape cls.d_input = garray.to_gpu(cls.input) cls.d_output = garray.zeros_like(cls.d_input) cls.d_input3d = garray.to_gpu(cls.input3d) cls.d_output3d = garray.zeros_like(cls.d_input3d) # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.parametrize("config", scenarios) def testMedfilt(self, config): if config["input_on_gpu"]: input_data = self.d_input if not (config["batched_2d"]) else self.d_input3d else: input_data = self.input if not (config["batched_2d"]) else self.input3d if config["output_on_gpu"]: output_data = self.d_output if not (config["batched_2d"]) else self.d_output3d else: output_data = None # Cuda median filter medfilt = MedianFilter( input_data.shape, footprint=config["footprint"], mode=config["mode"], cuda_options={"ctx": self.ctx}, ) res = medfilt.medfilt2(input_data, output=output_data) if config["output_on_gpu"]: res = res.get() # Reference (scipy) ref = medfilt2d(self.input, config["footprint"][0], mode=config["mode"]) max_absolute_error = np.max(np.abs(res - ref)) assert max_absolute_error < self.tol, "Something wrong with configuration %s" % str(config) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_muladd.py0000644000175000017500000000360714550227307021422 0ustar00pierrepierreimport pytest import numpy as np from nabu.processing.muladd import MulAdd from nabu.testutils import get_data from nabu.cuda.utils import get_cuda_context, __has_pycuda__ if __has_pycuda__: from nabu.processing.muladd_cuda import CudaMulAdd @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("chelsea.npz")["data"].astype("f") # (300, 451) cls.tol = 1e-7 if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.cu_ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestMulad: def test_muladd(self): dst = self.data.copy() other = self.data.copy() mul_add = MulAdd() # Test with no subregion mul_add(dst, other, 1, 2) assert np.allclose(dst, self.data * 1 + other * 2) # Test with x-y subregion dst = self.data.copy() mul_add(dst, other, 0.5, 1.7, (slice(10, 200), slice(15, 124)), (slice(100, 290), slice(200, 309))) assert np.allclose(dst[10:200, 15:124], self.data[10:200, 15:124] * 0.5 + self.data[100:290, 200:309] * 1.7) @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda/pycuda for this test") def test_cuda_muladd(self): mul_add = CudaMulAdd(ctx=self.cu_ctx) dst = mul_add.processing.to_device("dst", self.data) other = mul_add.processing.to_device("other", (self.data / 2).astype("f")) # Test with no subregion mul_add(dst, other, 3, 5) assert np.allclose(dst.get(), self.data * 3 + (self.data / 2) * 5) # Test with x-y subregion dst.set(self.data) mul_add(dst, other, 0.5, 1.7, (slice(10, 200), slice(15, 124)), (slice(100, 290), slice(200, 309))) assert np.allclose( dst.get()[10:200, 15:124], self.data[10:200, 15:124] * 0.5 + (self.data / 2)[100:290, 200:309] * 1.7 ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_padding.py0000644000175000017500000002104014550227307021551 0ustar00pierrepierreimport numpy as np import pytest from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import __has_pyopencl__, get_opencl_context from nabu.processing.padding_cuda import CudaPadding from nabu.processing.padding_opencl import OpenCLPadding from nabu.utils import calc_padding_lengths, get_cuda_srcfile from nabu.testutils import __do_long_tests__ from nabu.testutils import get_data, generate_tests_scenarios scenarios = { "shape": [(511, 512), (512, 511)], "pad_width": [((256, 255), (128, 127))], "mode_cuda": CudaPadding.supported_modes[:2] if __has_pycuda__ else [], "mode_opencl": OpenCLPadding.supported_modes[:2] if __has_pyopencl__ else [], "constant_values": [0, ((1.0, 2.0), (3.0, 4.0))], "output_is_none": [True, False], "backend": ["cuda", "opencl"], } if __do_long_tests__: scenarios["mode_cuda"] = CudaPadding.supported_modes if __has_pycuda__ else [] scenarios["mode_opencl"] = OpenCLPadding.supported_modes if __has_pyopencl__ else [] scenarios["pad_width"].extend([((0, 0), (6, 7))]) scenarios = generate_tests_scenarios(scenarios) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("brain_phantom.npz")["data"] cls.tol = 1e-7 if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = get_opencl_context(device_type="all") yield if __has_pycuda__: cls.cu_ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestPadding: @pytest.mark.parametrize("config", scenarios) def test_padding(self, config): backend = config["backend"] shape = config["shape"] padding_mode = config["mode_%s" % backend] data = self.data[: shape[0], : shape[1]] kwargs = {} if padding_mode == "constant": kwargs["constant_values"] = config["constant_values"] ref = np.pad(data, config["pad_width"], mode=padding_mode, **kwargs) PaddingCls = CudaPadding if backend == "cuda" else OpenCLPadding if backend == "cuda": backend_options = {"cuda_options": {"ctx": self.cu_ctx}} else: backend_options = {"opencl_options": {"ctx": self.cl_ctx}} padding = PaddingCls( config["shape"], config["pad_width"], mode=padding_mode, constant_values=config["constant_values"], **backend_options, ) if config["output_is_none"]: output = None else: output = padding.processing.allocate_array("output", ref.shape, dtype="f") d_img = padding.processing.allocate_array("d_img", data.shape, dtype="f") d_img.set(np.ascontiguousarray(data, dtype="f")) res = padding.pad(d_img, output=output) err_max = np.max(np.abs(res.get() - ref)) assert err_max < self.tol, str("Something wrong with padding for configuration %s" % (str(config))) @pytest.mark.skipif(not (__has_pycuda__) and not (__has_pyopencl__), reason="need pycuda or pyopencl") def test_custom_coordinate_transform(self): data = self.data R, C = np.indices(data.shape, dtype=np.int32) pad_width = ((256, 255), (254, 251)) mode = "reflect" coords_R = np.pad(R, pad_width[0], mode=mode)[:, 0] coords_C = np.pad(C, pad_width[1], mode=mode)[0, :] # Further transform of coordinates - here FFT layout coords_R = np.roll(coords_R, -pad_width[0][0]) coords_C = np.roll(coords_C, -pad_width[1][0]) padding_classes_to_test = [] if __has_pycuda__: padding_classes_to_test.append(CudaPadding) if __has_pyopencl__: padding_classes_to_test.append(OpenCLPadding) for padding_cls in padding_classes_to_test: ctx = self.cl_ctx if padding_cls.backend == "opencl" else self.cu_ctx padding = padding_cls(data.shape, (coords_R, coords_C), mode=mode, ctx=ctx) d_img = padding.processing.allocate_array("d_img", data.shape, dtype="f") d_img.set(data) d_out = padding.processing.allocate_array("d_out", padding.padded_shape, dtype="f") res = padding.pad(d_img, output=d_out) ref = np.roll(np.pad(data, pad_width, mode=mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)) err_max = np.max(np.abs(d_out.get() - ref)) assert err_max < self.tol, "Something wrong with custom padding" # # The following is testing a previous version of padding kernels # They use specific code (instead of a generic coordinate transform) # if __has_pycuda__: from nabu.cuda.kernel import CudaKernel import pycuda.gpuarray as garray scenarios_legacy = [ { "shape": (512, 501), "shape_padded": (1023, 1022), "constant_values": ((1.0, 2.0), (3.0, 4.0)), }, ] # parametrize with fixture and "params=" will launch a new class for each scenario. # the attributes set to "cls" will remain for all the tests done in this class # with the current scenario. @pytest.fixture(scope="class", params=scenarios_legacy) def bootstrap_legacy(request): cls = request.cls cls.data = get_data("mri_proj_astra.npz")["data"] cls.tol = 1e-7 cls.params = request.param cls.ctx = get_cuda_context(cleanup_at_exit=False) cls._calc_pad() cls._init_kernels() yield cls.ctx.pop() @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda and pycuda for this test") @pytest.mark.usefixtures("bootstrap_legacy") class TestPaddingLegacy: @classmethod def _calc_pad(cls): cls.shape = cls.params["shape"] cls.data = np.ascontiguousarray(cls.data[: cls.shape[0], : cls.shape[1]]) cls.shape_padded = cls.params["shape_padded"] ((pt, pb), (pl, pr)) = calc_padding_lengths(cls.shape, cls.shape_padded) cls.pad_top_len = pt cls.pad_bottom_len = pb cls.pad_left_len = pl cls.pad_right_len = pr @classmethod def _init_kernels(cls): cls.pad_kern = CudaKernel( "padding_constant", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiiiffff", ) cls.pad_edge_kern = CudaKernel( "padding_edge", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiii", ) cls.d_data_padded = garray.zeros(cls.shape_padded, "f") def _init_padding(self, arr=None): arr = arr or self.data self.d_data_padded.fill(0) Ny, Nx = self.shape self.d_data_padded[:Ny, :Nx] = self.data def _pad_numpy(self, arr=None, **np_pad_kwargs): arr = arr or self.data data_padded_ref = np.pad( arr, ((self.pad_top_len, self.pad_bottom_len), (self.pad_left_len, self.pad_right_len)), **np_pad_kwargs ) # Put in the FFT layout data_padded_ref = np.roll(data_padded_ref, (-self.pad_top_len, -self.pad_left_len), axis=(0, 1)) return data_padded_ref def test_constant_padding(self): self._init_padding() # Pad using the cuda kernel ((val_top, val_bottom), (val_left, val_right)) = self.params["constant_values"] Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.pad_kern( self.d_data_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, val_left, val_right, val_top, val_bottom, ) # Pad using numpy data_padded_ref = self._pad_numpy(mode="constant", constant_values=self.params["constant_values"]) # Compare errmax = np.max(np.abs(self.d_data_padded.get() - data_padded_ref)) assert errmax < self.tol, "Max error is too high" def test_edge_padding(self): self._init_padding() # Pad using the cuda kernel ((val_top, val_bottom), (val_left, val_right)) = self.params["constant_values"] Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.pad_edge_kern( self.d_data_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, ) # Pad using numpy data_padded_ref = self._pad_numpy(mode="edge") # Compare errmax = np.max(np.abs(self.d_data_padded.get() - data_padded_ref)) assert errmax < self.tol, "Max error is too high" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_roll.py0000644000175000017500000000446014550227307021122 0ustar00pierrepierreimport numpy as np import pytest from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import __has_pyopencl__, get_opencl_context from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__ from nabu.processing.roll_opencl import OpenCLRoll configs_roll = { "shape": [(300, 451), (300, 300), (255, 300)], "offset_x": [0, 10, 155], "dtype": [np.float32], # , np.complex64], } scenarios_roll = generate_tests_scenarios(configs_roll) @pytest.fixture(scope="class") def bootstrap_roll(request): cls = request.cls cls.data = get_data("chelsea.npz")["data"] cls.tol = 1e-7 if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = get_opencl_context(device_type="all") yield if __has_pycuda__: cls.cu_ctx.pop() @pytest.mark.usefixtures("bootstrap_roll") class TestRoll: @staticmethod def _compute_ref(data, direction, offset): ref = data.copy() ref[:, offset:] = np.roll(data[:, offset:], direction, axis=1) return ref @pytest.mark.skipif(not (__has_pyopencl__), reason="Need pyopencl for this test") @pytest.mark.parametrize("config", scenarios_roll) def test_opencl_roll(self, config): shape = config["shape"] dtype = config["dtype"] offset_x = config["offset_x"] data = np.ascontiguousarray(self.data[: shape[0], : shape[1]], dtype=dtype) ref_forward = self._compute_ref(data, 1, offset_x) ref_backward = self._compute_ref(data, -1, offset_x) roll_forward = OpenCLRoll(dtype, direction=1, offset=offset_x, ctx=self.cl_ctx) d_data = roll_forward.processing.allocate_array("data", data.shape, dtype=dtype) d_data.set(data) roll_backward = OpenCLRoll(dtype, direction=-1, offset=offset_x, queue=roll_forward.processing.queue) roll_forward(d_data) # from spire.utils import ims # ims([d_data.get(), ref_forward, d_data.get() - ref_forward]) assert np.allclose(d_data.get(), ref_forward), "roll_forward: something wrong with config=%s" % (str(config)) d_data.set(data) roll_backward(d_data) assert np.allclose(d_data.get(), ref_backward), "roll_backward: something wrong with config=%s" % (str(config)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_rotation.py0000644000175000017500000000526114550227307022011 0ustar00pierrepierreimport numpy as np import pytest from nabu.testutils import generate_tests_scenarios from nabu.processing.rotation_cuda import Rotation from nabu.processing.rotation import __have__skimage__ from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __have__skimage__: from skimage.transform import rotate from skimage.data import chelsea ny, nx = chelsea().shape[:2] if __has_pycuda__: from nabu.processing.rotation_cuda import CudaRotation import pycuda.gpuarray as garray if __have__skimage__: scenarios = generate_tests_scenarios( { # ~ "output_is_none": [False, True], "mode": ["edge"], "angle": [5.0, 10.0, 45.0, 57.0, 90.0], "center": [None, ((nx - 1) / 2.0, (ny - 1) / 2.0), ((nx - 1) / 2.0, ny - 1)], } ) else: scenarios = {} @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.image = chelsea().mean(axis=-1, dtype=np.float32) if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) cls.d_image = garray.to_gpu(cls.image) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.skipif(not (__have__skimage__), reason="Need scikit-image for rotation") @pytest.mark.usefixtures("bootstrap") class TestRotation: def _get_reference_rotation(self, config): return rotate( self.image, config["angle"], resize=False, center=config["center"], order=1, mode=config["mode"], clip=False, # preserve_range=False, ) def _check_result(self, res, config, tol): ref = self._get_reference_rotation(config) mae = np.max(np.abs(res - ref)) err_msg = str("Max error is too high for this configuration: %s" % str(config)) assert mae < tol, err_msg # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.parametrize("config", scenarios) def test_rotation(self, config): R = Rotation(self.image.shape, config["angle"], center=config["center"], mode=config["mode"]) res = R(self.image) self._check_result(res, config, 1e-6) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda rotation") @pytest.mark.parametrize("config", scenarios) def test_cuda_rotation(self, config): R = CudaRotation( self.image.shape, config["angle"], center=config["center"], mode=config["mode"], cuda_options={"ctx": self.ctx}, ) d_res = R(self.d_image) res = d_res.get() self._check_result(res, config, 0.5) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_transpose.py0000644000175000017500000000531014550227307022163 0ustar00pierrepierreimport numpy as np import pytest from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import __has_pyopencl__, get_opencl_context from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__ from nabu.processing.transpose import CudaTranspose, OpenCLTranspose configs = { "shape": [(300, 451), (300, 300), (255, 300)], "output_is_none": [True, False], "dtype_in_out": [(np.float32, np.float32)], } if __do_long_tests__: configs["dtype_in_out"].extend( [(np.float32, np.complex64), (np.complex64, np.complex64), (np.uint8, np.uint16), (np.uint8, np.int32)] ) scenarios = generate_tests_scenarios(configs) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("chelsea.npz")["data"] cls.tol = 1e-7 if __has_pycuda__: cls.cu_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = get_opencl_context(device_type="all") yield if __has_pycuda__: cls.cu_ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestTranspose: def _do_test_transpose(self, config, transpose_cls): shape = config["shape"] dtype = config["dtype_in_out"][0] dtype_out = config["dtype_in_out"][1] data = np.ascontiguousarray(self.data[: shape[0], : shape[1]], dtype=dtype) backend = transpose_cls.backend if backend == "opencl" and not (np.iscomplexobj(dtype(1))) and np.iscomplexobj(dtype_out(1)): pytest.skip("pyopencl does not support real to complex scalar cast") ctx = self.cu_ctx if backend == "cuda" else self.cl_ctx backend_options = {"ctx": ctx} transpose = transpose_cls(data.shape, dtype, dst_dtype=dtype_out, **backend_options) d_data = transpose.processing.allocate_array("data", shape, dtype) d_data.set(data) if config["output_is_none"]: d_out = None else: d_out = transpose.processing.allocate_array("output", shape[::-1], dtype_out) d_res = transpose(d_data, dst=d_out) assert ( np.max(np.abs(d_res.get() - data.T)) == 0 ), "something wrong with transpose(shape=%s, dtype=%s, dtype_out=%s)" % (shape, dtype, dtype_out) @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") @pytest.mark.parametrize("config", scenarios) def test_cuda_transpose(self, config): self._do_test_transpose(config, CudaTranspose) @pytest.mark.skipif(not (__has_pyopencl__), reason="Need pyopencl for this test") @pytest.mark.parametrize("config", scenarios) def test_opencl_transpose(self, config): self._do_test_transpose(config, OpenCLTranspose) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/tests/test_unsharp.py0000644000175000017500000001044714550227307021634 0ustar00pierrepierrefrom itertools import product import numpy as np import pytest from nabu.processing.unsharp import UnsharpMask from nabu.processing.unsharp_opencl import OpenclUnsharpMask, __have_opencl__ as __has_pyopencl__ from nabu.cuda.utils import __has_pycuda__, get_cuda_context from nabu.testutils import get_data if __has_pyopencl__: from pyopencl import CommandQueue import pyopencl.array as parray from silx.opencl.common import ocl if __has_pycuda__: import pycuda.gpuarray as garray from nabu.processing.unsharp_cuda import CudaUnsharpMask try: from skimage.filters import unsharp_mask __has_skimage__ = True except ImportError: __has_skimage__ = False @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("brain_phantom.npz")["data"] cls.imagej_results = get_data("dirac_unsharp_imagej.npz") cls.tol = 1e-4 cls.sigma = 1.6 cls.coeff = 0.5 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = ocl.create_context() yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestUnsharp: def get_reference_result(self, method, data=None): if data is None: data = self.data unsharp_mask = UnsharpMask(data.shape, self.sigma, self.coeff, method=method) return unsharp_mask.unsharp(data) def check_result(self, result, method, data=None, error_msg_prefix=None): reference = self.get_reference_result(method, data=data) mae = np.max(np.abs(result - reference)) err_msg = str( "%s: max error is too high with method=%s: %.2e > %.2e" % (error_msg_prefix or "", method, mae, self.tol) ) assert mae < self.tol, err_msg @pytest.mark.skipif(not (__has_skimage__), reason="Need scikit-image for this test") def test_mode_gaussian(self): dirac = np.zeros((43, 43), "f") dirac[dirac.shape[0] // 2, dirac.shape[1] // 2] = 1 sigma_list = [0.2, 0.5, 1.0, 2.0, 3.0] coeff_list = [0.5, 1.0, 3.0] for sigma, coeff in product(sigma_list, coeff_list): res = UnsharpMask(dirac.shape, sigma, coeff, method="gaussian").unsharp(dirac) ref = unsharp_mask(dirac, radius=sigma, amount=coeff, preserve_range=True) assert np.max(np.abs(res - ref)) < 1e-6, "Something wrong with mode='gaussian', sigma=%.2f, coeff=%.2f" % ( sigma, coeff, ) def test_mode_imagej(self): dirac = np.zeros(self.imagej_results["images"][0].shape, dtype="f") dirac[dirac.shape[0] // 2, dirac.shape[1] // 2] = 1 for sigma, coeff, ref in zip( self.imagej_results["sigma"], self.imagej_results["amount"], self.imagej_results["images"] ): res = UnsharpMask(dirac.shape, sigma, coeff, method="imagej").unsharp(dirac) assert np.max(np.abs(res - ref)) < 1e-3, "Something wrong with mode='imagej', sigma=%.2f, coeff=%.2f" % ( sigma, coeff, ) @pytest.mark.skipif(not (__has_pyopencl__), reason="Need pyopencl for this test") def test_opencl_unsharp(self): cl_queue = CommandQueue(self.cl_ctx) d_image = parray.to_device(cl_queue, self.data) d_out = parray.zeros_like(d_image) for method in OpenclUnsharpMask.avail_methods: d_image = parray.to_device(cl_queue, self.data) d_out = parray.zeros_like(d_image) opencl_unsharp = OpenclUnsharpMask(self.data.shape, self.sigma, self.coeff, method=method, ctx=self.cl_ctx) opencl_unsharp.unsharp(d_image, output=d_out) res = d_out.get() self.check_result(res, method, error_msg_prefix="OpenclUnsharpMask") @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_unsharp(self): d_image = garray.to_gpu(self.data) d_out = garray.zeros_like(d_image) for method in CudaUnsharpMask.avail_methods: cuda_unsharp = CudaUnsharpMask(self.data.shape, self.sigma, self.coeff, method=method, ctx=self.ctx) cuda_unsharp.unsharp(d_image, output=d_out) res = d_out.get() self.check_result(res, method, error_msg_prefix="CudaUnsharpMask") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/transpose.py0000644000175000017500000001047214550227307017767 0ustar00pierrepierreimport numpy as np from ..utils import get_opencl_srcfile, get_cuda_srcfile, updiv, BaseClassError, MissingComponentError from ..opencl.utils import __has_pyopencl__ from ..cuda.utils import __has_pycuda__ if __has_pyopencl__: from ..opencl.kernel import OpenCLKernel from ..opencl.processing import OpenCLProcessing from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype else: OpenCLKernel = OpenCLProcessing = cl_dtype_to_ctype = MissingComponentError("need pyopencl to use this class") if __has_pycuda__: from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing from pycuda.tools import base_dtype_to_ctype as cu_dtype_to_ctype else: CudaKernel = CudaProcessing = cu_dtype_to_ctype = MissingComponentError("need pycuda to use this class") # pylint: disable=E1101, E1102 class TransposeBase: """ A class for transposing (out-of-place) a cuda or opencl array """ KernelCls = BaseClassError ProcessingCls = BaseClassError dtype_to_ctype = BaseClassError backend = "none" def __init__(self, shape, dtype, dst_dtype=None, **backend_options): self.processing = self.ProcessingCls(**(backend_options or {})) self.shape = shape self.dtype = dtype self.dst_dtype = dst_dtype or dtype if len(shape) != 2: raise ValueError("Expected 2D array") self._kernel_init_args = [ "transpose", ] self._kernel_init_kwargs = { "options": [ "-DSRC_DTYPE=%s" % self.dtype_to_ctype(self.dtype), "-DDST_DTYPE=%s" % self.dtype_to_ctype(self.dst_dtype), ], } self._configure_kenel_initialization() self._transpose_kernel = self.KernelCls(*self._kernel_init_args, **self._kernel_init_kwargs) self._configure_kernel_call() def __call__(self, arr, dst=None): if dst is None: dst = self.processing.allocate_array("dst", self.shape[::-1], dtype=self.dst_dtype) self._transpose_kernel(arr, dst, np.int32(self.shape[1]), np.int32(self.shape[0]), **self._kernel_kwargs) return dst class CudaTranspose(TransposeBase): KernelCls = CudaKernel ProcessingCls = CudaProcessing dtype_to_ctype = cu_dtype_to_ctype backend = "cuda" def _configure_kenel_initialization(self): self._kernel_init_kwargs.update( { "filename": get_cuda_srcfile("transpose.cu"), "signature": "PPii", } ) def _configure_kernel_call(self): block = (32, 32, 1) grid = [updiv(a, b) for a, b in zip(self.shape, block)] self._kernel_kwargs = {"grid": grid, "block": block} class OpenCLTranspose(TransposeBase): KernelCls = OpenCLKernel ProcessingCls = OpenCLProcessing dtype_to_ctype = cl_dtype_to_ctype backend = "opencl" def _configure_kenel_initialization(self): self._kernel_init_args.append(self.processing.ctx) self._kernel_init_kwargs.update( { "filename": get_opencl_srcfile("transpose.cl"), "queue": self.processing.queue, } ) def _configure_kernel_call(self): block = (16, 16, 1) grid = [updiv(a, b) * b for a, b in zip(self.shape, block)] self._kernel_kwargs = {"global_size": grid, "local_size": block} # # An attempt to have a simplified access to transpose operation # # (backend, shape, dtype, dtype_out) _transposes_store = {} def transpose(array, dst=None, **backend_options): if hasattr(array, "with_queue"): backend = "opencl" transpose_cls = OpenCLTranspose backend_options["queue"] = array.queue # ! elif hasattr(array, "bind_to_texref"): backend = "cuda" transpose_cls = CudaTranspose else: raise ValueError("array should be either a pycuda.gpuarray.GPUArray or pyopencl.array.Array instance") dst_dtype = dst.dtype if dst is not None else None key = (backend, array.shape, np.dtype(array.dtype), dst_dtype) transpose_instance = _transposes_store.get(key, None) if transpose_instance is None: transpose_instance = transpose_cls(array.shape, array.dtype, dst_dtype=dst_dtype, **backend_options) _transposes_store[key] = transpose_instance return transpose_instance(array, dst=dst) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/unsharp.py0000644000175000017500000000517214550227307017432 0ustar00pierrepierreimport numpy as np from scipy.ndimage import convolve1d from silx.image.utils import gaussian_kernel class UnsharpMask: """ A helper class for unsharp masking. """ avail_methods = ["gaussian", "log", "imagej"] def __init__(self, shape, sigma, coeff, mode="reflect", method="gaussian"): """ Initialize a Unsharp mask. `UnsharpedImage = (1 + coeff)*Image - coeff * ConvolutedImage` If method == "log": `UnsharpedImage = Image + coeff*ConvolutedImage` Parameters ----------- shape: tuple Shape of the image. sigma: float Standard deviation of the Gaussian kernel coeff: float Coefficient in the linear combination of unsharp mask mode: str, optional Convolution mode. Default is "reflect" method: str, optional Method of unsharp mask. Can be "gaussian" (default) or "log" (Laplacian of Gaussian), or "imagej". Notes ----- The computation is the following depending on the method: - For method="gaussian": output = (1 + coeff) * image - coeff * image_blurred - For method="log": output = image + coeff * image_blurred - For method="imagej": output = (image - coeff*image_blurred)/(1-coeff) """ self.shape = shape self.ndim = len(self.shape) self.sigma = sigma self.coeff = coeff self._set_method(method) self.mode = mode self._compute_gaussian_kernel() def _set_method(self, method): if method not in self.avail_methods: raise ValueError("Unknown unsharp method '%s'. Available are %s" % (method, str(self.avail_methods))) self.method = method def _compute_gaussian_kernel(self): self._gaussian_kernel = np.ascontiguousarray(gaussian_kernel(self.sigma), dtype=np.float32) def _blur2d(self, image): res1 = convolve1d(image, self._gaussian_kernel, axis=1, mode=self.mode) res = convolve1d(res1, self._gaussian_kernel, axis=0, mode=self.mode) return res def unsharp(self, image, output=None): """ Reference unsharp mask implementation. """ image_b = self._blur2d(image) if self.method == "gaussian": res = (1 + self.coeff) * image - self.coeff * image_b elif self.method == "log": res = image + self.coeff * image_b else: # "imagej": res = (image - self.coeff * image_b) / (1 - self.coeff) if output is not None: output[:] = res[:] return output return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/processing/unsharp_cuda.py0000644000175000017500000000415014654107202020415 0ustar00pierrepierrefrom ..cuda.utils import __has_pycuda__ from ..processing.convolution_cuda import Convolution from ..cuda.processing import CudaProcessing from .unsharp import UnsharpMask if __has_pycuda__: from pycuda.elementwise import ElementwiseKernel class CudaUnsharpMask(UnsharpMask): def __init__(self, shape, sigma, coeff, mode="reflect", method="gaussian", **cuda_options): """ Unsharp Mask, cuda backend. """ super().__init__(shape, sigma, coeff, mode=mode, method=method) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_convolution() self._init_mad_kernel() self.cuda_processing.init_arrays_to_none(["_d_out"]) def _init_convolution(self): self.convolution = Convolution( self.shape, self._gaussian_kernel, mode=self.mode, extra_options={ # Use the lowest amount of memory "allocate_input_array": False, "allocate_output_array": False, "allocate_tmp_array": True, }, ) def _init_mad_kernel(self): # garray.GPUArray.mul_add is out of place... self.mad_kernel = ElementwiseKernel( # pylint: disable=E0606 "float* array, float fac, float* other, float otherfac", "array[i] = fac * array[i] + otherfac * other[i]", name="mul_add", ) def unsharp(self, image, output=None): if output is None: output = self.cuda_processing.allocate_array("_d_out", self.shape, "f") self.convolution(image, output=output) if self.method == "gaussian": self.mad_kernel(output, -self.coeff, image, 1.0 + self.coeff) elif self.method == "log": # output = output * coeff + image where output was image_blurred self.mad_kernel(output, self.coeff, image, 1.0) else: # "imagej": # output = (image - coeff*image_blurred)/(1-coeff) where output was image_blurred self.mad_kernel(output, -self.coeff / (1 - self.coeff), image, 1.0 / (1 - self.coeff)) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/processing/unsharp_opencl.py0000644000175000017500000000511514550227307020767 0ustar00pierrepierretry: import pyopencl.array as parray from pyopencl.elementwise import ElementwiseKernel from ..opencl.processing import OpenCLProcessing __have_opencl__ = True except ImportError: __have_opencl__ = False from .unsharp import UnsharpMask class OpenclUnsharpMask(UnsharpMask): def __init__( self, shape, sigma, coeff, mode="reflect", method="gaussian", **opencl_options, ): """ NB: For now, this class is designed to use the lowest amount of GPU memory as possible. Therefore, the input and output image/volumes are assumed to be already on device. """ if not (__have_opencl__): raise ImportError("Need pyopencl") super().__init__(shape, sigma, coeff, mode=mode, method=method) self.cl_processing = OpenCLProcessing(**(opencl_options or {})) self._init_convolution() self._init_mad_kernel() def _init_convolution(self): # Do it here because silx creates OpenCL contexts all over the place at import from silx.opencl.convolution import Convolution as CLConvolution self.convolution = CLConvolution( self.shape, self._gaussian_kernel, mode=self.mode, ctx=self.cl_processing.ctx, extra_options={ # Use the lowest amount of memory "allocate_input_array": False, "allocate_output_array": False, "allocate_tmp_array": True, "dont_use_textures": True, }, ) def _init_mad_kernel(self): # parray.Array.mul_add is out of place... self.mad_kernel = ElementwiseKernel( self.cl_processing.ctx, "float* array, float fac, float* other, float otherfac", "array[i] = fac * array[i] + otherfac * other[i]", name="mul_add", ) def unsharp(self, image, output): # For now image and output are assumed to be already allocated on device assert isinstance(image, self.cl_processing.array_class) assert isinstance(output, self.cl_processing.array_class) self.convolution(image, output=output) if self.method == "gaussian": self.mad_kernel(output, -self.coeff, image, 1.0 + self.coeff) elif self.method == "log": self.mad_kernel(output, self.coeff, image, 1.0) else: # "imagej": self.mad_kernel(output, -self.coeff / (1 - self.coeff), image, 1.0 / (1 - self.coeff)) return output # Alias OpenCLUnsharpMask = OpenclUnsharpMask ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/reconstruction/0000755000175000017500000000000014730277752016311 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/reconstruction/__init__.py0000644000175000017500000000024314402565210020402 0ustar00pierrepierrefrom .reconstructor import Reconstructor from .rings import MunchDeringer, munchetal_filter from .sinogram import SinoBuilder, convert_halftomo, SinoNormalization ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556963.0 nabu-2024.2.1/nabu/reconstruction/cone.py0000644000175000017500000004466614656662143017626 0ustar00pierrepierrefrom math import sqrt import numpy as np from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing from ..reconstruction.filtering_cuda import CudaSinoFilter from ..utils import get_cuda_srcfile try: import astra __have_astra__ = True except ImportError: __have_astra__ = False class ConebeamReconstructor: """ A reconstructor for cone-beam geometry using the astra toolbox. """ default_extra_options = { "axis_correction": None, "clip_outer_circle": False, "scale_factor": None, "filter_cutoff": 1.0, "outer_circle_value": 0.0, # "use_astra_fdk": True, "use_astra_fdk": False, } def __init__( self, sinos_shape, source_origin_dist, origin_detector_dist, angles=None, volume_shape=None, rot_center=None, relative_z_position=None, pixel_size=None, padding_mode="zeros", filter_name=None, slice_roi=None, cuda_options=None, extra_options=None, ): """ Initialize a cone beam reconstructor. This reconstructor works on slabs of data, meaning that one partial volume is obtained from one stack of sinograms. To reconstruct a full volume, the reconstructor must be called on a series of sinograms stacks, with an updated "relative_z_position" each time. Parameters ----------- sinos_shape: tuple Shape of the sinograms stack, in the form (n_sinos, n_angles, prj_width) source_origin_dist: float Distance, in pixel units, between the beam source (cone apex) and the "origin". The origin is defined as the center of the sample origin_detector_dist: float Distance, in pixel units, between the center of the sample and the detector. angles: array, optional Rotation angles in radians. If provided, its length should be equal to sinos_shape[1]. volume_shape: tuple of int, optional Shape of the output volume slab, in the form (n_z, n_y, n_x). If not provided, the output volume slab shape is (sinos_shape[0], sinos_shape[2], sinos_shape[2]). rot_center: float, optional Rotation axis position. Default is `(detector_width - 1)/2.0` relative_z_position: float, optional Position of the central slice of the slab, with respect to the full stack of slices. By default it is set to zero, meaning that the current slab is assumed in the middle of the stack axis_correction: array, optional Array of the same size as the number of projections. Each corresponds to a horizontal displacement. pixel_size: float or tuple, optional Size of the pixel. Possible options: - Nothing is provided (default): in this case, all lengths are normalized with respect to the pixel size, i.e 'source_origin_dist' and 'origin_detector_dist' should be expressed in pixels (and 'pixel_size' is set to 1). - A scalar number is provided: in this case it is the spacing between two pixels (in each dimension) - A tuple is provided: in this case it is the spacing between two pixels in both dimensions, vertically then horizontally, i.e (detector_spacing_y, detector_spacing_x) scale_factor: float, optional Post-reconstruction scale factor. padding_mode: str, optional How to pad the data before applying FDK. By default this is done by astra with zero-padding. If padding_mode is other than "zeros", it will be done by nabu and the padded data is passed to astra where no additional padding is done. Beware that in its current implementation, this option almost doubles the memory needed. slice_roi: Whether to reconstruct only a region of interest for each horizontal slice. This parameter must be in the form (start_x, end_x, start_y, end_y) with no negative values. Note that the current implementation just crops the final reconstructed volume, i.e there is no speed or memory benefit. use_astra_fdk: bool Whether to use the native Astra Toolbox FDK implementation. If set to False, the cone-beam pre-weighting and projections padding/filtering is done by nabu. Note that this parameter is automatically set to False if padding_mode != "zeros". Notes ------ This reconstructor is using the astra toolbox [1]. Therefore the implementation uses Astra's reference frame, which is centered on the sample (source and detector move around the sample). For more information see Fig. 2 of paper [1]. To define the cone-beam geometry, two distances are needed: - Source-origin distance (hereby d1) - Origin-detector distance (hereby d2) The magnification at distance d2 is m = 1+d2/d1, so given a detector pixel size p_s, the sample voxel size is p_s/m. To make things simpler, this class internally uses a different (but equivalent) geometry: - d2 is set to zero, meaning that the detector is (virtually) moved to the center of the sample - The detector is "re-scaled" to have a pixel size equal to the voxel size (p_s/m) Having the detector in the same plane as the sample center simplifies things when it comes to slab-wise reconstruction: defining a volume slab (in terms of z_min, z_max) is equivalent to define the detector bounds, like in parallel geometry. References ----------- [1] Aarle, Wim & Palenstijn, Willem & Cant, Jeroen & Janssens, Eline & Bleichrodt, Folkert & Dabravolski, Andrei & De Beenhouwer, Jan & Batenburg, Kees & Sijbers, Jan. (2016). Fast and flexible X-ray tomography using the ASTRA toolbox. Optics Express. 24. 25129-25147. 10.1364/OE.24.025129. """ self._configure_extra_options(extra_options) self._init_cuda(cuda_options) self._set_sino_shape(sinos_shape) self._orig_prog_geom = None self._init_geometry( source_origin_dist, origin_detector_dist, pixel_size, angles, volume_shape, rot_center, relative_z_position, slice_roi, ) self._init_fdk(padding_mode, filter_name) self._alg_id = None self._vol_id = None self._proj_id = None def _configure_extra_options(self, extra_options): self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) def _init_cuda(self, cuda_options): cuda_options = cuda_options or {} self.cuda = CudaProcessing(**cuda_options) def _set_sino_shape(self, sinos_shape): if len(sinos_shape) != 3: raise ValueError("Expected a 3D shape") self.sinos_shape = sinos_shape self.n_sinos, self.n_angles, self.prj_width = sinos_shape def _init_fdk(self, padding_mode, filter_name): self.padding_mode = padding_mode self._use_astra_fdk = bool(self.extra_options.get("use_astra_fdk", True)) self._use_astra_fdk &= padding_mode in ["zeros", "constant", None, "none"] if self._use_astra_fdk: return self.sino_filter = CudaSinoFilter( self.sinos_shape[1:], filter_name=filter_name, padding_mode=self.padding_mode, # TODO (?) configure FFT backend extra_options={"cutoff": self.extra_options.get("filter_cutoff", 1.0)}, cuda_options={"ctx": self.cuda.ctx}, ) # In astra, FDK pre-weighting does the "n_a/(pi/2) multiplication" # TODO not sure where this "magnification **2" factor comes from ? mult_factor = self.n_angles / 3.141592 * 2 / (self.magnification**2) self.sino_filter.set_filter(self.sino_filter.filter_f * mult_factor, normalize=False) # def _set_pixel_size(self, pixel_size): if pixel_size is None: det_spacing_y = det_spacing_x = 1 elif np.iterable(pixel_size): det_spacing_y, det_spacing_x = pixel_size else: # assuming scalar det_spacing_y = det_spacing_x = pixel_size self._det_spacing_y = det_spacing_y self._det_spacing_x = det_spacing_x def _set_slice_roi(self, slice_roi): self.slice_roi = slice_roi self._vol_geom_n_x = self.n_x self._vol_geom_n_y = self.n_y self._crop_data = True if slice_roi is None: return start_x, end_x, start_y, end_y = slice_roi if roi_is_centered(self.volume_shape[1:], (slice(start_y, end_y), slice(start_x, end_x))): # For FDK, astra can only reconstruct subregion centered around the origin self._vol_geom_n_x = self.n_x - start_x * 2 self._vol_geom_n_y = self.n_y - start_y * 2 else: raise NotImplementedError( "Cone-beam geometry supports only slice_roi centered around origin (got slice_roi=%s with n_x=%d, n_y=%d)" % (str(slice_roi), self.n_x, self.n_y) ) def _init_geometry( self, source_origin_dist, origin_detector_dist, pixel_size, angles, volume_shape, rot_center, relative_z_position, slice_roi, ): if angles is None: self.angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=True) else: self.angles = angles if volume_shape is None: volume_shape = (self.sinos_shape[0], self.sinos_shape[2], self.sinos_shape[2]) self.volume_shape = volume_shape self.n_z, self.n_y, self.n_x = self.volume_shape self.source_origin_dist = source_origin_dist self.origin_detector_dist = origin_detector_dist self.magnification = 1 + origin_detector_dist / source_origin_dist self._set_slice_roi(slice_roi) self.vol_geom = astra.create_vol_geom(self._vol_geom_n_y, self._vol_geom_n_x, self.n_z) self.vol_shape = astra.geom_size(self.vol_geom) self._cor_shift = 0.0 self.rot_center = rot_center if rot_center is not None: self._cor_shift = (self.sinos_shape[-1] - 1) / 2.0 - rot_center self._set_pixel_size(pixel_size) self._axis_corrections = self.extra_options.get("axis_correction", None) self._create_astra_proj_geometry(relative_z_position) def _create_astra_proj_geometry(self, relative_z_position): # This object has to be re-created each time, because once the modifications below are done, # it is no more a "cone" geometry but a "cone_vec" geometry, and cannot be updated subsequently # (see astra/functions.py:271) self.proj_geom = astra.create_proj_geom( "cone", self._det_spacing_x, self._det_spacing_y, self.n_sinos, self.prj_width, self.angles, self.source_origin_dist, self.origin_detector_dist, ) self.relative_z_position = relative_z_position or 0.0 # This will turn the geometry of type "cone" into a geometry of type "cone_vec" if self._orig_prog_geom is None: self._orig_prog_geom = self.proj_geom self.proj_geom = astra.geom_postalignment(self.proj_geom, (self._cor_shift, 0)) # (src, detector_center, u, v) = (srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ) vecs = self.proj_geom["Vectors"] # To adapt the center of rotation: # dX = cor_shift * cos(theta) - origin_detector_dist * sin(theta) # dY = origin_detector_dist * cos(theta) + cor_shift * sin(theta) if self._axis_corrections is not None: # should we check that dX and dY match the above formulas ? cor_shifts = self._cor_shift + self._axis_corrections vecs[:, 3] = cor_shifts * np.cos(self.angles) - self.origin_detector_dist * np.sin(self.angles) vecs[:, 4] = self.origin_detector_dist * np.cos(self.angles) + cor_shifts * np.sin(self.angles) # To adapt the z position: # Component 2 of vecs is the z coordinate of the source, component 5 is the z component of the detector position # We need to re-create the same inclination of the cone beam, thus we need to keep the inclination of the two z positions. # The detector is centered on the rotation axis, thus moving it up or down, just moves it out of the reconstruction volume. # We can bring back the detector in the correct volume position, by applying a rigid translation of both the detector and the source. # The translation is exactly the amount that brought the detector up or down, but in the opposite direction. vecs[:, 2] = -self.relative_z_position def _set_output(self, volume): if volume is not None: expected_shape = self.vol_shape # if not (self._crop_data) else self._output_cropped_shape self.cuda.check_array(volume, expected_shape) self.cuda.set_array("output", volume) if volume is None: self.cuda.allocate_array("output", self.vol_shape) d_volume = self.cuda.get_array("output") z, y, x = d_volume.shape self._vol_link = astra.data3d.GPULink(d_volume.ptr, x, y, z, d_volume.strides[-2]) self._vol_id = astra.data3d.link("-vol", self.vol_geom, self._vol_link) def _set_input(self, sinos): self.cuda.check_array(sinos, self.sinos_shape) self.cuda.set_array("sinos", sinos) # self.cuda.sinos is now a GPU array # TODO don't create new link/proj_id if ptr is the same ? # But it seems Astra modifies the input sinogram while doing FDK, so this might be not relevant d_sinos = self.cuda.get_array("sinos") # self._proj_data_link = astra.data3d.GPULink(d_sinos.ptr, self.prj_width, self.n_angles, self.n_z, sinos.strides[-2]) self._proj_data_link = astra.data3d.GPULink( d_sinos.ptr, self.prj_width, self.n_angles, self.n_sinos, d_sinos.strides[-2] ) self._proj_id = astra.data3d.link("-sino", self.proj_geom, self._proj_data_link) def _preprocess_data(self): if self._use_astra_fdk: return d_sinos = self.cuda.sinos fdk_preweighting( d_sinos, self._orig_prog_geom, relative_z_position=self.relative_z_position, cor_shift=self._cor_shift ) for i in range(d_sinos.shape[0]): self.sino_filter.filter_sino(d_sinos[i], output=d_sinos[i]) def _update_reconstruction(self): if self._use_astra_fdk: cfg = astra.astra_dict("FDK_CUDA") else: cfg = astra.astra_dict("BP3D_CUDA") cfg["ReconstructionDataId"] = self._vol_id cfg["ProjectionDataId"] = self._proj_id if self._alg_id is not None: astra.algorithm.delete(self._alg_id) self._alg_id = astra.algorithm.create(cfg) def reconstruct(self, sinos, output=None, relative_z_position=None): """ sinos: numpy.ndarray or pycuda.gpuarray Sinograms, with shape (n_sinograms, n_angles, width) output: pycuda.gpuarray, optional Output array. If not provided, a new numpy array is returned relative_z_position: int, optional Position of the central slice of the slab, with respect to the full stack of slices. By default it is set to zero, meaning that the current slab is assumed in the middle of the stack """ self._create_astra_proj_geometry(relative_z_position) self._set_input(sinos) self._set_output(output) self._preprocess_data() self._update_reconstruction() astra.algorithm.run(self._alg_id) # # NB: Could also be done with # from astra.experimental import direct_BP3D # projector_id = astra.create_projector("cuda3d", self.proj_geom, self.vol_geom, options=None) # direct_BP3D(projector_id, self._vol_link, self._proj_data_link) # result = self.cuda.get_array("output") if output is None: result = result.get() if self.extra_options.get("scale_factor", None) is not None: result *= np.float32(self.extra_options["scale_factor"]) # in-place for pycuda self.cuda.recover_arrays_references(["sinos", "output"]) return result def __del__(self): if getattr(self, "_alg_id", None) is not None: astra.algorithm.delete(self._alg_id) if getattr(self, "_vol_id", None) is not None: astra.data3d.delete(self._vol_id) if getattr(self, "_proj_id", None) is not None: astra.data3d.delete(self._proj_id) def selection_is_centered(size, start, stop): """ Return True if (start, stop) define a selection that is centered on the middle of the array. """ if stop > 0: stop -= size return stop == -start def roi_is_centered(shape, slice_): """ Return True if "slice_" define a selection that is centered on the middle of the array. """ return all([selection_is_centered(shp, s.start, s.stop) for shp, s in zip(shape, slice_)]) def fdk_preweighting(d_sinos, proj_geom, relative_z_position=0.0, cor_shift=0.0): preweight_kernel = CudaKernel( "devFDK_preweight", filename=get_cuda_srcfile("cone.cu"), signature="Piiifffffiii", ) # n_angles, n_z, n_x = d_sinos.shape n_z, n_angles, n_x = d_sinos.shape det_origin = sqrt(proj_geom["DistanceOriginDetector"] ** 2 + cor_shift**2) block = (32, 16, 1) grid = (((n_x + 32 - 1) // 32) * ((n_z + 32 - 1) // 32), (n_angles + 16 - 1) // 16, 1) preweight_kernel( d_sinos, np.uint32(n_x), # unsigned int projPitch, np.uint32(0), # unsigned int startAngle, np.uint32(n_angles), # unsigned int endAngle, np.float32(proj_geom["DistanceOriginSource"]), # float fSrcOrigin, np.float32(det_origin), # float fDetOrigin, np.float32(relative_z_position), # float fZShift, np.float32(proj_geom["DetectorSpacingX"]), # float fDetUSize, np.float32(proj_geom["DetectorSpacingY"]), # float fDetVSize, np.int32(n_angles), # dims.iProjAngles; np.int32(n_x), # dims.iProjU; // number of detectors in the U direction np.int32(n_z), # dims.iProjV // number of detectors in the V direction block=block, grid=grid, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/reconstruction/fbp.py0000644000175000017500000001162714656662150017436 0ustar00pierrepierreimport numpy as np import pycuda.driver as cuda from ..utils import updiv, get_cuda_srcfile from ..cuda.utils import copy_array, check_textures_availability from ..cuda.processing import CudaProcessing from ..cuda.kernel import CudaKernel from .filtering_cuda import CudaSinoFilter from .sinogram_cuda import CudaSinoMult from .fbp_base import BackprojectorBase class CudaBackprojector(BackprojectorBase): backend = "cuda" kernel_filename = "backproj.cu" backend_processing_class = CudaProcessing SinoFilterClass = CudaSinoFilter SinoMultClass = CudaSinoMult def _check_textures_availability(self): self._use_textures = self.extra_options.get("use_textures", True) and check_textures_availability() def _get_kernel_signature(self): kern_full_sig = list("PPiifiiffPPPf") if self._axis_correction is None: kern_full_sig[11] = "" if self._use_textures: # texture references - no object is passed (deprecated, removed in Cuda 12) kern_full_sig[1] = "" return "".join(kern_full_sig) def _get_kernel_options(self): super()._get_kernel_options() self._kernel_options.update( { "file_name": get_cuda_srcfile(self.kernel_filename), "kernel_signature": self._get_kernel_signature(), "texture_name": "tex_projections", } ) def _prepare_kernel_args(self): super()._prepare_kernel_args() self.kern_proj_kwargs.update( { "shared_size": self._kernel_options["shared_size"], } ) # texture references - no object is passed (deprecated, removed in Cuda 12) if self._use_textures: self.kern_proj_args.pop(1) else: self._d_sino = self._processing.allocate_array("_d_sino", self.sino_shape) self.kern_proj_args[1] = self._d_sino.gpudata def _prepare_textures(self): if self._use_textures: self.texref_proj = self.gpu_projector.module.get_texref(self._kernel_options["texture_name"]) self.texref_proj.set_filter_mode(cuda.filter_mode.LINEAR) self.gpu_projector.prepare(self._kernel_options["kernel_signature"], [self.texref_proj]) # Bind texture self._d_sino_cua = cuda.np_to_array(np.zeros(self.sino_shape, "f"), "C") self.texref_proj.set_array(self._d_sino_cua) else: # d_sino_ref = self._d_sino.gpudata # self.kern_proj_args.insert(2, d_sino_ref) self.gpu_projector.prepare(self._kernel_options["kernel_signature"], []) def _compile_kernels(self): self._prepare_kernel_args() if self._use_textures: self._kernel_options["sourcemodule_options"].append("-DUSE_TEXTURES") self.gpu_projector = CudaKernel( self._kernel_options["kernel_name"], filename=self._kernel_options["file_name"], options=self._kernel_options["sourcemodule_options"], silent_compilation_warnings=True, # textures and Cuda 11 ) if self.halftomo and self.rot_center < self.dwidth: self.sino_mult = CudaSinoMult(self.sino_shape, self.rot_center, ctx=self._processing.ctx) self._prepare_textures() # has to be done after compilation for Cuda (to bind texture to built kernel) def _transfer_to_texture(self, sino, do_checks=True): if do_checks and not (sino.flags.c_contiguous): raise ValueError("Expected C-Contiguous array") if self._use_textures: copy_array(self._d_sino_cua, sino, check=do_checks) else: if id(self._d_sino) == id(sino): return self._d_sino[:] = sino[:] # COMPAT. Backprojector = CudaBackprojector class PolarBackprojector(Backprojector): """ Cuda Backprojector with output in polar coordinates. """ cuda_fname = "backproj_polar.cu" cuda_kernel_name = "backproj_polar" # patch parent method: force slice_shape to (n_angles, n_x) def _set_angles(self, angles, n_angles): Backprojector._set_angles(self, angles, n_angles) self.slice_shape = (self.n_angles, self.n_x) # patch parent method: def _set_slice_roi(self, slice_roi): if slice_roi is not None: raise ValueError("slice_roi is not supported with this class") Backprojector._set_slice_roi(self, slice_roi) # patch parent method: don't do the 4X compute-workload optimization for this kernel def _get_kernel_options(self): Backprojector._get_kernel_options(self) block = self._kernel_options["block"] self._kernel_options["grid"] = (updiv(self.n_x, block[0]), updiv(self.n_y, block[1])) # patch parent method: update kernel args def _compile_kernels(self): n_y = self.n_y self.n_y = self.n_angles Backprojector._compile_kernels(self) self.n_y = n_y ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/reconstruction/fbp_base.py0000644000175000017500000004130414654107202020411 0ustar00pierrepierreimport numpy as np from ..utils import updiv, nextpow2, convert_index, deprecation_warning from ..processing.processing_base import ProcessingBase from .filtering import SinoFilter from .sinogram import SinoMult from .sinogram import get_extended_sinogram_width class BackprojectorBase: """ Base class for backprojectors. """ backend = "numpy" default_padding_mode = "zeros" kernel_name = "backproj" default_extra_options = { "padding_mode": None, "axis_correction": None, "centered_axis": False, "clip_outer_circle": False, "scale_factor": None, "filter_cutoff": 1.0, "outer_circle_value": 0.0, } kernel_filename = None backend_processing_class = ProcessingBase SinoFilterClass = SinoFilter SinoMultClass = SinoMult _sino_filter_other_options = {} def __init__( self, sino_shape, slice_shape=None, angles=None, rot_center=None, padding_mode=None, halftomo=False, filter_name=None, slice_roi=None, scale_factor=None, extra_options=None, backend_options=None, ): """ Initialize a Backprojector. Parameters ----------- sino_shape: tuple Shape of the sinogram, in the form `(n_angles, detector_width)` (for backprojecting one sinogram) or `(n_sinos, n_angles, detector_width)`. slice_shape: int or tuple, optional Shape of the slice. By default, the slice shape is (n_x, n_x) where `n_x = detector_width` angles: array-like, optional Rotation anles in radians. By default, angles are equispaced between [0, pi[. rot_center: float, optional Rotation axis position. Default is `(detector_width - 1)/2.0` padding_mode: str, optional Padding mode when filtering the sinogram. Can be "zeros" (default) or "edges". filter_name: str, optional Name of the filter for filtered-backprojection. slice_roi: tuple, optional. Whether to backproject in a restricted area. If set, it must be in the form (start_x, end_x, start_y, end_y). `end_x` and `end_y` are non inclusive ! For example if the detector has 2048 pixels horizontally, then you can choose `start_x=0` and `end_x=2048`. If one of the value is set to None, it is replaced with a default value (0 for start, n_x and n_y for end) scale_factor: float, optional Scaling factor for backprojection. For example, to get the linear absorption coefficient in 1/cm, this factor has to be set as the pixel size in cm. DEPRECATED - please use this parameter in "extra_options" extra_options: dict, optional Advanced extra options. See the "Extra options" section for more information. backend_options: dict, optional OpenCL/Cuda options passed to the OpenCLProcessing or CudaProcessing class. Other parameters ----------------- extra_options: dict, optional Dictionary with a set of advanced options. The default are the following: - "padding_mode": "zeros" Padding mode when filtering the sinogram. Can be "zeros" or "edges". DEPRECATED - please use "padding_mode" directly in parameters. - "axis_correction": None Whether to set a correction for the rotation axis. If set, this should be an array with as many elements as the number of angles. This is useful when there is an horizontal displacement of the rotation axis. - centered_axis: bool Whether to "center" the slice on the rotation axis position. If set to True, then the reconstructed region is centered on the rotation axis. - scale_factor: float Scaling factor for backprojection. For example, to get the linear absorption coefficient in 1/cm, this factor has to be set as the pixel size in cm. - clip_outer_circle: False Whether to set to zero the pixels outside the reconstruction mask - filter_cutoff: float Cut-off frequency usef for Fourier filter. Default is 1.0 """ self._processing = self.backend_processing_class(**(backend_options or {})) self._configure_extra_options(scale_factor, padding_mode, extra_options=extra_options) self._check_textures_availability() self._init_geometry(sino_shape, slice_shape, angles, rot_center, halftomo, slice_roi) self._init_filter(filter_name) self._allocate_memory() self._compute_angles() self._compile_kernels() def _configure_extra_options(self, scale_factor, padding_mode, extra_options=None): extra_options = extra_options or {} # compat. scale_factor = None if scale_factor is not None: deprecation_warning( "Please use the parameter 'scale_factor' in the 'extra_options' dict", do_print=True, func_name="fbp_scale_factor", ) scale_factor = extra_options.get("scale_factor", None) or scale_factor or 1.0 # if "padding_mode" in extra_options: deprecation_warning( "Please use 'padding_mode' directly in Backprojector arguments, not in 'extra_options'", do_print=True, func_name="fbp_padding_mode", ) # self._backproj_scale_factor = scale_factor self._axis_array = None self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options) self.padding_mode = padding_mode or self.extra_options["padding_mode"] or self.default_padding_mode self._axis_array = self.extra_options["axis_correction"] def _init_geometry(self, sino_shape, slice_shape, angles, rot_center, halftomo, slice_roi): if slice_shape is not None and slice_roi is not None: raise ValueError("slice_shape and slice_roi cannot be used together") self.sino_shape = sino_shape if len(sino_shape) == 2: n_angles, dwidth = sino_shape else: raise ValueError("Expected 2D sinogram") self.dwidth = dwidth self.halftomo = halftomo if rot_center is None: if halftomo: raise ValueError("Need to know 'rot_center' when using halftomo") rot_center = (self.dwidth - 1) / 2.0 self.rot_center = rot_center self._set_slice_shape(slice_shape) self.axis_pos = self.rot_center self._set_angles(angles, n_angles) self._set_slice_roi(slice_roi) # # offset = start - move # move = 0 if not(centered_axis) else start + (n-1)/2. - c if self.extra_options["centered_axis"]: self.offsets = { "x": self.rot_center - (self.n_x - 1) / 2.0, "y": self.rot_center - (self.n_y - 1) / 2.0, } # self._set_axis_corr() def _set_slice_shape(self, slice_shape): if not (self.halftomo): n_x = n_y = self.dwidth else: n_x = n_y = get_extended_sinogram_width(self.dwidth, self.rot_center) if slice_shape is not None: if np.isscalar(slice_shape): slice_shape = (slice_shape, slice_shape) n_y, n_x = slice_shape self.n_x = n_x self.n_y = n_y self.slice_shape = (n_y, n_x) def _set_angles(self, angles, n_angles): self.n_angles = n_angles if angles is None: angles = n_angles if np.isscalar(angles): end_angle = np.pi if not (self.halftomo) else 2 * np.pi take_end_angle = self.halftomo angles = np.linspace(0, end_angle, angles, take_end_angle) else: assert len(angles) == self.n_angles, "expected %d angles but got %d" % (len(angles), self.n_angles) self.angles = angles def _set_slice_roi(self, slice_roi): self.offsets = {"x": 0, "y": 0} self.slice_roi = slice_roi if slice_roi is None: return start_x, end_x, start_y, end_y = slice_roi # convert negative indices start_x = convert_index(start_x, self.n_x, 0) start_y = convert_index(start_y, self.n_y, 0) end_x = convert_index(end_x, self.n_x, self.n_x) end_y = convert_index(end_y, self.n_y, self.n_y) self.slice_shape = (end_y - start_y, end_x - start_x) self.n_x = self.slice_shape[-1] self.n_y = self.slice_shape[-2] self.offsets = {"x": start_x, "y": start_y} def _allocate_memory(self): # 1D textures are not supported in pyopencl self.h_msin = np.zeros((1, self.n_angles), "f") self.h_cos = np.zeros((1, self.n_angles), "f") # self._d_sino = self._processing.allocate_array("d_sino", self.sino_shape, "f") self._processing.init_arrays_to_none(["_d_slice", "d_sino"]) def _compute_angles(self): self.h_cos[0] = np.cos(self.angles).astype("f") self.h_msin[0] = (-np.sin(self.angles)).astype("f") self._d_msin = self._processing.set_array("d_msin", self.h_msin[0]) self._d_cos = self._processing.set_array("d_cos", self.h_cos[0]) if self._axis_correction is not None: self._d_axcorr = self._processing._set_array("d_axcorr", self._axis_correction) def _set_axis_corr(self): axcorr = self.extra_options["axis_correction"] self._axis_correction = axcorr if axcorr is None: return if len(axcorr) != self.n_angles: raise ValueError("Expected %d angles but got %d" % (self.n_angles, len(axcorr))) self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32) self._axis_correction[0, :] = axcorr[:] # pylint: disable=E1136 def _init_filter(self, filter_name): self.filter_name = filter_name if filter_name in ["None", "none"]: self.sino_filter = None return sinofilter_other_kwargs = {} if self.backend != "numpy": sinofilter_other_kwargs["%s_options" % self.backend] = {"ctx": self._processing.ctx} self.sino_filter = self.SinoFilterClass( self.sino_shape, filter_name=self.filter_name, padding_mode=self.padding_mode, extra_options={"cutoff": self.extra_options.get("filter_cutoff", 1.0)}, **sinofilter_other_kwargs, ) if self.halftomo: # When doing half-tomography, each projections is seen "twice". # SinoFilter normalizes with pi/n_angles, but in half-tomography here n_angles is somehow halved. # TODO it should even be "n_turns", where n_turns can be computed from the angles self.sino_filter.set_filter(self.sino_filter.filter_f * (self.n_angles / np.pi * 2)) def reset_rot_center(self, rot_center): """ Define a new center of rotation for the current backprojector. """ self.rot_center = rot_center self.axis_pos = rot_center # See kernels signature of backproj.cu and backproj.cl. # The ifdef makes things a bit more complicated proj_arg_idx = 4 if self.backend == "cuda" and self._use_textures: proj_arg_idx = 3 self.kern_proj_args[proj_arg_idx] = rot_center if self.extra_options["centered_axis"]: self.offsets = { "x": self.rot_center - (self.n_x - 1) / 2.0, "y": self.rot_center - (self.n_y - 1) / 2.0, } self.kern_proj_args[proj_arg_idx + 3] = self.offsets["x"] self.kern_proj_args[proj_arg_idx + 4] = self.offsets["y"] # Try to factorize some code between Cuda and OpenCL # Not ideal, as cuda uses "grid" = n_blocks_launched, # while OpenCL uses "global_size" = n_threads_launched def _get_kernel_options(self): sourcemodule_options = [] # We use blocks of 16*16 (see why in kernel doc), and one thread # handles 2 pixels per dimension. block = (16, 16, 1) # The Cuda kernel is optimized for 16x16 threads blocks # If one of the dimensions is smaller than 16, it has to be addapted if self.n_x < 16 or self.n_y < 16: tpb_x = min(int(nextpow2(self.n_x)), 16) tpb_y = min(int(nextpow2(self.n_y)), 16) block = (tpb_x, tpb_y, 1) sourcemodule_options.append("-DSHARED_SIZE=%d" % (tpb_x * tpb_y)) grid = (updiv(updiv(self.n_x, block[0]), 2), updiv(updiv(self.n_y, block[1]), 2)) if self.extra_options["clip_outer_circle"]: sourcemodule_options.append("-DCLIP_OUTER_CIRCLE") shared_size = int(np.prod(block)) * 2 if self._axis_correction is not None: sourcemodule_options.append("-DDO_AXIS_CORRECTION") shared_size += int(np.prod(block)) shared_size *= 4 # sizeof(float32) self._kernel_options = { "kernel_name": self.kernel_name, "sourcemodule_options": sourcemodule_options, "grid": grid, "block": block, "shared_size": shared_size, } def _prepare_kernel_args(self): self._get_kernel_options() self.kern_proj_args = [ None, # output d_slice holder None, # placeholder for sino (OpenCL or Cuda+no-texture) np.int32(self.n_angles), np.int32(self.dwidth), np.float32(self.axis_pos), np.int32(self.n_x), np.int32(self.n_y), np.float32(self.offsets["x"]), np.float32(self.offsets["y"]), self._d_cos, self._d_msin, np.float32(self._backproj_scale_factor), ] if self._axis_correction is not None: self.kern_proj_args.insert(-1, self._d_axcorr) self.kern_proj_kwargs = { "grid": self._kernel_options["grid"], "block": self._kernel_options["block"], } def _set_output(self, output, check=False): self._output_is_ndarray = isinstance(output, np.ndarray) if output is None or self._output_is_ndarray: self._processing.allocate_array("_d_slice", self.slice_shape, dtype=np.float32) output = self._processing._d_slice # pylint: disable=E1101 elif check: assert output.dtype == np.float32 assert output.shape == self.slice_shape, "Expected output shape %s but got %s" % ( self.slice_shape, output.shape, ) if self.extra_options.get("clip_outer_circle", False): out_circle_val = self.extra_options.get("outer_circle_value", 0) if out_circle_val != 0: output.fill(out_circle_val) return output def _set_kernel_slice_arg(self, d_slice): self.kern_proj_args[0] = d_slice def backproj(self, sino, output=None, do_checks=True): if self.halftomo and self.rot_center < self.dwidth: self.sino_mult.prepare_sino(sino) self._transfer_to_texture(sino) d_slice = self._set_output(output, check=do_checks) self._set_kernel_slice_arg(d_slice) self.gpu_projector(*self.kern_proj_args, **self.kern_proj_kwargs) if output is not None and not (self._output_is_ndarray): return output else: return self._processing._d_slice.get(ary=output) def filtered_backprojection(self, sino, output=None): # if isinstance(sino, self._processing.array_class): d_sino = sino else: d_sino = self._processing.to_device("d_sino", sino) # if self.sino_filter is not None: filt_kwargs = {} # if a new device array was allocated for sinogram, then the filtering can overwrite it, # since it won't affect user argument if id(d_sino) != id(sino): filt_kwargs = {"output": d_sino} # sino_to_backproject = self.sino_filter(d_sino, **filt_kwargs) else: sino_to_backproject = d_sino return self.backproj(sino_to_backproject, output=output) fbp = filtered_backprojection # shorthand def __repr__(self): res = "%s(sino_shape=%s, slice_shape=%s, rot_center=%.2f, halftomo=%s)" % ( self.__class__.__name__, str(self.sino_shape), str(self.slice_shape), self.rot_center, self.halftomo, ) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1707838209.0 nabu-2024.2.1/nabu/reconstruction/fbp_opencl.py0000644000175000017500000000631214562705401020762 0ustar00pierrepierreimport pyopencl as cl from ..utils import get_opencl_srcfile from ..opencl.processing import OpenCLProcessing from ..opencl.kernel import OpenCLKernel from ..opencl.utils import allocate_texture, check_textures_availability, copy_to_texture from .filtering_opencl import OpenCLSinoFilter from .sinogram_opencl import OpenCLSinoMult from .fbp_base import BackprojectorBase class OpenCLBackprojector(BackprojectorBase): default_extra_options = {**BackprojectorBase.default_extra_options, "use_textures": True} backend = "opencl" kernel_filename = "backproj.cl" backend_processing_class = OpenCLProcessing SinoFilterClass = OpenCLSinoFilter SinoMultClass = OpenCLSinoMult def _check_textures_availability(self): self._use_textures = self.extra_options.get("use_textures", True) and check_textures_availability( self._processing.ctx ) def _get_kernel_options(self): super()._get_kernel_options() self._kernel_options.update( { "file_name": get_opencl_srcfile(self.kernel_filename), } ) def _prepare_kernel_args(self): super()._prepare_kernel_args() block = self.kern_proj_kwargs.pop("block") local_size = block grid = self.kern_proj_kwargs.pop("grid") global_size = (grid[0] * block[0], grid[1] * block[1]) # global_size = (updiv(self.n_x, 2), updiv(self.n_y, 2)) self.kern_proj_args.insert(0, self._processing.queue) # OpenCLProcessing.__call__ expects first arg to be queue self.kern_proj_kwargs.update( { "global_size": global_size, "local_size": local_size, } ) def _prepare_textures(self): if self._use_textures: d_sino_ref = self.d_sino_tex = allocate_texture(self._processing.ctx, self.sino_shape) self._kernel_options["sourcemodule_options"].append("-DUSE_TEXTURES") else: self._d_sino = self._processing.allocate_array("_d_sino", self.sino_shape) d_sino_ref = self._d_sino.data self.kern_proj_args[2] = d_sino_ref def _compile_kernels(self): self._prepare_kernel_args() self._prepare_textures() # has to be done before compilation for OpenCL (to pass -DUSE_TEXTURES) self.kern_proj_args.append(cl.LocalMemory(self._kernel_options["shared_size"])) self.gpu_projector = OpenCLKernel( self._kernel_options["kernel_name"], self._processing.ctx, filename=self._kernel_options["file_name"], options=self._kernel_options["sourcemodule_options"], ) if self.halftomo and self.rot_center < self.dwidth: self.sino_mult = OpenCLSinoMult(self.sino_shape, self.rot_center, ctx=self._processing.ctx) def _transfer_to_texture(self, sino, do_checks=True): if self._use_textures: return copy_to_texture(self._processing.queue, self.d_sino_tex, sino) else: if id(self._d_sino) == id(sino): return return cl.enqueue_copy(self._processing.queue, self._d_sino.data, sino.data) def _set_kernel_slice_arg(self, d_slice): self.kern_proj_args[1] = d_slice ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/reconstruction/filtering.py0000644000175000017500000002072314654107202020635 0ustar00pierrepierrefrom math import pi import numpy as np from scipy.fft import rfft, irfft from silx.image.tomography import compute_fourier_filter, get_next_power from ..processing.padding_base import PaddingBase from ..utils import check_supported, get_num_threads # # COMPAT. # from .filtering_cuda import CudaSinoFilter # SinoFilter = deprecated_class( # "From version 2023.1, 'filtering_cuda.CudaSinoFilter' should be used instead of 'filtering.SinoFilter'. In the future, 'filtering.SinoFilter' will be a numpy-only class.", # do_print=True, # )(CudaSinoFilter) # # class SinoFilter: available_filters = [ "ramlak", "shepp-logan", "cosine", "hamming", "hann", "tukey", "lanczos", "hilbert", ] """ A class for sinogram filtering. It does the following: - pad input array - Fourier transform each row - multiply with a 1D or 2D filter - inverse Fourier transform """ available_padding_modes = PaddingBase.supported_modes default_extra_options = {"cutoff": 1.0, "fft_threads": 0} # use all threads by default def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, ): self._init_extra_options(extra_options) self._set_padding_mode(padding_mode) self._calculate_shapes(sino_shape) self._init_fft() self._allocate_memory() self._compute_filter(filter_name) def _init_extra_options(self, extra_options): self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) def _set_padding_mode(self, padding_mode): # Compat. if padding_mode == "edges": padding_mode = "edge" if padding_mode == "zeros": padding_mode = "constant" # check_supported(padding_mode, self.available_padding_modes, "padding mode") self.padding_mode = padding_mode def _calculate_shapes(self, sino_shape): self.ndim = len(sino_shape) if self.ndim == 2: n_angles, dwidth = sino_shape n_sinos = 1 elif self.ndim == 3: n_sinos, n_angles, dwidth = sino_shape else: raise ValueError("Invalid sinogram number of dimensions") self.sino_shape = sino_shape self.n_angles = n_angles self.dwidth = dwidth # Make sure to use int() here, otherwise pycuda/pyopencl will crash in some cases self.dwidth_padded = int(get_next_power(2 * self.dwidth)) self.sino_padded_shape = (n_angles, self.dwidth_padded) if self.ndim == 3: self.sino_padded_shape = (n_sinos,) + self.sino_padded_shape sino_f_shape = list(self.sino_padded_shape) sino_f_shape[-1] = sino_f_shape[-1] // 2 + 1 self.sino_f_shape = tuple(sino_f_shape) self.pad_left = (self.dwidth_padded - self.dwidth) // 2 self.pad_right = self.dwidth_padded - self.dwidth - self.pad_left def _init_fft(self): pass def _allocate_memory(self): pass def set_filter(self, h_filt, normalize=True): """ Set a filter for sinogram filtering. Parameters ---------- h_filt: numpy.ndarray Array containing the filter. Each line of the sinogram will be filtered with this filter. It has to be the Real-to-Complex Fourier Transform of some real filter, padded to 2*sinogram_width. normalize: bool or float, optional Whether to normalize (multiply) the filter with pi/num_angles. """ if h_filt.size != self.sino_f_shape[-1]: raise ValueError( """ Invalid filter size: expected %d, got %d. Please check that the filter is the Fourier R2C transform of some real 1D filter. """ % (self.sino_f_shape[-1], h_filt.size) ) if not (np.iscomplexobj(h_filt)): print("Warning: expected a complex Fourier filter") self.filter_f = h_filt.copy() if normalize: self.filter_f *= pi / self.n_angles self.filter_f = self.filter_f.astype(np.complex64) def _compute_filter(self, filter_name): self.filter_name = filter_name or "ram-lak" # TODO add this one into silx if self.filter_name == "hilbert": freqs = np.fft.fftfreq(self.dwidth_padded) filter_f = 1.0 / (2 * pi * 1j) * np.sign(freqs) # else: filter_f = compute_fourier_filter( self.dwidth_padded, self.filter_name, cutoff=self.extra_options["cutoff"], ) filter_f = filter_f[: self.dwidth_padded // 2 + 1] # R2C self.set_filter(filter_f, normalize=True) def _check_array(self, arr): if arr.dtype != np.float32: raise ValueError("Expected data type = numpy.float32") if arr.shape != self.sino_shape: raise ValueError("Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape)) def filter_sino(self, sino, output=None, no_output=False): """ Perform the sinogram siltering. Parameters ---------- sino: numpy.ndarray or pycuda.gpuarray.GPUArray Input sinogram (2D or 3D) output: numpy.ndarray or pycuda.gpuarray.GPUArray, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) # sino_padded = np.pad( # sino, ((0, 0), (0, self.dwidth_padded - self.dwidth)), mode=self.padding_mode # ) # pad with a FFT-friendly layout sino_padded = np.pad(sino, ((0, 0), (self.pad_left, self.pad_right)), mode=self.padding_mode) sino_padded_f = rfft(sino_padded, axis=1, workers=get_num_threads(self.extra_options["fft_threads"])) sino_padded_f *= self.filter_f sino_filtered = irfft(sino_padded_f, axis=1, workers=get_num_threads(self.extra_options["fft_threads"])) if output is None: res = np.zeros(self.sino_shape, dtype=np.float32) else: res = output if self.ndim == 2: # res[:] = sino_filtered[:, : self.dwidth] # pylint: disable=E1126 # ?! res[:] = sino_filtered[:, self.pad_left : -self.pad_right] # pylint: disable=E1126 # ?! else: # res[:] = sino_filtered[:, :, : self.dwidth] # pylint: disable=E1126 # ?! res[:] = sino_filtered[:, :, self.pad_left : -self.pad_right] # pylint: disable=E1126 # ?! return res __call__ = filter_sino def filter_sinogram( sinogram, padded_width, filter_name="ramlak", padding_mode="constant", normalize=True, filter_cutoff=1.0, **padding_kwargs, ): """ Simple function to filter sinogram. Parameters ---------- sinogram: numpy.ndarray Sinogram, two dimensional array with shape (n_angles, sino_width) padded_width: int Width to use for padding. Must be greater than sinogram width (i.e than sinogram.shape[-1]) filter_name: str, optional Which filter to use. Default is ramlak (roughly equivalent to abs(nu) in frequency domain) padding_mode: str, optional Which padding mode to use. Default is zero-padding. normalize: bool, optional Whether to multiply the filtered sinogram with pi/n_angles filter_cutoff: float, optional frequency cutoff for filter """ n_angles, width = sinogram.shape # Initially, padding was done this way # sinogram_padded = np.pad(sinogram, ((0, 0), (0, padded_width - width)), mode=padding_mode, **padding_kwargs) # pad_left = (padded_width - width) // 2 pad_right = padded_width - width - pad_left sinogram_padded = np.pad(sinogram, ((0, 0), (pad_left, pad_right)), mode=padding_mode, **padding_kwargs) # fourier_filter = compute_fourier_filter(padded_width, filter_name, cutoff=filter_cutoff) if normalize: fourier_filter *= np.pi / n_angles fourier_filter = fourier_filter[: padded_width // 2 + 1] # R2C sino_f = rfft(sinogram_padded, axis=1) sino_f *= fourier_filter # sino_filtered = irfft(sino_f, axis=1)[:, :width] # pylint: disable=E1126 # ?! sino_filtered = irfft(sino_f, axis=1)[:, pad_left:-pad_right] # pylint: disable=E1126 # ?! return sino_filtered ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/reconstruction/filtering_cuda.py0000644000175000017500000000771114712705065021641 0ustar00pierrepierreimport numpy as np from ..cuda.processing import CudaProcessing from ..utils import get_cuda_srcfile from ..processing.padding_cuda import CudaPadding from ..processing.fft_cuda import get_fft_class from .filtering import SinoFilter class CudaSinoFilter(SinoFilter): default_extra_options = {**SinoFilter.default_extra_options, **{"fft_backend": "vkfft"}} def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, cuda_options=None, ): self._cuda_options = cuda_options or {} self.cuda = CudaProcessing(**self._cuda_options) super().__init__(sino_shape, filter_name=filter_name, padding_mode=padding_mode, extra_options=extra_options) self._init_kernels() def _init_fft(self): fft_cls = get_fft_class(self.extra_options["fft_backend"]) self.fft = fft_cls( self.sino_padded_shape, dtype=np.float32, axes=(-1,), ) def _allocate_memory(self): self.d_filter_f = self.cuda.allocate_array("d_filter_f", (self.sino_f_shape[-1],), dtype=np.complex64) self.d_sino_padded = self.cuda.allocate_array("d_sino_padded", self.fft.shape) self.d_sino_f = self.cuda.allocate_array("d_sino_f", self.fft.shape_out, dtype=np.complex64) def set_filter(self, h_filt, normalize=True): super().set_filter(h_filt, normalize=normalize) self.d_filter_f[:] = self.filter_f[:] def _init_kernels(self): # pointwise complex multiplication fname = get_cuda_srcfile("ElementOp.cu") if self.ndim == 2: kernel_name = "inplace_complex_mul_2Dby1D" kernel_sig = "PPii" else: kernel_name = "inplace_complex_mul_3Dby1D" kernel_sig = "PPiii" self.mult_kernel = self.cuda.kernel(kernel_name, filename=fname, signature=kernel_sig) self.kern_args = (self.d_sino_f, self.d_filter_f) self.kern_args += self.d_sino_f.shape[::-1] # padding self.padding_kernel = CudaPadding( self.sino_shape, ((0, 0), (self.pad_left, self.pad_right)), mode=self.padding_mode, cuda_options=self._cuda_options, ) def filter_sino(self, sino, output=None): """ Perform the sinogram siltering. Parameters ---------- sino: numpy.ndarray or pycuda.gpuarray.GPUArray Input sinogram (2D or 3D) output: pycuda.gpuarray.GPUArray, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) if not (isinstance(sino, self.cuda.array_class)): sino = self.cuda.set_array("sino", sino) elif not (sino.flags.c_contiguous): # Transfer the device array into another, c-contiguous, device array # We can throw an error as well in this case, but often we so something like fbp(radios[:, i, :]) sino_tmp = self.cuda.allocate_array("sino_contig", sino.shape) sino_tmp.set(sino) sino = sino_tmp # Padding self.padding_kernel(sino, output=self.d_sino_padded) # FFT self.fft.fft(self.d_sino_padded, output=self.d_sino_f) # multiply padded sinogram with filter in the Fourier domain self.mult_kernel(*self.kern_args) # TODO tune block size ? # iFFT self.fft.ifft(self.d_sino_f, output=self.d_sino_padded) # return if output is None: res = self.cuda.allocate_array("output", self.sino_shape) else: res = output if self.ndim == 2: res[:] = self.d_sino_padded[:, self.pad_left : self.pad_left + self.dwidth] else: res[:] = self.d_sino_padded[:, :, self.pad_left : self.pad_left + self.dwidth] return res __call__ = filter_sino ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/reconstruction/filtering_opencl.py0000644000175000017500000000756514550227307022212 0ustar00pierrepierreimport numpy as np from ..utils import get_opencl_srcfile from ..opencl.processing import OpenCLProcessing from ..processing.padding_opencl import OpenCLPadding from ..opencl.memcpy import OpenCLMemcpy2D from .filtering import SinoFilter try: from pyvkfft.opencl import VkFFTApp as clfft # pylint: disable=E0401 __has_vkfft__ = True except: __has_vkfft__ = False class OpenCLSinoFilter(SinoFilter): def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, opencl_options=None, ): self._opencl_options = opencl_options or {} self.opencl = OpenCLProcessing(**self._opencl_options) self.queue = self.opencl.queue super().__init__(sino_shape, filter_name=filter_name, padding_mode=padding_mode, extra_options=extra_options) self._init_kernels() def _init_fft(self): if not (__has_vkfft__): raise ImportError("Please install pyvkfft to use this class") self.fft = clfft(self.sino_padded_shape, np.float32, self.queue, r2c=True, ndim=1, inplace=False) def _allocate_memory(self): self.d_sino_padded = self.opencl.allocate_array("d_sino_padded", self.sino_padded_shape, dtype=np.float32) self.d_sino_f = self.opencl.allocate_array("d_sino_f", self.sino_f_shape, np.complex64) self.d_filter_f = self.opencl.allocate_array("d_filter_f", (self.sino_f_shape[-1],), dtype=np.complex64) def set_filter(self, h_filt, normalize=True): super().set_filter(h_filt, normalize=normalize) self.d_filter_f[:] = self.filter_f[:] def _init_kernels(self): # pointwise complex multiplication fname = get_opencl_srcfile("ElementOp.cl") if self.ndim == 2: kernel_name = "inplace_complex_mul_2Dby1D" else: kernel_name = "inplace_complex_mul_3Dby1D" self.mult_kernel = self.opencl.kernel(kernel_name, filename=fname) # padding self.padding_kernel = OpenCLPadding( self.sino_shape, ((0, 0), (self.pad_left, self.pad_right)), mode=self.padding_mode, opencl_options={"queue": self.queue}, ) # memcpy2D self.memcpy2D = OpenCLMemcpy2D(queue=self.queue) def filter_sino(self, sino, output=None): """ Perform the sinogram siltering. Parameters ---------- sino: numpy.ndarray or pyopencl.array Input sinogram (2D or 3D) output: pyopencl.array, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) sino = self.opencl.set_array("sino", sino) # Padding self.padding_kernel.pad(sino, output=self.d_sino_padded) # FFT self.fft.fft(self.d_sino_padded, self.d_sino_f) # multiply padded sinogram with filter in the Fourier domain self.mult_kernel( self.queue, self.d_sino_f, self.d_filter_f, *(np.int32(self.d_sino_f.shape[::-1])), # pylint: disable=E1133 # local_size=None, global_size=self.d_sino_f.shape[::-1], ) # TODO tune block size ? # iFFT self.fft.ifft(self.d_sino_f, self.d_sino_padded) # return if output is None: res = self.opencl.allocate_array("output", self.sino_shape) else: res = output if self.ndim == 2: # res[:] = self.d_sino_padded[:, self.pad_left : self.pad_left + self.dwidth] self.memcpy2D(res, self.d_sino_padded, res.shape[::-1], src_offset_xy=(self.pad_left, 0)) else: res[:] = self.d_sino_padded[:, :, self.pad_left : self.pad_left + self.dwidth] return res __call__ = filter_sino ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/reconstruction/hbp.py0000644000175000017500000004275614715655362017452 0ustar00pierrepierreimport math import numpy as np from ..utils import get_cuda_srcfile from ..cuda.processing import __has_pycuda__ if __has_pycuda__: from ..cuda.kernel import CudaKernel from .sinogram_cuda import CudaSinoMult from .fbp import CudaBackprojector try: import pycuda.driver as cuda from pycuda import gpuarray as garray __have_hbp__ = True except: __have_hbp__ = False def buildConebeamGeometry( anglesRad, rotAxisProjectionFromLeftPixelUnits, sourceSampleDistanceVoxelUnits, opticalAxisFromLeftPixelUnits=None ): """Generate fanbeam/conebeam projection matrices (as required by the backprojector) based on geometry parameters""" if opticalAxisFromLeftPixelUnits is None: if hasattr(rotAxisProjectionFromLeftPixelUnits, "__iter__"): opticalAxisFromLeftPixelUnits = rotAxisProjectionFromLeftPixelUnits[0] else: opticalAxisFromLeftPixelUnits = rotAxisProjectionFromLeftPixelUnits t = opticalAxisFromLeftPixelUnits d = sourceSampleDistanceVoxelUnits if hasattr(rotAxisProjectionFromLeftPixelUnits, "__iter__"): P_list = [ np.array([[0, -t / d, 1, a], [1, 0, 0, 0], [0, -1 / d, 0, 1]], dtype=np.float64) # pylint: disable=E1130 for a in rotAxisProjectionFromLeftPixelUnits ] else: a = rotAxisProjectionFromLeftPixelUnits P_list = [ np.array([[0, -t / d, 1, a], [1, 0, 0, 0], [0, -1 / d, 0, 1]], dtype=np.float64) # pylint: disable=E1130 ] * len(anglesRad) R = lambda w: np.array( [[1, 0, 0, 0], [0, np.cos(w), np.sin(w), 0], [0, -np.sin(w), np.cos(w), 0], [0, 0, 0, 1]], dtype=np.float64 ) return np.array([P @ R(-w) for P, w in zip(P_list, anglesRad)]) class HierarchicalBackprojector(CudaBackprojector): kernel_filename = "hierarchical_backproj.cu" def _init_geometry(self, sino_shape, slice_shape, angles, rot_center, halftomo, slice_roi): super()._init_geometry(sino_shape, slice_shape, angles, rot_center, halftomo, slice_roi) # pylint: disable=E1130 # -angles because different convention for the rotation direction self.angles = -self.angles # to do the reconstruction in reduction_steps steps self.reduction_steps = self.extra_options.get("hbp_reduction_steps", 2) reduction_factor = int(math.ceil((sino_shape[-2]) ** (1 / self.reduction_steps))) # TODO customize axis_source_meters = 1.0e9 voxel_size_microns = 1.0 # axis_cor = self.extra_options.get("axis_correction", None) if axis_cor is None: axis_cor = 0 bpgeometry = buildConebeamGeometry( self.angles, self.rot_center + axis_cor, 1.0e6 * axis_source_meters / voxel_size_microns ) self.setup_hbp(bpgeometry, reductionFactor=reduction_factor, legs=self.extra_options.get("hbp_legs", 4)) def setup_hbp( self, bpgeometry, reductionFactor=20, grid_wh_factors=(1, 1), fac=1, legs=4, ): # This implementation seems not to use textures self._use_textures = False # for the non texture implementation, this big number will discard texture limitations large_factor_for_non_texture_memory_access = 2**10 # TODO: read limits from device info. self.GPU_MAX_GRIDSIZE = 2**15 * large_factor_for_non_texture_memory_access self.GPU_MAX_GRIDS = 2**11 * large_factor_for_non_texture_memory_access if self.sino_shape[0] != len(bpgeometry): raise ValueError("self.sino_shape[0] != len(bpgeometry)") if self.sino_shape[0] != len(self.angles): raise ValueError("self.sino_shape[0] != len(self.angles)") if self.sino_shape[1] > self.GPU_MAX_GRIDSIZE: raise ValueError(f"self.sino_shape[1] > {self.GPU_MAX_GRIDSIZE} not supported by GPU") if self.sino_shape[0] > self.GPU_MAX_GRIDSIZE: raise ValueError(f"self.sino_shape[0] > {self.GPU_MAX_GRIDSIZE} currently not supported") self.reductionFactor = reductionFactor self.legs = legs self.bpsetupsH = bpgeometry.astype(np.float32) # self.bpsetupsD = cuda.mem_alloc(self.bpsetupsH.nbytes) # cuda.memcpy_htod(self.bpsetupsD, self.bpsetupsH) self.bpsetupsD = self._processing.to_device("bpsetupsD", self.bpsetupsH) # if allocate_cuda_sinogram: # self.sinogramD = cuda.mem_alloc(self.sino_shape[0] * self.sino_shape[1] * self.float_size) # else: # self.sinogramD = None self.sinogramD = None self.whf = grid_wh_factors if self.sino_shape[1] * 2 * self.whf[0] * fac > self.GPU_MAX_GRIDSIZE: print(f"WARNING: gridsampling limited to {self.GPU_MAX_GRIDSIZE}") self.whf[0] = self.GPU_MAX_GRIDSIZE / (self.sino_shape[1] * 2 * fac) ############################################### ########## create intermediate grids ########## ############################################### self.reductionFactors = [] self.grids = [] # shapes self.gridTransforms = [] # grid-to-world self.gridInvTransforms = [] # world-to-grid self.gridTransformsH = [] # host buffer self.gridTransformsD = [] # device buffer ### first level grid: will receive backprojections # #################################################### N = self.slice_shape[1] * fac angularRange = abs(self.angles.ptp()) / self.sino_shape[0] * reductionFactor ngrids = int(math.ceil(self.sino_shape[0] / reductionFactor)) grid_width = int( np.rint(2 * N * self.whf[0]) ) # double sampling to account/compensate for diamond shaped grid of ray-intersections grid_height = int( math.ceil(angularRange * N * self.whf[1]) ) # small-angle approximation, generates as much "lines" as needed to account for all intersection levels m = (len(self.angles) // reductionFactor) * reductionFactor # TODO: improve angle calculation for more general cases tmpangles = np.angle( np.average(np.exp(1.0j * self.angles[:m].reshape(m // reductionFactor, reductionFactor)), axis=1) ) tmpangles = np.concatenate((tmpangles, (np.angle(np.average(np.exp(1.0j * self.angles[m:]))),)))[:ngrids] gridAinvT = self._getAinvT(N, grid_height, grid_width) setupRs = self._getRotationMatrices(tmpangles) pad = int(math.ceil(ngrids / legs) * legs - ngrids) # add nan-padding for inline-signaling of unused grids self.gridTransforms += [ np.array( [(R @ gridAinvT) for R in setupRs] + [np.ones((3, 3), np.float32) * math.nan] * pad, dtype=np.float32 ) ] self.gridInvTransforms += [np.array([np.linalg.inv(t) for t in self.gridTransforms[-1]], dtype=np.float32)] self.grids += [(grid_height, grid_width, int(math.ceil(ngrids / legs)))] self.reductionFactors += [reductionFactor] ### intermediate level grids: accumulation grids ### #################################################### # Actual iteration count typically within 1-5. Cf. break condition for i in range(100): # for a reasonable (with regard to memory requirement) grid-aspect ratio in the intermediate levels, # the covered angular range per grid should not exceed 28.6°, i.e., # fewer than 7 (6.3) or 13 (12.6) grids for a 180° / 360° scan is not reasonable if int(math.ceil(ngrids / reductionFactor)) < 20: break angularRange *= reductionFactor ngrids = int(math.ceil(ngrids / reductionFactor)) grid_height = int( math.ceil(angularRange * N * self.whf[1]) ) # implicit small angle approximation, whose validity is # asserted by the preceding "break" gridAinvT = self._getAinvT(N, grid_height, grid_width) prevAngles = tmpangles m = (len(prevAngles) // reductionFactor) * reductionFactor # TODO: improve angle calculation for more general cases tmpangles = np.angle( np.average(np.exp(1.0j * prevAngles[:m].reshape(m // reductionFactor, reductionFactor)), axis=1) ) tmpangles = np.concatenate((tmpangles, (np.angle(np.average(np.exp(1.0j * prevAngles[m:]))),)))[:ngrids] setupRsRed = self._getRotationMatrices(tmpangles) pad = int(math.ceil(ngrids / legs) * legs - ngrids) self.gridTransforms += [ np.array( [(R @ gridAinvT) for R in setupRsRed] + [np.ones((3, 3), np.float32) * math.nan] * pad, dtype=np.float32, ) ] self.gridInvTransforms += [np.array([np.linalg.inv(t) for t in self.gridTransforms[-1]], dtype=np.float32)] self.grids += [(grid_height, grid_width, int(math.ceil(ngrids / legs)))] self.reductionFactors += [reductionFactor] ##### final accumulation grid ################# ############################################### reductionFactor = ngrids ngrids = 1 grid_size = self.slice_shape[1] grid_width = grid_size grid_height = grid_size # gridAinvT = self._getAinvT(N, grid_height, grid_width) gridAinvT = self._getAinvT(N, grid_height, grid_width, 1 / fac) self.gridTransforms += [ np.array([gridAinvT] * legs, dtype=np.float32) ] # inflate transform list for convenience in reconstruction loop self.gridInvTransforms += [np.array([np.linalg.inv(t) for t in self.gridTransforms[-1]], dtype=np.float32)] self.grids += [(grid_height, grid_width, ngrids)] self.reductionFactors += [reductionFactor] #### accumulation grids ##### self.gridTransformsD = [] self.gridInvTransformsD = [] self.gridsD = [] max_grid_size = get_max_grid_size(self.grids) for i in range(len(self.grids)): gridTransformH = np.array(self.gridTransforms[i][:, :2, :3], dtype=np.float32, order="C").copy() gridInvTransformH = np.array(self.gridInvTransforms[i][:, :2, :3], dtype=np.float32, order="C").copy() self.gridTransformsD.append(self._processing.to_device("gridTransformsD%d " % i, gridTransformH.ravel())) self.gridInvTransformsD.append( self._processing.to_device("gridInvTransformsD%d" % i, gridInvTransformH.ravel()) ) if legs == 1 or i + 1 != (len(self.grids)): if i < 2: self.gridsD.append(self._processing.allocate_array("gridsD%d" % i, max_grid_size)) else: self.gridsD.append(self.gridsD[i % 2]) else: self.gridsD.append(self._processing.allocate_array("gridsD%d" % i, get_max_grid_size(self.grids[-1:]))) self.imageBufferShape = (grid_size, grid_size) self.imageBufferD = self._processing.allocate_array( "imageBufferD", self.imageBufferShape[0] * self.imageBufferShape[1] ) self.imageBufferH = np.zeros(self.imageBufferShape, dtype=np.float32) def _getAinvT(self, finalGridWidthAndHeight, currentGridHeight, currentGridWidth, scale=1): N = finalGridWidthAndHeight grid_height = currentGridHeight grid_width = currentGridWidth # shifts a texture coordinate from corner origin to center origin T = np.array(((1, 0, -0.5 * (grid_height - 1)), (0, 1, -0.5 * (grid_width - 1)), (0, 0, 1)), dtype=np.float32) # scales texture coordinates (of subsampled grid) into the unit/cooridnate system of a fully sampled grid Ainv = np.array( (((N - 1) / (grid_height - 1) * scale, 0, 0), (0, (N - 1) / (grid_width - 1) * scale, 0), (0, 0, 1)), dtype=np.float32, ) return Ainv @ T def _getRotationMatrices(self, angles): return [ np.array(((np.cos(a), np.sin(a), 0), (-np.sin(a), np.cos(a), 0), (0, 0, 1)), dtype=np.float32) for a in angles ] def _compile_kernels(self): # pylint: disable=E0606 self.backprojector = CudaKernel( "backprojector", filename=get_cuda_srcfile(self.kernel_filename), signature="PPiiiiPiifPi", ) self.aggregator = CudaKernel( "aggregator", filename=get_cuda_srcfile(self.kernel_filename), signature="iPPiiiiPiiiP" ) self.clip_outer_circle_kernel = CudaKernel( "clip_outer_circle", filename=get_cuda_srcfile(self.kernel_filename), signature="Pii" ) # Duplicate of fbp.py ... if self.halftomo and self.rot_center < self.dwidth: self.sino_mult = CudaSinoMult(self.sino_shape, self.rot_center, ctx=self._processing.ctx) # def _set_sino(self, sino, do_checks=True): if do_checks and not (sino.flags.c_contiguous): raise ValueError("Expected C-Contiguous array") else: self._d_sino = self._processing.allocate_array("_d_sino", self.sino_shape) if id(self._d_sino) == id(sino): return self._d_sino[:] = sino[:] def backproj(self, sino, output=None, do_checks=True, reference=False): if self.halftomo and self.rot_center < self.dwidth: self.sino_mult.prepare_sino(sino) self._set_sino(sino) lws = (64, 4, 4) if reference: gws = getGridSize(self.grids[-1], lws) (grid_height, grid_width, ngrids) = self.grids[-1] self.backprojector( self.bpsetupsD, self.gridTransformsD[-1].gpudata, np.int32(self.sino_shape[0]), np.int32(grid_width), np.int32(grid_height), np.int32(ngrids), self.gridsD[-1], np.int32(self.sino_shape[1]), np.int32(self.sino_shape[0]), np.float32(self._backproj_scale_factor), self._d_sino, np.int32(0), # offset block=lws, grid=gws, ) else: for leg in list(range(0, self.legs)): gridOffset = leg * self.grids[0][2] projOffset = gridOffset * self.reductionFactors[0] gws = getGridSize(self.grids[0], lws) (grid_height, grid_width, ngrids) = self.grids[0] self.backprojector( self.bpsetupsD, self.gridTransformsD[0][6 * gridOffset :], np.int32(self.reductionFactors[0]), np.int32(grid_width), np.int32(grid_height), np.int32(ngrids), self.gridsD[0], np.int32(self.sino_shape[1]), np.int32(self.sino_shape[0]), np.float32(self._backproj_scale_factor), self._d_sino, np.int32(projOffset), block=lws, grid=gws, ) for i in range(1, len(self.grids)): if self.grids[i][2] >= 8: lws = (16, 16, 4) else: lws = (32, 32, 1) gws = getGridSize(self.grids[i], lws) (new_grid_height, new_grid_width, new_ngrids) = self.grids[i] (prev_grid_height, prev_grid_width, prev_ngrids) = self.grids[i - 1] gridOffset = leg * self.grids[i][2] prevGridOffset = leg * self.grids[i - 1][2] self.aggregator( np.int32((i + 1 == len(self.grids)) and (leg > 0)), self.gridTransformsD[i][6 * gridOffset :], self.gridInvTransformsD[i - 1][6 * prevGridOffset :], np.int32(self.reductionFactors[i]), np.int32(new_grid_width), np.int32(new_grid_height), np.int32(new_ngrids), self.gridsD[i], np.int32(prev_grid_width), np.int32(prev_grid_height), np.int32(prev_ngrids), self.gridsD[i - 1], block=lws, grid=gws, ) if self.extra_options.get("clip_outer_circle", False): lws = (16, 16, 1) ny, nx = self.slice_shape gws = getGridSize((nx, ny, 1), lws) self.clip_outer_circle_kernel(self.gridsD[-1], np.int32(ny), np.int32(nx), block=lws, grid=gws) # FIXME pycuda fails to do a discontiguous memcpy for more than 2^31 bytes if self.gridsD[-1].nbytes > 2**31: r1d = self.gridsD[-1].get() r2d = np.ascontiguousarray(r1d.reshape(self.slice_shape)) if output is not None: output[:] = r2d[:] return output else: return r2d # -------- else: return self.gridsD[-1].reshape(self.slice_shape).get(ary=output) def get_max_grid_size(grids): size_max = 0 for dims in grids: size = 1 for d in dims: size = size * d if size > size_max: size_max = size return size_max def getGridSize(minimum, local): m, l = np.array(minimum), np.array(local) new = (m // l) * l new[new < m] += l[new < m] return tuple(map(int, new // l)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/reconstruction/mlem.py0000644000175000017500000000700014720040151017566 0ustar00pierrepierreimport numpy as np try: import corrct as cct __have_corrct__ = True except ImportError: __have_corrct__ = False class MLEMReconstructor: """ A reconstructor for MLEM reconstruction using the CorrCT toolbox. """ default_extra_options = { "compute_shifts": False, "tomo_consistency": False, "v_min_for_v_shifts": 0, "v_max_for_v_shifts": None, "v_min_for_u_shifts": 0, "v_max_for_u_shifts": None, } def __init__( self, sinos_shape, angles_rad, shifts_uv=None, cor=None, n_iterations=50, extra_options=None, ): """ """ if not (__have_corrct__): raise ImportError("Need corrct package") self.angles_rad = angles_rad self.n_iterations = n_iterations self._configure_extra_options(extra_options) self._set_sino_shape(sinos_shape) self._set_shifts(shifts_uv, cor) def _configure_extra_options(self, extra_options): self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) def _set_sino_shape(self, sinos_shape): if len(sinos_shape) != 3: raise ValueError("Expected a 3D shape") self.sinos_shape = sinos_shape self.n_sinos, self.n_angles, self.prj_width = sinos_shape if self.n_angles != len(self.angles_rad): raise ValueError( f"Number of angles ({len(self.angles_rad)}) does not match size of sinograms ({self.n_angles})." ) def _set_shifts(self, shifts_uv, cor): if shifts_uv is None: self.shifts_uv = np.zeros([self.n_angles, 2]) else: if shifts_uv.shape[0] != self.n_angles: raise ValueError( f"Number of shifts given ({shifts_uv.shape[0]}) does not mathc the number of projections ({self.n_angles})." ) self.shifts_uv = shifts_uv.copy() self.cor = cor def reconstruct(self, data_vwu): """ data_align_vwu: numpy.ndarray or pycuda.gpuarray Raw data, with shape (n_sinograms, n_angles, width) output: optional Output array. If not provided, a new numpy array is returned """ if not isinstance(data_vwu, np.ndarray): data_vwu = data_vwu.get() data_vwu /= data_vwu.mean() # MLEM recons self.vol_geom_align = cct.models.VolumeGeometry.get_default_from_data(data_vwu) self.prj_geom_align = cct.models.ProjectionGeometry.get_default_parallel() # Vertical shifts were handled in pipeline. Set them to ZERO self.shifts_uv[:, 1] = 0.0 self.prj_geom_align.set_detector_shifts_vu(self.shifts_uv.T[::-1]) variances_align = cct.processing.compute_variance_poisson(data_vwu) self.weights_align = cct.processing.compute_variance_weight(variances_align, normalized=True) # , use_std=True self.data_term_align = cct.data_terms.DataFidelity_wl2(self.weights_align) solver = cct.solvers.MLEM(verbose=True, data_term=self.data_term_align) self.solver_opts = dict(lower_limit=0) # , x_mask=cct.processing.circular_mask(vol_geom_align.shape_xyz[:-2]) with cct.projectors.ProjectorUncorrected( self.vol_geom_align, self.angles_rad, rot_axis_shift_pix=self.cor, prj_geom=self.prj_geom_align ) as A: rec, _ = solver(A, data_vwu, iterations=self.n_iterations, **self.solver_opts) return rec ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/reconstruction/projection.py0000644000175000017500000002161314550227307021031 0ustar00pierrepierreimport numpy as np import pycuda.driver as cuda import pycuda.gpuarray as garray from ..utils import updiv, get_cuda_srcfile from ..cuda.utils import copy_array from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing _sizeof_float32 = np.dtype(np.float32).itemsize class Projector: """ A class for performing a tomographic projection (Radon Transform) using Cuda. """ _projector_name = "joseph_projector" _projector_signature = "PiiPfPPPPiiifii" def __init__( self, slice_shape, angles, rot_center=None, detector_width=None, normalize=False, extra_options=None, cuda_options=None, ): """ Initialize a Cuda tomography forward projector. Parameters ----------- slice_shape: tuple Shape of the slice: (num_rows, num_columns). angles: int or sequence Either an integer number of angles, or a list of custom angles values in radian. param rot_center: float, optional Rotation axis position. Default is `(shape[1]-1)/2.0`. detector_width: int, optional Detector width in pixels. If `detector_width > slice_shape[1]`, the projection data will be surrounded with zeros. Using `detector_width < slice_shape[1]` might result in a local tomography setup. normalize: bool, optional Whether to normalize projection. If set to True, sinograms are multiplied by the factor pi/(2*nprojs). extra_options: dict, optional Current allowed options: offset_x, axis_corrections cuda_options: dict, optional Cuda options passed to the CudaProcessing class. """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._configure_extra_options(extra_options) self._init_geometry(slice_shape, rot_center, angles, detector_width) self.normalize = normalize self._allocate_memory() self._compute_angles() self._proj_precomputations() self._compile_kernels() def _configure_extra_options(self, extra_options): self.extra_options = { "offset_x": None, "axis_corrections": None, # TODO } extra_opts = extra_options or {} self.extra_options.update(extra_opts) def _init_geometry(self, slice_shape, rot_center, angles, detector_width): if np.isscalar(slice_shape): slice_shape = (slice_shape, slice_shape) self.shape = slice_shape if np.isscalar(angles): angles = np.linspace(0, np.pi, angles, endpoint=False, dtype="f") self.angles = angles self.nprojs = len(angles) self.dwidth = detector_width or self.shape[1] self.sino_shape = (self.nprojs, self.dwidth) # In PYHST (c_hst_project_1over.cu), axis_pos is overwritten to (dimslice-1)/2. # So tuning axis position is done in another way. In CCspace.c: # offset_x = start_x - move_x # start_x = start_voxel_1 (zero-based, so 0 by default) # MOVE_X = start_x + (num_x - 1)/2 - ROTATION_AXIS_POSITION; self.axis_pos = self.rot_center = rot_center or (self.shape[1] - 1) / 2.0 self.offset_x = self.extra_options["offset_x"] or np.float32(self.axis_pos - (self.shape[1] - 1) / 2.0) self.axis_pos0 = np.float32((self.shape[1] - 1) / 2.0) def _allocate_memory(self): self.dimgrid_x = updiv(self.dwidth, 16) self.dimgrid_y = updiv(self.nprojs, 16) self._dimrecx = self.dimgrid_x * 16 self._dimrecy = self.dimgrid_y * 16 self.d_sino = garray.zeros((self._dimrecy, self._dimrecx), np.float32) self.d_angles = garray.zeros((self._dimrecy,), np.float32) self._d_beginPos = garray.zeros((2, self._dimrecy), np.int32) self._d_strideJoseph = garray.zeros((2, self._dimrecy), np.int32) self._d_strideLine = garray.zeros((2, self._dimrecy), np.int32) self.d_axis_corrections = garray.zeros((self.nprojs,), np.float32) if self.extra_options.get("axis_corrections", None) is not None: self.d_axis_corrections.set(self.extra_options["axis_corrections"]) # Textures self.d_image_cua = cuda.np_to_array(np.zeros((self.shape[0] + 2, self.shape[1] + 2), "f"), "C") def _compile_kernels(self): self.gpu_projector = CudaKernel( self._projector_name, filename=get_cuda_srcfile("proj.cu"), ) self.texref_slice = self.gpu_projector.module.get_texref("texSlice") self.texref_slice.set_array(self.d_image_cua) self.texref_slice.set_filter_mode(cuda.filter_mode.LINEAR) self.gpu_projector.prepare(self._projector_signature, [self.texref_slice]) self.kernel_args = ( self.d_sino.gpudata, np.int32(self.shape[1]), np.int32(self.dwidth), self.d_angles.gpudata, np.float32(self.axis_pos0), self.d_axis_corrections.gpudata, self._d_beginPos.gpudata, self._d_strideJoseph.gpudata, self._d_strideLine.gpudata, np.int32(self.nprojs), np.int32(self._dimrecx), np.int32(self._dimrecy), self.offset_x, np.int32(1), # josephnoclip, 1 by default np.int32(self.normalize), ) self._proj_kernel_blk = (16, 16, 1) self._proj_kernel_grd = (self.dimgrid_x, self.dimgrid_y, 1) def _compute_angles(self): angles2 = np.zeros(self._dimrecy, dtype=np.float32) # dimrecy != num_projs angles2[: self.nprojs] = np.copy(self.angles) angles2[self.nprojs :] = angles2[self.nprojs - 1] self.angles2 = angles2 self.d_angles[:] = angles2[:] def _proj_precomputations(self): beginPos = np.zeros((2, self._dimrecy), dtype=np.int32) strideJoseph = np.zeros((2, self._dimrecy), dtype=np.int32) strideLine = np.zeros((2, self._dimrecy), dtype=np.int32) cos_angles = np.cos(self.angles2) sin_angles = np.sin(self.angles2) dimslice = self.shape[1] M1 = np.abs(cos_angles) > 0.70710678 M1b = np.logical_not(M1) M2 = cos_angles > 0 M2b = np.logical_not(M2) M3 = sin_angles > 0 M3b = np.logical_not(M3) case1 = M1 * M2 case2 = M1 * M2b case3 = M1b * M3 case4 = M1b * M3b beginPos[:, case1] = 0 strideJoseph[0][case1] = 1 strideJoseph[1][case1] = 0 strideLine[0][case1] = 0 strideLine[1][case1] = 1 beginPos[:, case2] = dimslice - 1 strideJoseph[0][case2] = -1 strideJoseph[1][case2] = 0 strideLine[0][case2] = 0 strideLine[1][case2] = -1 beginPos[0][case3] = dimslice - 1 beginPos[1][case3] = 0 strideJoseph[0][case3] = 0 strideJoseph[1][case3] = 1 strideLine[0][case3] = -1 strideLine[1][case3] = 0 beginPos[0][case4] = 0 beginPos[1][case4] = dimslice - 1 strideJoseph[0][case4] = 0 strideJoseph[1][case4] = -1 strideLine[0][case4] = 1 strideLine[1][case4] = 0 self._d_beginPos.set(beginPos) self._d_strideJoseph.set(strideJoseph) self._d_strideLine.set(strideLine) def _check_input_array(self, image): if image.shape != self.shape: raise ValueError("Expected slice shape = %s, got %s" % (str(self.shape), str(image.shape))) if image.dtype != np.dtype("f"): raise ValueError("Expected float32 data type, got %s" % str(image.dtype)) if not isinstance(image, (np.ndarray, garray.GPUArray)): raise ValueError("Expected either numpy.ndarray or pyopencl.array.Array") if isinstance(image, np.ndarray): if not image.flags["C_CONTIGUOUS"]: raise ValueError("Please use C-contiguous arrays") def set_image(self, image, check=True): if check: self._check_input_array(image) copy_array( self.d_image_cua, image, dst_x_in_bytes=_sizeof_float32, dst_y=1, check=False, # cannot check when using offsets ) def projection(self, image, output=None, do_checks=True): """ Perform the projection of an image. Parameters ----------- image: array Image to forward project output: array, optional Output image """ self.set_image(image, check=do_checks) self.gpu_projector(*self.kernel_args, grid=self._proj_kernel_grd, block=self._proj_kernel_blk) if output is None: res = self.d_sino.get() res = res[: self.nprojs, : self.dwidth] # copy ? else: output[:, :] = self.d_sino[: self.nprojs, : self.dwidth] res = output return res __call__ = projection ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/reconstruction/reconstructor.py0000644000175000017500000001627614654107202021576 0ustar00pierrepierreimport numpy as np from ..utils import convert_index class Reconstructor: """ Abstract base class for reconstructors. A `Reconstructor` is a helper to reconstruct slices in arbitrary directions (not only usual "horizontal slices") in parallel-beam tomography. Current limitations: - Limitation to the three main axes - One instance of Reconstructor can only reconstruct successive slices Typical scenarios examples: - "I want to reconstruct several slices along 'z'", where `z` is the vertical axis. In this case, we reconstruct "horizontal slices" in planes perpendicular to the rotation axis. - "I want to reconstruct slices along 'y'". Here `y` is an axis perpendicular to `z`, i.e we reconstruct "vertical slices". A `Reconstructor` is tied to the set of slices to reconstruct (axis and orientation). Once defined, it cannot be changed ; i.e another class has to be instantiated to reconstruct slices in other axes/indices. The volume geometry conventions are defined below:: __________ / /| / / | z / / | ^ /_________/ | | | | | | y | | / | / | | / | / | | / | / |__________|/ |/ ---------- > x The axis `z` parallel to the rotation axis. The usual parallel-beam tomography setting reconstructs slices along `z`, i.e in planes parallel to (x, y). """ def __init__(self, shape, indices_range, axis="z", vol_type="sinograms", slices_roi=None): """ Initialize a reconstructor. Parameters ----------- shape: tuple Shape of the stack of sinograms or projections. indices_range: tuple Range of indices to reconstruct, in the form (start, end). As the standard Python behavior, the upper bound is not included. For example, to reconstruct 100 slices (numbered from 0 to 99), then you can provide (0, 100) or (0, None). Providing (0, 99) or (0, -1) will omit the last slice. axis: str Axis along which the slices are reconstructed. This axis is orthogonal to the slices planes. This parameter can be either "x", "y", or "z". Default is "z" (reconstruct slices perpendicular to the rotation axis). vol_type: str, optional Whether the parameter `shape` describes a volume of sinograms or projections. The two are the same except that axes 0 and 1 are swapped. Can be "sinograms" (default) or "projections". slices_roi: tuple, optional Define a Region Of Interest to reconstruct a part of each slice. By default, the whole slice is reconstructed for each slice index. This parameter is in the form `(start_u, end_u, start_v, end_v)`, where `u` and `v` are horizontal and vertical axes on the reconstructed slice respectively, regardless of its orientation. If one of the values is set to None, it will be replaced by the corresponding default value. Examples --------- To reconstruct the first two horizontal slices, i.e along `z`: `R = Reconstructor(vol_shape, [0, 1])` To reconstruct vertical slices 0-100 along the `y` axis: `R = Reconstructor(vol_shape, (0, 100), axis="y")` """ self._set_shape(shape, vol_type) self._set_axis(axis) self._set_indices(indices_range) self._configure_geometry(slices_roi) def _set_shape(self, shape, vol_type): if "sinogram" in vol_type.lower(): self.vol_type = "sinograms" elif "projection" in vol_type.lower(): self.vol_type = "projections" else: raise ValueError("vol_type can be either 'sinograms' or 'projections'") if len(shape) != 3: raise ValueError("Expected a 3D array description, but shape does not have 3 dims") self.shape = shape if self.vol_type == "sinograms": n_z, n_a, n_x = shape else: n_a, n_z, n_x = shape self.sinos_shape = (n_z, n_a, n_x) self.projs_shape = (n_a, n_z, n_x) self.data_shape = self.sinos_shape if self.vol_type == "sinograms" else self.projs_shape self.n_a = n_a self.n_x = n_x self.n_y = n_x # square slice by default self.n_z = n_z self._n = {"x": self.n_x, "y": self.n_y, "z": self.n_z} def _set_axis(self, axis): if axis.lower() not in ["x", "y", "z"]: raise ValueError("axis can be either 'x', 'y' or 'z' (got %s)" % axis) self.axis = axis def _set_indices(self, indices_range): start, end = indices_range npix = self._n[self.axis] start = convert_index(start, npix, 0) end = convert_index(end, npix, npix) self.indices_range = (start, end) self._idx_start = start self._idx_end = end self.indices = np.arange(start, end) def _configure_geometry(self, slices_roi): self.slices_roi = slices_roi or (None, None, None, None) start_u, end_u, start_v, end_v = self.slices_roi uv_to_xyz = { "z": ("x", "y"), # reconstruct along z: u = x, v = y "y": ("y", "z"), # reconstruct along y: u = y, v = z "x": ("y", "z"), # reconstruct along x: u = y, v = z } rotated_axes = uv_to_xyz[self.axis] u_max = self._n[rotated_axes[0]] v_max = self._n[rotated_axes[1]] start_u = convert_index(start_u, u_max, 0) end_u = convert_index(end_u, u_max, u_max) start_v = convert_index(start_v, v_max, 0) end_v = convert_index(end_v, v_max, v_max) self.slices_roi = (start_u, end_u, start_v, end_v) if self.axis == "z": self.backprojector_roi = self.slices_roi start_z, end_z = self._idx_start, self._idx_end if self.axis == "y": self.backprojector_roi = (start_u, end_u, self._idx_start, self._idx_end) start_z, end_z = start_v, end_v if self.axis == "x": self.backprojector_roi = (self._idx_start, self._idx_end, start_u, end_u) start_z, end_z = start_v, end_v else: raise ValueError("Invalid axis") self._z_indices = np.arange(start_z, end_z) self.output_shape = ( self._z_indices.size, self.backprojector_roi[3] - self.backprojector_roi[2], self.backprojector_roi[1] - self.backprojector_roi[0], ) def _check_data(self, data): if data.shape != self.data_shape: raise ValueError( "Invalid data shape: expected %s shape %s, but got %s" % (self.vol_type, self.data_shape, data.shape) ) if data.dtype != np.float32: raise ValueError("Expected float32 data type") def reconstruct(self): raise ValueError("Base class") __call__ = reconstruct ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/reconstruction/reconstructor_cuda.py0000644000175000017500000000315414550227307022565 0ustar00pierrepierreimport numpy as np from .fbp import Backprojector from .reconstructor import Reconstructor class CudaReconstructor(Reconstructor): def __init__(self, shape, indices, axis="z", vol_type="sinograms", slices_roi=None, **backprojector_options): Reconstructor.__init__(self, shape, indices, axis=axis, vol_type=vol_type, slices_roi=slices_roi) self._init_backprojector(**backprojector_options) def _init_backprojector(self, **backprojector_options): self.backprojector = Backprojector( self.sinos_shape[1:], slice_roi=self.backprojector_roi, **backprojector_options ) def reconstruct(self, data, output=None): """ Reconstruct from sinograms or projections. """ self._check_data(data) B = self.backprojector if output is None: output = np.zeros(self.output_shape, "f") new_output = True else: assert output.shape == self.output_shape, str( "Expected output_shape = %s, got %s" % (str(self.output_shape), str(output.shape)) ) assert output.dtype == np.float32 new_output = False def reconstruct_fbp(data, output, i, i0): if self.vol_type == "sinograms": current_sino = data[i] else: current_sino = data[:, i, :] if new_output: output[i0] = B.fbp(current_sino) else: B.fbp(current_sino, output=output[i0]) for i0, i in enumerate(self._z_indices): reconstruct_fbp(data, output, i, i0) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/reconstruction/rings.py0000644000175000017500000002246714550227307020007 0ustar00pierrepierreimport numpy as np from scipy.fft import rfft, irfft from silx.image.tomography import get_next_power from ..thirdparty.pore3d_deringer_munch import munchetal_filter from ..utils import get_2D_3D_shape, get_num_threads, check_supported from ..misc.fourier_filters import get_bandpass_filter try: from algotom.prep.removal import remove_all_stripe __has_algotom__ = True except ImportError: __has_algotom__ = False class MunchDeringer: def __init__(self, sigma, sinos_shape, levels=None, wname="db15", padding=None, padding_mode="edge"): """ Initialize a "Munch Et Al" sinogram deringer. See References for more information. Parameters ----------- sigma: float Standard deviation of the damping parameter. The higher value of sigma, the more important the filtering effect on the rings. levels: int, optional Number of wavelets decomposition levels. By default (None), the maximum number of decomposition levels is used. wname: str, optional Default is "db15" (Daubechies, 15 vanishing moments) sinos_shape: tuple, optional Shape of the sinogram (or sinograms stack). padding: tuple of two int, optional Horizontal padding to use for reducing the aliasing artefacts References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ self._get_shapes(sinos_shape, padding) self.sigma = sigma self.levels = levels self.wname = wname self.padding_mode = padding_mode self._check_can_use_wavelets() def _get_shapes(self, sinos_shape, padding): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x # Handle "padding=True" or "padding=False" if isinstance(padding, bool): if padding: padding = (n_x // 2, n_x // 2) else: padding = None # if padding is not None: pad_x1, pad_x2 = padding if np.iterable(pad_x1) or np.iterable(pad_x2): raise ValueError("Expected padding in the form (x1, x2)") self.sino_padded_shape = (n_a, n_x + pad_x1 + pad_x2) self.padding = padding def _check_can_use_wavelets(self): if munchetal_filter is None: raise ValueError("Need pywavelets to use this class") def _destripe_2D(self, sino, output): if self.padding is not None: sino = np.pad(sino, ((0, 0), self.padding), mode=self.padding_mode) res = munchetal_filter(sino, self.levels, self.sigma, wname=self.wname) if self.padding is not None: res = res[:, self.padding[0] : -self.padding[1]] output[:] = res return output def remove_rings(self, sinos, output=None): """ Main function to performs rings artefacts removal on sinogram(s). CAUTION: this function defaults to in-place processing, meaning that the sinogram(s) you pass will be overwritten. Parameters ---------- sinos: numpy.ndarray Sinogram or stack of sinograms. output: numpy.ndarray, optional Output array. If set to None (default), the output overwrites the input. """ if output is None: output = sinos if sinos.ndim == 2: return self._destripe_2D(sinos, output) n_sinos = sinos.shape[0] for i in range(n_sinos): self._destripe_2D(sinos[i], output[i]) return output class VoDeringer: """ An interface to Nghia Vo's "remove_all_stripe". Needs algotom to run. """ def __init__(self, sinos_shape, **remove_all_stripe_options): self._check_requirement() self._get_shapes(sinos_shape) self._remove_all_stripe_kwargs = remove_all_stripe_options def _check_requirement(self): if not __has_algotom__: raise ImportError("Need algotom") def _get_shapes(self, sinos_shape): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x def remove_rings_sinogram(self, sino, output=None): new_sino = remove_all_stripe(sino, **self._remove_all_stripe_kwargs) # out-of-place if output is not None: output[:] = new_sino[:] return output return new_sino def remove_rings_sinograms(self, sinos, output=None): if output is None: output = sinos for i in range(sinos.shape[0]): output[i] = self.remove_rings_sinogram(sinos[i]) return output def remove_rings_radios(self, radios): sinos = np.moveaxis(radios, 1, 0) # (n_a, n_z, n_x) --> (n_z, n_a, n_x) return self.remove_rings_sinograms(sinos) remove_rings = remove_rings_sinograms class SinoMeanDeringer: supported_modes = ["subtract", "divide"] def __init__(self, sinos_shape, mode="subtract", filter_cutoff=None, padding_mode="edge", fft_num_threads=None): """ Rings correction with mean subtraction/division. The principle of this method is to subtract (or divide) the sinogram by its mean along a certain axis. In short: sinogram -= filt(sinogram.mean(axis=0)) where `filt` is some bandpass filter. Parameters ---------- sinos_shape: tuple of int Sinograms shape, in the form (n_angles, n_x) or (n_sinos, n_angles, n_x) mode: str, optional Operation to do on the sinogram, either "subtract" or "divide" filter_cutoff: tuple, optional Cut-off of the bandpass filter applied on the sinogram profiles. Empty (default) means no filtering. Possible values forms are: - (sigma_low, sigma_high): two float values defining the standard deviation of gaussian(sigma_low) * (1 - gaussian(sigma_high)). High values of sigma mean stronger effect of associated filters. - ((cutoff_low, transition_low), (cutoff_high, transition_high)) where "cutoff" is in normalized Nyquist frequency (0.5 is the maximum frequency), and "transition" is the width of filter decay in fraction of the cutoff frequency padding_mode: str, optional Padding mode when filtering the sinogram profile. Should be "constant" (i.e "zeros") for mathematical correctness, but in practice this yields a Gibbs effect when replicating the sinogram, so "edges" is recommended. fft_num_threads: int, optional How many threads to use for computing the fast Fourier transform when filtering the sinogram profile. Defaut is all the available threads. """ self._get_shapes(sinos_shape) check_supported(mode, self.supported_modes, "operation mode") self.mode = mode self._init_filter(filter_cutoff, fft_num_threads, padding_mode) def _get_shapes(self, sinos_shape): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x def _init_filter(self, filter_cutoff, fft_num_threads, padding_mode): self.filter_cutoff = filter_cutoff self._filter_f = None if filter_cutoff is None: return self._filter_size = get_next_power(self.n_x * 2) self._filter_f = get_bandpass_filter( (1, self._filter_size), cutoff_lowpass=filter_cutoff[0], cutoff_highpass=filter_cutoff[1], use_rfft=True, data_type=np.float32, ).ravel() self._fft_n_threads = get_num_threads(fft_num_threads) # compat if padding_mode == "edges": padding_mode = "edge" # self.padding_mode = padding_mode size_diff = self._filter_size - self.n_x self._pad_left, self._pad_right = size_diff // 2, size_diff - size_diff // 2 def _apply_filter(self, sino_profile): if self._filter_f is None: return sino_profile sino_profile = np.pad(sino_profile, (self._pad_left, self._pad_right), mode=self.padding_mode) sino_f = rfft(sino_profile, workers=self._fft_n_threads) sino_f *= self._filter_f return irfft(sino_f, workers=self._fft_n_threads)[self._pad_left : -self._pad_right] # ascontiguousarray ? def remove_rings_sinogram(self, sino, output=None): # if output is not None: raise NotImplementedError # sino_profile = sino.mean(axis=0) sino_profile = self._apply_filter(sino_profile) if self.mode == "subtract": sino -= sino_profile elif self.mode == "divide": sino /= sino_profile return sino def remove_rings_sinograms(self, sinos, output=None): # if output is not None: raise NotImplementedError # for i in range(sinos.shape[0]): self.remove_rings_sinogram(sinos[i]) remove_rings = remove_rings_sinograms ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/reconstruction/rings_cuda.py0000644000175000017500000003261014712705065020774 0ustar00pierrepierreimport numpy as np from ..utils import docstring, get_cuda_srcfile, updiv from ..cuda.processing import CudaProcessing, __has_pycuda__ from ..processing.padding_cuda import CudaPadding from ..processing.fft_cuda import get_fft_class, get_available_fft_implems from ..processing.transpose import CudaTranspose from .rings import MunchDeringer, SinoMeanDeringer, VoDeringer if __has_pycuda__: import pycuda.gpuarray as garray from ..cuda.kernel import CudaKernel try: from pycudwt import Wavelets __have_pycudwt__ = True except ImportError: __have_pycudwt__ = False # pylint: disable=E0606 class CudaMunchDeringer(MunchDeringer): def __init__( self, sigma, sinos_shape, levels=None, wname="db15", padding=None, padding_mode="edge", fft_backend="vkfft", cuda_options=None, ): """ Initialize a "Munch Et Al" sinogram deringer with the Cuda backend. See References for more information. Parameters ----------- sigma: float Standard deviation of the damping parameter. The higher value of sigma, the more important the filtering effect on the rings. levels: int, optional Number of wavelets decomposition levels. By default (None), the maximum number of decomposition levels is used. wname: str, optional Default is "db15" (Daubechies, 15 vanishing moments) sinos_shape: tuple, optional Shape of the sinogram (or sinograms stack). References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ super().__init__(sigma, sinos_shape, levels=levels, wname=wname, padding=padding, padding_mode=padding_mode) self._check_can_use_wavelets() self.cuda_processing = CudaProcessing(**(cuda_options or {})) self.ctx = self.cuda_processing.ctx self._init_pycudwt() self._init_padding() self._init_fft(fft_backend) self._setup_fw_kernel() def _check_can_use_wavelets(self): if not (__have_pycudwt__ and __has_pycuda__): raise ValueError("Needs pycuda and pycudwt to use this class") def _init_padding(self): if self.padding is None: return self.padder = CudaPadding( self.sinos_shape[1:], ((0, 0), self.padding), mode=self.padding_mode, cuda_options={"ctx": self.cuda_processing.ctx}, ) def _init_fft(self, fft_backend): self.fft_cls = get_fft_class(backend=fft_backend) # For all k >= 1, we perform a batched (I)FFT along axis 0 on an array # of shape (n_a/2^k, n_x/2^k) (up to DWT size rounding) if self.fft_cls.implem == "vkfft": self._create_plans_vkfft() else: self._create_plans_skfft() def _create_plans_skfft(self): self._fft_plans = {} for level, d_vcoeff in self._d_vertical_coeffs.items(): self._fft_plans[level] = self.fft_cls(d_vcoeff.shape, np.float32, r2c=True, axes=(0,), ctx=self.ctx) def _create_plans_vkfft(self): """ VKFFT does not support batched R2C transforms along axis 0 ("slow axis"). We can either use C2C (faster, but needs more memory) or transpose the arrays to do R2C along axis=1. Here we transpose the arrays. """ self._fft_plans = {} self._transpose_forward_1 = {} self._transpose_forward_2 = {} self._transpose_inverse_1 = {} self._transpose_inverse_2 = {} for level, d_vcoeff in self._d_vertical_coeffs.items(): shape = d_vcoeff.shape # Normally, a batched 1D fft on 2D data of shape (Ny, Nx) along axis 0 returns an array of shape (Ny/2+1, Nx): # # (Ny, Nx) --[fft_0]--> (Ny/2, Nx) # f32 c64 # # In this case, we can only do batched 1D transform along axis 1, so we have to trick with transposes: # # (Ny, Nx) --[T]--> (Nx, Ny) --[fft_1]--> (Nx, Ny/2) --[T]--> (Ny/2, Nx) # f32 f32 c64 c64 # # (In both cases IFFT is done the same way from right to left) self._transpose_forward_1[level] = CudaTranspose(shape, np.float32, ctx=self.ctx) self._fft_plans[level] = self.fft_cls(shape[::-1], np.float32, r2c=True, ctx=self.ctx) self._transpose_forward_2[level] = CudaTranspose((shape[1], shape[0] // 2 + 1), np.complex64, ctx=self.ctx) self._transpose_inverse_1[level] = CudaTranspose((shape[0] // 2 + 1, shape[1]), np.complex64, ctx=self.ctx) self._transpose_inverse_2[level] = CudaTranspose(shape[::-1], np.float32, ctx=self.ctx) def _init_pycudwt(self): if self.levels is None: self.levels = 100 # will be clipped by pycudwt sino_shape = self.sinos_shape[1:] if self.padding is None else self.sino_padded_shape self.cudwt = Wavelets(np.zeros(sino_shape, "f"), self.wname, self.levels) self.levels = self.cudwt.levels # Access memory allocated by "pypwt" from pycuda self._d_sino = garray.empty(sino_shape, np.float32, gpudata=self.cudwt.image_int_ptr()) self._get_vertical_coeffs() def _get_vertical_coeffs(self): self._d_vertical_coeffs = {} # Transfer the (0-memset) coefficients in order to get all the shapes coeffs = self.cudwt.coeffs for i in range(self.cudwt.levels): shape = coeffs[i + 1][1].shape self._d_vertical_coeffs[i + 1] = garray.empty( shape, np.float32, gpudata=self.cudwt.coeff_int_ptr(3 * i + 2) ) def _setup_fw_kernel(self): self._fw_kernel = CudaKernel( "kern_fourierwavelets", filename=get_cuda_srcfile("fourier_wavelets.cu"), signature="Piif", ) def _apply_fft(self, level): d_coeffs = self._d_vertical_coeffs[level] # All the memory is allocated (or re-used) under the hood if self.fft_cls.implem == "vkfft": d_coeffs_t = self._transpose_forward_1[level]( d_coeffs ) # allocates self._transpose_forward_1[level].processing.dst d_coeffs_t_f = self._fft_plans[level].fft(d_coeffs_t) # allocates self._fft_plans[level].output_fft d_coeffs_f = self._transpose_forward_2[level]( d_coeffs_t_f ) # allocates self._transpose_forward_2[level].processing.dst else: d_coeffs_f = self._fft_plans[level].fft(d_coeffs) return d_coeffs_f def _apply_ifft(self, d_coeffs_f, level): d_coeffs = self._d_vertical_coeffs[level] if self.fft_cls.implem == "vkfft": d_coeffs_t_f = self._transpose_inverse_1[level](d_coeffs_f, dst=self._fft_plans[level].output_fft) d_coeffs_t = self._fft_plans[level].ifft( d_coeffs_t_f, output=self._transpose_forward_1[level].processing.dst ) self._transpose_inverse_2[level](d_coeffs_t, dst=d_coeffs) else: self._fft_plans[level].ifft(d_coeffs_f, output=d_coeffs) def _destripe_2D(self, d_sino, output): if not (d_sino.flags.c_contiguous): sino = self.cuda_processing.allocate_array("_d_sino", d_sino.shape, np.float32) sino[:] = d_sino[:] else: sino = d_sino if self.padding is not None: sino = self.padder.pad(sino) # set the "image" for DWT (memcpy D2D) self._d_sino.set(sino) # perform forward DWT self.cudwt.forward() for i in range(self.cudwt.levels): level = i + 1 Ny, Nx = self._d_vertical_coeffs[level].shape # Batched FFT along axis 0 d_vertical_coeffs_f = self._apply_fft(level) # Dampen the wavelets coefficients self._fw_kernel(d_vertical_coeffs_f, Nx, Ny, self.sigma) # IFFT self._apply_ifft(d_vertical_coeffs_f, level) # Finally, inverse DWT self.cudwt.inverse() d_out = self._d_sino if self.padding is not None: d_out = self._d_sino[:, self.padding[0] : -self.padding[1]] # memcpy2D output.set(d_out) return output def can_use_cuda_deringer(): """ Check wether cuda implementation of deringer can be used. Checking for installed modules is not enough, as for example pyvkfft can be installed without cuda devices """ can_do_fft = get_available_fft_implems() != [] return can_do_fft and __have_pycudwt__ class CudaVoDeringer(VoDeringer): """ An interface to topocupy's "remove_all_stripe". """ def _check_requirement(self): # Do it here, otherwise cupy shows warnings at import even if not used from ..thirdparty.tomocupy_remove_stripe import remove_all_stripe_pycuda, __have_tomocupy_deringer__ if not (__have_tomocupy_deringer__): raise ImportError("need cupy") self._remove_all_stripe_pycuda = remove_all_stripe_pycuda def remove_rings_radios(self, radios): return self._remove_all_stripe_pycuda(radios, layout="radios", **self._remove_all_stripe_kwargs) def remove_rings_sinograms(self, sinos): return self._remove_all_stripe_pycuda(sinos, layout="sinos", **self._remove_all_stripe_kwargs) def remove_rings_sinogram(self, sino): sinos = sino.reshape((1, sino.shape[0], -1)) # no copy self.remove_rings_sinograms(sinos) return sino remove_rings = remove_rings_sinograms class CudaSinoMeanDeringer(SinoMeanDeringer): @docstring(SinoMeanDeringer) def __init__( self, sinos_shape, mode="subtract", filter_cutoff=None, padding_mode="edge", fft_num_threads=None, **cuda_options, ): self.processing = CudaProcessing(**(cuda_options or {})) super().__init__(sinos_shape, mode, filter_cutoff, padding_mode, fft_num_threads) self._init_kernels() def _init_kernels(self): self.d_sino_profile = self.processing.allocate_array("sino_profile", self.n_x) self._mean_kernel = self.processing.kernel( "vertical_mean", filename=get_cuda_srcfile("normalization.cu"), signature="PPiii", ) self._mean_kernel_block = (32, 1, 32) self._mean_kernel_grid = [updiv(a, b) for a, b in zip(self.sinos_shape[::-1], self._mean_kernel_block)] self._mean_kernel_args = [self.d_sino_profile, np.int32(self.n_x), np.int32(self.n_angles), np.int32(self.n_z)] self._mean_kernel_kwargs = { "grid": self._mean_kernel_grid, "block": self._mean_kernel_block, } self._op_kernel = self.processing.kernel( "inplace_generic_op_3Dby1D", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPiii", options=["-DGENERIC_OP=%d" % (3 if self.mode == "divide" else 1)], ) self._op_kernel_block = (16, 16, 4) self._op_kernel_grid = [updiv(a, b) for a, b in zip(self.sinos_shape[::-1], self._op_kernel_block)] self._op_kernel_args = [self.d_sino_profile, np.int32(self.n_x), np.int32(self.n_angles), np.int32(self.n_z)] self._op_kernel_kwargs = { "grid": self._op_kernel_grid, "block": self._op_kernel_block, } def _init_filter(self, filter_cutoff, fft_num_threads, padding_mode): super()._init_filter(filter_cutoff, fft_num_threads, padding_mode) if filter_cutoff is None: return self._d_filter_f = self.processing.to_device("_filter_f", self._filter_f) self.padder = CudaPadding( (self.n_x, 1), ((self._pad_left, self._pad_right), (0, 0)), mode=self.padding_mode, cuda_options={"ctx": self.processing.ctx}, ) fft_cls = get_fft_class() self._fft = fft_cls(self._filter_size, np.float32, r2c=True) def _apply_filter(self, sino_profile): if self._filter_f is None: return sino_profile sino_profile = sino_profile.reshape((-1, 1)) # view sino_profile_p = self.padder.pad(sino_profile).ravel() sino_profile_f = self._fft.fft(sino_profile_p) sino_profile_f *= self._d_filter_f self._fft.ifft(sino_profile_f, output=sino_profile_p) self.d_sino_profile[:] = sino_profile_p[self._pad_left : -self._pad_right] return self.d_sino_profile def remove_rings_sinogram(self, sino, output=None): # if output is not None: raise NotImplementedError # if not (sino.flags.c_contiguous): d_sino = self.processing.allocate_array("d_sino", sino.shape, np.float32) d_sino[:] = sino[:] else: d_sino = sino self._mean_kernel(d_sino, *self._mean_kernel_args, **self._mean_kernel_kwargs) self._apply_filter(self.d_sino_profile) self._op_kernel(d_sino, *self._op_kernel_args, **self._op_kernel_kwargs) if not (sino.flags.c_contiguous): sino[:] = self.processing.d_sino[:] return sino def remove_rings_sinograms(self, sinograms): for i in range(sinograms.shape[0]): self.remove_rings_sinogram(sinograms[i]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/reconstruction/sinogram.py0000644000175000017500000004120214550227307020470 0ustar00pierrepierreimport numpy as np from scipy.interpolate import interp1d from ..utils import get_2D_3D_shape, check_supported, deprecated_class, deprecated class SinoBuilder: """ A class to build sinograms. """ def __init__( self, sinos_shape=None, radios_shape=None, rot_center=None, halftomo=False, angles=None, interpolate=False ): """ Initialize a SinoBuilder instance. Parameters ---------- sinos_shape: tuple of int Shape of the stack of sinograms, in the form `(n_z, n_angles, n_x)`. If not provided, it is derived from `radios_shape`. radios_shape: tuple of int Shape of the chunk of radios, in the form `(n_angles, n_z, n_x)`. If not provided, it is derived from `sinos_shape`. rot_center: int or array Rotation axis position. A scalar indicates the same rotation axis position for all the projections. halftomo: bool Whether "half tomography" is enabled. Default is False. interpolate: bool, optional Only used if halftomo=True. Whether to re-grid the second part of sinograms to match projection k with projection k + n_a/2. This forces each pair of projection (k, k + n_a/2) to be separated by exactly 180 degrees. angles: array, optional Rotation angles (in radians). Used and required only when halftomo and interpolate are True. """ self._get_shapes(sinos_shape, radios_shape) self.set_rot_center(rot_center) self._configure_halftomo(halftomo, interpolate, angles) def _get_shapes(self, sinos_shape, radios_shape=None): if (sinos_shape is None) and (radios_shape is None): raise ValueError("Need to provide sinos_shape and/or radios_shape") if sinos_shape is None: n_a, n_z, n_x = get_2D_3D_shape(radios_shape) sinos_shape = (n_z, n_a, n_x) elif len(sinos_shape) == 2: sinos_shape = (1,) + sinos_shape if radios_shape is None: n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) radios_shape = (n_a, n_z, n_x) elif len(radios_shape) == 2: radios_shape = (1,) + radios_shape self.sinos_shape = sinos_shape self.radios_shape = radios_shape n_a, n_z, n_x = radios_shape self.n_angles = n_a self.n_z = n_z self.n_x = n_x def set_rot_center(self, rot_center): """ Set the rotation axis position for the current radios/sinos stack. rot_center: int or array Rotation axis position. A scalar indicates the same rotation axis position for all the projections. """ if rot_center is None: rot_center = (self.n_x - 1) / 2.0 if not (np.isscalar(rot_center)): rot_center = np.array(rot_center) if rot_center.size != self.n_angles: raise ValueError( "Expected rot_center to have %d elements but got %d" % (self.n_angles, rot_center.size) ) self.rot_center = rot_center def _configure_halftomo(self, halftomo, interpolate, angles): self.halftomo = halftomo self.interpolate = interpolate self.angles = angles self._halftomo_flip = False if not self.halftomo: return if interpolate and (angles is None): raise ValueError("The parameter 'angles' has to be provided when using halftomo=True and interpolate=True") self.extended_sino_width = get_extended_sinogram_width(self.n_x, self.rot_center) # If CoR is on the left: "flip" the logic if self.rot_center < (self.n_x - 1) / 2: self.rot_center = self.n_x - 1 - self.rot_center self._halftomo_flip = True # if abs(self.rot_center - ((self.n_x - 1) / 2.0)) < 1: # which tol ? raise ValueError("Half tomography: incompatible rotation axis position: %.2f" % self.rot_center) self.sinos_halftomo_shape = (self.n_z, (self.n_angles + 1) // 2, self.extended_sino_width) def _check_array_shape(self, array, kind="radio"): expected_shape = self.radios_shape if "radio" in kind else self.sinos_shape assert array.shape == expected_shape, "Expected radios shape %s, but got %s" % (expected_shape, array.shape) @property def output_shape(self): """ Get the output sinograms shape. """ if self.halftomo: return self.sinos_halftomo_shape return self.sinos_shape # # 2D # def _get_sino_simple(self, radios, i): return radios[:, i, :] # view def _get_sino_halftomo(self, sino, output=None): # TODO output is ignored for now if self.interpolate: match_half_sinos_parts(sino, self.angles) elif self.n_angles & 1: # Odd number of projections - add one line in the end sino = np.vstack([sino, np.zeros_like(sino[-1])]) if self._halftomo_flip: sino = sino[:, ::-1] if self.rot_center > self.n_x: # (hopefully rare) case where CoR is outside FoV result = _convert_halftomo_right(sino, self.extended_sino_width) else: # Standard case result = convert_halftomo(sino, self.extended_sino_width) if self._halftomo_flip: result = result[:, ::-1] return result def get_sino(self, radios, i, output=None): """ The the sinogram at a given index. Parameters ---------- radios: array 3D array with shape (n_z, n_angles, n_x) i: int Sinogram index Returns ------- sino: array Two dimensional array with shape (n_angles2, n_x2) where the dimensions are determined by the current settings. """ sino = self._get_sino_simple(radios, i) if self.halftomo: return self._get_sino_halftomo(sino, output=None) else: return sino def convert_sino(self, sino, output=None): if not self.halftomo: return sino return self._get_sino_halftomo(sino, output=output) # # 3D # def _get_sinos_simple(self, radios, output=None): res = np.rollaxis(radios, 1, 0) # view if output is not None: output[...] = res[...] # copy return output return res def _get_sinos_halftomo(self, radios, output=None): n_a, n_z, n_x = radios.shape if output is None: output = np.zeros(self.sinos_halftomo_shape, dtype=np.float32) elif output.shape != self.output_shape: raise ValueError("Expected output shape to be %s, but got %s" % (str(output.shape), str(self.output_shape))) for i in range(n_z): sino = self._get_sino_simple(radios, i) output[i] = self._get_sino_halftomo(sino) return output def get_sinos(self, radios, output=None): if self.halftomo: return self._get_sinos_halftomo(radios, output=output) else: return self._get_sinos_simple(radios, output=output) @deprecated("Use get_sino() or get_sinos() instead", do_print=True) def radios_to_sinos(self, radios, output=None, copy=False): """ DEPRECATED. Use get_sinos() or get_sino() instead. """ return self.get_sinos(radios, output=output) SinoProcessing = deprecated_class("'SinoProcessing' was renamed 'SinoBuilder'", do_print=True)(SinoBuilder) class SinoMult: """ A class for preparing a sinogram for half-tomography reconstruction, without stitching the two parts """ def __init__(self, sino_shape, rot_center): self._set_shape(sino_shape) self._prepare_weights(rot_center) def _set_shape(self, sino_shape): _, self.n_a, self.n_x = get_2D_3D_shape(sino_shape) def _prepare_weights(self, rot_center): n_x = self.n_x middle = (n_x - 1) / 2.0 if rot_center >= middle: overlap_width = int(2 * (n_x - 1 - rot_center)) self.overlap_region = slice(-overlap_width, None) self.pad_left, self.pad_right = 0, n_x - overlap_width else: overlap_width = int(2 * rot_center) self.overlap_region = slice(0, overlap_width) self.pad_left, self.pad_right = n_x - overlap_width, 0 weights = np.linspace(0, 1, overlap_width, endpoint=True) if rot_center >= middle: weights = weights[::-1] self.weights = np.ascontiguousarray(weights, dtype="f") overlap_region_indices = np.arange(self.n_x)[self.overlap_region] self.start_x = overlap_region_indices[0] self.end_x = overlap_region_indices[-1] self.extended_width = n_x + self.pad_left + self.pad_right def prepare_sino(self, sino): sino[:, self.overlap_region] *= self.weights return sino def convert_halftomo(sino, extended_width, transition_width=None): """ Converts a sinogram into a sinogram with extended FOV with the "half tomography" setting. """ assert sino.ndim == 2 assert (sino.shape[0] % 2) == 0 na, nx = sino.shape na2 = na // 2 r = extended_width // 2 d = transition_width or nx - r res = np.zeros((na2, 2 * r), dtype="f") sino1 = sino[:na2, :] sino2 = sino[na2:, ::-1] res[:, : r - d] = sino1[:, : r - d] # w1 = np.linspace(0, 1, 2 * d, endpoint=True) res[:, r - d : r + d] = (1 - w1) * sino1[:, r - d :] + w1 * sino2[:, 0 : 2 * d] # res[:, r + d :] = sino2[:, 2 * d :] return res # This function can have a cuda counterpart, see test_interpolation.py def match_half_sinos_parts(sino, angles, output=None): """ Modifies the lower part of the half-acquisition sinogram so that each projection pair is separated by exactly 180 degrees. This means that `new_sino[k]` and `new_sino[k + n_angles//2]` will be 180 degrees apart. Parameters ---------- sino: numpy.ndarray Two dimensional array with the sinogram in the form (n_angles, n_x) angles: numpy.ndarray One dimensional array with the rotation angles. output: numpy.array, optional Output sinogram. By default, the array 'sino' is modified in-place. Notes ----- This function assumes that the angles are in an increasing order. """ n_a = angles.size n_a_2 = n_a // 2 sino_part1 = sino[:n_a_2, :] sino_part2 = sino[n_a_2:, :] angles = np.rad2deg(angles) # more numerically stable ? angles_1 = angles[:n_a_2] angles_2 = angles[n_a_2:] angles_2_target = angles_1 + 180.0 interpolator = interp1d(angles_2, sino_part2, axis=0, kind="linear", copy=False, fill_value="extrapolate") if output is None: output = sino else: output[:n_a_2, :] = sino[:n_a_2, :] output[n_a_2:, :] = interpolator(angles_2_target) return output # EXPERIMENTAL def _convert_halftomo_right(sino, extended_width): """ Converts a sinogram into a sinogram with extended FOV with the "half tomography" setting, with a CoR outside the image support. """ assert sino.ndim == 2 na, nx = sino.shape assert (na % 2) == 0 rotation_axis_position = extended_width // 2 assert rotation_axis_position > nx sino2 = np.pad(sino, ((0, 0), (0, rotation_axis_position - nx)), mode="reflect") return convert_halftomo(sino2, extended_width) def get_extended_sinogram_width(sino_width, rotation_axis_position): """ Compute the width (in pixels) of the extended sinogram for half-acquisition setting. """ middle = (sino_width - 1) / 2.0 if rotation_axis_position >= middle: overlap_width = int(2 * (sino_width - 1 - rotation_axis_position)) else: overlap_width = int(2 * rotation_axis_position) return 2 * sino_width - overlap_width def prepare_half_tomo_sinogram(sino, rot_center, get_extended_sino=True): if get_extended_sino: sino = sino.copy() n_angles, n_x = sino.shape middle = (n_x - 1) / 2.0 if rot_center >= middle: overlap_width = int(2 * (n_x - 1 - rot_center)) overlap_region = slice(-overlap_width, None) pad_left, pad_right = 0, n_x - overlap_width else: overlap_width = int(2 * rot_center) overlap_region = slice(0, overlap_width) pad_left, pad_right = n_x - overlap_width, 0 weights = np.linspace(0, 1, overlap_width, endpoint=True) if rot_center >= middle: weights = weights[::-1] sino[:, overlap_region] *= weights if get_extended_sino: return np.pad(sino, ((0, 0), (pad_left, pad_right)), mode="constant") return sino class SinoNormalization: """ A class for sinogram normalization utilities. """ kinds = [ "chebyshev", "subtraction", "division", ] operations = {"subtraction": np.subtract, "division": np.divide} def __init__(self, kind="chebyshev", sinos_shape=None, radios_shape=None, normalization_array=None): """ Initialize a SinoNormalization class. Parameters ----------- kind: str, optional Normalization type. They can be the following: - chebyshev: Each sinogram line is estimated by a Chebyshev polynomial of degree 2. This estimation is then subtracted from the sinogram. - subtraction: Each sinogram is subtracted with a user-provided array. The array can be 1D (angle-independent) and 2D (angle-dependent) - division: same as previously, but with a division operation. Default is "chebyshev" sinos_shape: tuple, optional Shape of the sinogram or sinogram stack. Either this parameter or 'radios_shape' has to be provided. radios_shape: tuple, optional Shape of the projections or projections stack. Either this parameter or 'sinos_shape' has to be provided. normalization_array: numpy.ndarray, optional Normalization array when kind='subtraction' or kind='division'. """ self._get_shapes(sinos_shape, radios_shape) self._set_kind(kind, normalization_array) _get_shapes = SinoBuilder._get_shapes def _set_kind(self, kind, normalization_array): check_supported(kind, self.kinds, "sinogram normalization kind") self.normalization_kind = kind self._normalization_instance_method = self._normalize_chebyshev # default if kind in ["subtraction", "division"]: if not isinstance(normalization_array, np.ndarray): raise ValueError( "Expected 'normalization_array' to be provided as a numpy array for normalization kind='%s'" % kind ) if normalization_array.shape[-1] != self.sinos_shape[-1]: n_a, n_x = self.sinos_shape[-2:] raise ValueError("Expected normalization_array to have shape (%d, %d) or (%d, )" % (n_a, n_x, n_x)) self.norm_operation = self.operations[kind] self._normalization_instance_method = self._normalize_op self.normalization_array = normalization_array # # Chebyshev normalization # def _normalize_chebyshev_2D(self, sino): output = sino # inplace Nr, Nc = sino.shape J = np.arange(Nc) x = 2.0 * (J + 0.5 - Nc / 2) / Nc sum0 = Nc f2 = 3.0 * x * x - 1.0 sum1 = (x**2).sum() sum2 = (f2**2).sum() for i in range(Nr): ff0 = sino[i, :].sum() ff1 = (x * sino[i, :]).sum() ff2 = (f2 * sino[i, :]).sum() output[i, :] = sino[i, :] - (ff0 / sum0 + ff1 * x / sum1 + ff2 * f2 / sum2) return output def _normalize_chebyshev_3D(self, sino): for i in range(sino.shape[0]): self._normalize_chebyshev_2D(sino[i]) return sino def _normalize_chebyshev(self, sino): if sino.ndim == 2: self._normalize_chebyshev_2D(sino) else: self._normalize_chebyshev_3D(sino) return sino # # Array subtraction/division # def _normalize_op(self, sino): if sino.ndim == 2: self.norm_operation(sino, self.normalization_array, out=sino) else: for i in range(sino.shape[0]): self.norm_operation(sino[i], self.normalization_array, out=sino[i]) return sino # # Dispatch # def normalize(self, sino): """ Normalize a sinogram or stack of sinogram. The process is done in-place, meaning that the sinogram content is overwritten. """ return self._normalization_instance_method(sino) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/reconstruction/sinogram_cuda.py0000644000175000017500000002464314654107202021472 0ustar00pierrepierreimport numpy as np from ..utils import get_cuda_srcfile, updiv, deprecated_class from .sinogram import SinoBuilder, SinoNormalization, SinoMult from .sinogram import _convert_halftomo_right # FIXME Temporary patch from ..cuda.processing import CudaProcessing class CudaSinoBuilder(SinoBuilder): def __init__( self, sinos_shape=None, radios_shape=None, rot_center=None, halftomo=False, angles=None, cuda_options=None ): """ Initialize a CudaSinoBuilder instance. Please see the documentation of nabu.reconstruction.sinogram.Builder and nabu.cuda.processing.CudaProcessing. """ super().__init__( sinos_shape=sinos_shape, radios_shape=radios_shape, rot_center=rot_center, halftomo=halftomo, angles=angles ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_halftomo() def _init_cuda_halftomo(self): if not (self.halftomo): return kernel_name = "halftomo_kernel" self.halftomo_kernel = self.cuda_processing.kernel( kernel_name, get_cuda_srcfile("halftomo.cu"), signature="PPPiii", ) blk = (32, 32, 1) # tune ? self._halftomo_blksize = blk self._halftomo_gridsize = (updiv(self.extended_sino_width, blk[0]), updiv((self.n_angles + 1) // 2, blk[1]), 1) d = self.n_x - self.extended_sino_width // 2 # will have to be adapted for varying axis pos self.halftomo_weights = np.linspace(0, 1, 2 * abs(d), endpoint=True, dtype="f") self.d_halftomo_weights = self.cuda_processing.to_device("d_halftomo_weights", self.halftomo_weights) # Allocate one single sinogram (kernel needs c-contiguous array). # If odd number of angles: repeat last angle. self.d_sino = self.cuda_processing.allocate_array( "d_sino", (self.n_angles + (self.n_angles & 1), self.n_x), "f" ) self.h_sino = self.d_sino.get() # self.cuda_processing.init_arrays_to_none(["d_output"]) if self._halftomo_flip: self.xflip_kernel = self.cuda_processing.kernel( "reverse2D_x", get_cuda_srcfile("ElementOp.cu"), signature="Pii" ) blk = (32, 32, 1) self._xflip_blksize = blk self._xflip_gridsize_1 = (updiv(self.n_x, blk[0]), updiv(self.n_angles, blk[1]), 1) self._xflip_gridsize_2 = self._halftomo_gridsize # # 2D # def _get_sino_halftomo(self, sino, output=None): if output is None: output = self.cuda_processing.allocate_array("d_output", self.output_shape[1:]) elif output.shape != self.output_shape[1:]: raise ValueError("Expected output to have shape %s but got %s" % (self.output_shape[1:], output.shape)) d_sino = self.d_sino n_a, n_x = sino.shape d_sino[:n_a] = sino[:] if self.n_angles & 1: d_sino[-1, :].fill(0) if self._halftomo_flip: self.xflip_kernel(d_sino, n_x, n_a, grid=self._xflip_gridsize_1, block=self._xflip_blksize) # Sometimes CoR is set well outside the FoV. Not supported by cuda backend for now. # TODO/FIXME: TEMPORARY PATCH, waiting for cuda implementation if self.rot_center > self.n_x: d_sino.get(self.h_sino) # copy D2H res = _convert_halftomo_right(self.h_sino, self.extended_sino_width) output.set(res) # copy H2D # else: self.halftomo_kernel( d_sino, output, self.d_halftomo_weights, n_a, n_x, self.extended_sino_width // 2, grid=self._halftomo_gridsize, block=self._halftomo_blksize, ) if self._halftomo_flip: self.xflip_kernel( output, self.extended_sino_width, (n_a + 1) // 2, grid=self._xflip_gridsize_2, block=self._xflip_blksize ) return output # # 3D # def _get_sinos_simple(self, radios, output=None): if output is None: return radios.transpose(axes=(1, 0, 2)) # view else: # why can't I do a discontig single copy ? for i in range(radios.shape[1]): output[i] = radios[:, i, :] return output def _get_sinos_halftomo(self, radios, output=None): if output is None: output = self.cuda_processing.allocate_array("output", self.output_shape, "f") elif output.shape != self.output_shape: raise ValueError("Expected output to have shape %s but got %s" % (self.output_shape, output.shape)) for i in range(self.n_z): sino = self._get_sino_simple(radios, i) self._get_sino_halftomo(sino, output=output[i]) return output CudaSinoProcessing = deprecated_class("'CudaSinoProcessing' was renamed 'CudaSinoBuilder'", do_print=True)( CudaSinoBuilder ) class CudaSinoMult(SinoMult): def __init__(self, sino_shape, rot_center, **cuda_options): super().__init__(sino_shape, rot_center) self.cuda_processing = CudaProcessing(**cuda_options) self._init_kernel() def _init_kernel(self): self.halftomo_kernel = self.cuda_processing.kernel( "halftomo_prepare_sinogram", filename=get_cuda_srcfile("halftomo.cu"), signature="PPiiii" ) self.d_weights = self.cuda_processing.set_array("d_weights", self.weights) self._halftomo_kernel_other_args = [ self.d_weights, np.int32(self.n_a), np.int32(self.n_x), np.int32(self.start_x), np.int32(self.end_x), ] self._grid = (self.n_x, self.n_a) self._blk = (32, 32, 1) # tune ? def prepare_sino(self, sino): sino = self.cuda_processing.set_array("d_sino", sino) self.halftomo_kernel(sino, *self._halftomo_kernel_other_args, grid=self._grid, block=self._blk) return sino class CudaSinoNormalization(SinoNormalization): def __init__( self, kind="chebyshev", sinos_shape=None, radios_shape=None, normalization_array=None, cuda_options=None ): super().__init__( kind=kind, sinos_shape=sinos_shape, radios_shape=radios_shape, normalization_array=normalization_array ) self._get_shapes(sinos_shape, radios_shape) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_normalization() _get_shapes = SinoBuilder._get_shapes # # Chebyshev normalization # def _init_cuda_normalization(self): self._d_tmp = self.cuda_processing.allocate_array("_d_tmp", self.sinos_shape[-2:], "f") if self.normalization_kind == "chebyshev": self._chebyshev_kernel = self.cuda_processing.kernel( "normalize_chebyshev", filename=get_cuda_srcfile("normalization.cu"), signature="Piii", ) self._chebyshev_kernel_args = [np.int32(self.n_x), np.int32(self.n_angles), np.int32(self.n_z)] blk = (1, 64, 16) # TODO tune ? self._chebyshev_kernel_kwargs = { "block": blk, "grid": (1, int(updiv(self.n_angles, blk[1])), int(updiv(self.n_z, blk[2]))), } elif self.normalization_array is not None: normalization_array = self.normalization_array # If normalization_array is 1D, make a 2D array by repeating the line if normalization_array.ndim == 1: normalization_array = np.tile(normalization_array, (self.n_angles, 1)) self._d_normalization_array = self.cuda_processing.to_device( "_d_normalization_array", normalization_array.astype("f") ) # pylint: disable=E0606 if self.normalization_kind == "subtraction": generic_op_val = 1 elif self.normalization_kind == "division": generic_op_val = 3 self._norm_kernel = self.cuda_processing.kernel( "inplace_generic_op_2Dby2D", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPii", options=["-DGENERIC_OP=%d" % generic_op_val], ) self._norm_kernel_args = [self._d_normalization_array, np.int32(self.n_angles), np.int32(self.n_x)] blk = (32, 32, 1) self._norm_kernel_kwargs = { "block": blk, "grid": (int(updiv(self.n_angles, blk[0])), int(updiv(self.n_x, blk[1])), 1), } def _normalize_chebyshev(self, sinos): if sinos.flags.c_contiguous: self._chebyshev_kernel(sinos, *self._chebyshev_kernel_args, **self._chebyshev_kernel_kwargs) else: # This kernel seems to have an issue on arrays that are not C-contiguous. # We have to process image per image. nz = np.int32(1) nthreadsperblock = (1, 32, 1) # TODO tune nblocks = (1, int(updiv(self.n_angles, nthreadsperblock[1])), 1) for i in range(sinos.shape[0]): self._d_tmp[:] = sinos[i][:] self._chebyshev_kernel( self._d_tmp, np.int32(self.n_x), np.int32(self.n_angles), np.int32(1), grid=nblocks, block=nthreadsperblock, ) sinos[i][:] = self._d_tmp[:] return sinos # # Array subtraction/division # def _normalize_op(self, sino): if sino.ndim == 2: # Things can go wrong if "sino" is a non-contiguous 2D array # But this should be handled outside this function, as the processing is in-place self._norm_kernel(sino, *self._norm_kernel_args, **self._norm_kernel_kwargs) else: if sino.flags.forc: # Contiguous 3D array. But pycuda wants the same shape for both operands. for i in range(sino.shape[0]): self._norm_kernel(sino[i], *self._norm_kernel_args, **self._norm_kernel_kwargs) else: # Non-contiguous 2D array. Make a temp. copy for i in range(sino.shape[0]): self._d_tmp[:] = sino[i][:] self._norm_kernel(self._d_tmp, *self._norm_kernel_args, **self._norm_kernel_kwargs) sino[i][:] = self._d_tmp[:] return sino ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/reconstruction/sinogram_opencl.py0000644000175000017500000000271614726604214022040 0ustar00pierrepierreimport numpy as np from ..opencl.kernel import OpenCLKernel from ..opencl.processing import OpenCLProcessing from ..utils import get_opencl_srcfile from .sinogram import SinoMult class OpenCLSinoMult(SinoMult): def __init__(self, sino_shape, rot_center, **opencl_options): super().__init__(sino_shape, rot_center) self.opencl_processing = OpenCLProcessing(**opencl_options) self._init_kernel() def _init_kernel(self): self.halftomo_kernel = OpenCLKernel( "halftomo_prepare_sinogram", self.opencl_processing.ctx, filename=get_opencl_srcfile("halftomo.cl"), ) self.d_weights = self.opencl_processing.set_array("d_weights", self.weights) self._halftomo_kernel_other_args = [ self.d_weights, np.int32(self.n_a), np.int32(self.n_x), np.int32(self.start_x), np.int32(self.end_x), ] self._global_size = (self.n_x, self.n_a) self._local_size = None # (32, 32, 1) # tune ? def prepare_sino(self, sino): sino = self.opencl_processing.set_array("d_sino", sino) ev = self.halftomo_kernel( self.opencl_processing.queue, sino, *self._halftomo_kernel_other_args, global_size=self._global_size, local_size=self._local_size, ) if self.opencl_processing.device_type == "cpu": ev.wait() return sino ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/reconstruction/tests/0000755000175000017500000000000014730277752017453 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/reconstruction/tests/__init__.py0000644000175000017500000000000114315516747021552 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556963.0 nabu-2024.2.1/nabu/reconstruction/tests/test_cone.py0000644000175000017500000004266614656662143022025 0ustar00pierrepierreimport pytest import numpy as np from scipy.ndimage import gaussian_filter, shift from nabu.utils import subdivide_into_overlapping_segment, clip_circle try: import astra __has_astra__ = True except ImportError: __has_astra__ = False from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: from nabu.reconstruction.cone import ConebeamReconstructor if __has_astra__: from astra.extrautils import clipCircle @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.vol_shape = (128, 126, 126) cls.n_angles = 180 cls.prj_width = 192 # detector larger than the sample cls.src_orig_dist = 1000 cls.orig_det_dist = 100 cls.volume, cls.cone_data = generate_hollow_cube_cone_sinograms( cls.vol_shape, cls.n_angles, cls.src_orig_dist, cls.orig_det_dist, prj_width=cls.prj_width ) if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.skipif(not (__has_pycuda__ and __has_astra__), reason="Need pycuda and astra for this test") @pytest.mark.usefixtures("bootstrap") class TestCone: def _create_cone_reconstructor(self, relative_z_position=None): return ConebeamReconstructor( self.cone_data.shape, self.src_orig_dist, self.orig_det_dist, relative_z_position=relative_z_position, volume_shape=self.volume.shape, cuda_options={"ctx": self.ctx}, ) def test_simple_cone_reconstruction(self): C = self._create_cone_reconstructor() res = C.reconstruct(self.cone_data) delta = np.abs(res - self.volume) # Can we do better ? We already had to lowpass-filter the volume! # First/last slices are OK assert np.max(delta[:8]) < 1e-5 assert np.max(delta[-8:]) < 1e-5 # Middle region has a relatively low error assert np.max(delta[40:-40]) < 0.11 # Transition zones between "zero" and "cube" has a large error assert np.max(delta[10:25]) < 0.2 assert np.max(delta[-25:-10]) < 0.2 # End of transition zones have a smaller error assert np.max(delta[25:40]) < 0.125 assert np.max(delta[-40:-25]) < 0.125 def test_against_explicit_astra_calls(self): C = self._create_cone_reconstructor() res = C.reconstruct(self.cone_data) # # Check that ConebeamReconstructor is consistent with these calls to astra # # "vol_geom" shape layout is (y, x, z). But here this geometry is used for the reconstruction # (i.e sinogram -> volume)and not for projection (volume -> sinograms). # So we assume a square slice. Mind that this is a particular case. vol_geom = astra.create_vol_geom(self.vol_shape[2], self.vol_shape[2], self.vol_shape[0]) angles = np.linspace(0, 2 * np.pi, self.n_angles, True) proj_geom = astra.create_proj_geom( "cone", 1.0, 1.0, self.cone_data.shape[0], self.prj_width, angles, self.src_orig_dist, self.orig_det_dist, ) sino_id = astra.data3d.create("-sino", proj_geom, data=self.cone_data) rec_id = astra.data3d.create("-vol", vol_geom) cfg = astra.astra_dict("FDK_CUDA") cfg["ReconstructionDataId"] = rec_id cfg["ProjectionDataId"] = sino_id alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id) res_astra = astra.data3d.get(rec_id) # housekeeping astra.algorithm.delete(alg_id) astra.data3d.delete(rec_id) astra.data3d.delete(sino_id) assert ( np.max(np.abs(res - res_astra)) < 5e-4 ), "ConebeamReconstructor results are inconsistent with plain calls to astra" def test_projection_full_vs_partial(self): """ In the ideal case, all the data volume (and reconstruction) fits in memory. In practice this is rarely the case, so we have to reconstruct the volume slabs by slabs. The slabs should be slightly overlapping to avoid "stitching" artefacts at the edges. """ # Astra seems to duplicate the projection data, even if all GPU memory is handled externally # Let's try with (n_z * n_y * n_x + 2 * n_a * n_z * n_x) * 4 < mem_limit # 256^3 seems OK with n_a = 200 (180 MB) n_z = n_y = n_x = 256 n_a = 200 src_orig_dist = 1000 orig_det_dist = 100 volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist ) C_full = ConebeamReconstructor(cone_data.shape, src_orig_dist, orig_det_dist, cuda_options={"ctx": self.ctx}) vol_geom = astra.create_vol_geom(n_y, n_x, n_z) proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, n_z, n_x, C_full.angles, src_orig_dist, orig_det_dist) proj_id, projs_full_geom = astra.create_sino3d_gpu(volume, proj_geom, vol_geom) astra.data3d.delete(proj_id) # Do the same slab-by-slab inner_slab_size = 64 overlap = 16 slab_size = inner_slab_size + overlap * 2 slabs = subdivide_into_overlapping_segment(n_z, slab_size, overlap) projs_partial_geom = np.zeros_like(projs_full_geom) for slab in slabs: z_min, z_inner_min, z_inner_max, z_max = slab rel_z_pos = (z_min + z_max) / 2 - n_z / 2 subvolume = volume[z_min:z_max, :, :] C = ConebeamReconstructor( (z_max - z_min, n_a, n_x), src_orig_dist, orig_det_dist, relative_z_position=rel_z_pos, cuda_options={"ctx": self.ctx}, ) proj_id, projs = astra.create_sino3d_gpu(subvolume, C.proj_geom, C.vol_geom) astra.data3d.delete(proj_id) projs_partial_geom[z_inner_min:z_inner_max] = projs[z_inner_min - z_min : z_inner_max - z_min] error_profile = [ np.max(np.abs(proj_partial - proj_full)) for proj_partial, proj_full in zip(projs_partial_geom, projs_full_geom) ] assert np.all(np.isclose(error_profile, 0.0, atol=0.0375)), "Mismatch between full-cone and slab geometries" def test_cone_reconstruction_magnified_vs_demagnified(self): """ This will only test the astra toolbox. When reconstructing a volume from cone-beam data, the volume "should" have a smaller shape than the projection data shape (because of cone magnification). But astra provides the same results when backprojecting on a "de-magnified grid" and the original grid shape. """ n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 100 magnification = 1 + orig_det_dist / src_orig_dist angles = np.linspace(0, 2 * np.pi, n_a, True) volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) rec_original_grid = astra_cone_beam_reconstruction( cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=False ) rec_reduced_grid = astra_cone_beam_reconstruction( cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=True ) m_z = (n_z - int(n_z / magnification)) // 2 m_y = (n_y - int(n_y / magnification)) // 2 m_x = (n_x - int(n_x / magnification)) // 2 assert np.allclose(rec_original_grid[m_z:-m_z, m_y:-m_y, m_x:-m_x], rec_reduced_grid) def test_reconstruction_full_vs_partial(self): n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 100 angles = np.linspace(0, 2 * np.pi, n_a, True) volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) rec_full_volume = astra_cone_beam_reconstruction(cone_data, angles, src_orig_dist, orig_det_dist) rec_partial = np.zeros_like(rec_full_volume) inner_slab_size = 64 overlap = 18 slab_size = inner_slab_size + overlap * 2 slabs = subdivide_into_overlapping_segment(n_z, slab_size, overlap) for slab in slabs: z_min, z_inner_min, z_inner_max, z_max = slab m1, m2 = z_inner_min - z_min, z_max - z_inner_max C = ConebeamReconstructor((z_max - z_min, n_a, n_x), src_orig_dist, orig_det_dist) rec = C.reconstruct( cone_data[z_min:z_max], relative_z_position=((z_min + z_max) / 2) - n_z / 2, # (z_min + z_max)/2. ) rec_partial[z_inner_min:z_inner_max] = rec[m1 : (-m2) or None] # Compare volumes in inner circle for i in range(n_z): clipCircle(rec_partial[i]) clipCircle(rec_full_volume[i]) diff = np.abs(rec_partial - rec_full_volume) err_max_profile = np.max(diff, axis=(-1, -2)) err_median_profile = np.median(diff, axis=(-1, -2)) assert np.max(err_max_profile) < 2e-3 assert np.max(err_median_profile) < 5.1e-6 def test_reconstruction_horizontal_translations(self): n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 50 volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) # Apply horizontal translations on projections. This could have been done directly with astra shift_min, shift_max = -2, 5 shifts_float = (shift_max - shift_min) * np.random.rand(n_a) - shift_min shifts_int = np.random.randint(shift_min, high=shift_max + 1, size=n_a) reconstructor_args = [ cone_data.shape, src_orig_dist, orig_det_dist, ] reconstructor_kwargs = { "volume_shape": volume.shape, "cuda_options": {"ctx": self.ctx}, } cone_reconstructor = ConebeamReconstructor(*reconstructor_args, **reconstructor_kwargs) rec = cone_reconstructor.reconstruct(cone_data) # Translations done with floating-point shift values give a blurring of the image that cannot be recovered. # Error tolerance has to be higher for these shifts. for shift_type, shifts, err_tol in [ ("integer shifts", shifts_int, 5e-3), ("float shifts", shifts_float, 1.5e-1), ]: cone_data_shifted = np.zeros_like(cone_data) [shift(cone_data[:, i, :], (0, shifts[i]), output=cone_data_shifted[:, i, :]) for i in range(n_a)] # Reconstruct with horizontal shifts cone_reconstructor_with_correction = ConebeamReconstructor( *reconstructor_args, **reconstructor_kwargs, extra_options={"axis_correction": -shifts}, ) rec_with_correction = cone_reconstructor_with_correction.reconstruct(cone_data_shifted) metric = lambda img: np.max(np.abs(clip_circle(img, radius=int(0.85 * img.shape[1] // 2)))) error_profile = np.array([metric(rec[i] - rec_with_correction[i]) for i in range(n_z)]) assert error_profile.max() < err_tol, "Max error with %s is too high" % shift_type # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.arange(n_z), error_profile) # plt.legend([shift_type]) # plt.show() def test_padding_mode(self): n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 50 volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) reconstructor_args = [ cone_data.shape, src_orig_dist, orig_det_dist, ] reconstructor_kwargs = { "volume_shape": volume.shape, "cuda_options": {"ctx": self.ctx}, } cone_reconstructor_zero_padding = ConebeamReconstructor(*reconstructor_args, **reconstructor_kwargs) rec_z = cone_reconstructor_zero_padding.reconstruct(cone_data) for padding_mode in ["edges"]: cone_reconstructor = ConebeamReconstructor( *reconstructor_args, padding_mode=padding_mode, **reconstructor_kwargs ) rec = cone_reconstructor.reconstruct(cone_data) metric = lambda img: np.max(np.abs(clip_circle(img, radius=int(0.85 * 128)))) error_profile = np.array([metric(rec[i] - rec_z[i]) for i in range(n_z)]) # import matplotlib.pyplot as plt # plt.figure() # plt.plot(np.arange(n_z), error_profile) # plt.legend([padding_mode]) # plt.show() assert error_profile.max() < 3.1e-2, "Max error for padding=%s is too high" % padding_mode if padding_mode != "zeros": assert not (np.allclose(rec[n_z // 2], rec_z[n_z // 2])), ( "Reconstruction should be different when padding_mode=%s" % padding_mode ) def test_roi(self): n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 50 volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, rot_center_shift=10, ) reconstructor_args = [ cone_data.shape, src_orig_dist, orig_det_dist, ] reconstructor_kwargs = { "volume_shape": volume.shape, "rot_center": (n_x - 1) / 2 + 10, "cuda_options": {"ctx": self.ctx}, } cone_reconstructor_full = ConebeamReconstructor(*reconstructor_args, **reconstructor_kwargs) ref = cone_reconstructor_full.reconstruct(cone_data) # roi is in the form (start_x, end_x, start_y, end_y) for roi in ((20, -20, 10, -10), (0, n_x, 0, n_y), (50, -50, 15, -15)): # convert negative indices start_x, end_x, start_y, end_y = roi if start_y < 0: start_y += n_y if start_x < 0: start_x += n_x cone_reconstructor = ConebeamReconstructor(*reconstructor_args, slice_roi=roi, **reconstructor_kwargs) rec = cone_reconstructor.reconstruct(cone_data) assert np.allclose(rec, ref[:, roi[2] : roi[3], roi[0] : roi[1]]), "Something wrong with roi=%s" % ( str(roi) ) def generate_hollow_cube_cone_sinograms( vol_shape, n_angles, src_orig_dist, orig_det_dist, prj_width=None, apply_filter=True, rot_center_shift=None, ): # Adapted from Astra toolbox python samples n_z, n_y, n_x = vol_shape vol_geom = astra.create_vol_geom(n_y, n_x, n_z) prj_width = prj_width or n_x prj_height = n_z angles = np.linspace(0, 2 * np.pi, n_angles, True) proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, prj_width, prj_width, angles, src_orig_dist, orig_det_dist) if rot_center_shift is not None: proj_geom = astra.geom_postalignment(proj_geom, (-rot_center_shift, 0)) magnification = 1 + orig_det_dist / src_orig_dist # hollow cube cube = np.zeros(astra.geom_size(vol_geom), dtype="f") d = int(min(n_x, n_y) / 2 * (1 - np.sqrt(2) / 2)) cube[20:-20, d:-d, d:-d] = 1 cube[40:-40, d + 20 : -(d + 20), d + 20 : -(d + 20)] = 0 # d = int(min(n_x, n_y) / 2 * (1 - np.sqrt(2) / 2) * magnification) # d1 = d + 10 # d2 = d + 20 # cube[40:-40, d1:-d1, d1:-d1] = 1 # cube[60:-60, d2 : -d2, d2 : -d2] = 0 # High-frequencies yield cannot be accurately retrieved if apply_filter: cube = gaussian_filter(cube, (1.0, 1.0, 1.0)) proj_id, proj_data = astra.create_sino3d_gpu(cube, proj_geom, vol_geom) astra.data3d.delete(proj_id) # (n_z, n_angles, n_x) return cube, proj_data def astra_cone_beam_reconstruction(cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=False): """ Handy (but data-inefficient) function to reconstruct data from cone-beam geometry """ n_z, n_a, n_x = cone_data.shape proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, n_z, n_x, angles, src_orig_dist, orig_det_dist) sino_id = astra.data3d.create("-sino", proj_geom, data=cone_data) m = 1 + orig_det_dist / src_orig_dist if demagnify_volume else 1.0 n_z_vol, n_y_vol, n_x_vol = int(n_z / m), int(n_x / m), int(n_x / m) vol_geom = astra.create_vol_geom(n_y_vol, n_x_vol, n_z_vol) rec_id = astra.data3d.create("-vol", vol_geom) cfg = astra.astra_dict("FDK_CUDA") cfg["ReconstructionDataId"] = rec_id cfg["ProjectionDataId"] = sino_id alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id) rec = astra.data3d.get(rec_id) astra.data3d.delete(sino_id) astra.data3d.delete(rec_id) astra.algorithm.delete(alg_id) return rec ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/reconstruction/tests/test_deringer.py0000644000175000017500000002022414712705065022654 0ustar00pierrepierreimport numpy as np import pytest from nabu.reconstruction.rings_cuda import CudaSinoMeanDeringer from nabu.testutils import compare_arrays, get_data, generate_tests_scenarios, __do_long_tests__ from nabu.reconstruction.rings import MunchDeringer, SinoMeanDeringer, VoDeringer, __has_algotom__ from nabu.thirdparty.pore3d_deringer_munch import munchetal_filter from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: import pycuda.gpuarray as garray from nabu.processing.fft_cuda import get_available_fft_implems from nabu.thirdparty.tomocupy_remove_stripe import __have_tomocupy_deringer__ from nabu.reconstruction.rings_cuda import ( CudaMunchDeringer, can_use_cuda_deringer, CudaVoDeringer, ) __has_cuda_deringer__ = can_use_cuda_deringer() else: __has_cuda_deringer__ = False __have_tomocupy_deringer__ = False fw_scenarios = generate_tests_scenarios( { "levels": [4], "sigma": [1.0], "wname": ["db15"], "padding": [(100, 100)], "fft_implem": ["vkfft"], } ) if __do_long_tests__: fw_scenarios = generate_tests_scenarios( { "levels": [4, 2], "sigma": [1.0, 2.0], "wname": ["db15", "haar", "rbio4.4"], "padding": [None, (100, 100), (50, 71)], "fft_implem": ["skcuda", "vkfft"], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino = get_data("mri_sino500.npz")["data"] cls.sino2 = get_data("sino_bamboo_hercules.npz")["data"] cls.tol = 5e-3 cls.rings = {150: 0.5, -150: 0.5} if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) cls._available_fft_implems = get_available_fft_implems() yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestDeringer: @staticmethod def add_stripes_to_sino(sino, rings_desc): """ Create a new sinogram by adding synthetic stripes to an existing one. Parameters ---------- sino: array-like Sinogram. rings_desc: dict Dictionary describing the stripes locations and intensity. The location is an integer in [0, N[ where N is the number of columns. The intensity is a float: percentage of the current column mean value. """ sino_out = np.copy(sino) for loc, intensity in rings_desc.items(): sino_out[:, loc] += sino[:, loc].mean() * intensity return sino_out @staticmethod def get_fourier_wavelets_reference_result(sino, config): # Reference destriping with pore3d "munchetal_filter" padding = config.get("padding", None) if padding is not None: sino = np.pad(sino, ((0, 0), padding), mode="edge") ref = munchetal_filter(sino, config["levels"], config["sigma"], wname=config["wname"]) if config["padding"] is not None: ref = ref[:, padding[0] : -padding[1]] return ref @pytest.mark.skipif(munchetal_filter is None, reason="Need PyWavelets for this test") @pytest.mark.parametrize("config", fw_scenarios) def test_munch_deringer(self, config): deringer = MunchDeringer( config["sigma"], self.sino.shape, levels=config["levels"], wname=config["wname"], padding=config["padding"] ) sino = self.add_stripes_to_sino(self.sino, self.rings) ref = self.get_fourier_wavelets_reference_result(sino, config) # Wrapping with DeRinger res = np.zeros((1,) + sino.shape, dtype=np.float32) deringer.remove_rings(sino, output=res) err_max = np.max(np.abs(res[0] - ref)) assert err_max < self.tol, "Max error is too high" @pytest.mark.skipif( not (__has_cuda_deringer__) or munchetal_filter is None, reason="Need pycuda, pycudwt and (scikit-cuda or pyvkfft) for this test", ) @pytest.mark.parametrize("config", fw_scenarios) def test_cuda_munch_deringer(self, config): fft_implem = config["fft_implem"] if fft_implem not in self._available_fft_implems: pytest.skip("FFT implementation %s is not available" % fft_implem) sino = self.add_stripes_to_sino(self.sino, self.rings) deringer = CudaMunchDeringer( config["sigma"], self.sino.shape, levels=config["levels"], wname=config["wname"], padding=config["padding"], fft_backend=fft_implem, cuda_options={"ctx": self.ctx}, ) d_sino = garray.to_gpu(sino) deringer.remove_rings(d_sino) res = d_sino.get() ref = self.get_fourier_wavelets_reference_result(sino, config) err_max = np.max(np.abs(res - ref)) assert err_max < 1e-1, "Max error is too high with configuration %s" % (str(config)) @pytest.mark.skipif( not (__has_algotom__), reason="Need algotom for this test", ) def test_vo_deringer(self): deringer = VoDeringer(self.sino.shape) sino_deringed = deringer.remove_rings_sinogram(self.sino) sinos = np.tile(self.sino, (10, 1, 1)) sinos_deringed = deringer.remove_rings_sinograms(sinos) # TODO check result. The generated test sinogram is "too synthetic" for this kind of deringer @pytest.mark.skipif( not (__have_tomocupy_deringer__), reason="Need cupy for this test", ) def test_cuda_vo_deringer(self): # Beware, this deringer seems to be buggy for "too-small" sinograms # (NaNs on the edges and in some regions). To be investigated deringer = CudaVoDeringer(self.sino2.shape) d_sino = garray.to_gpu(self.sino2) deringer.remove_rings_sinogram(d_sino) sino = d_sino.get() if __has_algotom__: vo_deringer = VoDeringer(self.sino2.shape) sino_deringed = vo_deringer.remove_rings_sinogram(self.sino2) assert ( np.max(np.abs(sino - sino_deringed)) < 2e-3 ), "Cuda implementation of Vo deringer does not yield the same results as base implementation" def test_mean_deringer(self): deringer_no_filtering = SinoMeanDeringer(self.sino.shape, mode="subtract") sino = self.sino.copy() deringer_no_filtering.remove_rings_sinogram(sino) sino = self.sino.copy() deringer_with_filtering = SinoMeanDeringer(self.sino.shape, mode="subtract", filter_cutoff=(0, 30)) deringer_with_filtering.remove_rings_sinogram(sino) # TODO check results @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") def test_cuda_mean_deringer(self): cuda_deringer = CudaSinoMeanDeringer( self.sino.shape, mode="subtract", filter_cutoff=( 0, 10, ), ctx=self.ctx, ) deringer = SinoMeanDeringer( self.sino.shape, mode="subtract", filter_cutoff=( 0, 10, ), ) d_sino = cuda_deringer.processing.to_device("sino", self.sino) cuda_deringer.remove_rings_sinogram(d_sino) sino = self.sino.copy() sino_d = deringer.remove_rings_sinogram(sino) dirac = np.zeros(self.sino.shape[-1], "f") dirac[dirac.size // 2] = 1 deringer_filter_response = deringer._apply_filter(dirac) d_dirac = cuda_deringer.processing.to_device("dirac", dirac) cuda_deringer_filter_response = cuda_deringer._apply_filter(d_dirac) is_close, residual = compare_arrays( deringer_filter_response, cuda_deringer_filter_response.get(), 1e-7, return_residual=True ) assert is_close, "Cuda deringer does not have the correct filter response: max_error=%.2e" % residual # There is a rather large discrepancy between the vertical_mean kernel and numpy.mean(). Not sure who is right is_close, residual = compare_arrays(sino_d, d_sino.get(), 1e-1, return_residual=True) assert is_close, ( "Cuda deringer does not yield the same result as base implementation: max_error=%.2e" % residual ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/reconstruction/tests/test_fbp.py0000644000175000017500000003567314715655362021651 0ustar00pierrepierreimport numpy as np import pytest from scipy.ndimage import shift from nabu.pipeline.params import fbp_filters from nabu.utils import clip_circle from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__ from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import get_opencl_context, __has_pyopencl__ from nabu.processing.fft_cuda import has_skcuda, has_vkfft as has_vkfft_cu from nabu.processing.fft_opencl import has_vkfft as has_vkfft_cl __has_pycuda__ = __has_pycuda__ and (has_skcuda() or has_vkfft_cu()) __has_pyopencl__ = __has_pyopencl__ and has_vkfft_cl() if __has_pycuda__: from nabu.reconstruction.fbp import CudaBackprojector from nabu.reconstruction.hbp import HierarchicalBackprojector if __has_pyopencl__: from nabu.reconstruction.fbp_opencl import OpenCLBackprojector scenarios = generate_tests_scenarios({"backend": ["cuda", "opencl"]}) if __do_long_tests__: scenarios = generate_tests_scenarios( { "backend": ["cuda", "opencl"], "input_on_gpu": [False, True], "output_on_gpu": [False, True], "use_textures": [True, False], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino_512 = get_data("mri_sino500.npz")["data"] cls.ref_512 = get_data("mri_rec_astra.npz")["data"] # always use contiguous arrays cls.sino_511 = np.ascontiguousarray(cls.sino_512[:, :-1]) # Could be set to 5.0e-2 when using textures. When not using textures, interpolation slightly differs cls.tol = 5.1e-2 if __has_pycuda__: cls.cuda_ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.opencl_ctx = get_opencl_context("all") yield if __has_pycuda__: cls.cuda_ctx.pop() def clip_to_inner_circle(img, radius_factor=0.99, out_value=0): radius = int(radius_factor * max(img.shape) / 2) return clip_circle(img, radius=radius, out_value=out_value) @pytest.mark.usefixtures("bootstrap") class TestFBP: def _get_backprojector(self, config, *bp_args, **bp_kwargs): if config["backend"] == "cuda": if not (__has_pycuda__): pytest.skip("Need pycuda + (scikit-cuda or pyvkfft)") Backprojector = CudaBackprojector ctx = self.cuda_ctx else: if not (__has_pyopencl__): pytest.skip("Need pyopencl + pyvkfft") Backprojector = OpenCLBackprojector ctx = self.opencl_ctx if config.get("use_textures", True) is False: # patch "extra_options" extra_options = bp_kwargs.pop("extra_options", {}) extra_options["use_textures"] = False bp_kwargs["extra_options"] = extra_options return Backprojector(*bp_args, **bp_kwargs, backend_options={"ctx": ctx}) @staticmethod def apply_fbp(config, backprojector, sinogram): if config.get("input_on_gpu", False): sinogram = backprojector._processing.set_array("sinogram", sinogram) if config.get("output_on_gpu", False): output = backprojector._processing.allocate_array("output", backprojector.slice_shape, dtype="f") else: output = None res = backprojector.fbp(sinogram, output=output) if config.get("output_on_gpu", False): res = res.get() return res @pytest.mark.parametrize("config", scenarios) def test_fbp_512(self, config): """ Simple test of a FBP on a 512x512 slice """ B = self._get_backprojector(config, (500, 512)) res = self.apply_fbp(config, B, self.sino_512) delta_clipped = clip_to_inner_circle(res - self.ref_512) err_max = np.max(np.abs(delta_clipped)) assert err_max < self.tol, "Something wrong with config=%s" % (str(config)) @pytest.mark.parametrize("config", scenarios) def test_fbp_511(self, config): """ Test FBP of a 511x511 slice where the rotation axis is at (512-1)/2.0 """ B = self._get_backprojector(config, (500, 511), rot_center=255.5) res = self.apply_fbp(config, B, self.sino_511) ref = self.ref_512[:-1, :-1] delta_clipped = clip_to_inner_circle(res - ref) err_max = np.max(np.abs(delta_clipped)) assert err_max < self.tol, "Something wrong with config=%s" % (str(config)) @pytest.mark.parametrize("config", scenarios) def test_fbp_roi(self, config): """ Test FBP in region of interest """ sino = self.sino_511 B0 = self._get_backprojector(config, sino.shape, rot_center=255.5) ref = B0.fbp(sino) def backproject_roi(roi, reference): B = self._get_backprojector(config, sino.shape, rot_center=255.5, slice_roi=roi) res = self.apply_fbp(config, B, sino) err_max = np.max(np.abs(res - reference)) return err_max cases = { # Test 1: use slice_roi=(0, -1, 0, -1), i.e plain FBP of whole slice 1: [(0, None, 0, None), ref], # Test 2: horizontal strip 2: [(0, None, 50, 55), ref[50:55, :]], # Test 3: vertical strip 3: [(60, 65, 0, None), ref[:, 60:65]], # Test 4: rectangular inner ROI 4: [(157, 162, 260, -10), ref[260:-10, 157:162]], } for roi, ref in cases.values(): err_max = backproject_roi(roi, ref) assert err_max < self.tol, "Something wrong with ROI = %s for config=%s" % ( str(roi), str(config), ) @pytest.mark.parametrize("config", scenarios) def test_fbp_axis_corr(self, config): """ Test the "axis correction" feature """ sino = self.sino_512 # Create a sinogram with a drift in the rotation axis def create_drifted_sino(sino, drifts): out = np.zeros_like(sino) for i in range(sino.shape[0]): out[i] = shift(sino[i], drifts[i]) return out drifts = np.linspace(0, 20, sino.shape[0]) sino = create_drifted_sino(sino, drifts) B = self._get_backprojector(config, sino.shape, extra_options={"axis_correction": drifts}) res = self.apply_fbp(config, B, sino) delta_clipped = clip_circle(res - self.ref_512, radius=200) err_max = np.max(np.abs(delta_clipped)) # Max error is relatively high, migh be due to interpolation of scipy shift in sinogram assert err_max < 10.0, "Max error is too high" @pytest.mark.parametrize("config", scenarios) def test_fbp_clip_circle(self, config): """ Test the "clip outer circle" parameter in (extra options) """ sino = self.sino_512 tol = 1e-5 for rot_center in [None, sino.shape[1] / 2.0 - 10, sino.shape[1] / 2.0 + 15]: B = self._get_backprojector( config, sino.shape, rot_center=rot_center, extra_options={"clip_outer_circle": True} ) res = self.apply_fbp(config, B, sino) B0 = self._get_backprojector( config, sino.shape, rot_center=rot_center, extra_options={"clip_outer_circle": False} ) res_noclip = B0.fbp(sino) ref = clip_to_inner_circle(res_noclip, radius_factor=1) abs_diff = np.abs(res - ref) err_max = np.max(abs_diff) assert err_max < tol, "Max error is too high for rot_center=%s ; %s" % (str(rot_center), str(config)) # Test with custom outer circle value B1 = self._get_backprojector( config, sino.shape, rot_center=rot_center, extra_options={"clip_outer_circle": True, "outer_circle_value": np.nan}, ) res1 = self.apply_fbp(config, B1, sino) ref1 = clip_to_inner_circle(res_noclip, radius_factor=1, out_value=np.nan) abs_diff1 = np.abs(res1 - ref1) err_max1 = np.nanmax(abs_diff1) assert err_max1 < tol, "Max error is too high for rot_center=%s ; %s" % (str(rot_center), str(config)) @pytest.mark.parametrize("config", scenarios) def test_fbp_centered_axis(self, config): """ Test the "centered_axis" parameter (in extra options) """ sino = np.pad(self.sino_512, ((0, 0), (100, 0))) rot_center = (self.sino_512.shape[1] - 1) / 2.0 + 100 B0 = self._get_backprojector(config, self.sino_512.shape) ref = B0.fbp(self.sino_512) # Check that "centered_axis" worked B = self._get_backprojector(config, sino.shape, rot_center=rot_center, extra_options={"centered_axis": True}) res = self.apply_fbp(config, B, sino) # The outside region (outer circle) is different as "res" is a wider slice diff = clip_to_inner_circle(res[50:-50, 50:-50] - ref) err_max = np.max(np.abs(diff)) assert err_max < 5e-2, "centered_axis without clip_circle: something wrong" # Check that "clip_outer_circle" works when used jointly with "centered_axis" B = self._get_backprojector( config, sino.shape, rot_center=rot_center, extra_options={ "centered_axis": True, "clip_outer_circle": True, }, ) res2 = self.apply_fbp(config, B, sino) diff = res2 - clip_to_inner_circle(res, radius_factor=1) err_max = np.max(np.abs(diff)) assert err_max < 1e-5, "centered_axis with clip_circle: something wrong" @pytest.mark.parametrize("config", scenarios) def test_fbp_filters(self, config): for filter_name in set(fbp_filters.values()): if filter_name in [None, "ramlak"]: continue B = self._get_backprojector(config, self.sino_512.shape, filter_name=filter_name) self.apply_fbp(config, B, self.sino_512) # not sure what to check in this case @pytest.mark.parametrize("config", scenarios) def test_differentiated_backprojection(self, config): # test Hilbert + DBP sino_diff = np.diff(self.sino_512, axis=1, prepend=0).astype("f") # Need to translate the axis a little bit, because of non-centered differentiation. # prepend -> +0.5 ; append -> -0.5 B = self._get_backprojector(config, sino_diff.shape, filter_name="hilbert", rot_center=255.5 + 0.5) rec = self.apply_fbp(config, B, sino_diff) # Looks good, but all frequencies are not recovered. Use a metric like SSIM or FRC ? @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for using HBP") @pytest.mark.usefixtures("bootstrap") class TestHBP: def _compare_to_reference(self, res, ref, err_msg="", radius_factor=0.9, rel_tol=0.02): delta_clipped = clip_to_inner_circle(res - ref, radius_factor=radius_factor) err_max = np.max(np.abs(delta_clipped)) err_max_rel = err_max / ref.max() assert err_max_rel < rel_tol, err_msg def test_hbp_simple(self): B = HierarchicalBackprojector(self.sino_512.shape) res = B.fbp(self.sino_512) self._compare_to_reference(res, self.ref_512) def test_hbp_input_output(self): B = HierarchicalBackprojector(self.sino_512.shape) d_sino = B._processing.to_device("d_sino2", self.sino_512) d_slice = B._processing.allocate_array("d_slice2", self.ref_512.shape) h_slice = np.zeros_like(self.ref_512) # in: host, out: host (not provided) # see test above # in: host, out: host (provided) res = B.fbp(self.sino_512, output=h_slice) self._compare_to_reference(h_slice, self.ref_512, err_msg="in: host, out: host (provided)") h_slice.fill(0) # in: host, out: device res = B.fbp(self.sino_512, output=d_slice) self._compare_to_reference(d_slice.get(), self.ref_512, err_msg="in: host, out: device") d_slice.fill(0) # in: device, out: host (not provided) res = B.fbp(d_sino) self._compare_to_reference(res, self.ref_512, err_msg="in: device, out: host (not provided)") # in: device, out: host (provided) res = B.fbp(d_sino, output=h_slice) self._compare_to_reference(h_slice, self.ref_512, err_msg="in: device, out: host (provided)") h_slice.fill(0) # in: device, out: device res = B.fbp(d_sino, output=d_slice) self._compare_to_reference(d_slice.get(), self.ref_512, err_msg="in: device, out: device") d_slice.fill(0) def test_hbp_cor(self): """ Test HBP with various sinogram shapes, obtained by truncating horizontally the original sinogram. The Center of rotation is always 255.5 (the one of original sinogram), so it also tests reconstruction with a shifted CoR. """ for crop in [1, 2, 5, 10]: sino = np.ascontiguousarray(self.sino_512[:, :-crop]) B = HierarchicalBackprojector(sino.shape, rot_center=255.5) res = B.fbp(sino) # HBP always uses "centered_axis=1", so we cannot compare non-integer shifts if crop % 2 == 0: ref = self.ref_512[crop // 2 : -crop // 2, crop // 2 : -crop // 2] self._compare_to_reference(res, ref, radius_factor=0.95, rel_tol=0.02) def test_hbp_clip_circle(self): B_clip = HierarchicalBackprojector(self.sino_512.shape, extra_options={"clip_outer_circle": True}) B_noclip = HierarchicalBackprojector(self.sino_512.shape, extra_options={"clip_outer_circle": False}) res_clip = B_clip.fbp(self.sino_512) res_noclip = B_noclip.fbp(self.sino_512) self._compare_to_reference(res_clip, clip_to_inner_circle(res_noclip, radius_factor=1), "clip_circle") def test_hbp_axis_corr(self): sino = self.sino_512 # Create a sinogram with a drift in the rotation axis def create_drifted_sino(sino, drifts): out = np.zeros_like(sino) for i in range(sino.shape[0]): out[i] = shift(sino[i], drifts[i]) return out drifts = np.linspace(0, 20, sino.shape[0]) sino = create_drifted_sino(sino, drifts) B = HierarchicalBackprojector(sino.shape, extra_options={"axis_correction": drifts}) res = B.fbp(sino) # Max error is relatively high, migh be due to interpolation of scipy shift in sinogram self._compare_to_reference(res, self.ref_512, radius_factor=0.95, rel_tol=0.04, err_msg="axis_corr") @pytest.mark.skipif(not (__do_long_tests__), reason="need NABU_LONG_TESTS=1 for this test") def test_hbp_scale_factor(self): scale_factor = 0.03125 B_scaled = HierarchicalBackprojector(self.sino_512.shape, extra_options={"scale_factor": scale_factor}) B_unscaled = HierarchicalBackprojector(self.sino_512.shape) res_scaled = B_scaled.fbp(self.sino_512) res_unscaled = B_unscaled.fbp(self.sino_512) self._compare_to_reference(res_scaled, res_unscaled * scale_factor, rel_tol=1e-7, err_msg="scale_factor") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/reconstruction/tests/test_filtering.py0000644000175000017500000001217714654107202023042 0ustar00pierrepierreimport numpy as np import pytest from nabu.reconstruction.filtering import SinoFilter, filter_sinogram from nabu.cuda.utils import __has_pycuda__ from nabu.opencl.utils import __has_pyopencl__ from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__ if __has_pycuda__: from nabu.cuda.utils import get_cuda_context from nabu.reconstruction.filtering_cuda import CudaSinoFilter import pycuda.gpuarray as garray if __has_pyopencl__: import pyopencl.array as parray from nabu.opencl.processing import OpenCLProcessing from nabu.reconstruction.filtering_opencl import OpenCLSinoFilter, __has_vkfft__ filters_to_test = ["ramlak", "shepp-logan", "tukey"] padding_modes_to_test = ["constant", "edge"] if __do_long_tests__: filters_to_test = ["ramlak", "shepp-logan", "cosine", "hamming", "hann", "tukey", "lanczos"] padding_modes_to_test = SinoFilter.available_padding_modes tests_scenarios = generate_tests_scenarios( { "filter_name": filters_to_test, "padding_mode": padding_modes_to_test, "output_provided": [True, False], "truncated_sino": [True, False], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino = get_data("mri_sino500.npz")["data"] cls.sino_truncated = np.ascontiguousarray(cls.sino[:, 160:-160]) if __has_pycuda__: cls.ctx_cuda = get_cuda_context(cleanup_at_exit=False) cls.sino_cuda = garray.to_gpu(cls.sino) cls.sino_truncated_cuda = garray.to_gpu(cls.sino_truncated) if __has_pyopencl__: cls.cl = OpenCLProcessing(device_type="all") cls.sino_cl = parray.to_device(cls.cl.queue, cls.sino) cls.sino_truncated_cl = parray.to_device(cls.cl.queue, cls.sino_truncated) yield if __has_pycuda__: cls.ctx_cuda.pop() @pytest.mark.usefixtures("bootstrap") class TestSinoFilter: @pytest.mark.parametrize("config", tests_scenarios) def test_filter(self, config): sino = self.sino if not (config["truncated_sino"]) else self.sino_truncated sino_filter = SinoFilter( sino.shape, filter_name=config["filter_name"], padding_mode=config["padding_mode"], ) if config["output_provided"]: output = np.zeros_like(sino) else: output = None res = sino_filter.filter_sino(sino, output=output) if output is not None: assert id(res) == id(output), "when providing output, return value must not change" ref = filter_sinogram( sino, sino_filter.dwidth_padded, filter_name=config["filter_name"], padding_mode=config["padding_mode"] ) assert np.allclose(res, ref, atol=4e-6) @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda + pycuda to use CudaSinoFilter") @pytest.mark.parametrize("config", tests_scenarios) def test_cuda_filter(self, config): sino = self.sino_cuda if not (config["truncated_sino"]) else self.sino_truncated_cuda h_sino = self.sino if not (config["truncated_sino"]) else self.sino_truncated sino_filter = CudaSinoFilter( sino.shape, filter_name=config["filter_name"], padding_mode=config["padding_mode"], cuda_options={"ctx": self.ctx_cuda}, ) if config["output_provided"]: output = garray.zeros(sino.shape, "f") else: output = None res = sino_filter.filter_sino(sino, output=output) if output is not None: assert id(res) == id(output), "when providing output, return value must not change" ref = filter_sinogram( h_sino, sino_filter.dwidth_padded, filter_name=config["filter_name"], padding_mode=config["padding_mode"] ) assert np.allclose(res.get(), ref, atol=6e-5), "test_cuda_filter: something wrong with config=%s" % ( str(config) ) @pytest.mark.skipif( not (__has_pyopencl__ and __has_vkfft__), reason="Need OpenCL + pyopencl + pyvkfft to use OpenCLSinoFilter" ) @pytest.mark.parametrize("config", tests_scenarios) def test_opencl_filter(self, config): sino = self.sino_cl if not (config["truncated_sino"]) else self.sino_truncated_cl h_sino = self.sino if not (config["truncated_sino"]) else self.sino_truncated sino_filter = OpenCLSinoFilter( sino.shape, filter_name=config["filter_name"], padding_mode=config["padding_mode"], opencl_options={"ctx": self.cl.ctx}, ) if config["output_provided"]: output = parray.zeros(self.cl.queue, sino.shape, "f") else: output = None res = sino_filter.filter_sino(sino, output=output) if output is not None: assert id(res) == id(output), "when providing output, return value must not change" ref = filter_sinogram( h_sino, sino_filter.dwidth_padded, filter_name=config["filter_name"], padding_mode=config["padding_mode"] ) assert np.allclose(res.get(), ref, atol=6e-5), "test_opencl_filter: something wrong with config=%s" % ( str(config) ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/reconstruction/tests/test_halftomo.py0000644000175000017500000001246314715655362022703 0ustar00pierrepierreimport numpy as np import pytest from nabu.processing.fft_cuda import get_available_fft_implems from nabu.testutils import get_data, generate_tests_scenarios, compare_shifted_images from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.opencl.utils import get_opencl_context, __has_pyopencl__ from nabu.thirdparty.algotom_convert_sino import extend_sinogram __has_cufft__ = False if __has_pycuda__: avail_fft = get_available_fft_implems() __has_cufft__ = len(avail_fft) > 0 __has_pycuda__ = __has_pycuda__ and __has_cufft__ # need both for using Cuda backprojector if __has_pycuda__: from nabu.reconstruction.fbp import CudaBackprojector from nabu.reconstruction.hbp import HierarchicalBackprojector if __has_pyopencl__: from nabu.reconstruction.fbp_opencl import OpenCLBackprojector scenarios = generate_tests_scenarios({"backend": ["cuda", "opencl"]}) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls file_desc = get_data("sino_halftomo.npz") cls.sino = file_desc["sinogram"] * 1e4 cls.rot_center = file_desc["rot_center"] cls.tol = 5e-3 if __has_pycuda__: cls.cuda_ctx = get_cuda_context() if __has_pyopencl__: cls.opencl_ctx = get_opencl_context("all") @pytest.mark.usefixtures("bootstrap") @pytest.mark.parametrize("config", scenarios) class TestHalftomo: def _get_backprojector(self, config, *bp_args, **bp_kwargs): if config["backend"] == "cuda": if not (__has_pycuda__): pytest.skip("Need pycuda + scikit-cuda or vkfft") Backprojector = CudaBackprojector ctx = self.cuda_ctx else: if not (__has_pyopencl__): pytest.skip("Need pyopencl") Backprojector = OpenCLBackprojector ctx = self.opencl_ctx if config.get("opencl_use_textures", True) is False: # patch "extra_options" extra_options = bp_kwargs.pop("extra_options", {}) extra_options["use_textures"] = False bp_kwargs["extra_options"] = extra_options return Backprojector(*bp_args, **bp_kwargs, backend_options={"ctx": ctx}) def test_halftomo_right_side(self, config, sino=None, rot_center=None): if sino is None: sino = self.sino if rot_center is None: rot_center = self.rot_center sino_extended, rot_center_ext = extend_sinogram(sino, rot_center, apply_log=False) sino_extended *= 2 # compat. with nabu normalization backprojector_extended = self._get_backprojector( config, sino_extended.shape, rot_center=rot_center_ext, halftomo=False, padding_mode="edges", angles=np.linspace(0, 2 * np.pi, sino.shape[0], True), extra_options={"centered_axis": True}, ) ref = backprojector_extended.fbp(sino_extended) backprojector = self._get_backprojector( config, sino.shape, rot_center=rot_center, halftomo=True, padding_mode="edges", extra_options={"centered_axis": True}, ) res = backprojector.fbp(sino) # The approach in algotom (used as reference) slightly differers: # - altogom extends the sinogram with padding, so that it's ready-to-use for FBP # - nabu filters the sinogram first, and then does the "half-tomo preparation". # Filtering the sinogram first is better to avoid artefacts due to sharp transition in the borders metric, upper_bound = compare_shifted_images(res, ref, return_upper_bound=True) assert metric < 5, "Something wrong for halftomo with backend %s" % (config["backend"]) def test_halftomo_left_side(self, config): sino = np.ascontiguousarray(self.sino[:, ::-1]) rot_center = sino.shape[-1] - 1 - self.rot_center return self.test_halftomo_right_side(config, sino=sino, rot_center=rot_center) def test_halftomo_cor_outside_fov(self, config): sino = np.ascontiguousarray(self.sino[:, : self.sino.shape[-1] // 2]) backprojector = self._get_backprojector(config, sino.shape, rot_center=self.rot_center, halftomo=True) res = backprojector.fbp(sino) # Just check that it runs, but no reference results. Who does this anyway ?! @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda") def test_hbp_halftomo(self, config): if config["backend"] == "opencl": pytest.skip("No HBP available in OpenCL") B = HierarchicalBackprojector(self.sino.shape, halftomo=True, rot_center=self.rot_center, padding_mode="edge") res = B.fbp(self.sino) sino_extended, rot_center_ext = extend_sinogram(self.sino, self.rot_center, apply_log=False) sino_extended *= 2 # compat. with nabu normalization B_extended = HierarchicalBackprojector( sino_extended.shape, rot_center=rot_center_ext, padding_mode="edge", angles=np.linspace(0, 2 * np.pi, self.sino.shape[0], True), ) res_e = B_extended.fbp(sino_extended) # see notes in test_halftomo_right_side() metric, upper_bound = compare_shifted_images(res, res_e, return_upper_bound=True) assert metric < 5, "Something wrong for halftomo with HBP" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/reconstruction/tests/test_mlem.py0000644000175000017500000000733514720040151022002 0ustar00pierrepierreimport pytest import numpy as np from nabu.testutils import get_data, __do_long_tests__ from nabu.cuda.utils import __has_pycuda__ from nabu.reconstruction.mlem import MLEMReconstructor, __have_corrct__ @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls datafile = get_data("sl_mlem.npz") cls.data = datafile["data"] cls.angles_rad = datafile["angles_rad"] cls.random_u_shifts = datafile["random_u_shifts"] cls.ref_rec_noshifts = datafile["ref_rec_noshifts"] cls.ref_rec_shiftsu = datafile["ref_rec_shiftsu"] cls.ref_rec_u_rand = datafile["ref_rec_u_rand"] cls.ref_rec_shiftsv = datafile["ref_rec_shiftsv"] # cls.ref_rec_v_rand = datafile["ref_rec_v_rand"] cls.tol = 2e-4 @pytest.mark.skipif(not (__has_pycuda__ and __have_corrct__), reason="Need pycuda and corrct for this test") @pytest.mark.usefixtures("bootstrap") class TestMLEM: """These tests test the general MLEM reconstruction algorithm and the behavior of the reconstruction with respect to horizontal shifts. Only horizontal shifts are tested here because vertical shifts are handled outside the reconstruction object, but in the embedding reconstruction pipeline. See FullFieldReconstructor""" def _create_MLEM_reconstructor(self, shifts_uv=None): return MLEMReconstructor( self.data.shape, -self.angles_rad, shifts_uv, cor=0.0, n_iterations=10 # mind the sign ) def test_simple_mlem_recons(self): R = self._create_MLEM_reconstructor() rec = R.reconstruct(self.data) delta = np.abs(rec[:, ::-1, :] - self.ref_rec_noshifts) assert np.max(delta) < self.tol def test_mlem_recons_with_u_shifts(self): shifts = np.zeros((len(self.angles_rad), 2)) shifts[:, 0] = -5 R = self._create_MLEM_reconstructor(shifts) rec = R.reconstruct(self.data) delta = np.abs(rec[:, ::-1] - self.ref_rec_shiftsu) assert np.max(delta) < self.tol def test_mlem_recons_with_random_u_shifts(self): R = self._create_MLEM_reconstructor(self.random_u_shifts) rec = R.reconstruct(self.data) delta = np.abs(rec[:, ::-1] - self.ref_rec_u_rand) assert np.max(delta) < self.tol def test_mlem_recons_with_constant_v_shifts(self): from nabu.preproc.shift import VerticalShift shifts = np.zeros((len(self.angles_rad), 2)) shifts[:, 1] = -20 nv, n_angles, nu = self.data.shape radios_movements = VerticalShift( (n_angles, nv, nu), -shifts[:, 1] ) # Minus sign here mimics what is done in the pipeline. tmp_in = np.swapaxes(self.data, 0, 1).copy() tmp_out = np.zeros_like(tmp_in) radios_movements.apply_vertical_shifts(tmp_in, list(range(n_angles)), output=tmp_out) data = np.swapaxes(tmp_out, 0, 1).copy() R = self._create_MLEM_reconstructor(shifts) rec = R.reconstruct(data) axslice = 120 trslice = 84 axslice1 = self.ref_rec_shiftsv[axslice] axslice2 = rec[axslice, ::-1] trslice1 = self.ref_rec_shiftsv[trslice] trslice2 = rec[trslice, ::-1] # delta = np.abs(rec[:, ::-1] - self.ref_rec_shiftsv) delta_ax = np.abs(axslice1 - axslice2) delta_tr = np.abs(trslice1 - trslice2) assert max(np.max(delta_ax), np.max(delta_tr)) < self.tol @pytest.mark.skip(reason="No valid reference reconstruction for this test.") def test_mlem_recons_with_random_v_shifts(self): """NOT YET IMPLEMENTED. This is a temporary version due to unpexcted behavior of CorrCT/Astra to compute a reference implementation. See [question on Astra's github](https://github.com/astra-toolbox/astra-toolbox/discussions/520). """ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682665866.0 nabu-2024.2.1/nabu/reconstruction/tests/test_projector.py0000644000175000017500000001362114422670612023064 0ustar00pierrepierreimport numpy as np import pytest from nabu.testutils import get_data from nabu.cuda.utils import __has_pycuda__ if __has_pycuda__: import pycuda.gpuarray as garray # from pycuda.cumath import fabs from pycuda.elementwise import ElementwiseKernel from nabu.cuda.utils import get_cuda_context from nabu.reconstruction.projection import Projector from nabu.reconstruction.fbp import Backprojector try: import astra __has_astra__ = True except ImportError: __has_astra__ = False @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.image = get_data("brain_phantom.npz")["data"] cls.sino_ref = get_data("mri_sino500.npz")["data"] cls.n_angles, cls.dwidth = cls.sino_ref.shape cls.rtol = 1e-3 if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") @pytest.mark.usefixtures("bootstrap") class TestProjection: def check_result(self, img1, img2, err_msg): max_diff = np.max(np.abs(img1 - img2)) assert max_diff / img1.max() < self.rtol, err_msg + " : max diff = %.3e" % max_diff def test_proj_simple(self): P = Projector(self.image.shape, self.n_angles) res = P(self.image) self.check_result(res, self.sino_ref, "Something wrong with simple projection") def test_input_output_kinds(self): P = Projector(self.image.shape, self.n_angles) # input on GPU, output on CPU d_img = garray.to_gpu(self.image) res = P(d_img) self.check_result(res, self.sino_ref, "Something wrong: input GPU, output CPU") # input on CPU, output on GPU out = garray.zeros(P.sino_shape, "f") res = P(self.image, output=out) self.check_result(out.get(), self.sino_ref, "Something wrong: input CPU, output GPU") # input and output on GPU out.fill(0) P(d_img, output=out) self.check_result(out.get(), self.sino_ref, "Something wrong: input GPU, output GPU") def test_odd_size(self): image = self.image[:511, :] P = Projector(image.shape, self.n_angles - 1) res = P(image) @pytest.mark.skipif(not (__has_astra__), reason="Need astra-toolbox for this test") def test_against_astra(self): def proj_astra(img, angles, rot_center=None): vol_geom = astra.create_vol_geom(img.shape) if np.isscalar(angles): angles = np.linspace(0, np.pi, angles, False) proj_geom = astra.create_proj_geom("parallel", 1.0, img.shape[-1], angles) if rot_center is not None: cor_shift = (img.shape[-1] - 1) / 2.0 - rot_center proj_geom = astra.geom_postalignment(proj_geom, cor_shift) projector_id = astra.create_projector("cuda", proj_geom, vol_geom) sinogram_id, sinogram = astra.create_sino(img, projector_id) astra.data2d.delete(sinogram_id) astra.projector.delete(projector_id) return sinogram # Center of rotation to test cors = [None, 255.5, 256, 260, 270.2, 300, 150] for cor in cors: res_astra = proj_astra(self.image, 500, rot_center=cor) res_nabu = Projector(self.image.shape, 500, rot_center=cor).projection(self.image) self.check_result(res_nabu, res_astra, "Projection with CoR = %s" % str(cor)) @pytest.mark.skipif(not (__has_astra__), reason="Need astra-toolbox for this test") def test_em_reconstruction(self): """ Test iterative reconstruction: Maximum Likelyhood Expectation Maximization (MLEM) """ subsampling = 5 sino = self.sino_ref[::subsampling, :] P = Projector(self.image.shape, sino.shape[0]) B = Backprojector(sino.shape, padding_mode="edge", extra_options={"centered_axis": True}) d_sino = garray.to_gpu(np.ascontiguousarray(sino)) def EM(sino, P, B, n_it, eps=1e-6): ones = np.ones(sino.shape, "f") oinv = garray.to_gpu((1.0 / B.backproj(ones)).astype("f")) x = garray.ones_like(oinv) y = garray.zeros_like(x) proj = garray.zeros_like(sino) proj_inv = sino.copy() update_projection = ElementwiseKernel( "float* proj_inv, float* proj, float* proj_data, float eps", "proj_inv[i] = proj_data[i] / ((fabsf(proj[i]) > eps) ? (proj[i]) : (1.0f))", "update_projection", ) for k in range(n_it): # proj = P(x) P.projection(x, output=proj) update_projection(proj_inv, proj, sino, eps) # x *= B(proj_inv) * oinv B.backproj(proj_inv, output=y) x *= y x *= oinv return x rec = EM(d_sino, P, B, 50) def EM_astra(sino, rec_shape, n_it): vol_geom = astra.create_vol_geom(rec_shape) proj_geom = astra.create_proj_geom( "parallel", 1.0, sino.shape[-1], np.linspace(0, np.pi, sino.shape[0], False) ) rec_id = astra.data2d.create("-vol", vol_geom) sinogram_id = astra.data2d.create("-sino", proj_geom) astra.data2d.store(sinogram_id, sino) astra.data2d.store(rec_id, np.ones(rec_shape, "f")) # ! cfg = astra.astra_dict("EM_CUDA") cfg["ReconstructionDataId"] = rec_id cfg["ProjectionDataId"] = sinogram_id alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id, n_it) rec = astra.data2d.get(rec_id) astra.algorithm.delete(alg_id) astra.data2d.delete(rec_id) astra.data2d.delete(sinogram_id) return rec ref = EM_astra(sino, self.image.shape, 50) err_max = np.max(np.abs(rec.get() - ref)) assert err_max < 0.2, "Discrepancy between EM and EM_astra" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/reconstruction/tests/test_reconstructor.py0000644000175000017500000000647114712705065024001 0ustar00pierrepierreimport numpy as np import pytest from nabu.processing.fft_cuda import get_available_fft_implems from nabu.testutils import ( get_big_data, __big_testdata_dir__, generate_tests_scenarios, __do_long_tests__, ) from nabu.cuda.utils import __has_pycuda__, get_cuda_context __has_cufft__ = False if __has_pycuda__: avail_fft = get_available_fft_implems() __has_cufft__ = len(avail_fft) > 0 __has_cuda_fbp__ = __has_cufft__ and __has_pycuda__ if __has_cuda_fbp__: from nabu.reconstruction.reconstructor_cuda import CudaReconstructor from nabu.reconstruction.fbp import Backprojector as CudaBackprojector import pycuda.gpuarray as garray scenarios = generate_tests_scenarios( { "axis": ["z", "y", "x"], "vol_type": ["sinograms", "projections"], "indices": [(300, 310)], # reconstruct 10 slices "slices_roi": [None, (None, None, 200, 400), (250, 300, None, None), (120, 330, 301, 512)], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.projs = get_big_data("MRI512_projs.npy") if __has_cuda_fbp__: cls.ctx = get_cuda_context() cls.d_projs = garray.to_gpu(cls.projs) cls.ref = None cls.tol = 5e-2 @pytest.mark.skipif( __big_testdata_dir__ is None or not (__do_long_tests__), reason="need environment variable NABU_BIGDATA_DIR and NABU_LONG_TESTS=1", ) @pytest.mark.usefixtures("bootstrap") class TestReconstructor: @pytest.mark.skipif(not (__has_cuda_fbp__), reason="need pycuda and (scikit-cuda or vkfft)") @pytest.mark.parametrize("config", scenarios) def test_cuda_reconstructor(self, config): data = self.projs d_data = self.d_projs if config["vol_type"] == "sinograms": data = np.moveaxis(self.projs, 1, 0) d_data = self.d_projs.transpose(axes=(1, 0, 2)) # view reconstructor = CudaReconstructor( data.shape, config["indices"], axis=config["axis"], vol_type=config["vol_type"], slices_roi=config["slices_roi"], ) res = reconstructor.reconstruct(d_data) ref = self.get_ref() ref = self.crop_array(ref, config) err_max = np.max(np.abs(res - ref)) assert err_max < self.tol, "something wrong with reconstructor, config = %s" % str(config) def get_ref(self): if self.ref is not None: return self.ref if __has_cuda_fbp__: fbp_cls = CudaBackprojector ref = np.zeros((512, 512, 512), "f") fbp = fbp_cls((self.projs.shape[0], self.projs.shape[-1])) for i in range(512): ref[i] = fbp.fbp(self.d_projs[:, i, :]) self.ref = ref return self.ref @staticmethod def crop_array(arr, config): indices = config["indices"] axis = config["axis"] slices_roi = config["slices_roi"] or (None, None, None, None) i_slice = slice(*indices) u_slice = slice(*slices_roi[:2]) v_slice = slice(*slices_roi[-2:]) if axis == "z": z_slice, y_slice, x_slice = i_slice, v_slice, u_slice if axis == "y": z_slice, y_slice, x_slice = v_slice, i_slice, u_slice if axis == "x": z_slice, y_slice, x_slice = v_slice, u_slice, i_slice return arr[z_slice, y_slice, x_slice] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/reconstruction/tests/test_sino_normalization.py0000644000175000017500000000715414402565210024772 0ustar00pierrepierreimport os.path as path import numpy as np import pytest from nabu.testutils import get_data from nabu.cuda.utils import __has_pycuda__ from nabu.reconstruction.sinogram import SinoNormalization if __has_pycuda__: from nabu.reconstruction.sinogram_cuda import CudaSinoNormalization import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino = get_data("sino_refill.npy") cls.tol = 1e-7 cls.norm_array_1D = np.arange(cls.sino.shape[-1]) + 1 cls.norm_array_2D = np.arange(cls.sino.size).reshape(cls.sino.shape) + 1 @pytest.mark.usefixtures("bootstrap") class TestSinoNormalization: def test_sino_normalization(self): sino_proc = SinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) sino = self.sino.copy() sino_proc.normalize(sino) @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for sinogram normalization with cuda backend") def test_sino_normalization_cuda(self): sino_proc = SinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) sino = self.sino.copy() ref = sino_proc.normalize(sino) cuda_sino_proc = CudaSinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) d_sino = garray.to_gpu(self.sino) cuda_sino_proc.normalize(d_sino) res = d_sino.get() assert np.max(np.abs(res - ref)) < self.tol def get_normalization_reference_result(self, op, normalization_arr): # Perform explicit operations to compare with numpy.divide, numpy.subtract, etc if op == "subtraction": ref = self.sino - normalization_arr elif op == "division": ref = self.sino / normalization_arr return ref def test_sino_array_subtraction_and_division(self): with pytest.raises(ValueError): SinoNormalization(kind="subtraction", sinos_shape=self.sino.shape) def compare_normalizations(normalization_arr, op): sino_normalization = SinoNormalization( kind=op, sinos_shape=self.sino.shape, normalization_array=normalization_arr ) sino = self.sino.copy() sino_normalization.normalize(sino) ref = self.get_normalization_reference_result(op, normalization_arr) assert np.allclose(sino, ref), "operation=%s, normalization_array dims=%d" % (op, normalization_arr.ndim) compare_normalizations(self.norm_array_1D, "subtraction") compare_normalizations(self.norm_array_1D, "division") compare_normalizations(self.norm_array_2D, "subtraction") compare_normalizations(self.norm_array_2D, "division") @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for sinogram normalization with cuda backend") def test_sino_array_subtraction_cuda(self): with pytest.raises(ValueError): CudaSinoNormalization(kind="subtraction", sinos_shape=self.sino.shape) def compare_normalizations(normalization_arr, op): sino_normalization = CudaSinoNormalization( kind=op, sinos_shape=self.sino.shape, normalization_array=normalization_arr ) sino = garray.to_gpu(self.sino) sino_normalization.normalize(sino) ref = self.get_normalization_reference_result(op, normalization_arr) assert np.allclose(sino.get(), ref) compare_normalizations(self.norm_array_1D, "subtraction") compare_normalizations(self.norm_array_2D, "subtraction") compare_normalizations(self.norm_array_1D, "division") compare_normalizations(self.norm_array_2D, "division") ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/resources/0000755000175000017500000000000014730277752015242 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/resources/__init__.py0000644000175000017500000000000014315516747017340 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/resources/cli/0000755000175000017500000000000014730277752016011 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/resources/cli/__init__.py0000644000175000017500000000000014315516747020107 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/cor.py0000644000175000017500000000025214402565210016357 0ustar00pierrepierrefrom ..pipeline.estimators import CORFinder as ECorFinder # This class has moved # The future location will be nabu.pipeline.estimators.CORFinder CORFinder = ECorFinder ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/nabu/resources/dataset_analyzer.py0000644000175000017500000004421014726604214021137 0ustar00pierrepierreimport os import numpy as np from silx.io.url import DataUrl from silx.io import get_data from tomoscan.esrf.scan.edfscan import EDFTomoScan from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from ..utils import check_supported, indices_to_slices from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader from ..io.utils import get_compacted_dataslices from .utils import get_values_from_file, is_hdf5_extension from .logger import LoggerOrPrint from ..pipeline.utils import nabu_env_settings class DatasetAnalyzer: _scanner = None kind = "none" """ Base class for datasets analyzers. """ def __init__(self, location, extra_options=None, logger=None): """ Initialize a Dataset analyzer. Parameters ---------- location: str Dataset location (directory or file name) extra_options: dict, optional Extra options on how to interpret the dataset. logger: logging object, optional Logger. If not set, messages will just be printed in stdout. """ self.logger = LoggerOrPrint(logger) self.location = location self._set_extra_options(extra_options) self._get_excluded_projections() self._set_default_dataset_values() self._init_dataset_scan() self._finish_init() def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} # COMPAT. advanced_options = { "force_flatfield": False, "output_dir": None, "exclude_projections": None, "hdf5_entry": None, # "nx_version": 1.0, } # -- advanced_options.update(extra_options) self.extra_options = advanced_options # pylint: disable=E1136 def _get_excluded_projections(self): excluded_projs = self.extra_options["exclude_projections"] self._ignore_projections = None if excluded_projs is None: return if excluded_projs["type"] == "angular_range": excluded_projs["type"] = "range" # compat with tomoscan #pylint: disable=E1137 values = excluded_projs["range"] for ignore_kind, dtype in {"indices": np.int32, "angles": np.float32}.items(): if excluded_projs["type"] == ignore_kind: values = get_values_from_file(excluded_projs["file"], any_size=True).astype(dtype).tolist() self._ignore_projections = {"kind": excluded_projs["type"], "values": values} # pylint: disable=E0606 def _init_dataset_scan(self, **kwargs): if self._scanner is None: raise ValueError("Base class") if self._scanner is NXtomoScan: if self.extra_options.get("hdf5_entry", None) is not None: kwargs["entry"] = self.extra_options["hdf5_entry"] if self.extra_options.get("nx_version", None) is not None: kwargs["nx_version"] = self.extra_options["nx_version"] if self._scanner is EDFTomoScan: # Assume 1 frame per file (otherwise too long to open each file) kwargs["n_frames"] = 1 self.dataset_scanner = self._scanner( # pylint: disable=E1102 self.location, ignore_projections=self._ignore_projections, **kwargs ) if self._ignore_projections is not None: self.logger.info("Excluding projections: %s" % str(self._ignore_projections)) if nabu_env_settings.skip_tomoscan_checks: self.logger.warning( " WARNING: according to nabu_env_settings.skip_tomoscan_checks, skipping virtual layout integrity check of tomoscan which is time consuming" ) self.dataset_scanner.set_check_behavior(run_check=False, raise_error=False) self.raw_flats = self.dataset_scanner.flats self.raw_darks = self.dataset_scanner.darks self.n_angles = len(self.dataset_scanner.projections) self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2) self._radio_dims_notbinned = self.radio_dims # COMPAT def _finish_init(self): pass def _set_default_dataset_values(self): self._detector_tilt = None self.translations = None self.ctf_translations = None self.axis_position = None self._rotation_angles = None self.z_per_proj = None self.x_per_proj = None self._energy = None self._pixel_size = None self._distance = None self._flats_srcurrent = None self._projections = None self._projections_srcurrent = None self._reduced_flats = None self._reduced_darks = None @property def energy(self): """ Return the energy in kev. """ if self._energy is None: self._energy = self.dataset_scanner.energy return self._energy @energy.setter def energy(self, val): self._energy = val @property def distance(self): """ Return the sample-detector distance in meters. """ if self._distance is None: self._distance = abs(self.dataset_scanner.distance) return self._distance @distance.setter def distance(self, val): self._distance = val @property def pixel_size(self): """ Return the pixel size in microns. """ # TODO X and Y pixel size if self._pixel_size is None: self._pixel_size = self.dataset_scanner.pixel_size * 1e6 return self._pixel_size @pixel_size.setter def pixel_size(self, val): self._pixel_size = val def _get_rotation_angles(self): return self._rotation_angles # None by default @property def rotation_angles(self): """ Return the rotation angles in radians. """ return self._get_rotation_angles() @rotation_angles.setter def rotation_angles(self, angles): self._rotation_angles = angles def _is_halftomo(self): return None # base class @property def is_halftomo(self): """ Indicates whether the current dataset was performed with half acquisition. """ return self._is_halftomo() @property def detector_tilt(self): """ Return the detector tilt in degrees """ return self._detector_tilt @detector_tilt.setter def detector_tilt(self, tilt): self._detector_tilt = tilt def _get_srcurrent(self, frame_type): # To be implemented by inheriting class return None @property def projections(self): if self._projections is None: self._projections = self.dataset_scanner.projections return self._projections @projections.setter def projections(self, val): raise ValueError @property def projections_srcurrent(self): """ Return the synchrotron electric current for each projection. """ if self._projections_srcurrent is None: self._projections_srcurrent = self._get_srcurrent("radios") # pylint: disable=E1128 return self._projections_srcurrent @projections_srcurrent.setter def projections_srcurrent(self, val): self._projections_srcurrent = val @property def flats_srcurrent(self): """ Return the synchrotron electric current for each flat image. """ if self._flats_srcurrent is None: self._flats_srcurrent = self._get_srcurrent("flats") # pylint: disable=E1128 return self._flats_srcurrent @flats_srcurrent.setter def flats_srcurrent(self, val): self._flats_srcurrent = val def check_defined_attribute(self, name, error_msg=None): """ Utility function to check that a given attribute is defined. """ if getattr(self, name, None) is None: raise ValueError(error_msg or str("No information on %s was found in the dataset" % name)) @property def flats(self): """ Return the REDUCED flat-field images. Either by reducing (median) the raw flats, or a user-defined reduced flats. """ if self._reduced_flats is None: self._reduced_flats = self.get_reduced_flats() return self._reduced_flats @flats.setter def flats(self, val): self._reduced_flats = val @property def darks(self): """ Return the REDUCED flat-field images. Either by reducing (mean) the raw darks, or a user-defined reduced darks. """ if self._reduced_darks is None: self._reduced_darks = self.get_reduced_darks() return self._reduced_darks @darks.setter def darks(self, val): self._reduced_darks = val class EDFDatasetAnalyzer(DatasetAnalyzer): """ EDF Dataset analyzer for legacy ESRF acquisitions """ _scanner = EDFTomoScan kind = "edf" def _finish_init(self): pass def _get_flats_darks(self): return @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset. Not applicable for EDF (return None) """ return None def _is_halftomo(self): return None def _get_rotation_angles(self): return np.deg2rad(self.dataset_scanner.rotation_angle()) def get_reduced_flats(self, **reader_kwargs): if self.raw_flats in [None, {}]: raise FileNotFoundError("No reduced flat ('refHST') found in %s" % self.location) # A few notes: # (1) In principle we could do the reduction (mean/median) from raw frames (ref_xxxx_yyyy) # but for legacy datasets it's always already done (by fasttomo3), and EDF support is supposed to be dropped on our side # (2) We use EDFStackReader class to handle the possible additional data modifications # (eg. subsampling, binning, distortion correction...) # (3) The following spawns one reader instance per file, which is not elegant, # but in principle there are typically 1-2 reduced flats in a scan readers = {k: EDFStackReader([self.raw_flats[k].file_path()], **reader_kwargs) for k in self.raw_flats.keys()} return {k: readers[k].load_data()[0] for k in self.raw_flats.keys()} def get_reduced_darks(self, **reader_kwargs): # See notes in get_reduced_flats() above if self.raw_darks in [None, {}]: raise FileNotFoundError("No reduced dark ('darkend.edf' or 'dark.edf') found in %s" % self.location) readers = {k: EDFStackReader([self.raw_darks[k].file_path()], **reader_kwargs) for k in self.raw_darks.keys()} return {k: readers[k].load_data()[0] for k in self.raw_darks.keys()} @property def files(self): return sorted([u.file_path() for u in self.dataset_scanner.projections.values()]) def get_reader(self, **kwargs): return EDFStackReader(self.files, **kwargs) class HDF5DatasetAnalyzer(DatasetAnalyzer): """ HDF5 dataset analyzer """ _scanner = NXtomoScan kind = "nx" # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this _image_key_value = {"flats": 1, "darks": 2, "radios": 0} # @property def z_translation(self): raw_data = np.array(self.dataset_scanner.z_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size @property def x_translation(self): raw_data = np.array(self.dataset_scanner.x_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size def _get_rotation_angles(self): if self._rotation_angles is None: angles = np.array(self.dataset_scanner.rotation_angle) projs_idx = np.array(list(self.projections.keys())) angles = angles[projs_idx] self._rotation_angles = np.deg2rad(angles) return self._rotation_angles def _get_dataset_hdf5_url(self): if len(self.projections) > 0: frames_to_take = self.projections elif len(self.raw_flats) > 0: frames_to_take = self.raw_flats elif len(self.raw_darks) > 0: frames_to_take = self.raw_darks else: raise ValueError("No projections, no flats and no darks ?!") first_proj_idx = sorted(frames_to_take.keys())[0] first_proj_url = frames_to_take[first_proj_idx] return DataUrl( file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), data_slice=None, scheme="silx" ) @property def dataset_hdf5_url(self): return self._get_dataset_hdf5_url() @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset """ return self.dataset_scanner.entry def _is_halftomo(self): try: is_halftomo = self.dataset_scanner.field_of_view.value.lower() == "half" except: is_halftomo = None return is_halftomo def get_data_slices(self, what): """ Return indices in the data volume where images correspond to a given kind. Parameters ---------- what: str Which keys to get. Can be "projections", "flats", "darks" Returns -------- slices: list of slice A list where each item is a slice. """ name_to_attr = { "projections": self.projections, "flats": self.raw_flats, "darks": self.raw_darks, } check_supported(what, name_to_attr.keys(), "image type") images = name_to_attr[what] # dict # we can't directly use set() on slice() object (unhashable). Use tuples slices = set() for du in get_compacted_dataslices(images).values(): if du.data_slice() is not None: s = (du.data_slice().start, du.data_slice().stop) else: s = None slices.add(s) slices_list = [slice(item[0], item[1]) if item is not None else None for item in list(slices)] return slices_list def _select_according_to_frame_type(self, data, frame_type): if data is None: return None return data[self.dataset_scanner.image_key_control == self._image_key_value[frame_type]] def get_reduced_flats(self, method="median", force_reload=False, **reader_kwargs): dkrf_reader = NXDarksFlats( self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs ) return dkrf_reader.get_reduced_flats(method=method, force_reload=force_reload, as_dict=True) def get_reduced_darks(self, method="mean", force_reload=False, **reader_kwargs): dkrf_reader = NXDarksFlats( self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs ) return dkrf_reader.get_reduced_darks(method=method, force_reload=force_reload, as_dict=True) def _get_srcurrent(self, frame_type): return self._select_according_to_frame_type(self.dataset_scanner.electric_current, frame_type) def frames_slices(self, frame_type): """ Return a list of slice objects corresponding to the data corresponding to "frame_type". For example, if the dataset flats are located at indices [1, 2, ..., 99], then frame_slices("flats") will return [slice(0, 100)]. """ return indices_to_slices( np.where(self.dataset_scanner.image_key_control == self._image_key_value[frame_type])[0] ) def get_reader(self, **kwargs): return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs) def analyze_dataset(dataset_path, extra_options=None, logger=None): if not (os.path.isdir(dataset_path)): if not (os.path.isfile(dataset_path)): raise ValueError("Error: %s no such file or directory" % dataset_path) if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))): raise ValueError("Error: expected a HDF5 file") dataset_analyzer_class = HDF5DatasetAnalyzer else: # directory -> assuming EDF dataset_analyzer_class = EDFDatasetAnalyzer dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger) return dataset_structure def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False): """ Get closest radios at radio_angles[0] and radio_angles[1] angles must be in angles Parameters ---------- dataset_info: `DatasetAnalyzer` instance Data structure with the dataset information radio_angles: tuple tuple of two elements: angles (in radian) to get return_indices: bool, optional Whether to return radios indices along with the radios array. Returns ------- res: array or tuple If return_indices is True, return a tuple (radios, indices). Otherwise, return an array with the radios. """ if not (isinstance(radio_angles, tuple) and len(radio_angles) == 2): raise TypeError("radio_angles should be a tuple of two elements.") if not isinstance(radio_angles[0], (np.floating, float)) or not isinstance(radio_angles[1], (np.floating, float)): raise TypeError( f"radio_angles should be float. Get {type(radio_angles[0])} and {type(radio_angles[1])} instead" ) radios_indices = [] radios_indices = sorted(dataset_info.projections.keys()) angles = dataset_info.rotation_angles angles = angles - angles.min() i_radio_1 = np.argmin(np.abs(angles - radio_angles[0])) i_radio_2 = np.argmin(np.abs(angles - radio_angles[1])) radios_indices = [radios_indices[i_radio_1], radios_indices[i_radio_2]] n_radios = 2 radios = np.zeros((n_radios,) + dataset_info.radio_dims[::-1], "f") for i in range(n_radios): radio_idx = radios_indices[i] radios[i] = get_data(dataset_info.projections[radio_idx]).astype("f") if return_indices: return radios, radios_indices else: return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/resources/gpu.py0000644000175000017500000001310614654107202016373 0ustar00pierrepierre""" gpu.py: general-purpose utilities for GPU """ from ..utils import check_supported try: from pycuda.driver import Device as CudaDevice from pycuda.driver import device_attribute as dev_attrs __has_pycuda__ = True except ImportError: __has_pycuda__ = False CudaDevice = type(None) try: from pyopencl import Device as CLDevice __has_pyopencl__ = True except ImportError: CLDevice = type(None) __has_pyopencl__ = False # # silx.opencl.common.Device cannot be supported as long as # silx.opencl instantiates the "ocl" singleton in __init__, # leaving opencl contexts all over the place in some cases # class GPUDescription: """ Simple description of a Graphical Processing Unit. This class is designed to be simple to understand, and to be serializable for being used by dask.distributed. """ def __init__(self, device, vendor=None, device_id=None): """ Create a description from a device. Parameters ---------- device: `pycuda.driver.Device` or `pyopencl.Device` Class describing a GPU device. """ is_cuda_device = isinstance(device, CudaDevice) is_cl_device = isinstance(device, CLDevice) if is_cuda_device: self._init_from_cuda_device(device) elif is_cl_device: self._init_from_cl_device(device) self._set_other_attrs(vendor, device_id) else: raise ValueError("Expected `pycuda.driver.Device` or `pyopencl.Device`") def _init_from_cuda_device(self, device): self._dict = { "type": "cuda", "name": device.name(), "memory_GB": device.total_memory() / 1e9, "compute_capability": device.compute_capability(), "device_id": device.get_attribute(dev_attrs.MULTI_GPU_BOARD_GROUP_ID), } def _init_from_cl_device(self, device): self._dict = { "type": "opencl", "name": device.name, "memory_GB": device.global_mem_size / 1e9, "vendor": device.vendor, } def _set_other_attrs(self, vendor, device_id): if vendor is not None: self._dict["vendor"] = vendor if device_id is not None: self._dict["device_id"] = device_id # device ID for OpenCL (!= platform ID) def _dict_to_self(self): for key, val in self._dict.items(): setattr(self, key, val) def get_dict(self): return self._dict GPU_PICK_METHODS = ["cuda", "auto"] def pick_gpus(method, cuda_gpus, opencl_platforms, n_gpus): check_supported(method, GPU_PICK_METHODS, "GPU picking method") if method == "cuda": return pick_gpus_nvidia(cuda_gpus, n_gpus) elif method == "auto": return pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus) else: return [] # TODO Fix this function, it is broken: # - returns something when n_gpus = 0 # - POCL increments device_id by 1 ! def pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus): """ Pick `n_gpus` devices with the best available driver. This function browse the visible Cuda GPUs and Opencl platforms to pick the GPUs with the best driver. A worker might see several implementations of a GPU driver. For example with Nvidia hardware, we can see: - The Cuda implementation (nvidia-cuda-toolkit) - OpenCL implementation by Nvidia (nvidia-opencl-icd) - OpenCL implementation by Portable OpenCL Parameters ---------- cuda_gpu: dict Dictionary where each key is an ID, and the value is a dictionary describing some attributes of the GPU (obtained with `GPUDescription`) opencl_platforms: dict Dictionary where each key is the platform name, and the value is a list of dictionary descriptions. n_gpus: int Number of GPUs to pick. """ def gpu_equal(gpu1, gpu2): # TODO find a better test ? # Some information are not always available depending on the opencl vendor ! return (gpu1["device_id"] == gpu2["device_id"]) and (gpu1["name"] == gpu2["name"]) def is_in_gpus(avail_gpus, query_gpu): for gpu in avail_gpus: if gpu_equal(gpu, query_gpu): return True return False # If some Nvidia hardware is visible, add it without question. # In the case we don't want it, we should either re-run resources discovery # with `try_cuda=False`, or mask individual devices with CUDA_VISIBLE_DEVICES. chosen_gpus = list(cuda_gpus.values()) if len(chosen_gpus) >= n_gpus: return chosen_gpus for platform, gpus in opencl_platforms.items(): for gpu_id, gpu in gpus.items(): if not (is_in_gpus(chosen_gpus, gpu)): # TODO prioritize some OpenCL implementations ? chosen_gpus.append(gpu) if len(chosen_gpus) < n_gpus: raise ValueError("Not enough GPUs: could only collect %d/%d" % (len(chosen_gpus), n_gpus)) return chosen_gpus def pick_gpus_nvidia(cuda_gpus, n_gpus): """ Pick one or more Nvidia GPUs. """ if len(cuda_gpus) < n_gpus: raise ValueError("Not enough Nvidia GPU: requested %d, but can get only %d" % (n_gpus, len(cuda_gpus))) # Sort GPUs by computing capabilities, pick the "best" ones gpus_cc = [] for gpu_id, gpu in cuda_gpus.items(): cc = gpu["compute_capability"] gpus_cc.append((gpu_id, cc[0] + 0.1 * cc[1])) gpus_cc_sorted = sorted(gpus_cc, key=lambda x: x[1], reverse=True) res = [] for i in range(n_gpus): res.append(cuda_gpus[gpus_cc_sorted[i][0]]) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/logger.py0000644000175000017500000000725614402565210017066 0ustar00pierrepierreimport logging import logging.config class Logger(object): def __init__(self, loggername, level="DEBUG", logfile="logger.log", console=True): """ Configure a Logger object. Parameters ----------- loggername: str Logger name. level: str, optional Logging level. Can be "debug", "info", "warning", "error", "critical". Default is "debug". logfile: str, optional File where the logs are written. If set to None, the logs are not written in a file. Default is "logger.log". console: bool, optional If set to True, the logs are (also) written in stdout/stderr. Default is True. """ self.loggername = loggername self.level = level self.logfile = logfile self.console = console self._configure_logger() def _configure_logger(self): conf = self._get_default_config_dict() for handler in conf["handlers"].keys(): conf["handlers"][handler]["level"] = self.level.upper() conf["loggers"][self.loggername]["level"] = self.level.upper() if not (self.console): conf["loggers"][self.loggername]["handlers"].remove("console") self.config = conf logging.config.dictConfig(conf) self.logger = logging.getLogger(self.loggername) def _get_default_config_dict(self): conf = { "version": 1, "formatters": { "default": {"format": "%(asctime)s - %(levelname)s - %(message)s", "datefmt": "%d-%m-%Y %H:%M:%S"}, "console": {"format": "%(message)s"}, }, "handlers": { "console": { "level": "DEBUG", "class": "logging.StreamHandler", "formatter": "console", "stream": "ext://sys.stdout", }, "file": { "level": "DEBUG", "class": "logging.FileHandler", "formatter": "default", "filename": self.logfile, }, }, "loggers": { self.loggername: { "level": "DEBUG", "handlers": ["console", "file"], # This logger inherits from root logger - dont propagate ! "propagate": False, } }, "disable_existing_loggers": False, } return conf def info(self, msg): return self.logger.info(msg) def debug(self, msg): return self.logger.debug(msg) def warning(self, msg): return self.logger.warning(msg) warn = warning def error(self, msg): return self.logger.error(msg) def fatal(self, msg): return self.logger.fatal(msg) def critical(self, msg): return self.logger.critical(msg) def LoggerOrPrint(logger): """ Logger that is either a "true" logger object, or a fall-back to "print". """ if logger is None: return PrinterLogger() return logger class PrinterLogger(object): def __init__(self): methods = [ "debug", "warn", "warning", "info", "error", "fatal", "critical", ] for method in methods: self.__setattr__(method, print) LogLevel = { "notset": logging.NOTSET, "debug": logging.DEBUG, "info": logging.INFO, "warn": logging.WARN, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, "fatal": logging.FATAL, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442905.0 nabu-2024.2.1/nabu/resources/nxflatfield.py0000644000175000017500000002215414730277631020114 0ustar00pierrepierreimport os import numpy as np from nxtomo.io import HDF5File from silx.io.url import DataUrl from silx.io import get_data from tomoscan.framereducer.reducedframesinfos import ReducedFramesInfos from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from ..utils import check_supported, is_writeable def get_frame_possible_urls(dataset_info, user_dir, output_dir): """ Return a dict with the possible location of reduced dark/flat frames. Parameters ---------- dataset_info: DatasetAnalyzer object DatasetAnalyzer object: data structure containing information on the parsed dataset user_dir: str or None User-provided directory location for the reduced frames. output_dir: str or None Output processing directory """ frame_types = ["flats", "darks"] h5scan = dataset_info.dataset_scanner # tomoscan object def make_dataurl(dirname, frame_type): """ The template formatting should be done by tomoscan in principle, but this complicates logging. """ if frame_type == "flats": dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0] else: dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0] rel_file_path = dataurl_default_template.file_path().format(scan_prefix=h5scan.get_dataset_basename()) return DataUrl( file_path=os.path.join(dirname, rel_file_path), data_path=dataurl_default_template.data_path().format(entry=h5scan.entry, index="{index}"), data_slice=dataurl_default_template.data_slice(), # not sure if needed scheme="silx", ) urls = {"user": None, "dataset": None, "output": None} if user_dir is not None: urls["user"] = {frame_type: make_dataurl(user_dir, frame_type) for frame_type in frame_types} # tomoscan.esrf.scan.hdf5scan.REDUCED_{DARKS|FLATS}_DATAURLS.file_path() is a relative path # Create a absolute path instead urls["dataset"] = { frame_type: make_dataurl(os.path.dirname(h5scan.master_file), frame_type) for frame_type in frame_types } if output_dir is not None: urls["output"] = {frame_type: make_dataurl(output_dir, frame_type) for frame_type in frame_types} return urls def save_reduced_frames(dataset_info, reduced_frames_arrays, reduced_frames_urls): reduce_func = {"flats": np.median, "darks": np.mean} # TODO configurable ? # Get "where to write". tomoscan expects a DataUrl darks_flats_dir_url = reduced_frames_urls.get("user", None) if darks_flats_dir_url is not None: output_url = darks_flats_dir_url elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())): output_url = reduced_frames_urls["dataset"] else: output_url = reduced_frames_urls["output"] # Get the "ReducedFrameInfos" data structure expected by tomoscan def _get_additional_info(frame_type): electric_current = dataset_info.dataset_scanner.electric_current count_time = dataset_info.dataset_scanner.count_time if electric_current is not None: electric_current = { sl.start: reduce_func[frame_type](electric_current[sl]) for sl in dataset_info.frames_slices(frame_type) } electric_current = [electric_current[k] for k in sorted(electric_current.keys())] if count_time is not None: count_time = { sl.start: reduce_func[frame_type](count_time[sl]) for sl in dataset_info.frames_slices(frame_type) } count_time = [count_time[k] for k in sorted(count_time.keys())] info = ReducedFramesInfos() info.count_time = count_time info.machine_electric_current = electric_current return info flats_info = _get_additional_info("flats") darks_info = _get_additional_info("darks") # Call tomoscan to save the reduced frames dataset_info.dataset_scanner.save_reduced_darks( reduced_frames_arrays["darks"], output_urls=[output_url["darks"]], darks_infos=darks_info, metadata_output_urls=[get_metadata_url(output_url["darks"], "darks")], overwrite=True, ) dataset_info.dataset_scanner.save_reduced_flats( reduced_frames_arrays["flats"], output_urls=[output_url["flats"]], flats_infos=flats_info, metadata_output_urls=[get_metadata_url(output_url["flats"], "flats")], overwrite=True, ) dataset_info.logger.info("Saved reduced darks/flats to %s" % output_url["flats"].file_path()) return output_url, flats_info, darks_info def get_metadata_url(url, frame_type): """ Return the url of the metadata stored alongside flats/darks """ check_supported(frame_type, ["flats", "darks"], "frame type") template_url = getattr(NXtomoScan, "REDUCED_%s_METADATAURLS" % frame_type.upper())[0] return DataUrl( file_path=url.file_path(), data_path=template_url.data_path(), scheme="silx", ) def tomoscan_load_reduced_frames(dataset_info, frame_type, url): tomoscan_method = getattr(dataset_info.dataset_scanner, "load_reduced_%s" % frame_type) return tomoscan_method( inputs_urls=[url], return_as_url=True, return_info=True, metadata_input_urls=[get_metadata_url(url, frame_type)], ) def data_url_exists(data_url): """ Return true iff the file exists and the data URL is valid (i.e data/group is actually in the file) """ if not (os.path.isfile(data_url.file_path())): return False group_exists = False with HDF5File(data_url.file_path(), "r") as f: data_path_without_index = data_url.data_path().split("{")[0] group_exists = f.get(data_path_without_index, default=None) is not None return group_exists # pylint: disable=E1136 def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=None, darks_flats_dir=None): """ Update a DatasetAnalyzer object with reduced flats/darks (hereafter "reduced frames"). How the reduced frames are loaded/computed/saved will depend on the "flatfield_mode" parameter. The principle is the following: (1) Attempt at loading already-computed reduced frames (XXX_darks.h5 and XXX_flats.h5): - First check files in the user-defined directory 'darks_flats_dir' - Then try to load from files located alongside the .nx dataset (dataset directory) - Then try to load from output_dir, if provided (2) If loading fails, or flatfield_mode == "force_compute", compute the reduced frames. (3) Save these reduced frames - Save in darks_flats_dir, if provided by user - Otherwise, save in the data directory (next to the .nx file), if write access OK - Otherwise, save in output directory """ if flatfield_mode is False: return frames_types = ["darks", "flats"] reduced_frames_urls = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir) def _compute_and_save_reduced_frames(): try: dataset_info.flats = dataset_info.get_reduced_flats() dataset_info.darks = dataset_info.get_reduced_darks() except FileNotFoundError: msg = "Could not find any flats and/or darks" raise FileNotFoundError(msg) _, flats_info, darks_info = save_reduced_frames( dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls ) dataset_info.flats_srcurrent = flats_info.machine_electric_current if flatfield_mode == "force-compute": _compute_and_save_reduced_frames() return def _can_load_from(folder_type): if reduced_frames_urls.get(folder_type, None) is None: return False return all([data_url_exists(reduced_frames_urls[folder_type][frame_type]) for frame_type in frames_types]) where_to_load_from = None if reduced_frames_urls["user"] is not None and _can_load_from("user"): where_to_load_from = "user" elif _can_load_from("dataset"): where_to_load_from = "dataset" elif _can_load_from("output"): where_to_load_from = "output" if where_to_load_from == None and flatfield_mode == "force-load": raise ValueError("Could not load darks/flats (using 'force-load')") if where_to_load_from is not None: reduced_frames_with_info = {} for frame_type in frames_types: reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames( dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type] ) dataset_info.logger.info( "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path()) ) red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type] setattr( dataset_info, frame_type, {k: get_data(red_frames_dict[k]) for k in red_frames_dict.keys()}, ) if frame_type == "flats": dataset_info.flats_srcurrent = red_frames_info.machine_electric_current else: _compute_and_save_reduced_frames() ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/resources/templates/0000755000175000017500000000000014730277752017240 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/templates/__init__.py0000644000175000017500000000000014402565210021320 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/templates/bm05_pag.conf0000644000175000017500000000061214402565210021461 0ustar00pierrepierre# # ESRF BM05 phase contrast # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 take_logarithm = 1 [phase] method = paganin delta_beta = 100 [reconstruction] rotation_axis_position = sliding-window enable_halftomo = auto clip_outer_circle = 1 centered_axis = 1 [output] location = ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/templates/id16_ctf.conf0000644000175000017500000000151214402565210021466 0ustar00pierrepierre# # ESRF ID16 single-distance CTF # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 flat_distortion_correction_enabled = 0 flat_distortion_params = tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None take_logarithm = 0 ccd_filter_enabled = 1 ccd_filter_threshold = 0.04 [phase] method = ctf delta_beta = 80 ctf_geometry = z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True ctf_advanced_params = length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True [reconstruction] rotation_axis_position = global cor_options = translation_movements_file = enable_halftomo = 0 clip_outer_circle = 1 centered_axis = 1 [output] location = file_format = tiff tiff_single_file = 1 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/templates/id16_holo.conf0000644000175000017500000000140514402565210021654 0ustar00pierrepierre# # ESRF ID16 holo-tomography # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 0 flat_distortion_correction_enabled = 0 flat_distortion_params = tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None take_logarithm = 0 [phase] method = none ctf_geometry = z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True ctf_advanced_params = length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True [reconstruction] rotation_axis_position = global cor_options = translation_movements_file = enable_halftomo = 0 clip_outer_circle = 1 centered_axis = 1 [output] location = file_format = tiff tiff_single_file = 1 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1732264041.0 nabu-2024.2.1/nabu/resources/templates/id16a_fluo.conf0000644000175000017500000000103614720040151022014 0ustar00pierrepierre# # ESRF ID16a fluo-tomography # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [dataset] hdf5_entry = all [preproc] flatfield = 0 flat_distortion_correction_enabled = 0 take_logarithm = 0 [phase] method = none [reconstruction] method = mlem rotation_axis_position = 0. cor_options = translation_movements_file = enable_halftomo = 0 clip_outer_circle = 1 centered_axis = 1 iterations = 200 [output] location = file_format = tiff tiff_single_file = 1 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/templates/id19_pag.conf0000644000175000017500000000060014402565210021461 0ustar00pierrepierre# # ESRF ID19 phase contrast # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 take_logarithm = 1 [phase] method = paganin delta_beta = 100 [reconstruction] rotation_axis_position = sliding-window translation_movements_file = enable_halftomo = auto [output] location = ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.520757 nabu-2024.2.1/nabu/resources/tests/0000755000175000017500000000000014730277752016404 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/resources/tests/__init__.py0000644000175000017500000000000014315516747020502 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731055828.0 nabu-2024.2.1/nabu/resources/tests/test_extract.py0000644000175000017500000000033414713350324021453 0ustar00pierrepierreimport pytest from nabu.utils import list_match_queries def test_list_match_queries(): # entry0000 .... entry0099 avail = ["entry%04d" % i for i in range(100)] query = "entry0000" list_match_queries() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1723556968.0 nabu-2024.2.1/nabu/resources/tests/test_nxflatfield.py0000644000175000017500000001012314656662150022307 0ustar00pierrepierrefrom os import path from tempfile import mkdtemp from shutil import rmtree import pytest import numpy as np from silx.io import get_data from nxtomo.nxobject.nxdetector import ImageKey from nabu.testutils import generate_nx_dataset from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer test_nxflatfield_scenarios = [ { "name": "simple", "flats_pos": [slice(1, 6)], "darks_pos": [slice(0, 1)], "output_dir": None, }, { "name": "simple_with_save", "flats_pos": [slice(1, 6)], "darks_pos": [slice(0, 1)], "output_dir": None, }, { "name": "multiple_with_save", "flats_pos": [slice(0, 10), slice(30, 40)], "darks_pos": [slice(95, 100)], "output_dir": path.join("{tempdir}", "output_reduced"), }, ] # parametrize with fixture and "params=" will launch a new class for each scenario. # the attributes set to "cls" will remain for all the tests done in this class # with the current scenario. @pytest.fixture(scope="class", params=test_nxflatfield_scenarios) def bootstrap(request): cls = request.cls cls.n_projs = 265 cls.params = request.param cls.tempdir = mkdtemp(prefix="nabu_") yield rmtree(cls.tempdir) @pytest.mark.usefixtures("bootstrap") class TestNXFlatField: _reduction_func = {"flats": np.median, "darks": np.mean} def get_nx_filename(self): return path.join(self.tempdir, "dataset_" + self.params["name"] + ".nx") def get_image_key(self): keys = np.zeros(self.n_projs, np.int32) for what, val in [("flats_pos", ImageKey.FLAT_FIELD.value), ("darks_pos", ImageKey.DARK_FIELD.value)]: for pos in self.params[what]: keys[pos.start : pos.stop] = val return keys @staticmethod def check_image_key(dataset_info, frame_type, expected_slices): data_slices = dataset_info.get_data_slices(frame_type) assert set(map(str, data_slices)) == set(map(str, expected_slices)) def test_nxflatfield(self): dataset_fname = self.get_nx_filename() image_key = self.get_image_key() generate_nx_dataset(dataset_fname, image_key) dataset_info = HDF5DatasetAnalyzer(dataset_fname) # When parsing a "raw" dataset, flats/darks are a series of images. # dataset_info.flats is a dictionary where each key is the index of the frame. # For example dataset_info.flats.keys() = [10, 11, 12, 13, ..., 19, 1200, 1201, 1202, ..., 1219] for frame_type in ["darks", "flats"]: self.check_image_key(dataset_info, frame_type, self.params[frame_type + "_pos"]) output_dir = self.params.get("output_dir", None) if output_dir is not None: output_dir = output_dir.format(tempdir=self.tempdir) update_dataset_info_flats_darks(dataset_info, True, output_dir=output_dir) # After reduction (median/mean), the flats/darks are located in another file. # median(series_1) goes to entry/flats/idx1, mean(series_2) goes to entry/flats/idx2, etc. assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"]) assert set(dataset_info.darks.keys()) == set(s.start for s in self.params["darks_pos"]) # Check that the computations were correct # Loads the entire volume in memory ! So keep the data volume small for the tests data_volume = get_data(dataset_info.dataset_hdf5_url) expected_flats = {} for s in self.params["flats_pos"]: expected_flats[s.start] = self._reduction_func["flats"](data_volume[s.start : s.stop], axis=0) expected_darks = {} for s in self.params["darks_pos"]: expected_darks[s.start] = self._reduction_func["darks"](data_volume[s.start : s.stop], axis=0) flats = dataset_info.flats for idx in flats.keys(): assert np.allclose(flats[idx], expected_flats[idx]) darks = dataset_info.darks for idx in darks.keys(): assert np.allclose(darks[idx], expected_darks[idx]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/resources/tests/test_units.py0000644000175000017500000000415314402565210021143 0ustar00pierrepierreimport pytest from nabu.utils import compare_dicts from nabu.resources.utils import get_quantities_and_units class TestUnits: expected_results = { "distance = 1 m ; pixel_size = 2.0 um": {"distance": 1.0, "pixel_size": 2e-6}, "distance = 1 m ; pixel_size = 2.6 micrometer": {"distance": 1.0, "pixel_size": 2.6e-6}, "distance = 10 m ; pixel_size = 2e-6 m": {"distance": 10, "pixel_size": 2e-6}, "distance = .5 m ; pixel_size = 2.6e-4 centimeter": {"distance": 0.5, "pixel_size": 2.6e-6}, "distance = 10 cm ; pixel_size = 2.5 micrometer ; energy = 1 ev": { "distance": 0.1, "pixel_size": 2.5e-6, "energy": 1.0e-3, }, "distance = 10 cm ; pixel_size = 9.0e-3 millimeter ; energy = 19 kev": { "distance": 0.1, "pixel_size": 9e-6, "energy": 19.0, }, } expected_failures = { # typo ("ke" instead of "kev") "distance = 10 cm ; energy = 10 ke": ValueError("Cannot convert: ke"), # No units "distance = 10 ; energy = 10 kev": ValueError("not enough values to unpack (expected 2, got 1)"), # Unit not separated by space "distance = 10m; energy = 10 kev": ValueError("not enough values to unpack (expected 2, got 1)"), # Invalid separator "distance = 10 m, energy = 10 kev": ValueError("too many values to unpack (expected 2)"), } def test_conversion(self): for test_str, expected_result in self.expected_results.items(): res = get_quantities_and_units(test_str) err_msg = str( "Something wrong with quantities/units extraction from '%s': expected %s, got %s" % (test_str, str(expected_result), str(res)) ) assert compare_dicts(res, expected_result) is None, err_msg def test_valid_input(self): for test_str, expected_failure in self.expected_failures.items(): with pytest.raises(type(expected_failure)) as e_info: get_quantities_and_units(test_str) assert e_info.value.args[0] == str(expected_failure) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/resources/utils.py0000644000175000017500000001313514550227307016746 0ustar00pierrepierrefrom ast import literal_eval import numpy as np from psutil import virtual_memory, cpu_count from pyunitsystem.metricsystem import MetricSystem from pyunitsystem.energysystem import EnergySI def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False): """ Read a text file and scan the values inside. This function expects one value per line, or values separated with a separator defined with the `sep` parameter. Parameters ---------- fname: str Path of the text file n_values: int, optional If set to a value, this function will check that it scans exactly this number of values. Ignored if `shape` is provided shape: tuple, optional Generalization of n_values for higher dimensions. sep: str, optional Separator between values. Default is white space any_size: bool, optional If set to True, then the parameters 'n_values' and 'shape' are ignored. Returns -------- arr: numpy.ndarray An array containing the values scanned from the text file """ if not (any_size) and not ((n_values is not None) ^ (shape is not None)): raise ValueError("Please provide either n_values or shape") arr = np.loadtxt(fname, ndmin=1) if (n_values is not None) and (arr.shape[0] != n_values): if any_size: arr = arr[:n_values] else: raise ValueError("Expected %d values, but could get %d values" % (n_values, arr.shape[0])) if (shape is not None) and (arr.shape != shape): if any_size: arr = arr[: shape[0], : shape[1]] # TODO handle more dimensions ? else: raise ValueError("Expected shape %s, but got shape %s" % (shape, arr.shape)) return arr def get_memory_per_node(max_mem, is_percentage=True): """ Get the available memory per node in GB. Parameters ---------- max_mem: float If is_percentage is False, then number is interpreted as an absolute number in GigaBytes. Otherwise, it should be a number between 0 and 100 and is interpreted as a percentage. is_percentage: bool A boolean indicating whether the parameter max_mem is to be interpreted as a percentage of available system memory. """ sys_avail_mem = virtual_memory().available / 1e9 if is_percentage: return (max_mem / 100.0) * sys_avail_mem else: return min(max_mem, sys_avail_mem) def get_threads_per_node(max_threads, is_percentage=True): """ Get the available memory per node in GB. Parameters ---------- max_threads: float If is_percentage is False, then number is interpreted as an absolute number of threads. Otherwise, it should be a number between 0 and 100 and is interpreted as a percentage. is_percentage: bool A boolean indicating whether the parameter max_threads is to be interpreted as a percentage of available system memory. """ sys_n_threads = cpu_count(logical=True) if is_percentage: return (max_threads / 100.0) * sys_n_threads else: return min(max_threads, sys_n_threads) def extract_parameters(params_str, sep=";"): """ Extract the named parameters from a string. Example -------- The function can be used as follows: >>> extract_parameters("window_width=None; median_filt_shape=(3,3); padding_mode='wrap'") ... {'window_width': None, 'median_filt_shape': (3, 3), 'padding_mode': 'wrap'} """ if params_str in ("", None): return {} params_list = params_str.strip(sep).split(sep) res = {} for param_str in params_list: param_name, param_val_str = param_str.strip().split("=") param_name = param_name.strip() param_val_str = param_val_str.strip() param_val = literal_eval(param_val_str) res[param_name] = param_val return res def compact_parameters(params_dict, sep=";"): """ Compact the parameters from a dict into a string. This is the inverse of extract_parameters. It can be used for example in tomwer to convert parameters into a string, for example for cor_options, prior to calling a nabu method which is expecting an argument to be in the form of a string containing options. Example -------- The function can be used as follows: >>> compact_parameters( {"side":"near", "near_pos":300 } ) ... "side=near; nearpos= 300;" """ if params_dict in ({}, None): return "" res = "" for key, val in params_dict.items(): res = res + "{key} = {val} " + sep return res def is_hdf5_extension(ext): return ext.lower() in ["h5", "hdf5", "nx"] def get_quantities_and_units(string, sep=";"): """ Return a dictionary with quantities as keys, and values in SI. Example ------- get_quantities_and_units("pixel_size = 1.2 um ; distance = 1 m") Will return {"pixel_size": 1.2e-6, "distance": 1} """ result = {} quantities = string.split(sep) for quantity in quantities: quantity_name, value_and_unit = quantity.split("=") quantity_name = quantity_name.strip() value_and_unit = value_and_unit.strip() value, unit = value_and_unit.split() val = float(value) # Convert to SI try: # handle metrics conversion_factor = MetricSystem.from_str(unit).value except ValueError: # handle energies conversion_factor = EnergySI.from_str(unit).value / EnergySI.KILOELECTRONVOLT.value result[quantity_name] = val * conversion_factor return result ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/0000755000175000017500000000000014730277752015224 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678437000.0 nabu-2024.2.1/nabu/stitching/__init__.py0000644000175000017500000000000014402565210017304 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/alignment.py0000644000175000017500000002156614654107202017551 0ustar00pierrepierreimport h5py import numpy from typing import Union from silx.utils.enum import Enum as _Enum from tomoscan.volumebase import VolumeBase from tomoscan.esrf.volume.hdf5volume import HDF5Volume from nabu.io.utils import DatasetReader class AlignmentAxis2(_Enum): """Specific alignment named to help users orienting themself with specific name""" CENTER = "center" LEFT = "left" RIGTH = "right" class AlignmentAxis1(_Enum): """Specific alignment named to help users orienting themself with specific name""" FRONT = "front" CENTER = "center" BACK = "back" class _Alignment(_Enum): """Internal alignment to be used for 2D alignment""" LOWER_BOUNDARY = "lower boundary" HIGHER_BOUNDARY = "higher boundary" CENTER = "center" @classmethod def from_value(cls, value): # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT): return _Alignment.LOWER_BOUNDARY elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH): return _Alignment.HIGHER_BOUNDARY elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER): return _Alignment.CENTER else: return super().from_value(value) def align_frame( data: numpy.ndarray, alignment: _Alignment, alignment_axis: int, new_aligned_axis_size: int, pad_mode="constant" ): """ Align 2D array to extend if size along `alignment_axis` to `new_aligned_axis_size`. :param numpy.ndarray data: data (frame) to align (2D numpy array) :param alignment_axis: axis along which we want to align the frame. Must be in (0, 1) :param HAlignment alignment: alignment strategy :param int new_width: output data width """ if alignment_axis not in (0, 1): raise ValueError(f"alignment_axis should be in (0, 1). Get {alignment_axis}") alignment = _Alignment.from_value(alignment) aligned_axis_size = data.shape[alignment_axis] if aligned_axis_size > new_aligned_axis_size: raise ValueError( f"data.shape[alignment_axis] ({data.shape[alignment_axis]}) > new_aligned_axis_size ({new_aligned_axis_size}). Unable to crop data" ) elif aligned_axis_size == new_aligned_axis_size: return data else: if alignment is _Alignment.CENTER: lower_boundary = (new_aligned_axis_size - aligned_axis_size) // 2 higher_boundary = (new_aligned_axis_size - aligned_axis_size) - lower_boundary elif alignment is _Alignment.LOWER_BOUNDARY: lower_boundary = 0 higher_boundary = new_aligned_axis_size - aligned_axis_size elif alignment is _Alignment.HIGHER_BOUNDARY: lower_boundary = new_aligned_axis_size - aligned_axis_size higher_boundary = 0 else: raise ValueError(f"alignment {alignment.value} is not handled") assert lower_boundary >= 0, f"pad size must be positive - lower boundary isn't ({lower_boundary})" assert higher_boundary >= 0, f"pad size must be positive - higher boundary isn't ({higher_boundary})" if alignment_axis == 1: return numpy.pad( data, pad_width=((0, 0), (lower_boundary, higher_boundary)), mode=pad_mode, ) elif alignment_axis == 0: return numpy.pad( data, pad_width=((lower_boundary, higher_boundary), (0, 0)), mode=pad_mode, ) else: raise ValueError("alignment_axis should be in (0, 1)") def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width: int, pad_mode="constant"): """ Align data horizontally to make sure new data width will ne `new_width`. :param numpy.ndarray data: data to align :param HAlignment alignment: alignment strategy :param int new_width: output data width """ alignment = AlignmentAxis2.from_value(alignment).value return align_frame( data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1 ) class PaddedRawData: """ Util class to extend a data when necessary Must to aplpy to a volume and to an hdf5dataset - array The idea behind is to avoid loading all the data in memory """ def __init__(self, data: Union[numpy.ndarray, h5py.Dataset], axis_1_pad_width: tuple) -> None: self._axis_1_pad_width = numpy.array(axis_1_pad_width) if not (self._axis_1_pad_width.size == 2 and self._axis_1_pad_width[0] >= 0 and self._axis_1_pad_width[1] >= 0): raise ValueError(f"'axis_1_pad_width' expects to positive elements. Get {axis_1_pad_width}") self._raw_data = data self._raw_data_end = None # note: for now we return only frames with zeros for padded frames. # in the future we could imagine having a method and miror existing volume or extend the closest frame, or get a mean value... self._empty_frame = None self._dtype = None self._shape = None self._raw_data_shape = self.raw_data.shape @staticmethod def get_empty_frame(shape, dtype): return numpy.zeros( shape=shape, dtype=dtype, ) @property def empty_frame(self): if self._empty_frame is None: self._empty_frame = self.get_empty_frame( shape=(self.shape[0], 1, self.shape[2]), dtype=self.dtype, ) return self._empty_frame @property def shape(self): if self._shape is None: self._shape = tuple( ( self._raw_data_shape[0], numpy.sum( numpy.array(self._axis_1_pad_width), ) + self._raw_data_shape[1], self._raw_data_shape[2], ) ) return self._shape @property def raw_data(self): return self._raw_data @property def raw_data_start(self): return self._axis_1_pad_width[0] @property def raw_data_end(self): if self._raw_data_end is None: self._raw_data_end = self._axis_1_pad_width[0] + self._raw_data_shape[1] return self._raw_data_end @property def dtype(self): if self._dtype is None: self._dtype = self.raw_data.dtype return self._dtype def __getitem__(self, args): if not isinstance(args, tuple) and len(args) == 3: raise ValueError("only handles 3D slicing") elif not (args[0] == slice(None, None, None) and args[2] == slice(None, None, None)): raise ValueError( "slicing only handled along axis 1. First and third tuple item are expected to be empty slice as slice(None, None, None)" ) else: if numpy.isscalar(args[1]): args = ( args[0], slice(args[1], args[1] + 1, 1), args[2], ) start = args[1].start if start is None: start = 0 stop = args[1].stop if stop is None: stop = self.shape[1] step = args[1].step # some test if start < 0 or stop < 0: raise ValueError("only positive position are handled") if start >= stop: raise ValueError("start >= stop") if stop > self.shape[1]: raise ValueError("stop > self.shape[1]") if step not in (1, None): raise ValueError("for now PaddedVolume only handles steps of 1") first_part_array = None if start < self.raw_data_start and (stop - start > 0): stop_first_part = min(stop, self.raw_data_start) first_part_array = numpy.repeat(self.empty_frame, repeats=stop_first_part - start, axis=1) start = stop_first_part third_part_array = None if stop > self.raw_data_end and (stop - start > 0): if stop > self.shape[1]: raise ValueError("requested slice is out of boundaries") start_third_part = max(start, self.raw_data_end) third_part_array = numpy.repeat(self.empty_frame, repeats=stop - start_third_part, axis=1) stop = self.raw_data_end if start >= self.raw_data_start and stop >= self.raw_data_start and (stop - start > 0): second_part_array = self.raw_data[:, start - self.raw_data_start : stop - self.raw_data_start, :] else: second_part_array = None parts = tuple(filter(lambda a: a is not None, (first_part_array, second_part_array, third_part_array))) return numpy.hstack( parts, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/config.py0000644000175000017500000014702314713343202017033 0ustar00pierrepierrefrom math import ceil from typing import Any, Iterable, Optional, Union, Sized from dataclasses import dataclass import numpy from pyunitsystem.metricsystem import MetricSystem from nxtomo.paths import nxtomo from tomoscan.factory import Factory from tomoscan.identifier import VolumeIdentifier, ScanIdentifier from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from ..pipeline.config_validators import ( boolean_validator, convert_to_bool, ) from ..utils import concatenate_dict, convert_str_to_tuple from ..io.utils import get_output_volume from .overlap import OverlapStitchingStrategy from .utils.utils import ShiftAlgorithm from .definitions import StitchingType from .alignment import AlignmentAxis1, AlignmentAxis2 from pyunitsystem.metricsystem import MetricSystem KEY_IMG_REG_METHOD = "img_reg_method" KEY_WINDOW_SIZE = "window_size" KEY_LOW_PASS_FILTER = "low_pass" KEY_HIGH_PASS_FILTER = "high_pass" KEY_OVERLAP_SIZE = "overlap_size" KEY_SIDE = "side" OUTPUT_SECTION = "output" INPUTS_SECTION = "inputs" PRE_PROC_SECTION = "preproc" POST_PROC_SECTION = "postproc" INPUT_DATASETS_FIELD = "input_datasets" INPUT_PIXEL_SIZE_MM = "pixel_size" INPUT_VOXEL_SIZE_MM = "voxel_size" STITCHING_SECTION = "stitching" STITCHING_STRATEGY_FIELD = "stitching_strategy" STITCHING_TYPE_FIELD = "type" DATA_FILE_FIELD = "location" OVERWRITE_RESULTS_FIELD = "overwrite_results" DATA_PATH_FIELD = "data_path" AXIS_0_POS_PX = "axis_0_pos_px" AXIS_1_POS_PX = "axis_1_pos_px" AXIS_2_POS_PX = "axis_2_pos_px" AXIS_0_POS_MM = "axis_0_pos_mm" AXIS_1_POS_MM = "axis_1_pos_mm" AXIS_2_POS_MM = "axis_2_pos_mm" AXIS_0_PARAMS = "axis_0_params" AXIS_1_PARAMS = "axis_1_params" AXIS_2_PARAMS = "axis_2_params" FLIP_LR = "fliplr" FLIP_UD = "flipud" NEXUS_VERSION_FIELD = "nexus_version" OUTPUT_DTYPE = "data_type" OUTPUT_VOLUME = "output_volume" STITCHING_SLICES = "slices" CROSS_CORRELATION_SLICE_FIELD = "slice_index_for_correlation" RESCALE_FRAMES = "rescale_frames" RESCALE_PARAMS = "rescale_params" KEY_RESCALE_MIN_PERCENTILES = "rescale_min_percentile" KEY_RESCALE_MAX_PERCENTILES = "rescale_max_percentile" ALIGNMENT_AXIS_2_FIELD = "alignment_axis_2" ALIGNMENT_AXIS_1_FIELD = "alignment_axis_1" PAD_MODE_FIELD = "pad_mode" AVOID_DATA_DUPLICATION_FIELD = "avoid_data_duplication" # SLURM SLURM_SECTION = "slurm" SLURM_PARTITION = "partition" SLURM_MEM = "memory" SLURM_COR_PER_TASKS = "cpu-per-task" SLURM_NUMBER_OF_TASKS = "n_tasks" SLURM_N_JOBS = "n_jobs" SLURM_OTHER_OPTIONS = "other_options" SLURM_PREPROCESSING_COMMAND = "python_venv" SLURM_MODULES_TO_LOADS = "modules" SLURM_CLEAN_SCRIPTS = "clean_scripts" # normalization by sample NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample" NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD = "active" NORMALIZATION_BY_SAMPLE_METHOD = "method" NORMALIZATION_BY_SAMPLE_SIDE = "side" NORMALIZATION_BY_SAMPLE_MARGIN = "margin" NORMALIZATION_BY_SAMPLE_WIDTH = "width" # kernel extra options STITCHING_KERNELS_EXTRA_PARAMS = "stitching_kernels_extra_params" KEY_THRESHOLD_FREQUENCY = "threshold_frequency" CROSS_CORRELATION_METHODS_AXIS_0 = { "": "", # for display ShiftAlgorithm.NABU_FFT.value: "will call nabu `find_shift_correlate` function - shift search in fourier space", ShiftAlgorithm.SKIMAGE.value: "use scikit image `phase_cross_correlation` function in real space", ShiftAlgorithm.NONE.value: "no shift research is done. will only get shift from motor positions", } CROSS_CORRELATION_METHODS_AXIS_2 = CROSS_CORRELATION_METHODS_AXIS_0.copy() CROSS_CORRELATION_METHODS_AXIS_2.update( { ShiftAlgorithm.CENTERED.value: "a fast and simple auto-CoR method. It only works when the CoR is not far from the middle of the detector. It does not work for half-tomography.", ShiftAlgorithm.GLOBAL.value: "a slow but robust auto-CoR.", ShiftAlgorithm.GROWING_WINDOW.value: "automatically find the CoR with a sliding-and-growing window. You can tune the option with the parameter 'cor_options'.", ShiftAlgorithm.SLIDING_WINDOW.value: "semi-automatically find the CoR with a sliding window. You have to specify on which side the CoR is (left, center, right). Please see the 'cor_options' parameter.", ShiftAlgorithm.COMPOSITE_COARSE_TO_FINE.value: "Estimate CoR from composite multi-angle images. Only works for 360 degrees scans.", ShiftAlgorithm.SINO_COARSE_TO_FINE.value: "Estimate CoR from sinogram. Only works for 360 degrees scans.", } ) SECTIONS_COMMENTS = { STITCHING_SECTION: "section dedicated to stich parameters\n", OUTPUT_SECTION: "section dedicated to output parameters\n", INPUTS_SECTION: "section dedicated to inputs\n", SLURM_SECTION: "section didicated to slurm. If you want to run locally avoid setting 'partition or remove this section'", NORMALIZATION_BY_SAMPLE_SECTION: "section dedicated to normalization by a sample. If activate each frame can be normalized by a sample of the frame", } DEFAULT_SHIFT_ALG_AXIS_0 = "nabu-fft" DEFAULT_SHIFT_ALG_AXIS_2 = "sliding-window" _shift_algs_axis_0 = "\n + ".join( [f"{key}: {value}" for key, value in CROSS_CORRELATION_METHODS_AXIS_0.items()] ) _shift_algs_axis_2 = "\n + ".join( [f"{key}: {value}" for key, value in CROSS_CORRELATION_METHODS_AXIS_2.items()] ) HELP_SHIFT_PARAMS = f"""options for shifts algorithms as `key1=value1,key2=value2`. For now valid keys are: - {KEY_OVERLAP_SIZE}: size to apply stitching. If not provided will take the largest size possible'. - {KEY_IMG_REG_METHOD}: algorithm to use to find overlaps between the different sections. Possible values are \n * for axis 0: {_shift_algs_axis_0}\n * and for axis 2: {_shift_algs_axis_2} - {KEY_LOW_PASS_FILTER}: low pass filter value for filtering frames before shift research - {KEY_HIGH_PASS_FILTER}: high pass filter value for filtering frames before shift research""" def _str_to_dict(my_str: Union[str, dict]): """convert a string as key_1=value_2;key_2=value_2 to a dict""" if isinstance(my_str, dict): return my_str res = {} for key_value in filter(None, my_str.split(";")): key, value = key_value.split("=") res[key] = value return res def _dict_to_str(ddict: dict): return ";".join([f"{str(key)}={str(value)}" for key, value in ddict.items()]) def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]: if my_str is None: return None elif isinstance(my_str, str): my_str = my_str.replace(" ", "") my_str = my_str.lstrip("[").lstrip("(") my_str = my_str.rstrip("]").lstrip(")") if my_str == "": return None try: shift = ShiftAlgorithm.from_value(my_str) except ValueError: shifts_as_str = filter(None, my_str.replace(";", ",").split(",")) return [float(shift) for shift in shifts_as_str] else: return shift elif isinstance(my_str, (tuple, list)): return [float(shift) for shift in my_str] else: raise TypeError("Only str or tuple of str expected expected") def _valid_stitching_kernels_params(my_dict: Union[dict, str]): if isinstance(my_dict, str): my_dict = _str_to_dict(my_str=my_dict) valid_keys = (KEY_THRESHOLD_FREQUENCY, KEY_SIDE) for key in my_dict.keys(): if not key in valid_keys: raise KeyError(f"{key} is a unrecognized key") return my_dict def _valid_shifts_params(my_dict: Union[dict, str]): if isinstance(my_dict, str): my_dict = _str_to_dict(my_str=my_dict) valid_keys = ( KEY_WINDOW_SIZE, KEY_IMG_REG_METHOD, KEY_OVERLAP_SIZE, KEY_HIGH_PASS_FILTER, KEY_LOW_PASS_FILTER, KEY_SIDE, ) for key in my_dict.keys(): if not key in valid_keys: raise KeyError(f"{key} is a unrecognized key") return my_dict def _slices_to_list_or_slice(my_str: Optional[str]) -> Union[str, slice]: if my_str is None: return None if isinstance(my_str, (tuple, list)): if len(my_str) == 2: return slice(int(my_str[0]), int(my_str[1])) elif len(my_str) == 3: return slice(int(my_str[0]), int(my_str[1]), int(my_str[2])) else: raise ValueError("expect at most free values to define a slice") assert isinstance(my_str, str), f"wrong type. Get {my_str}, {type(my_str)}" my_str = my_str.replace(" ", "") if ":" in my_str: split_string = my_str.split(":") start = int(split_string[0]) stop = int(split_string[1]) if len(split_string) == 2: step = None elif len(split_string) == 3: step = int(split_string[2]) else: raise ValueError(f"unable to interpret `slices` parameter: {my_str}") return slice(start, stop, step) else: my_str.replace(",", ";") return list(filter(None, my_str.split(";"))) def _scalar_or_tuple_to_bool_or_tuple_of_bool(my_str: Union[bool, tuple, str], default=False): if isinstance(my_str, bool): return my_str elif isinstance(my_str, str): my_str = my_str.replace(" ", "") my_str = my_str.lstrip("(").lstrip("[") my_str = my_str.rstrip(")").lstrip("]") my_str = my_str.replace(",", ";") values = my_str.split(";") values = tuple([convert_to_bool(value)[0] for value in values]) else: values = my_str if len(values) == 0: return default elif len(values) == 1: return values[0] else: return values from nabu.stitching.sample_normalization import Method, SampleSide class NormalizationBySample: def __init__(self) -> None: self._active = False self._method = Method.MEAN self._margin = 0 self._side = SampleSide.LEFT self._width = 30 def is_active(self): return self._active def set_is_active(self, active: bool): assert isinstance( active, bool ), f"active is expected to be a bool. Get {type(active)} instead. Value == {active}" self._active = active @property def method(self) -> Method: return self._method @method.setter def method(self, method: Union[Method, str]) -> None: self._method = Method.from_value(method) @property def margin(self) -> int: return self._margin @margin.setter def margin(self, margin: int): assert isinstance(margin, int), f"margin is expected to be an int. Get {type(margin)} instead" self._margin = margin @property def side(self) -> SampleSide: return self._side @side.setter def side(self, side: Union[SampleSide, str]): self._side = SampleSide.from_value(side) @property def width(self) -> int: return self._width @width.setter def width(self, width: int): assert isinstance(width, int), f"width is expected to be an int. Get {type(width)} instead" @staticmethod def from_dict(my_dict: dict): sample_normalization = NormalizationBySample() # active active = my_dict.get(NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD, None) if active is not None: active = active in (True, "True", 1, "1") sample_normalization.set_is_active(active) # method method = my_dict.get(NORMALIZATION_BY_SAMPLE_METHOD, None) if method is not None: sample_normalization.method = method # margin margin = my_dict.get(NORMALIZATION_BY_SAMPLE_MARGIN, None) if margin is not None: sample_normalization.margin = int(margin) # side side = my_dict.get(NORMALIZATION_BY_SAMPLE_SIDE, None) if side is not None: sample_normalization.side = side # width width = my_dict.get(NORMALIZATION_BY_SAMPLE_WIDTH, None) if width is not None: sample_normalization.width = int(width) return sample_normalization def to_dict(self) -> dict: return { NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: self.is_active(), NORMALIZATION_BY_SAMPLE_METHOD: self.method.value, NORMALIZATION_BY_SAMPLE_MARGIN: self.margin, NORMALIZATION_BY_SAMPLE_SIDE: self.side.value, NORMALIZATION_BY_SAMPLE_WIDTH: self.width, } def __eq__(self, __value: object) -> bool: if not isinstance(__value, NormalizationBySample): return False else: return self.to_dict() == __value.to_dict() @dataclass class SlurmConfig: "configuration for slurm jobs" partition: str = "" # note: must stay empty to make by default we don't use slurm (use by the configuration file) mem: str = "128" n_jobs: int = 1 other_options: str = "" preprocessing_command: str = "" modules_to_load: tuple = tuple() clean_script: bool = "" n_tasks: int = 1 n_cpu_per_task: int = 4 def __post_init__(self) -> None: # make sure either 'modules' or 'preprocessing_command' is provided if len(self.modules_to_load) > 0 and self.preprocessing_command not in (None, ""): raise ValueError( f"Either modules {SLURM_MODULES_TO_LOADS} or preprocessing_command {SLURM_PREPROCESSING_COMMAND} can be provided. Not both." ) def to_dict(self) -> dict: "dump configuration to dict" return { SLURM_PARTITION: self.partition if self.partition is not None else "", SLURM_MEM: self.mem, SLURM_N_JOBS: self.n_jobs, SLURM_OTHER_OPTIONS: self.other_options, SLURM_PREPROCESSING_COMMAND: self.preprocessing_command, SLURM_MODULES_TO_LOADS: self.modules_to_load, SLURM_CLEAN_SCRIPTS: self.clean_script, SLURM_NUMBER_OF_TASKS: self.n_tasks, SLURM_COR_PER_TASKS: self.n_cpu_per_task, } @staticmethod def from_dict(config: dict): return SlurmConfig( partition=config.get( SLURM_PARTITION, None ), # warning: never set a default value. Would generate infinite loop from slurm call mem=config.get(SLURM_MEM, "32GB"), n_jobs=int(config.get(SLURM_N_JOBS, 10)), other_options=config.get(SLURM_OTHER_OPTIONS, ""), n_tasks=config.get(SLURM_NUMBER_OF_TASKS, 1), n_cpu_per_task=config.get(SLURM_COR_PER_TASKS, 4), preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""), modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")), clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0], ) def _cast_shift_to_str(shifts: Union[tuple, str, None]) -> str: if shifts is None: return "" elif isinstance(shifts, ShiftAlgorithm): return shifts.value elif isinstance(shifts, str): return shifts elif isinstance(shifts, (tuple, list)): return ";".join([str(value) for value in shifts]) @dataclass class StitchingConfiguration: """ bass class to define stitching configuration """ axis_0_pos_px: Union[tuple, str, None] "position along axis 0 in absolute. unit: px" axis_1_pos_px: Union[tuple, str, None] "position along axis 1 in absolute. unit: px" axis_2_pos_px: Union[tuple, str, None] "position along axis 2 in absolute. unit: px" axis_0_pos_mm: Union[tuple, str, None] = None "position along axis 0 in absolute. unit: mm" axis_1_pos_mm: Union[tuple, str, None] = None "position along axis 0 in absolute. unit: mm" axis_2_pos_mm: Union[tuple, str, None] = None "position along axis 0 in absolute. unit: mm" axis_0_params: dict = None axis_1_params: dict = None axis_2_params: dict = None slurm_config: SlurmConfig = None flip_lr: Union[tuple, bool] = False "flip frame left-right. For scan this will be append to the NXtransformations of the detector" flip_ud: Union[tuple, bool] = False "flip frame up-down. For scan this will be append to the NXtransformations of the detector" overwrite_results: bool = False stitching_strategy: OverlapStitchingStrategy = OverlapStitchingStrategy.COSINUS_WEIGHTS stitching_kernels_extra_params: dict = None slice_for_cross_correlation: Union[str, int] = "middle" # opts for rescaling frame during stitching rescale_frames: bool = False rescale_params: dict = None normalization_by_sample: NormalizationBySample = None duplicate_data: bool = True """when possible (for HDF5) avoid duplicating data as-much-much-as-possible. Overlaping region between two frames will be duplicated. Remaining will be 'raw_data' for volume. For projection flat field will be applied""" @property def stitching_type(self): raise NotImplementedError("Base class") def __post_init__(self): if self.normalization_by_sample is None: self.normalization_by_sample = NormalizationBySample() @staticmethod def get_description_dict() -> dict: def get_pos_info(axis, unit, alternative): return f"position over {axis} in {unit}. If provided {alternative} must be set to blank. If none provided then will try to get information from existing metadata" def get_default_shift_params(window_size=None, shift_alg=None) -> str: return ";".join( [ f"{KEY_WINDOW_SIZE}={window_size or ''}", f"{KEY_IMG_REG_METHOD}={shift_alg or ''}", ] ) return { STITCHING_SECTION: { STITCHING_TYPE_FIELD: { "default": StitchingType.Z_PREPROC.value, "help": f"stitching to be applied. Must be in {StitchingType.values()}", "type": "required", }, STITCHING_STRATEGY_FIELD: { "default": "cosinus weights", "help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy.values()}.", "type": "required", }, CROSS_CORRELATION_SLICE_FIELD: { "default": "middle", "help": f"slice to use for image registration", "type": "optional", }, AXIS_0_POS_PX: { "default": "", "help": get_pos_info(axis=0, unit="pixel", alternative=AXIS_0_POS_MM), "type": "optional", }, AXIS_0_POS_MM: { "default": "", "help": get_pos_info(axis=1, unit="millimeter", alternative=AXIS_0_POS_PX), "type": "optional", }, AXIS_0_PARAMS: { "default": get_default_shift_params(window_size=50, shift_alg=DEFAULT_SHIFT_ALG_AXIS_0), "help": HELP_SHIFT_PARAMS, "type": "optional", }, AXIS_1_POS_PX: { "default": "", "help": get_pos_info(axis=1, unit="pixel", alternative=AXIS_1_POS_MM), "type": "optional", }, AXIS_1_POS_MM: { "default": "", "help": get_pos_info(axis=1, unit="millimeter", alternative=AXIS_1_POS_PX), "type": "optional", }, AXIS_1_PARAMS: { "default": get_default_shift_params(), "help": f"same as {AXIS_0_PARAMS} but for axis 1", "type": "optional", }, AXIS_2_POS_PX: { "default": "", "help": get_pos_info(axis=2, unit="pixel", alternative=AXIS_2_POS_MM), "type": "optional", }, AXIS_2_POS_MM: { "default": "", "help": get_pos_info(axis=2, unit="millimeter", alternative=AXIS_1_POS_PX), "type": "optional", }, AXIS_2_PARAMS: { "default": get_default_shift_params(window_size=200, shift_alg=DEFAULT_SHIFT_ALG_AXIS_2), "help": f"same as {AXIS_0_PARAMS} but for axis 2", "type": "optional", }, FLIP_LR: { "default": False, "help": "sometime scan or volume can have a left-right flip in frame (projection/slice) space. For recent NXtomo it should be handled automatically. But for volume you might need to request some flip.", "type": "optional", }, FLIP_UD: { "default": False, "help": "sometime scan or volume can have a up_down flip in frame (projection/slice) space. For recent NXtomo it should be handled automatically. But for volume you might need to request some flip.", "type": "optional", }, RESCALE_FRAMES: { "default": False, "help": "rescale each frame before applying stithcing", "type": "advanced", }, RESCALE_PARAMS: { "default": "", "help": f"parameters for rescaling frames as 'key1=value1;key_2=value2'. Valid Keys are {KEY_RESCALE_MIN_PERCENTILES} and {KEY_RESCALE_MAX_PERCENTILES}.", "type": "advanced", }, STITCHING_KERNELS_EXTRA_PARAMS: { "default": "", "help": f"advanced parameters for some stitching kernels. must be provided as 'key1=value1;key_2=value2'. Valid keys for now are: {KEY_THRESHOLD_FREQUENCY}: threshold to be used by the {OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE.value} to split images low and high frequencies in Fourier space.", "type": "advanced", }, ALIGNMENT_AXIS_2_FIELD: { "default": "center", "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2.values()}", "type": "advanced", }, PAD_MODE_FIELD: { "default": "constant", "help": f"pad mode to use for frame alignment. Valid values are 'constant', 'edge', 'linear_ramp', maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', and 'empty'. See nupy.pad documentation for details", "type": "advanced", }, AVOID_DATA_DUPLICATION_FIELD: { "default": "1", "help": "When possible (stitching on reconstructed volume and HDF5 volume as input and output) create link to original data instead of duplicating it all. Warning: this will create relative link between the stiched volume and the original reconstructed volume.", "validator": boolean_validator, "type": "advanced", }, }, OUTPUT_SECTION: { OVERWRITE_RESULTS_FIELD: { "default": "1", "help": "What to do in the case where the output file exists.\nBy default, the output data is never overwritten and the process is interrupted if the file already exists.\nSet this option to 1 if you want to overwrite the output files.", "validator": boolean_validator, "type": "required", }, }, INPUTS_SECTION: { INPUT_DATASETS_FIELD: { "default": "", "help": f"Dataset to stitch together. Must be volume for {StitchingType.Z_PREPROC.value} or NXtomo for {StitchingType.Z_POSTPROC.value}", "type": "required", }, STITCHING_SLICES: { "default": "", "help": f"slices to be stitched. Must be given along axis 0 for pre-processing (z) and along axis 1 for post-processing (y)", "type": "advanced", }, }, SLURM_SECTION: { SLURM_PARTITION: { "default": "", "help": "slurm partition to be used. If empty will run locally", "type": "optional", }, SLURM_MEM: { "default": "32GB", "help": "memory to allocate for each job", "type": "optional", }, SLURM_N_JOBS: { "default": 10, "help": "number of job to launch (split computation on N parallel jobs). Once all are finished we will concatenate the result.", "type": "optional", }, SLURM_COR_PER_TASKS: { "default": 4, "help": "number of cor per task launched", "type": "optional", }, SLURM_NUMBER_OF_TASKS: { "default": 1, "help": "(for parallel execution when possible). Split each job into this number of tasks", "type": "optional", }, SLURM_OTHER_OPTIONS: { "default": "", "help": "you can provide axtra options to slurm from this string", "type": "optional", }, SLURM_PREPROCESSING_COMMAND: { "default": "", "help": "python virtual environment to use", "type": "optional", }, SLURM_MODULES_TO_LOADS: { "default": "tomotools/stable", "help": "module to load", "type": "optional", }, }, NORMALIZATION_BY_SAMPLE_SECTION: { NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: { "default": False, "help": "should we apply frame normalization by a sample or not", "type": "advanced", }, NORMALIZATION_BY_SAMPLE_METHOD: { "default": "median", "help": "method to compute the normalization value", "type": "advanced", }, NORMALIZATION_BY_SAMPLE_SIDE: { "default": "left", "help": "side to pick the sample", "type": "advanced", }, NORMALIZATION_BY_SAMPLE_MARGIN: { "default": 0, "help": "margin (in px) between border and sample", "type": "advanced", }, NORMALIZATION_BY_SAMPLE_WIDTH: { "default": 30, "help": "sample width (in px)", "type": "advanced", }, }, } def to_dict(self): """dump configuration to a dict. Must be serializable because might be dump to HDF5 file""" return { SLURM_SECTION: self.slurm_config.to_dict() if self.slurm_config is not None else SlurmConfig().to_dict(), STITCHING_SECTION: { STITCHING_TYPE_FIELD: self.stitching_type.value, CROSS_CORRELATION_SLICE_FIELD: str(self.slice_for_cross_correlation), AXIS_0_POS_PX: _cast_shift_to_str(self.axis_0_pos_px), AXIS_0_POS_MM: _cast_shift_to_str(self.axis_0_pos_mm), AXIS_0_PARAMS: _dict_to_str(self.axis_0_params or {}), AXIS_1_POS_PX: _cast_shift_to_str(self.axis_1_pos_px), AXIS_1_POS_MM: _cast_shift_to_str(self.axis_1_pos_mm), AXIS_1_PARAMS: _dict_to_str(self.axis_1_params or {}), AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px), AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm), AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}), STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy.from_value(self.stitching_strategy).value, FLIP_UD: self.flip_ud, FLIP_LR: self.flip_lr, RESCALE_FRAMES: self.rescale_frames, RESCALE_PARAMS: _dict_to_str(self.rescale_params or {}), STITCHING_KERNELS_EXTRA_PARAMS: _dict_to_str(self.stitching_kernels_extra_params or {}), AVOID_DATA_DUPLICATION_FIELD: not self.duplicate_data, }, OUTPUT_SECTION: { OVERWRITE_RESULTS_FIELD: int( self.overwrite_results, ), }, NORMALIZATION_BY_SAMPLE_SECTION: self.normalization_by_sample.to_dict(), } class SingleAxisConfigMetaClass(type): """ Metaclass for single axis stitcher in order to aggregate dumper class and axis warning: this class is used by tomwer as well """ def __new__(mcls, name, bases, attrs, axis=None): # assert axis is not None mcls = super().__new__(mcls, name, bases, attrs) mcls._axis = axis return mcls @dataclass class SingleAxisStitchingConfiguration(StitchingConfiguration, metaclass=SingleAxisConfigMetaClass): """ base class to define z-stitching parameters """ slices: Union[slice, tuple, None] = ( None # slices to reconstruct. Over axis 0 for pre-processing, over axis 1 for post-processing. If None will reconstruct all ) alignment_axis_2: AlignmentAxis2 = AlignmentAxis2.CENTER pad_mode: str = "constant" # pad mode to be used for alignment @property def axis(self) -> int: # self._axis is defined by the metaclass return self._axis def settle_inputs(self) -> None: self.settle_slices() def settle_slices(self) -> tuple: raise ValueError("Base class") def get_output_object(self): raise ValueError("Base class") def to_dict(self): if isinstance(self.slices, slice): slices = f"{self.slices.start}:{self.slices.stop}:{self.slices.step}" elif self.slices in ("", None): slices = "" else: slices = ";".join(str(s) for s in self.slices) return concatenate_dict( super().to_dict(), { INPUTS_SECTION: { STITCHING_SLICES: slices, }, STITCHING_SECTION: { ALIGNMENT_AXIS_2_FIELD: self.alignment_axis_2.value, PAD_MODE_FIELD: self.pad_mode, }, }, ) @dataclass class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfiguration): """ base class to define z-stitching parameters """ input_scans: tuple = () # tuple of ScanBase output_file_path: str = "" output_data_path: str = "" output_nexus_version: Optional[float] = None pixel_size: Optional[float] = None @property def stitching_type(self) -> StitchingType: if self.axis == 0: return StitchingType.Z_PREPROC elif self.axis == 1: return StitchingType.Y_PREPROC else: raise ValueError( "unexpected axis value. Only stitching over axis 0 (aka z) and 1 (aka y) are handled. Current axis value is %s", self.axis, ) def get_output_object(self): return NXtomoScan( scan=self.output_file_path, entry=self.output_data_path, ) def settle_inputs(self) -> None: super().settle_inputs() self.settle_input_scans() def settle_input_scans(self): self.input_scans = [ ( Factory.create_tomo_object_from_identifier(identifier) if isinstance(identifier, (str, ScanIdentifier)) else identifier ) for identifier in self.input_scans ] def slice_idx_from_str_to_int(self, index): if isinstance(index, str): index = index.lower() if index == "first": return 0 elif index == "last": return len(self.input_scans[0].projections) - 1 elif index == "middle": return max(len(self.input_scans[0].projections) // 2 - 1, 0) return int(index) def settle_slices(self) -> tuple: """ interpret the slices to be stitched if needed Nore: if slices is an instance of slice will redefine start and stop to avoid having negative indexes :return: (slices:[slice,Iterable], n_proj:int) :rtype: tuple """ slices = self.slices if isinstance(slices, Sized) and len(slices) == 0: # in this case will stitch them all slices = None if len(self.input_scans) == 0: raise ValueError("No input scan provided") if slices is None: slices = slice(0, len(self.input_scans[0].projections), 1) n_proj = slices.stop elif isinstance(slices, slice): # force slices indices to be positive start = slices.start if start < 0: start += len(self.input_scans[0].projections) + 1 stop = slices.stop if stop < 0: stop += len(self.input_scans[0].projections) + 1 step = slices.step if step is None: step = 1 n_proj = ceil((stop - start) / step) # update slices for iteration simplify things slices = slice(start, stop, step) elif isinstance(slices, (tuple, list)): n_proj = len(slices) slices = [self.slice_idx_from_str_to_int(s) for s in slices] else: raise TypeError(f"slices is expected to be a tuple or a lice. Not {type(slices)}") self.slices = slices return slices, n_proj def to_dict(self): if self.pixel_size is None: pixel_size_mm = "" else: pixel_size_mm = self.pixel_size * MetricSystem.MILLIMETER.value return concatenate_dict( super().to_dict(), { PRE_PROC_SECTION: { DATA_FILE_FIELD: self.output_file_path, DATA_PATH_FIELD: self.output_data_path, NEXUS_VERSION_FIELD: self.output_nexus_version, }, INPUTS_SECTION: { INPUT_DATASETS_FIELD: ";".join( [str(scan.get_identifier()) for scan in self.input_scans], ), INPUT_PIXEL_SIZE_MM: pixel_size_mm, }, }, ) @staticmethod def get_description_dict() -> dict: return concatenate_dict( SingleAxisStitchingConfiguration.get_description_dict(), { PRE_PROC_SECTION: { DATA_FILE_FIELD: { "default": "", "help": "output nxtomo file path", "type": "required", }, DATA_PATH_FIELD: { "default": "", "help": "output nxtomo data path", "type": "required", }, NEXUS_VERSION_FIELD: { "default": "", "help": "nexus version. If not provided will pick the latest one know", "type": "required", }, }, }, ) @classmethod def from_dict(cls, config: dict): if not isinstance(config, dict): raise TypeError(f"config is expected to be a dict and not {type(config)}") inputs_scans_str = config.get(INPUTS_SECTION, {}).get(INPUT_DATASETS_FIELD, None) if inputs_scans_str in (None, ""): input_scans = [] else: input_scans = identifiers_as_str_to_instances(inputs_scans_str) output_file_path = config.get(PRE_PROC_SECTION, {}).get(DATA_FILE_FIELD, None) nexus_version = config.get(PRE_PROC_SECTION, {}).get(NEXUS_VERSION_FIELD, None) if nexus_version in (None, ""): nexus_version = nxtomo.LATEST_VERSION else: nexus_version = float(nexus_version) pixel_size = config.get(INPUT_PIXEL_SIZE_MM, "").replace(" ", "") if pixel_size == "": pixel_size = None else: pixel_size = float(pixel_size) / MetricSystem.MM return cls( stitching_strategy=OverlapStitchingStrategy.from_value( config[STITCHING_SECTION].get( STITCHING_STRATEGY_FIELD, OverlapStitchingStrategy.COSINUS_WEIGHTS, ), ), axis_0_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_0_POS_PX, None)), axis_0_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_0_POS_MM, None)), axis_0_params=_valid_shifts_params(_str_to_dict(config[STITCHING_SECTION].get(AXIS_0_PARAMS, {}))), axis_1_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_1_POS_PX, None)), axis_1_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_1_POS_MM, None)), axis_1_params=_valid_shifts_params( _str_to_dict( config[STITCHING_SECTION].get(AXIS_1_PARAMS, {}), ) ), axis_2_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_2_POS_PX, None)), axis_2_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_2_POS_MM, None)), axis_2_params=_valid_shifts_params( _str_to_dict( config[STITCHING_SECTION].get(AXIS_2_PARAMS, {}), ) ), input_scans=input_scans, output_file_path=output_file_path, output_data_path=config.get(PRE_PROC_SECTION, {}).get(DATA_PATH_FIELD, "entry_from_stitchig"), overwrite_results=config[STITCHING_SECTION].get(OVERWRITE_RESULTS_FIELD, True), output_nexus_version=nexus_version, slices=_slices_to_list_or_slice(config[INPUTS_SECTION].get(STITCHING_SLICES, None)), slurm_config=SlurmConfig.from_dict(config.get(SLURM_SECTION, {})), slice_for_cross_correlation=config[STITCHING_SECTION].get(CROSS_CORRELATION_SLICE_FIELD, "middle"), pixel_size=pixel_size, flip_ud=_scalar_or_tuple_to_bool_or_tuple_of_bool(config[STITCHING_SECTION].get(FLIP_UD, False)), flip_lr=_scalar_or_tuple_to_bool_or_tuple_of_bool(config[STITCHING_SECTION].get(FLIP_LR, False)), rescale_frames=convert_to_bool(config[STITCHING_SECTION].get(RESCALE_FRAMES, 0))[0], rescale_params=_str_to_dict(config[STITCHING_SECTION].get(RESCALE_PARAMS, {})), stitching_kernels_extra_params=_valid_stitching_kernels_params( _str_to_dict( config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}), ) ), alignment_axis_2=AlignmentAxis2.from_value( config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER) ), pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"), duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False), normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})), ) @dataclass class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfiguration): """ base class to define z-stitching parameters """ input_volumes: tuple = () # tuple of VolumeBase output_volume: Optional[VolumeIdentifier] = None voxel_size: Optional[float] = None alignment_axis_1: AlignmentAxis1 = AlignmentAxis1.CENTER @property def stitching_type(self) -> StitchingType: if self.axis == 0: return StitchingType.Z_POSTPROC else: raise ValueError(f"unexpected axis value. Only stitching over axis 0 (aka z) is handled. Not {self.axis}") def get_output_object(self): return self.output_volume def settle_inputs(self) -> None: super().settle_inputs() self.settle_input_volumes() def settle_input_volumes(self): self.input_volumes = [ ( Factory.create_tomo_object_from_identifier(identifier) if isinstance(identifier, (str, VolumeIdentifier)) else identifier ) for identifier in self.input_volumes ] def slice_idx_from_str_to_int(self, index): if isinstance(index, str): index = index.lower() if index == "first": return 0 elif index == "last": return self.input_volumes[0].get_volume_shape()[1] - 1 elif index == "middle": return max(self.input_volumes[0].get_volume_shape()[1] // 2 - 1, 0) return int(index) def settle_slices(self) -> tuple: """ interpret the slices to be stitched if needed Nore: if slices is an instance of slice will redefine start and stop to avoid having negative indexes :return: (slices:[slice,Iterable], n_proj:int) :rtype: tuple """ slices = self.slices if isinstance(slices, Sized) and len(slices) == 0: # in this case will stitch them all slices = None if len(self.input_volumes) == 0: raise ValueError("No input volume provided. Cannot settle slices") if slices is None: # before alignment was existing # slices = slice(0, self.input_volumes[0].get_volume_shape()[1], 1) slices = slice( 0, max([volume.get_volume_shape()[1] for volume in self.input_volumes]), 1, ) n_slices = slices.stop if isinstance(slices, slice): # force slices indices to be positive start = slices.start if start < 0: start += max([volume.get_volume_shape()[1] for volume in self.input_volumes]) + 1 stop = slices.stop if stop < 0: stop += max([volume.get_volume_shape()[1] for volume in self.input_volumes]) + 1 step = slices.step if step is None: step = 1 n_slices = ceil((stop - start) / step) # update slices for iteration simplify things slices = slice(start, stop, step) elif isinstance(slices, Iterable): n_slices = len(slices) slices = [self.slice_idx_from_str_to_int(s) for s in slices] else: raise TypeError(f"slices is expected to be a tuple or a slice. Not {type(slices)}") self.slices = slices return slices, n_slices @classmethod def from_dict(cls, config: dict): if not isinstance(config, dict): raise TypeError(f"config is expected to be a dict and not {type(config)}") inputs_volumes_str = config.get(INPUTS_SECTION, {}).get(INPUT_DATASETS_FIELD, None) if inputs_volumes_str in (None, ""): input_volumes = [] else: input_volumes = identifiers_as_str_to_instances(inputs_volumes_str) overwrite_results = config[STITCHING_SECTION].get(OVERWRITE_RESULTS_FIELD, True) in ("1", True, "True", 1) output_volume = config.get(POST_PROC_SECTION, {}).get(OUTPUT_VOLUME, None) if output_volume is not None: output_volume = Factory.create_tomo_object_from_identifier(output_volume) output_volume.overwrite = overwrite_results voxel_size = config.get(INPUTS_SECTION, {}).get(INPUT_VOXEL_SIZE_MM, "") voxel_size = voxel_size.replace(" ", "") if voxel_size == "": voxel_size = None else: voxel_size = float(voxel_size) * MetricSystem.MM # on the next section the one with a default value qre the optional one return cls( stitching_strategy=OverlapStitchingStrategy.from_value( config[STITCHING_SECTION].get( STITCHING_STRATEGY_FIELD, OverlapStitchingStrategy.COSINUS_WEIGHTS, ), ), axis_0_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_0_POS_PX, None)), axis_0_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_0_POS_MM, None)), axis_0_params=_valid_shifts_params(config[STITCHING_SECTION].get(AXIS_0_PARAMS, {})), axis_1_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_1_POS_PX, None)), axis_1_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_1_POS_MM, None)), axis_1_params=_valid_shifts_params(config[STITCHING_SECTION].get(AXIS_1_PARAMS, {})), axis_2_pos_px=str_to_shifts(config[STITCHING_SECTION].get(AXIS_2_POS_PX, None)), axis_2_pos_mm=str_to_shifts(config[STITCHING_SECTION].get(AXIS_2_POS_MM, None)), axis_2_params=_valid_shifts_params(config[STITCHING_SECTION].get(AXIS_2_PARAMS, {})), input_volumes=input_volumes, output_volume=output_volume, overwrite_results=overwrite_results, slices=_slices_to_list_or_slice(config[INPUTS_SECTION].get(STITCHING_SLICES, None)), slurm_config=SlurmConfig.from_dict(config.get(SLURM_SECTION, {})), voxel_size=voxel_size, slice_for_cross_correlation=config[STITCHING_SECTION].get(CROSS_CORRELATION_SLICE_FIELD, "middle"), flip_ud=_scalar_or_tuple_to_bool_or_tuple_of_bool(config[STITCHING_SECTION].get(FLIP_UD, False)), flip_lr=_scalar_or_tuple_to_bool_or_tuple_of_bool(config[STITCHING_SECTION].get(FLIP_LR, False)), rescale_frames=convert_to_bool(config[STITCHING_SECTION].get(RESCALE_FRAMES, 0))[0], rescale_params=_str_to_dict(config[STITCHING_SECTION].get(RESCALE_PARAMS, {})), stitching_kernels_extra_params=_valid_stitching_kernels_params( _str_to_dict( config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}), ) ), alignment_axis_1=AlignmentAxis1.from_value( config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER) ), alignment_axis_2=AlignmentAxis2.from_value( config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER) ), pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"), duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False), normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})), ) def to_dict(self): if self.voxel_size is None: voxel_size_mm = "" else: voxel_size_mm = numpy.array(self.voxel_size) / MetricSystem.MM return concatenate_dict( super().to_dict(), { INPUTS_SECTION: { INPUT_DATASETS_FIELD: [volume.get_identifier().to_str() for volume in self.input_volumes], INPUT_VOXEL_SIZE_MM: voxel_size_mm, }, POST_PROC_SECTION: { OUTPUT_VOLUME: ( self.output_volume.get_identifier().to_str() if self.output_volume is not None else "" ), }, STITCHING_SECTION: { ALIGNMENT_AXIS_1_FIELD: self.alignment_axis_1.value, }, }, ) @staticmethod def get_description_dict() -> dict: return concatenate_dict( SingleAxisStitchingConfiguration.get_description_dict(), { POST_PROC_SECTION: { OUTPUT_VOLUME: { "default": "", "help": "identifier of the output volume. Like hdf5:volume:[file_path]?path=[data_path] for an HDF5 volume", "type": "required", }, }, STITCHING_SECTION: { ALIGNMENT_AXIS_1_FIELD: { "default": "center", "help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1.values()}", "type": "advanced", } }, }, ) def identifiers_as_str_to_instances(list_identifiers_as_str: str) -> tuple: # convert str to a list of str that should represent identifiers if isinstance(list_identifiers_as_str, str): list_identifiers_as_str = list_identifiers_as_str.lstrip("[").lstrip("(") list_identifiers_as_str = list_identifiers_as_str.rstrip("]").rstrip(")") identifiers_as_str = convert_str_to_tuple(list_identifiers_as_str.replace(";", ",")) else: identifiers_as_str = list_identifiers_as_str if identifiers_as_str is None: return tuple() # convert identifiers as string to IdentifierType instances return tuple( [Factory.create_tomo_object_from_identifier(identifier_as_str) for identifier_as_str in identifiers_as_str] ) def dict_to_config_obj(config: dict): if not isinstance(config, dict): raise TypeError stitching_type = config.get(STITCHING_SECTION, {}).get(STITCHING_TYPE_FIELD, None) if stitching_type is None: raise ValueError("Unable to find stitching type from config dict") else: stitching_type = StitchingType.from_value(stitching_type) if stitching_type is StitchingType.Z_POSTPROC: return PostProcessedZStitchingConfiguration.from_dict(config) elif stitching_type is StitchingType.Z_PREPROC: return PreProcessedZStitchingConfiguration.from_dict(config) elif stitching_type is StitchingType.Y_PREPROC: return PreProcessedYStitchingConfiguration.from_dict(config) else: raise NotImplementedError(f"stitching type {stitching_type.value} not handled yet") def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, str]]) -> tuple: """ Return a default configuration for doing stitching. :param stitching_type: if None then return a configuration were use can provide inputs for any of the stitching. Else return config dict dedicated to a particular stitching :return: (config, section comments) """ if stitching_type is None: return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config) stitching_type = StitchingType.from_value(stitching_type) if stitching_type is StitchingType.Z_POSTPROC: return z_postproc_stitching_config elif stitching_type is StitchingType.Z_PREPROC: return z_preproc_stitching_config elif stitching_type is StitchingType.Y_PREPROC: return y_preproc_stitching_config else: raise NotImplementedError class PreProcessedYStitchingConfiguration(PreProcessedSingleAxisStitchingConfiguration, axis=1): pass class PreProcessedZStitchingConfiguration(PreProcessedSingleAxisStitchingConfiguration, axis=0): pass class PostProcessedZStitchingConfiguration(PostProcessedSingleAxisStitchingConfiguration, axis=0): pass y_preproc_stitching_config = PreProcessedYStitchingConfiguration.get_description_dict() z_preproc_stitching_config = PreProcessedZStitchingConfiguration.get_description_dict() z_postproc_stitching_config = PostProcessedZStitchingConfiguration.get_description_dict() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/definitions.py0000644000175000017500000000023614713343202020073 0ustar00pierrepierrefrom silx.utils.enum import Enum as _Enum class StitchingType(_Enum): Y_PREPROC = "y-preproc" Z_PREPROC = "z-preproc" Z_POSTPROC = "z-postproc" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/frame_composition.py0000644000175000017500000001503714654107202021304 0ustar00pierrepierrefrom dataclasses import dataclass import numpy from math import ceil @dataclass class FrameComposition: """ class used to define intervals to know where to dump raw data or stitched data according to requested policy. The idea is to create this once for all for one stitching operation and reuse it for each frame. """ composed_axis: int """axis along which the composition is done""" local_start: tuple """tuple of indices on the input frames ref to know where each region start (along the composed axis)""" local_end: tuple """tuple of indices on the input frames ref to know where each region end (along the composed axis)""" global_start: tuple """tuple of indices on the output frame ref to know where each region start (along the composed axis)""" global_end: tuple """tuple of indices on the output frame ref to know where each region end (along the composed axis)""" def browse(self): for i in range(len(self.local_start)): yield ( self.local_start[i], self.local_end[i], self.global_start[i], self.global_end[i], ) def compose(self, output_frame: numpy.ndarray, input_frames: tuple): if not output_frame.ndim in (2, 3): raise TypeError( f"output_frame is expected to be 2D (gray scale) or 3D (RGB(A)) and not {output_frame.ndim}" ) for ( global_start, global_end, local_start, local_end, input_frame, ) in zip( self.global_start, self.global_end, self.local_start, self.local_end, input_frames, ): if input_frame is not None: if self.composed_axis == 0: output_frame[global_start:global_end] = input_frame[local_start:local_end] elif self.composed_axis == 1: output_frame[:, global_start:global_end] = input_frame[:, local_start:local_end] else: raise ValueError(f"composed axis must be in (0, 1). Get {self.composed_axis}") @staticmethod def compute_raw_frame_compositions(frames: tuple, key_lines: tuple, overlap_kernels: tuple, stitching_axis): """ compute frame composition for raw data warning: we expect frames to be ordered y downward and the frame order to keep this ordering """ assert len(frames) == len(overlap_kernels) + 1 == len(key_lines) + 1 global_start_indices = [0] # extend shifts and kernels to have a first shift of 0 and two overlaps values at 0 to # generalize processing local_start_indices = [0] local_start_indices.extend( [ceil(key_line[1] + kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)] ) local_end_indices = list( [ceil(key_line[0] - kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)] ) local_end_indices.append(frames[-1].shape[stitching_axis]) for ( new_local_start_index, new_local_end_index, kernel, ) in zip(local_start_indices, local_end_indices, overlap_kernels): global_start_indices.append( global_start_indices[-1] + (new_local_end_index - new_local_start_index) + kernel.overlap_size ) # global end can be easily found from global start + local start and end global_end_indices = [] for global_start_index, new_local_start_index, new_local_end_index in zip( global_start_indices, local_start_indices, local_end_indices ): global_end_indices.append(global_start_index + new_local_end_index - new_local_start_index) return FrameComposition( composed_axis=stitching_axis, local_start=tuple(local_start_indices), local_end=tuple(local_end_indices), global_start=tuple(global_start_indices), global_end=tuple(global_end_indices), ) @staticmethod def compute_stitch_frame_composition(frames, key_lines: tuple, overlap_kernels: tuple, stitching_axis: int): """ compute frame composition for stiching. """ assert len(frames) == len(overlap_kernels) + 1 == len(key_lines) + 1 assert stitching_axis in (0, 1) # position in the stitched frame; local_start_indices = [0] * len(overlap_kernels) local_end_indices = [kernel.overlap_size for kernel in overlap_kernels] # position in the global frame. For this one it is simpler to rely on the raw frame composition composition_raw = FrameComposition.compute_raw_frame_compositions( frames=frames, key_lines=key_lines, overlap_kernels=overlap_kernels, stitching_axis=stitching_axis, ) global_start_indices = composition_raw.global_end[:-1] global_end_indices = composition_raw.global_start[1:] return FrameComposition( composed_axis=stitching_axis, local_start=tuple(local_start_indices), local_end=tuple(local_end_indices), global_start=tuple(global_start_indices), global_end=tuple(global_end_indices), ) @staticmethod def pprint_composition(raw_composition, stitch_composition): """ util to display what the output of the composition will looks like from composition """ for i_frame, (raw_comp, stitch_comp) in enumerate(zip(raw_composition.browse(), stitch_composition.browse())): raw_local_start, raw_local_end, raw_global_start, raw_global_end = raw_comp print( f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]" ) ( stitch_local_start, stitch_local_end, stitch_global_start, stitch_global_end, ) = stitch_comp print( f"stitch_frame[{stitch_global_start}:{stitch_global_end}] = stitched_frame_{i_frame}[{stitch_local_start}:{stitch_local_end}]" ) else: i_frame += 1 raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1] print( f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/overlap.py0000644000175000017500000004146014654107202017236 0ustar00pierrepierreimport numpy import logging from typing import Optional, Union from silx.utils.enum import Enum as _Enum from nabu.misc import fourier_filters from scipy.fft import rfftn as local_fftn from scipy.fft import irfftn as local_ifftn from tomoscan.utils.geometry import BoundingBox1D _logger = logging.getLogger(__name__) class OverlapStitchingStrategy(_Enum): MEAN = "mean" COSINUS_WEIGHTS = "cosinus weights" LINEAR_WEIGHTS = "linear weights" CLOSEST = "closest" IMAGE_MINIMUM_DIVERGENCE = "image minimum divergence" HIGHER_SIGNAL = "higher signal" DEFAULT_OVERLAP_STRATEGY = OverlapStitchingStrategy.COSINUS_WEIGHTS DEFAULT_OVERLAP_SIZE = None # could also be an int # default overlap to be take for stitching. Ig None: take the largest possible area class OverlapKernelBase: pass class ImageStichOverlapKernel(OverlapKernelBase): """ Stitch two images along Y (axis 0 in image space) """ DEFAULT_HIGH_FREQUENCY_THRESHOLD = 2 def __init__( self, stitching_axis: int, frame_unstitched_axis_size: tuple, stitching_strategy: OverlapStitchingStrategy = DEFAULT_OVERLAP_STRATEGY, overlap_size: int = DEFAULT_OVERLAP_SIZE, extra_params: Optional[dict] = None, ) -> None: """ :param stitching_axis: axis along which stitching is operate. Must be in '0', '1' :param frame_unstitched_axis_size: according to the stitching axis the stitched framed will always have a constant size: * If stitching_axis == 0 then it will be the frame width * If stitching_axis == 1 then it will be the frame height :param stitching_strategy: stategy / algorithm to use in order to generate the stitching :param overlap_size: size (int) of the overlap (stitching) between the two images :param extra_params: possibly extra parameters to operate the stitching """ from nabu.stitching.config import KEY_THRESHOLD_FREQUENCY # avoid acylic import if not isinstance(overlap_size, int) and overlap_size > 0: raise TypeError( f"overlap_size is expected to be a positive int, {overlap_size} - not {overlap_size} ({type(overlap_size)})" ) if not isinstance(frame_unstitched_axis_size, int) or not frame_unstitched_axis_size > 0: raise TypeError( f"frame_width is expected to be a positive int, {frame_unstitched_axis_size} - not {frame_unstitched_axis_size} ({type(frame_unstitched_axis_size)})" ) if not stitching_axis in (0, 1): raise ValueError( "stitching_axis is expected to be the axis along which stitching must be done. It should be '0' or '1'" ) self._stitching_axis = stitching_axis self._overlap_size = abs(overlap_size) self._frame_unstitched_axis_size = frame_unstitched_axis_size self._stitching_strategy = OverlapStitchingStrategy.from_value(stitching_strategy) self._weights_img_1 = None self._weights_img_2 = None if extra_params is None: extra_params = {} self._high_frequency_threshold = extra_params.get( KEY_THRESHOLD_FREQUENCY, self.DEFAULT_HIGH_FREQUENCY_THRESHOLD ) def __str__(self) -> str: return f"z-stitching kernel (policy={self.stitching_strategy.value}, overlap_size={self.overlap_size}, frame={self._frame_unstitched_axis_size})" @staticmethod def __check_img(img, name): if not isinstance(img, numpy.ndarray) and img.ndim == 2: raise ValueError(f"{name} is expected to be 2D numpy array") @property def stitched_axis(self) -> int: return self._stitching_axis @property def unstitched_axis(self) -> int: """ util function. The kernel is operating stitching on images along a single axis (`stitching_axis`). This property is returning the other axis. """ if self.stitched_axis == 0: return 1 else: return 0 @property def overlap_size(self) -> int: return self._overlap_size @overlap_size.setter def overlap_size(self, size: int): if not isinstance(size, int): raise TypeError(f"height expects a int ({type(size)} provided instead)") if not size >= 0: raise ValueError(f"height is expected to be positive") self._overlap_size = abs(size) # update weights if needed if self._weights_img_1 is not None or self._weights_img_2 is not None: self.compute_weights() @property def img_2(self) -> numpy.ndarray: return self._img_2 @property def weights_img_1(self) -> Optional[numpy.ndarray]: return self._weights_img_1 @property def weights_img_2(self) -> Optional[numpy.ndarray]: return self._weights_img_2 @property def stitching_strategy(self) -> OverlapStitchingStrategy: return self._stitching_strategy def compute_weights(self): if self.stitching_strategy is OverlapStitchingStrategy.MEAN: weights_img_1 = numpy.ones(self._overlap_size) * 0.5 weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.CLOSEST: n_item = self._overlap_size // 2 + self._overlap_size % 2 weights_img_1 = numpy.concatenate( [ numpy.ones(n_item), numpy.zeros(self._overlap_size - n_item), ] ) weights_img_2 = numpy.concatenate( [ numpy.zeros(n_item), numpy.ones(self._overlap_size - n_item), ] ) elif self.stitching_strategy is OverlapStitchingStrategy.LINEAR_WEIGHTS: weights_img_1 = numpy.linspace(1.0, 0.0, self._overlap_size) weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.COSINUS_WEIGHTS: angles = numpy.linspace(0.0, numpy.pi / 2.0, self._overlap_size) weights_img_1 = numpy.cos(angles) ** 2 weights_img_2 = numpy.sin(angles) ** 2 elif self.stitching_strategy in ( OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE, OverlapStitchingStrategy.HIGHER_SIGNAL, ): # those strategies are not using constant weights but have treatments depending on the provided img_1 and mg_2 during stitching return else: raise NotImplementedError(f"{self.stitching_strategy} not implemented") if self._stitching_axis == 0: self._weights_img_1 = weights_img_1.reshape(-1, 1) * numpy.ones(self._frame_unstitched_axis_size).reshape( 1, -1 ) self._weights_img_2 = weights_img_2.reshape(-1, 1) * numpy.ones(self._frame_unstitched_axis_size).reshape( 1, -1 ) elif self._stitching_axis == 1: self._weights_img_1 = weights_img_1.reshape(1, -1) * numpy.ones(self._frame_unstitched_axis_size).reshape( -1, 1 ) self._weights_img_2 = weights_img_2.reshape(1, -1) * numpy.ones(self._frame_unstitched_axis_size).reshape( -1, 1 ) else: raise ValueError(f"stitching_axis should be in (0, 1). {self._stitching_axis} provided") def stitch(self, img_1, img_2, check_input=True) -> tuple: """Compute overlap region from the defined strategy""" if check_input: self.__check_img(img_1, "img_1") self.__check_img(img_2, "img_2") if img_1.shape != img_2.shape: raise ValueError( f"both images are expected to be of the same shape to apply stitch ({img_1.shape} vs {img_2.shape})" ) if self._stitching_strategy is OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE: return ( compute_image_minimum_divergence( img_1=img_1, img_2=img_2, high_frequency_threshold=self._high_frequency_threshold, stitching_axis=self.stitched_axis, ), None, None, ) elif self._stitching_strategy is OverlapStitchingStrategy.HIGHER_SIGNAL: return ( compute_image_higher_signal( img_1=img_1, img_2=img_2, ), None, None, ) else: if self.weights_img_1 is None or self.weights_img_2 is None: self.compute_weights() return ( img_1 * self.weights_img_1 + img_2 * self.weights_img_2, self.weights_img_1, self.weights_img_2, ) def compute_image_minimum_divergence( img_1: numpy.ndarray, img_2: numpy.ndarray, high_frequency_threshold, stitching_axis: int ): """ Algorithm to improve treatment of high frequency. It split the two images into two parts: high frequency and low frequency. The two low frequency part will be stitched using a 'sinusoidal' / cosinus weights approach. When the two high frequency parts will be stitched by taking the lower divergent pixels """ # split low and high frequencies def split_image(image: numpy.ndarray, threshold: int) -> tuple: """split an image to return (low_frequency, high_frequency)""" lowpass_filter = fourier_filters.get_lowpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) highpass_filter = fourier_filters.get_highpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) low_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * lowpass_filter, axes=(-2, -1)).real high_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * highpass_filter, axes=(-2, -1)).real return (low_fre_part, high_fre_part) low_freq_img_1, high_freq_img_1 = split_image(img_1, threshold=high_frequency_threshold) low_freq_img_2, high_freq_img_2 = split_image(img_2, threshold=high_frequency_threshold) # handle low frequency if stitching_axis == 0: frame_cst_size = img_1.shape[1] overlap_size = img_1.shape[0] elif stitching_axis == 1: frame_cst_size = img_1.shape[0] overlap_size = img_1.shape[1] else: raise ValueError("") low_freq_stitching_kernel = ImageStichOverlapKernel( frame_unstitched_axis_size=frame_cst_size, stitching_strategy=OverlapStitchingStrategy.COSINUS_WEIGHTS, overlap_size=overlap_size, stitching_axis=stitching_axis, ) low_freq_stitched = low_freq_stitching_kernel.stitch( img_1=low_freq_img_1, img_2=low_freq_img_2, check_input=False, )[0] # handle high frequency mean_high_frequency = numpy.mean([high_freq_img_1, high_freq_img_2]) assert numpy.isscalar(mean_high_frequency) high_freq_distance_img_1 = numpy.abs(high_freq_img_1 - mean_high_frequency) high_freq_distance_img_2 = numpy.abs(high_freq_img_2 - mean_high_frequency) high_freq_stitched = numpy.where( high_freq_distance_img_1 >= high_freq_distance_img_2, high_freq_distance_img_2, high_freq_distance_img_1 ) # merge back low and high frequencies together def merge_images(low_freq: numpy.ndarray, high_freq: numpy.ndarray) -> numpy.ndarray: """merge two part of an image. The low frequency part with the high frequency part""" return low_freq + high_freq return merge_images(low_freq_stitched, high_freq_stitched) def compute_image_higher_signal(img_1: numpy.ndarray, img_2: numpy.ndarray): """ the higher signal will pick pixel on the image having the higher signal. A use case is that if there is some artefacts on images which creates stripes (from scintillator artefacts for example) it could be removed from this method """ # note: to be think about. But maybe it can be interesting to rescale img_1 and img_2 # to ge something more coherent return numpy.where(img_1 >= img_2, img_1, img_2) def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis: int, raise_error: bool): """ check over frames if there is a single overlap other juxtaposed frames (at most and at least) :param frames: liste of ordered / sorted frames along axis to test (from higher to lower) :param positions: positions of frames in 3D space as (position axis 0, position axis 1, position axis 2) :param axis: axis to check :param raise_error: if True then raise an error if two frames don't have at least and at most one overlap. Else log an error """ if not isinstance(frames, (tuple, numpy.ndarray)): raise TypeError(f"frames is expected to be a tuple or a numpy array. Get {type(frames)} instead") if not isinstance(positions, tuple) and len(positions) == 3: raise TypeError(f"positions is expected to be a tuple of 3 elements. Get {type(positions)} instead") assert isinstance(axis, int), "axis is expected to be an int" assert isinstance(raise_error, bool), "raise_error is expected to be a bool" def treat_error(error_msg: str): if raise_error: raise ValueError(error_msg) else: _logger.error(raise_error) if axis == 0: axis_frame_space = 0 elif axis == 2: raise NotImplementedError(f"overlap check along axis {axis_frame_space}") elif axis == 1: axis_frame_space = 1 # convert each frame to appropriate bounding box according to the axis def convert_to_bb(frame: numpy.ndarray, position: tuple, axis: int): assert isinstance(axis, int) assert isinstance(position, tuple), f"position expected a tuple. Get {type(position)} instead" assert len(position) == 3, f"Expect to have three items for the position. Get {len(position)}" start_frame = position[axis] - frame.shape[axis_frame_space] // 2 end_frame = start_frame + frame.shape[axis_frame_space] return BoundingBox1D(start_frame, end_frame) bounding_boxes = { convert_to_bb(frame=frame, position=position, axis=axis): position for frame, position in zip(frames, positions) } def get_frame_index(my_bb) -> str: bb_index = tuple(bounding_boxes.keys()).index(my_bb) + 1 if bb_index in (1, 21, 31): return f"{bb_index}st" elif bb_index in (2, 22, 32): return f"{bb_index}nd" elif bb_index == (3, 23, 33): return f"{bb_index}rd" else: return f"{bb_index}th" # check that theres an overlap between two juxtaposed bb (or frame at the end) all_bounding_boxes = tuple(bounding_boxes.keys()) bb_with_expected_overlap = [ (bb_frame, bb_next_frame) for bb_frame, bb_next_frame in zip(all_bounding_boxes[:-1], all_bounding_boxes[1:]) ] for bb_pair in bb_with_expected_overlap: bb_frame, bb_next_frame = bb_pair if bb_frame.max < bb_next_frame.min: treat_error(f"provided frames seems un sorted (from the higher to the lower)") if bb_frame.min < bb_next_frame.min: treat_error( f"Seems like {get_frame_index(bb_frame)} frame is fully overlaping with frame {get_frame_index(bb_next_frame)}" ) if bb_frame.get_overlap(bb_next_frame) is None: treat_error( f"no overlap found between two juxtaposed frames - {get_frame_index(bb_frame)} and {get_frame_index(bb_next_frame)}" ) # check there is no overlap between none juxtaposed bb def pick_all_none_juxtaposed_bb(index, my_bounding_boxes: tuple): """return all the bounding boxes to check for the index 'index': :return: (tested_bounding_box, bounding_boxes_to_test) """ my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)} bounding_boxes = dict( filter( lambda pair: pair[0] not in (index - 1, index, index + 1), my_bounding_boxes.items(), ) ) return my_bounding_boxes[index], bounding_boxes.values() bb_without_expected_overlap = [ pick_all_none_juxtaposed_bb(index, all_bounding_boxes) for index in range(len(all_bounding_boxes)) ] for bb_pair in bb_without_expected_overlap: bb_frame, bb_not_juxtaposed_frames = bb_pair for bb_not_juxtaposed_frame in bb_not_juxtaposed_frames: if bb_frame.get_overlap(bb_not_juxtaposed_frame) is not None: treat_error( f"overlap found between two frames not juxtaposed - {bounding_boxes[bb_frame]} and {bounding_boxes[bb_not_juxtaposed_frame]}" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/stitching/sample_normalization.py0000644000175000017500000000402514550227307022015 0ustar00pierrepierreimport numpy from silx.utils.enum import Enum as _Enum class SampleSide(_Enum): LEFT = "left" RIGHT = "right" class Method(_Enum): MEAN = "mean" MEDIAN = "median" def normalize_frame( frame: numpy.ndarray, side: SampleSide, method: Method, sample_width: int = 50, margin_before_sample: int = 0 ): """ normalize the frame from a sample section picked at the left of the right of the frame :param frame: frame to normalize :param SampleSide side: side to pick the sample :param Method method: normalization method :param int sample_width: sample width :param int margin: margin before the sampling area """ if not isinstance(frame, numpy.ndarray): raise TypeError(f"Frame is expected to be a 2D numpy array.") if frame.ndim != 2: raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D") side = SampleSide.from_value(side) method = Method.from_value(method) if frame.shape[1] < sample_width + margin_before_sample: raise ValueError( f"frame width ({frame.shape[1]}) < sample_width + margin ({sample_width + margin_before_sample})" ) # create sample if side is SampleSide.LEFT: sample_start = margin_before_sample sample_end = margin_before_sample + sample_width sample = frame[:, sample_start:sample_end] elif side is SampleSide.RIGHT: sample_start = frame.shape[1] - (sample_width + margin_before_sample) sample_end = frame.shape[1] - margin_before_sample sample = frame[:, sample_start:sample_end] else: raise ValueError(f"side {side.value} not handled") # do normalization if method is Method.MEAN: normalization_array = numpy.mean(sample, axis=1) elif method is Method.MEDIAN: normalization_array = numpy.median(sample, axis=1) else: raise ValueError(f"side {side.value} not handled") for line in range(normalization_array.shape[0]): frame[line, :] -= normalization_array[line] return frame ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/single_axis_stitching.py0000644000175000017500000000227614713343202022147 0ustar00pierrepierrefrom .y_stitching import y_stitching from .z_stitching import z_stitching from tomoscan.identifier import BaseIdentifier from nabu.stitching.config import ( SingleAxisStitchingConfiguration, PreProcessedYStitchingConfiguration, PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration, ) def stitching(configuration: SingleAxisStitchingConfiguration, progress=None) -> BaseIdentifier: """ Apply stitching from provided configuration. Stitching will be applied along a single axis at the moment. like: axis 0 ^ | x-ray | --------> ------> axis 2 / / axis 1 """ if isinstance(configuration, (PreProcessedYStitchingConfiguration,)): return y_stitching(configuration=configuration, progress=progress) elif isinstance(configuration, (PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration)): return z_stitching(configuration=configuration, progress=progress) else: raise NotImplementedError(f"configuration type ({type(configuration)}) is not handled") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/slurm_utils.py0000644000175000017500000002140014654107202020140 0ustar00pierrepierreimport os import copy from typing import Optional, Union import numpy from silx.io.url import DataUrl from tomoscan.tomoobject import TomoObject from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from tomoscan.esrf import EDFTomoScan from tomoscan.esrf.volume import HDF5Volume, MultiTIFFVolume from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase from ..app.bootstrap_stitching import _SECTIONS_COMMENTS from ..pipeline.config import generate_nabu_configfile from .config import ( StitchingConfiguration, get_default_stitching_config, PreProcessedSingleAxisStitchingConfiguration, PostProcessedSingleAxisStitchingConfiguration, SLURM_SECTION, ) try: from sluurp.job import SBatchScriptJob except ImportError: has_sluurp = False else: has_sluurp = True def split_stitching_configuration_to_slurm_job( configuration: StitchingConfiguration, yield_configuration: bool = False ): """ generator to split a StitchingConfiguration into several SBatchScriptJob. This will handle: * division into several jobs according to `slices` and `n_job` * creation of SBatchScriptJob handling slurm configuration and command to be launched :param StitchingConfiguration configuration: configuration of the stitching to launch (into several jobs) :param bool yield_configuration: if True then yield (SBatchScriptJob, StitchingConfiguration) else yield only SBatchScriptJob """ if not isinstance(configuration, StitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {StitchingConfiguration}. {type(configuration)} provided." ) if not has_sluurp: raise ImportError("sluurp not install. Please install it to distribute stitching on slurm (pip install sluurm)") slurm_configuration = configuration.slurm_config n_jobs = slurm_configuration.n_jobs stitching_type = configuration.stitching_type # cleqn slurm configurqtion slurm_configuration = slurm_configuration.to_dict() clean_script = slurm_configuration.pop("clean_scripts", False) slurm_configuration.pop("n_jobs", None) # for now other_options is not handled slurm_configuration.pop("other_options", None) if "memory" in slurm_configuration and isinstance(slurm_configuration["memory"], str): memory = slurm_configuration["memory"].lower().replace(" ", "") memory = memory.rstrip("b").rstrip("g") slurm_configuration["memory"] = memory # handle slices if None configuration.settle_inputs() slice_sub_parts = split_slices(slices=configuration.slices, n_parts=n_jobs) if isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration): stitch_prefix = os.path.basename(os.path.splitext(configuration.output_file_path)[0]) configuration.output_file_path = os.path.abspath(configuration.output_file_path) elif isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration): stitch_prefix = os.path.basename(os.path.splitext(configuration.output_volume.file_path)[0]) configuration.output_volume.file_path = os.path.abspath(configuration.output_volume.file_path) else: raise TypeError(f"{type(configuration)} not handled") for i_sub_part, slice_sub_part in enumerate(slice_sub_parts): sub_configuration = copy.deepcopy(configuration) # update slice sub_configuration.slices = slice_sub_part # remove slurm configuration because once on the partition we run it manually sub_configuration.slurm_config = None if isinstance(sub_configuration, PreProcessedSingleAxisStitchingConfiguration): original_output_file_path, file_extension = os.path.splitext(sub_configuration.output_file_path) sub_configuration.output_file_path = os.path.join( original_output_file_path, os.path.basename(original_output_file_path) + f"_part_{i_sub_part}" + file_extension, ) output_obj = NXtomoScan( scan=sub_configuration.output_file_path, entry=sub_configuration.output_data_path, ) elif isinstance(sub_configuration, PostProcessedSingleAxisStitchingConfiguration): if isinstance(sub_configuration.output_volume, (HDF5Volume, MultiTIFFVolume)): original_output_file_path, file_extension = os.path.splitext(sub_configuration.output_volume.file_path) sub_configuration.output_volume.file_path = os.path.join( original_output_file_path, os.path.basename(original_output_file_path) + f"_part_{i_sub_part}" + file_extension, ) elif isinstance(sub_configuration.output_volume, VolumeSingleFrameBase): url = sub_configuration.output_volume.url original_output_folder = url.file_path() sub_part_url = DataUrl( file_path=os.path.join( original_output_folder, os.path.basename(original_output_folder) + f"_part_{i_sub_part}", ), data_path=url.data_path(), scheme=url.scheme(), data_slice=url.data_slice(), ) sub_configuration.output_volume.url = sub_part_url output_obj = sub_configuration.output_volume else: raise TypeError(f"{type(sub_configuration)} not handled") working_directory = get_working_directory(output_obj) if working_directory is not None: script_dir = working_directory else: script_dir = "./" # save sub part nabu configuration file slurm_script_name = f"{stitch_prefix}_part_{i_sub_part}.sh" nabu_script_name = f"{stitch_prefix}_part_{i_sub_part}.conf" os.makedirs(script_dir, exist_ok=True) default_config = get_default_stitching_config(stitching_type) default_config.pop(SLURM_SECTION, None) generate_nabu_configfile( fname=os.path.join(script_dir, nabu_script_name), default_config=default_config, comments=True, sections_comments=_SECTIONS_COMMENTS, prefilled_values=sub_configuration.to_dict(), options_level="advanced", ) command = f"python3 -m nabu.app.stitching {os.path.join(script_dir, nabu_script_name)}" script = (command,) job = SBatchScriptJob( slurm_config=slurm_configuration, script=script, script_path=os.path.join(script_dir, slurm_script_name), clean_script=clean_script, working_directory=working_directory, ) job.overwrite = True if yield_configuration: yield job, sub_configuration else: yield job def split_slices(slices: Union[slice, tuple], n_parts: int): if not isinstance(n_parts, int): raise TypeError(f"n_parts should be an int. {type(n_parts)} provided") if isinstance(slices, slice): assert isinstance(slices.start, int), "slices.start must be an integer" assert isinstance(slices.stop, int), "slices.stop must be an integer" start = stop = slices.start steps_size = int(numpy.ceil((slices.stop - slices.start) / n_parts)) while stop < slices.stop: stop = min(start + steps_size, slices.stop) yield slice(start, stop, slices.step) start = stop elif isinstance(slices, (tuple, list)): start = stop = 0 steps_size = int(numpy.ceil(len(slices) / n_parts)) while stop < len(slices): stop = min(start + steps_size, len(slices)) yield slices[start:stop] start = stop else: raise TypeError(f"slices type ({type(slices)}) is not handled. Must be a slice or an Iterable") def get_working_directory(obj: TomoObject) -> Optional[str]: """ return working directory for a specific TomoObject """ if not isinstance(obj, TomoObject): raise TypeError(f"obj should be an instance of {TomoObject}. {type(obj)} provided") if isinstance(obj, (HDF5Volume, MultiTIFFVolume)): if obj.file_path is None: return None else: return os.path.abspath(os.path.dirname(obj.file_path)) elif isinstance(obj, VolumeSingleFrameBase): if obj.data_url is not None: return os.path.abspath(obj.data_url.file_path()) else: return None elif isinstance(obj, EDFTomoScan): return obj.path elif isinstance(obj, NXtomoScan): if obj.master_file is None: return None else: return os.path.abspath(os.path.dirname(obj.master_file)) else: raise RuntimeError(f"obj type not handled ({type(obj)})") ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/stitcher/0000755000175000017500000000000014730277752017051 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/__init__.py0000644000175000017500000000000014654107202021133 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/stitching/stitcher/base.py0000644000175000017500000001020514712705065020324 0ustar00pierrepierrefrom copy import copy from typing import Union from nabu.stitching.config import SingleAxisStitchingConfiguration from tomoscan.esrf import NXtomoScan from tomoscan.volumebase import VolumeBase from tomoscan.identifier import BaseIdentifier def get_obj_constant_side_length(obj: Union[NXtomoScan, VolumeBase], axis: int) -> int: """ return tomo object lenght that will be constant over 1D stitching. In the case of a stitching along axis 0 this will be: * the projection width for pre-processing * volume.shape[2] for post-processing In the case of a stitching along axis 1 this will be: * the projection height for pre-processing """ if isinstance(obj, NXtomoScan): if axis == 0: return obj.dim_1 elif axis in (1, 2): return obj.dim_2 elif isinstance(obj, VolumeBase) and axis == 0: return obj.get_volume_shape()[-1] else: raise TypeError(f"obj type ({type(obj)}) and axis == {axis} is not handled") class _StitcherBase: """ Any stitcher base class """ def __init__(self, configuration, progress=None) -> None: if not isinstance(configuration, SingleAxisStitchingConfiguration): raise TypeError # flag to check if the serie has been ordered yet or not self._configuration = copy(configuration) # copy configuration because we will edit it self._frame_composition = None self._progress = progress self._overlap_kernels = [] # kernels to create the stitching on overlaps. @property def serie_label(self) -> str: """return serie name for logs""" raise NotImplementedError("Base class") @property def reading_orders(self): """ as scan can be take on one direction or the order (rotation goes from X to Y then from Y to X) we might need to read data from one direction or another """ return self._reading_orders def order_input_tomo_objects(self): """ order inputs tomo objects """ raise NotImplementedError("Base class") def check_inputs(self): """ order inputs tomo objects """ raise NotImplementedError("Base class") def pre_processing_computation(self): """ some specific pre-processing that can be call before retrieving the data """ pass @staticmethod def param_is_auto(param): return param in ("auto", ("auto",)) def stitch(self, store_composition: bool = True) -> BaseIdentifier: """ Apply expected stitch from configuration and return the DataUrl of the object created :param bool store_composition: if True then store the composition used for stitching in frame_composition. So it can be reused by third part (like tomwer) to display composition made """ raise NotImplementedError("base class") @property def frame_composition(self): return self._frame_composition @staticmethod def from_abs_pos_to_rel_pos(abs_position: tuple): """ return relative position from on object to the other but in relative this time :param tuple abs_position: tuple containing the absolute positions :return: len(abs_position) - 1 relative position :rtype: tuple """ return tuple([pos_obj_b - pos_obj_a for (pos_obj_a, pos_obj_b) in zip(abs_position[:-1], abs_position[1:])]) @staticmethod def from_rel_pos_to_abs_pos(rel_positions: tuple, init_pos: int): """ return absolute positions from a tuple of relative position and an initial position :param tuple rel_positions: tuple containing the absolute positions :return: len(rel_positions) + 1 relative position :rtype: tuple """ abs_pos = [ init_pos, ] for rel_pos in rel_positions: abs_pos.append(abs_pos[-1] + rel_pos) return abs_pos def _compute_shifts(self): """ after this stage the final shifts must be determine """ raise NotImplementedError("base class") ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/stitcher/dumper/0000755000175000017500000000000014730277752020345 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/dumper/__init__.py0000644000175000017500000000026014654107202022437 0ustar00pierrepierrefrom .postprocessing import PostProcessingStitchingDumper from .postprocessing import PostProcessingStitchingDumperNoDD from .preprocessing import PreProcessingStitchingDumper ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/dumper/base.py0000644000175000017500000000656314654107202021626 0ustar00pierrepierreimport h5py import numpy from typing import Union, Optional from tomoscan.identifier import BaseIdentifier from nabu.stitching.config import StitchingConfiguration from tomoscan.volumebase import VolumeBase from contextlib import AbstractContextManager class DumperBase: """ Base class to define all the functions that can be used to save a stitching """ def __init__(self, configuration) -> None: assert isinstance(configuration, StitchingConfiguration) self._configuration = configuration @property def configuration(self): return self._configuration @property def output_identifier(self) -> BaseIdentifier: raise NotImplementedError("Base class") def save_stitched_frame( self, stitched_frame: numpy.ndarray, i_frame: int, axis: int, **kwargs, ): self.save_frame_to_disk( output_dataset=self.output_dataset, index=i_frame, stitched_frame=stitched_frame, axis=axis, region_start=0, region_end=None, ) @property def output_dataset(self) -> Optional[Union[h5py.VirtualLayout, h5py.Dataset, VolumeBase]]: return self._output_dataset @output_dataset.setter def output_dataset(self, dataset: Optional[Union[h5py.VirtualLayout, h5py.Dataset, VolumeBase]]): self._output_dataset = dataset @staticmethod def save_frame_to_disk( output_dataset: Union[h5py.Dataset, h5py.VirtualLayout], index: int, stitched_frame: Union[numpy.ndarray, h5py.VirtualSource], axis: int, region_start: int, region_end: int, ): if not isinstance(output_dataset, (h5py.VirtualLayout, h5py.Dataset, numpy.ndarray)): raise TypeError( f"'output_dataset' should be a 'h5py.Dataset' or a 'h5py.VirtualLayout'. Get {type(output_dataset)}" ) if not isinstance(stitched_frame, (h5py.VirtualSource, numpy.ndarray)): raise TypeError( f"'stitched_frame' should be a 'numpy.ndarray' or a 'h5py.VirtualSource'. Get {type(stitched_frame)}" ) if isinstance(output_dataset, h5py.VirtualLayout) and not isinstance(stitched_frame, h5py.VirtualSource): raise TypeError( "output_dataset is an instance of h5py.VirtualLayout and stitched_frame not an instance of h5py.VirtualSource" ) if axis == 0: if region_end is not None: output_dataset[index, region_start:region_end] = stitched_frame else: output_dataset[index, region_start:] = stitched_frame elif axis == 1: if region_end is not None: output_dataset[region_start:region_end, index, :] = stitched_frame else: output_dataset[region_start:, index, :] = stitched_frame elif axis == 2: if region_end is not None: output_dataset[region_start:region_end, :, index] = stitched_frame else: output_dataset[region_start:, :, index] = stitched_frame else: raise ValueError(f"provided axis ({axis}) is invalid") def create_output_dataset(self): """ function called at the beginning of the stitching to prepare output dataset """ raise NotImplementedError ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/stitcher/dumper/postprocessing.py0000644000175000017500000003402014713343202023761 0ustar00pierrepierreimport h5py import numpy import logging from typing import Optional from .base import DumperBase from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration from nabu import version as nabu_version from nabu.io.writer import get_datetime from tomoscan.identifier import VolumeIdentifier from tomoscan.volumebase import VolumeBase from tomoscan.esrf.volume import HDF5Volume from tomoscan.io import HDF5File from contextlib import AbstractContextManager _logger = logging.getLogger(__name__) class OutputVolumeContext(AbstractContextManager): """ Utils class to Manage the data volume creation and save it (data only !). target: used for volume stitching In the case of HDF5 we want to save this directly in the file to avoid keeping the full volume in memory. Insure also contain processing will be common between the different processing If stitching_sources_arr_shapes is provided this mean that we want to create stitching region and then create a VDS to avoid data duplication """ def __init__( self, volume: VolumeBase, volume_shape: tuple, dtype: numpy.dtype, dumper, ) -> None: super().__init__() if not isinstance(volume, VolumeBase): raise TypeError(f"Volume is expected to be an instance of {VolumeBase}. {type(volume)} provided instead") self._volume = volume self._volume_shape = volume_shape self.__file_handler = None self._dtype = dtype self._dumper = dumper @property def _file_handler(self): return self.__file_handler def _build_hdf5_output(self): return self._file_handler.create_dataset( self._volume.data_url.data_path(), shape=self._volume_shape, dtype=self._dtype, ) def _create_stitched_volume_dataset(self): # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory if isinstance(self._volume, HDF5Volume): self.__file_handler = HDF5File(self._volume.data_url.file_path(), mode="a") # if need to delete an existing dataset if self._volume.overwrite and self._volume.data_path in self._file_handler: try: del self._file_handler[self._volume.data_path] except Exception as e: _logger.error(f"Fail to overwrite data. Reason is {e}") data = None self._file_handler.close() self._duplicate_data = True # avoid creating a dataset for stitched volume as creation of the stitched_volume failed return data # create dataset try: data = self._build_hdf5_output() except Exception as e2: _logger.error(f"Fail to create final dataset. Reason is {e2}") data = None self._file_handler.close() self._duplicate_data = True # avoid creating a dataset for stitched volume as creation of the stitched_volume failed else: raise TypeError("only HDF5 output is handled") # else: # # for other file format: create the full dataset in memory before dumping it # data = numpy.empty(self._volume_shape, dtype=self._dtype) # self._volume.data = data return data def __enter__(self): assert self._dumper.output_dataset is None self._dumper.output_dataset = self._create_stitched_volume_dataset() return self._dumper.output_dataset def __exit__(self, exc_type, exc_value, traceback): if self._file_handler is not None: return self._file_handler.close() else: self._volume.save_data() class OutputVolumeNoDDContext(OutputVolumeContext): """ Dedicated output volume context for saving a volume without Data Duplication (DD) """ def __init__( self, volume: VolumeBase, volume_shape: tuple, dtype: numpy.dtype, dumper, stitching_sources_arr_shapes: Optional[tuple], ) -> None: if not isinstance(dumper, PostProcessingStitchingDumperNoDD): raise TypeError # TODO: compute volume_shape from here self._stitching_sources_arr_shapes = stitching_sources_arr_shapes super().__init__(volume, volume_shape, dtype, dumper) def __enter__(self): dataset = super().__enter__() assert isinstance(self._dumper, PostProcessingStitchingDumperNoDD) self._dumper.stitching_regions_hdf5_dataset = self._create_stitched_sub_region_datasets() return dataset def _build_hdf5_output(self): return h5py.VirtualLayout( shape=self._volume_shape, dtype=self._dtype, ) def __exit__(self, exc_type, exc_value, traceback): # in the case of no data duplication we need to create the virtual dataset at the end if not isinstance(self._dumper.output_dataset, h5py.VirtualLayout): raise TypeError("dumper output_dataset should be a virtual layout") self._file_handler.create_virtual_dataset(self._volume.data_url.data_path(), layout=self._dumper.output_dataset) super().__exit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) def _create_stitched_sub_region_datasets(self): # create datasets to store overlaps if needed if not isinstance(self._volume, HDF5Volume): raise TypeError("Avoid Data Duplication is only available for HDF5 output volume") stitching_regions_hdf5_dataset = [] for i_region, overlap_shape in enumerate(self._stitching_sources_arr_shapes): data_path = f"{self._volume.data_path}/stitching_regions/region_{i_region}" if self._volume.overwrite and data_path in self._file_handler: del self._file_handler[data_path] stitching_regions_hdf5_dataset.append( self._file_handler.create_dataset( name=data_path, shape=overlap_shape, dtype=self._dtype, ) ) self._dumper.stitching_regions_hdf5_dataset = stitching_regions_hdf5_dataset return stitching_regions_hdf5_dataset class PostProcessingStitchingDumper(DumperBase): """ dumper to be used when save data durint post-processing stitching (on recosntructed volume). Output is expected to be an NXtomo """ OutputDatasetContext = OutputVolumeContext def __init__(self, configuration) -> None: if not isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {PostProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead" ) super().__init__(configuration) self._output_dataset = None self._input_volumes = configuration.input_volumes def save_configuration(self): voxel_size = self._input_volumes[0].voxel_size def get_position(): # the z-serie is z-ordered from higher to lower. We can reuse this with pixel size and shape to # compute the position of the stitched volume if voxel_size is None: return None return numpy.array(self._input_volumes[0].position) + voxel_size * ( numpy.array(self._input_volumes[0].get_volume_shape()) / 2.0 - numpy.array(self.configuration.output_volume.get_volume_shape()) / 2.0 ) self.configuration.output_volume.voxel_size = voxel_size or "" try: self.configuration.output_volume.position = get_position() except Exception: self.configuration.output_volume.position = numpy.array([0, 0, 0]) self.configuration.output_volume.metadata.update( { "about": { "program": "nabu-stitching", "version": nabu_version, "date": get_datetime(), }, "configuration": self.configuration.to_dict(), } ) self.configuration.output_volume.save_metadata() @property def output_identifier(self) -> VolumeIdentifier: return self.configuration.output_volume.get_identifier() def create_output_dataset(self): """ function called at the beginning of the stitching to prepare output dataset """ self._dataset = h5py.VirtualLayout( shape=self._volume_shape, dtype=self._dtype, ) class PostProcessingStitchingDumperNoDD(PostProcessingStitchingDumper): """ same as PostProcessingStitchingDumper but prevent to do data duplication. In this case we need to work on HDF5 file only """ OutputDatasetContext = OutputVolumeNoDDContext def __init__(self, configuration) -> None: if not isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {PostProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead" ) super().__init__(configuration) self._stitching_regions_hdf5_dataset = None self._raw_regions_hdf5_dataset = None def create_output_dataset(self): """ function called at the beginning of the stitching to prepare output dataset """ self._dataset = h5py.VirtualLayout( shape=self._volume_shape, dtype=self._dtype, ) @staticmethod def create_subset_selection(dataset: h5py.Dataset, slices: tuple) -> h5py.VirtualSource: assert isinstance(dataset, h5py.Dataset), f"dataset is expected to be a h5py.Dataset. Get {type(dataset)}" assert isinstance(slices, tuple), f"slices is expected to be a tuple of slices. Get {type(slices)} instead" import h5py._hl.selections as selection virtual_source = h5py.VirtualSource(dataset) sel = selection.select(dataset.shape, slices, dataset=dataset) virtual_source.sel = sel return virtual_source @PostProcessingStitchingDumper.output_dataset.setter def output_dataset(self, dataset: Optional[h5py.VirtualLayout]): if dataset is not None and not isinstance(dataset, h5py.VirtualLayout): raise TypeError("in the case we want to avoid data duplication 'output_dataset' must be a VirtualLayout") self._output_dataset = dataset @property def stitching_regions_hdf5_dataset(self) -> Optional[tuple]: """hdf5 dataset storing the stitched regions""" return self._stitching_regions_hdf5_dataset @stitching_regions_hdf5_dataset.setter def stitching_regions_hdf5_dataset(self, datasets: tuple): self._stitching_regions_hdf5_dataset = datasets @property def raw_regions_hdf5_dataset(self) -> Optional[tuple]: """hdf5 raw dataset""" return self._raw_regions_hdf5_dataset @raw_regions_hdf5_dataset.setter def raw_regions_hdf5_dataset(self, datasets: tuple): self._raw_regions_hdf5_dataset = datasets def save_stitched_frame( self, stitched_frame: numpy.ndarray, composition_cls: dict, i_frame: int, axis: int, ): """ Save the full stitched frame to disk """ output_dataset = self.output_dataset if output_dataset is None: raise ValueError("output_dataset must be set before calling any frame stitching") stitching_regions_hdf5_dataset = self.stitching_regions_hdf5_dataset if stitching_regions_hdf5_dataset is None: raise ValueError("stitching_region_hdf5_dataset must be set before calling any frame stitching") raw_regions_hdf5_dataset = self.raw_regions_hdf5_dataset # save stitched region stitching_regions = composition_cls["overlap_composition"] for (_, _, region_start, region_end), stitching_region_hdf5_dataset in zip( stitching_regions.browse(), stitching_regions_hdf5_dataset ): assert isinstance(output_dataset, h5py.VirtualLayout) assert isinstance(stitching_region_hdf5_dataset, h5py.Dataset) stitching_region_array = stitched_frame[region_start:region_end] self.save_frame_to_disk( output_dataset=stitching_region_hdf5_dataset, index=i_frame, stitched_frame=stitching_region_array, axis=1, region_start=0, region_end=None, ) vs = self.create_subset_selection( dataset=stitching_region_hdf5_dataset, slices=( slice(0, stitching_region_hdf5_dataset.shape[0]), slice(i_frame, i_frame + 1), slice(0, stitching_region_hdf5_dataset.shape[2]), ), ) self.save_frame_to_disk( output_dataset=output_dataset, index=i_frame, axis=axis, region_start=region_start, region_end=region_end, stitched_frame=vs, ) # create virtual source of the raw data raw_regions = composition_cls["raw_composition"] for (frame_start, frame_end, region_start, region_end), raw_region_hdf5_dataset in zip( raw_regions.browse(), raw_regions_hdf5_dataset ): vs = self.create_subset_selection( dataset=raw_region_hdf5_dataset, slices=( slice(frame_start, frame_end), slice(i_frame, i_frame + 1), slice(0, raw_region_hdf5_dataset.shape[2]), ), ) self.save_frame_to_disk( output_dataset=output_dataset, index=i_frame, axis=1, region_start=region_start, region_end=region_end, stitched_frame=vs, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/dumper/preprocessing.py0000644000175000017500000000425314654107202023571 0ustar00pierrepierreimport h5py import numpy import logging from .base import DumperBase from nabu.stitching.config import PreProcessedSingleAxisStitchingConfiguration from nabu import version as nabu_version from nabu.io.writer import get_datetime from silx.io.dictdump import dicttonx from tomoscan.identifier import ScanIdentifier _logger = logging.getLogger(__name__) class PreProcessingStitchingDumper(DumperBase): """ dumper to be used when save data durint pre-processing stitching (on projections). Output is expected to be an NXtomo """ def __init__(self, configuration) -> None: if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead" ) super().__init__(configuration) def save_frame_to_disk(self, output_dataset: h5py.Dataset, index: int, stitched_frame: numpy.ndarray, **kwargs): output_dataset[index] = stitched_frame def save_configuration(self): """dump configuration used for stitching at the NXtomo entry""" process_name = "stitching_configuration" config_dict = self.configuration.to_dict() # adding nabu specific information nabu_process_info = { "@NX_class": "NXentry", f"{process_name}@NX_class": "NXprocess", f"{process_name}/program": "nabu-stitching", f"{process_name}/version": nabu_version, f"{process_name}/date": get_datetime(), f"{process_name}/configuration": config_dict, } dicttonx( nabu_process_info, h5file=self.configuration.output_file_path, h5path=self.configuration.output_data_path, update_mode="replace", mode="a", ) @property def output_identifier(self) -> ScanIdentifier: return self.configuration.get_output_object().get_identifier() def create_output_dataset(self): """ function called at the beginning of the stitching to prepare output dataset """ raise NotImplementedError ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/stitcher/post_processing.py0000644000175000017500000006305214713343202022633 0ustar00pierrepierreimport logging import numpy import os import h5py from typing import Union from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration from nabu.stitching.alignment import AlignmentAxis1 from nabu.stitching.alignment import PaddedRawData from math import ceil from tomoscan.io import HDF5File from tomoscan.esrf.scan.utils import cwd_context from tomoscan.esrf import NXtomoScan from tomoscan.series import Series from tomoscan.volumebase import VolumeBase from tomoscan.esrf.volume import HDF5Volume from typing import Iterable from contextlib import AbstractContextManager from pyunitsystem.metricsystem import MetricSystem from nabu.stitching.config import ( PostProcessedSingleAxisStitchingConfiguration, KEY_IMG_REG_METHOD, ) from nabu.stitching.utils.utils import find_volumes_relative_shifts from nabu.io.utils import DatasetReader from .single_axis import SingleAxisStitcher _logger = logging.getLogger(__name__) class FlippingValueError(ValueError): pass class PostProcessingStitching(SingleAxisStitcher): """ Loader to be used when load data during post-processing stitching (on recosntructed volume). Output is expected to be an NXtomo """ def __init__(self, configuration, progress=None) -> None: if not isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {PostProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead" ) self._input_volumes = configuration.input_volumes self.__output_data_type = None self._series = Series("series", iterable=self._input_volumes, use_identifiers=False) super().__init__(configuration, progress=progress) @property def stitching_axis_in_frame_space(self): if self.axis == 0: return 0 elif self.axis in (1, 2): raise NotImplementedError(f"post-processing stitching along axis {self.axis} is not handled.") else: raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}") def settle_flips(self): super().settle_flips() if not self.configuration.duplicate_data: if len(numpy.unique(self.configuration.flip_lr)) > 1: raise FlippingValueError( "Stitching without data duplication cannot handle volume with different flip. Please run the stitching with data duplication" ) if True in self.configuration.flip_ud: raise FlippingValueError( "Stitching without data duplication cannot handle with up / down flips. Please run the stitching with data duplication" ) def order_input_tomo_objects(self): def get_min_bound(volume): try: bb = volume.get_bounding_box(axis=self.axis) except ValueError: # if missing information bb = None if bb is not None: return bb.min else: # if can't find bounding box (missing metadata to the volume # try to get it from the scan metadata = volume.metadata or volume.load_metadata() scan_location = metadata.get("nabu_config", {}).get("dataset", {}).get("location", None) scan_entry = metadata.get("nabu_config", {}).get("dataset", {}).get("hdf5_entry", None) if scan_location is not None: # this work around (until most volume have position metadata) works only for Hdf5volume with cwd_context(os.path.dirname(volume.file_path)): o_scan = NXtomoScan(scan_location, scan_entry) bb_acqui = o_scan.get_bounding_box(axis=None) # for next step volume position will be required. # if you can find it set it directly volume.position = (numpy.array(bb_acqui.max) - numpy.array(bb_acqui.min)) / 2.0 + numpy.array( bb_acqui.min ) # for now translation are stored in pixel size ref instead of real_pixel_size volume.pixel_size = o_scan.x_real_pixel_size if bb_acqui is not None: return bb_acqui.min[0] raise ValueError("Unable to find volume position. Unable to deduce z position") try: # order volumes from higher z to lower z # if axis 0 position is provided then use directly it if self.configuration.axis_0_pos_px is not None and len(self.configuration.axis_0_pos_px) > 0: order = numpy.argsort(self.configuration.axis_0_pos_px) sorted_series = Series( self.series.name, numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0)[::-1], use_identifiers=False, ) else: # else use bounding box sorted_series = Series( self.series.name, sorted(self.series[:], key=get_min_bound, reverse=True), use_identifiers=False, ) except ValueError: _logger.warning( "Unable to find volume positions in metadata. Expect the volume to be ordered already (decreasing along axis 0.)" ) else: if sorted_series == self.series: pass elif sorted_series != self.series: if sorted_series[:] != self.series[::-1]: raise ValueError( "Unable to get comprehensive input. ordering along axis 0 is not respected (decreasing)." ) else: _logger.warning( f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder positions" ) if self.configuration.axis_0_pos_mm is not None: self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1] if self.configuration.axis_0_pos_px is not None: self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1] if self.configuration.axis_1_pos_mm is not None: self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1] if self.configuration.axis_1_pos_px is not None: self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1] if self.configuration.axis_2_pos_mm is not None: self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1] if self.configuration.axis_2_pos_px is not None: self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1] if not numpy.isscalar(self._configuration.flip_ud): self._configuration.flip_ud = self._configuration.flip_ud[::-1] if not numpy.isscalar(self._configuration.flip_lr): self._configuration.flip_ud = self._configuration.flip_lr[::-1] self._series = sorted_series def check_inputs(self): """ insure input data is coherent """ # check input volume if self.configuration.output_volume is None: raise ValueError("input volume should be provided") n_volumes = len(self.series) if n_volumes == 0: raise ValueError("no scan to stich together") if not isinstance(self.configuration.output_volume, VolumeBase): raise TypeError(f"make sure we return a volume identifier not {(type(self.configuration.output_volume))}") # check axis 0 position if isinstance(self.configuration.axis_0_pos_px, Iterable) and len(self.configuration.axis_0_pos_px) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_px)}") if isinstance(self.configuration.axis_0_pos_mm, Iterable) and len(self.configuration.axis_0_pos_mm) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_mm)}") # check axis 1 position if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}") if isinstance(self.configuration.axis_1_pos_mm, Iterable) and len(self.configuration.axis_1_pos_mm) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_mm)}") # check axis 2 position if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}") if isinstance(self.configuration.axis_2_pos_mm, Iterable) and len(self.configuration.axis_2_pos_mm) != ( n_volumes ): raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}") self._reading_orders = [] # the first scan will define the expected reading orderd, and expected flip. # if all scan are flipped then we will keep it this way self._reading_orders.append(1) @staticmethod def _get_bunch_of_data( bunch_start: int, bunch_end: int, step: int, volumes: tuple, flip_lr_arr: bool, flip_ud_arr: bool, ): """ goal is to load contiguous frames as much as possible... return for each volume the bunch of slice along axis 1 warning: they can have different shapes """ def get_sub_volume(volume, flip_lr, flip_ud): sub_volume = volume[:, bunch_start:bunch_end:step, :] if flip_lr: sub_volume = numpy.fliplr(sub_volume) if flip_ud: sub_volume = numpy.flipud(sub_volume) return sub_volume sub_volumes = [ get_sub_volume(volume, flip_lr, flip_ud) for volume, flip_lr, flip_ud in zip(volumes, flip_lr_arr, flip_ud_arr) ] # generator on it self: we want to iterate over the y axis n_slices_in_bunch = ceil((bunch_end - bunch_start) / step) assert isinstance(n_slices_in_bunch, int) for i in range(n_slices_in_bunch): yield [sub_volume[:, i, :] for sub_volume in sub_volumes] def compute_estimated_shifts(self): axis_0_pos_px = self.configuration.axis_0_pos_px self._axis_0_rel_ini_shifts = [] # compute overlap along axis 0 for upper_volume, lower_volume, upper_volume_axis_0_pos, lower_volume_axis_0_pos in zip( self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:] ): upper_volume_low_pos = upper_volume_axis_0_pos - upper_volume.get_volume_shape()[0] / 2 lower_volume_high_pos = lower_volume_axis_0_pos + lower_volume.get_volume_shape()[0] / 2 self._axis_0_rel_ini_shifts.append( int(lower_volume_high_pos - upper_volume_low_pos) # overlap are expected to be int for now ) self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px) self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1) def _compute_positions_as_px(self): """compute if necessary position other axis 0 from volume metadata""" def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm): if pos_as_px is not None: if pos_as_mm is not None: raise ValueError( f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})" ) else: return pos_as_px elif pos_as_mm is not None: # deduce from position given in configuration and pixel size axis_N_pos_px = [] for volume, pos_in_mm in zip(self.series, pos_as_mm): voxel_size_m = self.configuration.voxel_size or volume.voxel_size axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0]) return axis_N_pos_px else: # deduce from motor position and pixel size axis_N_pos_px = [] base_position_m = self.series[0].get_bounding_box(axis=axis).min for volume in self.series: voxel_size_m = self.configuration.voxel_size or volume.voxel_size volume_axis_bb = volume.get_bounding_box(axis=axis) axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / voxel_size_m[0])) return axis_N_pos_px self.configuration.axis_0_pos_px = get_position_as_px_on_axis( axis=0, pos_as_px=self.configuration.axis_0_pos_px, pos_as_mm=self.configuration.axis_0_pos_mm, ) self.configuration.axis_0_pos_mm = None self.configuration.axis_1_pos_px = get_position_as_px_on_axis( axis=1, pos_as_px=self.configuration.axis_1_pos_px, pos_as_mm=self.configuration.axis_1_pos_mm, ) self.configuration.axis_2_pos_px = get_position_as_px_on_axis( axis=2, pos_as_px=self.configuration.axis_2_pos_px, pos_as_mm=self.configuration.axis_2_pos_mm, ) self.configuration.axis_2_pos_mm = None def _compute_shifts(self): n_volumes = len(self.configuration.input_volumes) if n_volumes == 0: raise ValueError("no scan to stich provided") slice_for_shift = self.configuration.slice_for_cross_correlation or "middle" y_rel_shifts = self._axis_0_rel_ini_shifts x_rel_shifts = self._axis_1_rel_ini_shifts dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.series]) final_rel_shifts = [] for ( upper_volume, lower_volume, x_rel_shift, y_rel_shift, flip_ud_upper, flip_ud_lower, ) in zip( self.series[:-1], self.series[1:], x_rel_shifts, y_rel_shifts, self.configuration.flip_ud[:-1], self.configuration.flip_ud[1:], ): x_cross_algo = self.configuration.axis_2_params.get(KEY_IMG_REG_METHOD, None) y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None) # compute relative shift found_shift_y, found_shift_x = find_volumes_relative_shifts( upper_volume=upper_volume, lower_volume=lower_volume, dtype=self.get_output_data_type(), dim_axis_1=dim_axis_1, slice_for_shift=slice_for_shift, x_cross_correlation_function=x_cross_algo, y_cross_correlation_function=y_cross_algo, x_shifts_params=self.configuration.axis_2_params, y_shifts_params=self.configuration.axis_0_params, estimated_shifts=(y_rel_shift, x_rel_shift), flip_ud_lower_frame=flip_ud_lower, flip_ud_upper_frame=flip_ud_upper, alignment_axis_1=self.configuration.alignment_axis_1, alignment_axis_2=self.configuration.alignment_axis_2, overlap_axis=self.axis, ) final_rel_shifts.append( (found_shift_y, found_shift_x), ) # set back values. Now position should start at 0 self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts] self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts] self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts) _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}") print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}") _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}") print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}") def get_output_data_type(self): if self.__output_data_type is None: def find_output_data_type(): first_vol = self._input_volumes[0] if first_vol.data is not None: return first_vol.data.dtype elif isinstance(first_vol, HDF5Volume): with DatasetReader(first_vol.data_url) as vol_dataset: return vol_dataset.dtype else: return first_vol.load_data(store=False).dtype self.__output_data_type = find_output_data_type() return self.__output_data_type def _create_stitched_volume(self, store_composition: bool): overlap_kernels = self._overlap_kernels self._slices_to_stitch, n_slices = self.configuration.settle_slices() # sync overwrite_results with volume overwrite parameter self.configuration.output_volume.overwrite = self.configuration.overwrite_results # init final volume final_volume = self.configuration.output_volume final_volume_shape = ( int( numpy.asarray([volume.get_volume_shape()[0] for volume in self._input_volumes]).sum() - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum(), ), n_slices, self._stitching_constant_length, ) data_type = self.get_output_data_type() if self.progress: self.progress.total = final_volume_shape[1] y_index = 0 if isinstance(self._slices_to_stitch, slice): step = self._slices_to_stitch.step or 1 else: step = 1 output_dataset_args = { "volume": final_volume, "volume_shape": final_volume_shape, "dtype": data_type, "dumper": self.dumper, } from .dumper.postprocessing import PostProcessingStitchingDumperNoDD # TODO: FIXME: for now not very elegant but in the case of avoiding data duplication # we need to provide the the information about the stitched part shape. # this should be move to the dumper in the future if isinstance(self.dumper, PostProcessingStitchingDumperNoDD): output_dataset_args["stitching_sources_arr_shapes"] = tuple( [(abs(overlap), n_slices, self._stitching_constant_length) for overlap in self._axis_0_rel_final_shifts] ) with self.dumper.OutputDatasetContext(**output_dataset_args): # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array with _RawDatasetsContext( self._input_volumes, alignment_axis_1=self.configuration.alignment_axis_1, ) as raw_datasets: # note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume) # to speed up we read by bunch of dataset. For numpy array this doesn't change anything # but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks) # note: we read trhough axis 1 if isinstance(self.dumper, PostProcessingStitchingDumperNoDD): self.dumper.raw_regions_hdf5_dataset = raw_datasets for bunch_start, bunch_end in PostProcessingStitching._data_bunch_iterator( slices=self._slices_to_stitch, bunch_size=50 ): for data_frames in PostProcessingStitching._get_bunch_of_data( bunch_start, bunch_end, step=step, volumes=raw_datasets, flip_lr_arr=self.configuration.flip_lr, flip_ud_arr=self.configuration.flip_ud, ): if self.configuration.rescale_frames: data_frames = self.rescale_frames(data_frames) if self.configuration.normalization_by_sample.is_active(): data_frames = self.normalize_frame_by_sample(data_frames) sf = PostProcessingStitching.stitch_frames( frames=data_frames, axis=self.axis, output_dtype=data_type, x_relative_shifts=self._axis_1_rel_final_shifts, y_relative_shifts=self._axis_0_rel_final_shifts, overlap_kernels=overlap_kernels, dumper=self.dumper, i_frame=y_index, return_composition_cls=store_composition if y_index == 0 else False, stitching_axis=self.axis, check_inputs=y_index == 0, # on process check on the first iteration ) if y_index == 0 and store_composition: _, self._frame_composition = sf if self.progress is not None: self.progress.update() y_index += 1 # alias to general API def _create_stitching(self, store_composition): self._create_stitched_volume(store_composition=store_composition) class _RawDatasetsContext(AbstractContextManager): """ return volume data for all input volume (target: used for volume stitching). If the volume is an HDF5Volume then the HDF5 dataset will be used (on disk) If the volume is of another type then it will be loaded in memory then used (more memory consuming) """ def __init__(self, volumes: tuple, alignment_axis_1) -> None: super().__init__() for volume in volumes: if not isinstance(volume, VolumeBase): raise TypeError( f"Volumes are expected to be an instance of {VolumeBase}. {type(volume)} provided instead" ) self._volumes = volumes self.__file_handlers = [] self._alignment_axis_1 = alignment_axis_1 @property def alignment_axis_1(self): return self._alignment_axis_1 def __enter__(self): # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory datasets = [] shapes = {volume.get_volume_shape()[1] for volume in self._volumes} axis_1_dim = max(shapes) axis_1_need_padding = len(shapes) > 1 try: for volume in self._volumes: if volume.data is not None: data = volume.data elif isinstance(volume, HDF5Volume): file_handler = HDF5File(volume.data_url.file_path(), mode="r") dataset = file_handler[volume.data_url.data_path()] data = dataset self.__file_handlers.append(file_handler) # for other file format: load the full dataset in memory else: data = volume.load_data(store=False) if data is None: raise ValueError(f"No data found for volume {volume.get_identifier()}") if axis_1_need_padding: data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1) datasets.append(data) except Exception as e: # if some errors happen during loading HDF5 for file_handled in self.__file_handlers: file_handled.close() raise e return datasets def __exit__(self, exc_type, exc_value, traceback): success = True for file_handler in self.__file_handlers: success = success and file_handler.close() if exc_type is None: return success def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1): alignment = AlignmentAxis1.from_value(alignment) if alignment is AlignmentAxis1.BACK: axis_1_pad_width = (axis_1_dim - data.shape[1], 0) elif alignment is AlignmentAxis1.CENTER: half_width = int((axis_1_dim - data.shape[1]) / 2) axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width) elif alignment is AlignmentAxis1.FRONT: axis_1_pad_width = (0, axis_1_dim - data.shape[1]) else: raise ValueError(f"alignment {alignment} is not handled") return PaddedRawData( data=data, axis_1_pad_width=axis_1_pad_width, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731941746.0 nabu-2024.2.1/nabu/stitching/stitcher/pre_processing.py0000644000175000017500000014431414716652562022454 0ustar00pierrepierreimport numpy import logging import h5py import os from typing import Iterable from silx.io.url import DataUrl from silx.io.utils import get_data from datetime import datetime from nxtomo.nxobject.nxdetector import ImageKey from nxtomo.application.nxtomo import NXtomo from nxtomo.nxobject.nxtransformations import NXtransformations from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation, DetZFlipTransformation from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths from tomoscan.io import HDF5File from tomoscan.series import Series from tomoscan.esrf import NXtomoScan, EDFTomoScan from tomoscan.esrf.scan.utils import ( get_compacted_dataslices, ) # this version has a 'return_url_set' needed here. At one point they should be merged together from nabu.stitching.config import ( PreProcessedSingleAxisStitchingConfiguration, KEY_IMG_REG_METHOD, ) from nabu.stitching.utils import find_projections_relative_shifts from functools import lru_cache as cache from .single_axis import SingleAxisStitcher from pyunitsystem.metricsystem import MetricSystem _logger = logging.getLogger(__name__) class PreProcessingStitching(SingleAxisStitcher): """ loader to be used when save data during pre-processing stitching (on projections). Output is expected to be an NXtomo warning: axis are provided according to the `acquisition space `_ """ def __init__(self, configuration, progress=None) -> None: """ """ if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration): raise TypeError( f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead" ) super().__init__(configuration, progress=progress) self._series = Series("series", iterable=configuration.input_scans, use_identifiers=False) self._reading_orders = [] # TODO: rename flips to axis_0_flips, axis_1_flips, axis_2_flips... self._x_flips = [] self._y_flips = [] self._z_flips = [] # 'expend' auto shift request if only set once for all if numpy.isscalar(self.configuration.axis_0_pos_px): self.configuration.axis_0_pos_px = [ self.configuration.axis_0_pos_px, ] * (len(self.series) - 1) if numpy.isscalar(self.configuration.axis_1_pos_px): self.configuration.axis_1_pos_px = [ self.configuration.axis_1_pos_px, ] * (len(self.series) - 1) if numpy.isscalar(self.configuration.axis_1_pos_px): self.configuration.axis_1_pos_px = [ self.configuration.axis_1_pos_px, ] * (len(self.series) - 1) if self.configuration.axis_0_params is None: self.configuration.axis_0_params = {} if self.configuration.axis_1_params is None: self.configuration.axis_1_params = {} if self.configuration.axis_2_params is None: self.configuration.axis_2_params = {} def pre_processing_computation(self): self.compute_reduced_flats_and_darks() @property def stitching_axis_in_frame_space(self): if self.axis == 0: return 0 elif self.axis == 1: return 1 elif self.axis == 2: raise NotImplementedError( "pre-processing stitching along axis 2 is not handled. This would require to do interpolation between frame along the rotation angle. Just not possible" ) else: raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}") @property def x_flips(self) -> list: return self._x_flips @property def y_flips(self) -> list: return self._y_flips def order_input_tomo_objects(self): def get_min_bound(scan): return scan.get_bounding_box(axis=self.axis).min # order scans along the stitched axis if self.axis == 0: position_along_stitched_axis = self.configuration.axis_0_pos_px elif self.axis == 1: position_along_stitched_axis = self.configuration.axis_1_pos_px else: raise ValueError( "stitching cannot be done along axis 2 for pre-processing. This would require to interpolate frame between different rotation angle" ) # if axis 0 position is provided then use directly it if position_along_stitched_axis is not None and len(position_along_stitched_axis) > 0: order = numpy.argsort(position_along_stitched_axis)[::-1] sorted_series = Series( self.series.name, numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0), use_identifiers=False, ) else: # else use bounding box sorted_series = Series( self.series.name, sorted(self.series[:], key=get_min_bound, reverse=True), use_identifiers=False, ) if sorted_series != self.series: if sorted_series[:] != self.series[::-1]: raise ValueError( f"Unable to get comprehensive input. Axis {self.axis} (decreasing) ordering is not respected." ) else: _logger.warning( f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder overlap height, stitching height and invert shifts" ) if self.configuration.axis_0_pos_mm is not None: self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1] if self.configuration.axis_0_pos_px is not None: self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1] if self.configuration.axis_1_pos_mm is not None: self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1] if self.configuration.axis_1_pos_px is not None: self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1] if self.configuration.axis_2_pos_mm is not None: self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1] if self.configuration.axis_2_pos_px is not None: self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1] if not numpy.isscalar(self._configuration.flip_ud): self._configuration.flip_ud = self._configuration.flip_ud[::-1] if not numpy.isscalar(self._configuration.flip_lr): self._configuration.flip_ud = self._configuration.flip_lr[::-1] self._series = sorted_series def check_inputs(self): """ insure input data is coherent """ n_scans = len(self.series) if n_scans == 0: raise ValueError("no scan to stich together") for scan in self.series: from tomoscan.scanbase import TomoScanBase if not isinstance(scan, TomoScanBase): raise TypeError(f"z-preproc stitching expects instances of {TomoScanBase}. {type(scan)} provided.") # check output file path and data path are provided if self.configuration.output_file_path in (None, ""): raise ValueError("output_file_path should be provided to the configuration") if self.configuration.output_data_path in (None, ""): raise ValueError("output_data_path should be provided to the configuration") # check number of shift provided for axis_pos_px, axis_name in zip( ( self.configuration.axis_0_pos_px, self.configuration.axis_1_pos_px, self.configuration.axis_1_pos_px, self.configuration.axis_0_pos_mm, self.configuration.axis_1_pos_mm, self.configuration.axis_2_pos_mm, ), ( "axis_0_pos_px", "axis_1_pos_px", "axis_2_pos_px", "axis_0_pos_mm", "axis_1_pos_mm", "axis_2_pos_mm", ), ): if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans): raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}") self._reading_orders = [] # the first scan will define the expected reading orderd, and expected flip. # if all scan are flipped then we will keep it this way self._reading_orders.append(1) # check scans are coherent (nb projections, rotation angle, energy...) for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]): if len(scan_0.projections) != len(scan_1.projections): raise ValueError(f"{scan_0} and {scan_1} have a different number of projections") if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan): # check rotation (only of is an NXtomoScan) scan_0_angles = numpy.asarray(scan_0.rotation_angle) scan_0_projections_angles = scan_0_angles[ numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value ] scan_1_angles = numpy.asarray(scan_1.rotation_angle) scan_1_projections_angles = scan_1_angles[ numpy.asarray(scan_1.image_key_control) == ImageKey.PROJECTION.value ] if not numpy.allclose(scan_0_projections_angles, scan_1_projections_angles, atol=10e-1): if numpy.allclose( scan_0_projections_angles, scan_1_projections_angles[::-1], atol=10e-1, ): reading_order = -1 * self._reading_orders[-1] else: raise ValueError(f"Angles from {scan_0} and {scan_1} are different") else: reading_order = 1 * self._reading_orders[-1] self._reading_orders.append(reading_order) # check energy if scan_0.energy is None: _logger.warning(f"no energy found for {scan_0}") elif not numpy.isclose(scan_0.energy, scan_1.energy, rtol=1e-03): _logger.warning( f"different energy found between {scan_0} ({scan_0.energy}) and {scan_1} ({scan_1.energy})" ) # check FOV if not scan_0.field_of_view == scan_1.field_of_view: raise ValueError(f"{scan_0} and {scan_1} have different field of view") # check distance if scan_0.distance is None: _logger.warning(f"no distance found for {scan_0}") elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3): raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance") # check pixel size if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size): raise ValueError( f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}" ) if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size): raise ValueError( f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}" ) for scan in self.series: # check x, y and z translation are constant (only if is an NXtomoScan) if isinstance(scan, NXtomoScan): if scan.x_translation is not None and not numpy.isclose( min(scan.x_translation), max(scan.x_translation) ): _logger.warning( "x translations appears to be evolving over time. Might end up with wrong stitching" ) if scan.y_translation is not None and not numpy.isclose( min(scan.y_translation), max(scan.y_translation) ): _logger.warning( "y translations appears to be evolving over time. Might end up with wrong stitching" ) if scan.z_translation is not None and not numpy.isclose( min(scan.z_translation), max(scan.z_translation) ): _logger.warning( "z translations appears to be evolving over time. Might end up with wrong stitching" ) def _compute_positions_as_px(self): """insure we have or we can deduce an estimated position as pixel""" def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm): if pos_as_px is not None: if pos_as_mm is not None: raise ValueError( f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})" ) else: return pos_as_px elif pos_as_mm is not None: # deduce from position given in configuration and pixel size axis_N_pos_px = [] for scan, pos_in_mm in zip(self.series, pos_as_mm): pixel_size_m = self.configuration.pixel_size or scan.pixel_size axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m) return axis_N_pos_px else: # deduce from motor position and pixel size axis_N_pos_px = [] base_position_m = self.series[0].get_bounding_box(axis=axis).min for scan in self.series: pixel_size_m = self.configuration.pixel_size or scan.pixel_size scan_axis_bb = scan.get_bounding_box(axis=axis) axis_N_mean_pos_m = (scan_axis_bb.max - scan_axis_bb.min) / 2 + scan_axis_bb.min axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / pixel_size_m)) return axis_N_pos_px for axis, property_px_name, property_mm_name in zip( (0, 1, 2), ( "axis_0_pos_px", "axis_1_pos_px", "axis_2_pos_px", ), ( "axis_0_pos_mm", "axis_1_pos_mm", "axis_2_pos_mm", ), ): assert hasattr( self.configuration, property_px_name ), f"configuration API changed. should have {property_px_name}" assert hasattr( self.configuration, property_mm_name ), f"configuration API changed. should have {property_px_name}" try: new_px_position = get_position_as_px_on_axis( axis=axis, pos_as_px=getattr(self.configuration, property_px_name), pos_as_mm=getattr(self.configuration, property_mm_name), ) except ValueError: # when unable to find the position if axis == self.axis: # if we cannot find position over the stitching axis then raise an error: unable to process without raise else: _logger.warning(f"Unable to find position over axis {axis}. Set them to zero") setattr( self.configuration, property_px_name, numpy.array([0] * len(self.series)), ) else: setattr( self.configuration, property_px_name, new_px_position, ) # clear position in mm as the one we will used are the px one self.configuration.axis_0_pos_mm = None self.configuration.axis_1_pos_mm = None self.configuration.axis_2_pos_mm = None # add some log if self.configuration.axis_2_pos_mm is not None or self.configuration.axis_2_pos_px is not None: _logger.warning("axis 2 position is not handled by the stitcher. Will be ignored") axis_0_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_0_pos_px]) axis_1_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_1_pos_px]) axis_2_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_2_pos_px]) _logger.info(f"axis 0 position to be used: " + axis_0_pos) _logger.info(f"axis 1 position to be used: " + axis_1_pos) _logger.info(f"axis 2 position to be used: " + axis_2_pos) _logger.info(f"stitching will be applied along axis: {self.axis}") def compute_estimated_shifts(self): if self.axis == 0: # if we want to stitch over axis 0 (aka z) axis_0_pos_px = self.configuration.axis_0_pos_px self._axis_0_rel_ini_shifts = [] # compute overlap along axis 0 for upper_scan, lower_scan, upper_scan_axis_0_pos, lower_scan_axis_0_pos in zip( self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:] ): upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2 lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2 # simple test of overlap. More complete test are run by check_overlaps later if lower_scan_high_pos <= upper_scan_pos: raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}") self._axis_0_rel_ini_shifts.append( int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now ) self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px) self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1) elif self.axis == 1: # if we want to stitch over axis 1 (aka Y in acquisition reference - which is x in frame reference) axis_1_pos_px = self.configuration.axis_1_pos_px self._axis_1_rel_ini_shifts = [] # compute overlap along axis 0 for left_scan, right_scan, left_scan_axis_1_pos, right_scan_axis_1_pos in zip( self.series[:-1], self.series[1:], axis_1_pos_px[:-1], axis_1_pos_px[1:] ): left_scan_pos = left_scan_axis_1_pos - left_scan.dim_1 / 2 right_scan_high_pos = right_scan_axis_1_pos + right_scan.dim_1 / 2 # simple test of overlap. More complete test are run by check_overlaps later if right_scan_high_pos <= left_scan_pos: raise ValueError(f"no overlap found between {left_scan} and {right_scan}") self._axis_1_rel_ini_shifts.append( int(right_scan_high_pos - left_scan_pos) # overlap are expected to be int for now ) self._axis_0_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_0_pos_px) self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1) else: raise NotImplementedError("stitching only forseen for axis 0 and 1 for now") def _compute_shifts(self): """ compute all shift requested (set to 'auto' in the configuration) """ n_scans = len(self.configuration.input_scans) if n_scans == 0: raise ValueError("no scan to stich provided") projection_for_shift = self.configuration.slice_for_cross_correlation or "middle" if self.axis not in (0, 1): raise NotImplementedError("only stitching over axis 0 and 2 are handled for pre-processing stitching") final_rel_shifts = [] for ( scan_0, scan_1, order_s0, order_s1, x_rel_shift, y_rel_shift, ) in zip( self.series[:-1], self.series[1:], self.reading_orders[:-1], self.reading_orders[1:], self._axis_1_rel_ini_shifts, self._axis_0_rel_ini_shifts, ): x_cross_algo = self.configuration.axis_1_params.get(KEY_IMG_REG_METHOD, None) y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None) # compute relative shift found_shift_y, found_shift_x = find_projections_relative_shifts( upper_scan=scan_0, lower_scan=scan_1, projection_for_shift=projection_for_shift, x_cross_correlation_function=x_cross_algo, y_cross_correlation_function=y_cross_algo, x_shifts_params=self.configuration.axis_1_params, # image x map acquisition axis 1 (Y) y_shifts_params=self.configuration.axis_0_params, # image y map acquisition axis 0 (Z) invert_order=order_s1 != order_s0, estimated_shifts=(y_rel_shift, x_rel_shift), axis=self.axis, ) final_rel_shifts.append( (found_shift_y, found_shift_x), ) # set back values. Now position should start at 0 self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts] self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts] self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts) _logger.info(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}") print(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}") _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}") print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}") def _create_nx_tomo(self, store_composition: bool = False): """ create final NXtomo with stitched frames. Policy: save all projections flat fielded. So this NXtomo will only contain projections (no dark and no flat). But nabu will be able to reconstruct it with field `flatfield` set to False """ nx_tomo = NXtomo() nx_tomo.energy = self.series[0].energy start_times = list(filter(None, [scan.start_time for scan in self.series])) end_times = list(filter(None, [scan.end_time for scan in self.series])) if len(start_times) > 0: nx_tomo.start_time = ( numpy.asarray([numpy.datetime64(start_time) for start_time in start_times]).min().astype(datetime) ) else: _logger.warning("Unable to find any start_time from input") if len(end_times) > 0: nx_tomo.end_time = ( numpy.asarray([numpy.datetime64(end_time) for end_time in end_times]).max().astype(datetime) ) else: _logger.warning("Unable to find any end_time from input") title = ";".join([scan.sequence_name or "" for scan in self.series]) nx_tomo.title = f"stitch done from {title}" self._slices_to_stitch, n_proj = self.configuration.settle_slices() # handle detector (without frames) nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view nx_tomo.instrument.detector.distance = self.series[0].distance nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj nx_tomo.instrument.detector.tomo_n = n_proj # note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set nx_tomo.instrument.detector.transformations = NXtransformations() if isinstance(self.series[0], NXtomoScan): # note: first scan is always the reference as order to read data (so no rotation_angle inversion here) rotation_angle = numpy.asarray(self.series[0].rotation_angle) nx_tomo.sample.rotation_angle = rotation_angle[ numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value ] elif isinstance(self.series[0], EDFTomoScan): nx_tomo.sample.rotation_angle = numpy.linspace( start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n ) else: raise NotImplementedError( f"scan type ({type(self.series[0])} is not handled)", NXtomoScan, isinstance(self.series[0], NXtomoScan), ) # do a sub selection of the rotation angle if a we are only computing a part of the slices def apply_slices_selection(array, slices, allow_empty: bool = False): if isinstance(slices, slice): return array[slices.start : slices.stop : 1] elif isinstance(slices, Iterable): return list([array[index] for index in slices]) else: raise RuntimeError("slices must be instance of a slice or of an iterable") nx_tomo.sample.rotation_angle = apply_slices_selection( array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch ) # handle sample if False not in [isinstance(scan, NXtomoScan) for scan in self.series]: def get_sample_translation_for_projs(scan: NXtomoScan, attr): values = numpy.array(getattr(scan, attr)) mask = scan.image_key_control == ImageKey.PROJECTION.value return values[mask] # we consider the new x, y and z position to be at the center of the one created x_translation = [ get_sample_translation_for_projs(scan, "x_translation") for scan in self.series if scan.x_translation is not None ] if len(x_translation) > 0: # if there is some metadata about {x|y|z} translations # we want to take the mean of each frame for each projections x_translation = apply_slices_selection( numpy.array(x_translation).mean(axis=0), slices=self._slices_to_stitch, ) else: # if no NXtomo has information about x_translation. # note: if at least one has missing values the numpy.Array(x_translation) with create an error as well x_translation = [0.0] * n_proj _logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0") nx_tomo.sample.x_translation = x_translation y_translation = [ get_sample_translation_for_projs(scan, "y_translation") for scan in self.series if scan.y_translation is not None ] if len(y_translation) > 0: y_translation = apply_slices_selection( numpy.array(y_translation).mean(axis=0), slices=self._slices_to_stitch, ) else: y_translation = [0.0] * n_proj _logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0") nx_tomo.sample.y_translation = y_translation z_translation = [ get_sample_translation_for_projs(scan, "z_translation") for scan in self.series if scan.z_translation is not None ] if len(z_translation) > 0: z_translation = apply_slices_selection( numpy.array(z_translation).mean(axis=0), slices=self._slices_to_stitch, ) else: z_translation = [0.0] * n_proj _logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0") nx_tomo.sample.z_translation = z_translation nx_tomo.sample.name = self.series[0].sample_name # compute stitched frame shape if self.axis == 0: stitched_frame_shape = ( n_proj, ( numpy.asarray([scan.dim_2 for scan in self.series]).sum() - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum() ), self._stitching_constant_length, ) elif self.axis == 1: stitched_frame_shape = ( n_proj, self._stitching_constant_length, ( numpy.asarray([scan.dim_1 for scan in self.series]).sum() - numpy.asarray([abs(overlap) for overlap in self._axis_1_rel_final_shifts]).sum() ), ) else: raise NotImplementedError("stitching on pre-processing along axis 2 (x-ray direction) is not handled") if stitched_frame_shape[0] < 1 or stitched_frame_shape[1] < 1 or stitched_frame_shape[2] < 1: raise RuntimeError(f"Error in stitched frame shape calculation. {stitched_frame_shape} found.") # get expected output dataset first (just in case output and input files are the same) first_proj_idx = sorted(self.series[0].projections.keys())[0] first_proj_url = self.series[0].projections[first_proj_idx] if h5py.is_hdf5(first_proj_url.file_path()): first_proj_url = DataUrl( file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), scheme="h5py", ) # first save the NXtomo entry without the frame # dicttonx will fail if the folder does not exists dir_name = os.path.dirname(self.configuration.output_file_path) if dir_name not in (None, ""): os.makedirs(dir_name, exist_ok=True) nx_tomo.save( file_path=self.configuration.output_file_path, data_path=self.configuration.output_data_path, nexus_path_version=self.configuration.output_nexus_version, overwrite=self.configuration.overwrite_results, ) transformation_matrices = { scan.get_identifier() .to_str() .center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple()))) for scan in self.series } _logger.info( "scan detector transformation matrices are:\n" "\n".join(["/n".join(item) for item in transformation_matrices.items()]) ) _logger.info( f"reading order is {self.reading_orders}", ) def get_output_data_type(): return numpy.float32 # because we will apply flat field correction on it and they are not raw data output_dtype = get_output_data_type() # append frames ("instrument/detector/data" dataset) with HDF5File(self.configuration.output_file_path, mode="a") as h5f: # note: nx_tomo.save already handles the possible overwrite conflict by removing # self.configuration.output_file_path or raising an error stitched_frame_path = "/".join( [ self.configuration.output_data_path, _get_nexus_paths(self.configuration.output_nexus_version).PROJ_PATH, ] ) self.dumper.output_dataset = h5f.create_dataset( name=stitched_frame_path, shape=stitched_frame_shape, dtype=output_dtype, ) # TODO: we could also create in several time and create a virtual dataset from it. scans_projections_indexes = [] for scan, reverse in zip(self.series, self.reading_orders): scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1))) if self.progress: self.progress.total = self.get_n_slices_to_stitch() if isinstance(self._slices_to_stitch, slice): step = self._slices_to_stitch.step or 1 else: step = 1 i_proj = 0 for bunch_start, bunch_end in self._data_bunch_iterator(slices=self._slices_to_stitch, bunch_size=50): for data_frames in self._get_bunch_of_data( bunch_start, bunch_end, step=step, scans=self.series, scans_projections_indexes=scans_projections_indexes, flip_ud_arr=self.configuration.flip_ud, flip_lr_arr=self.configuration.flip_lr, reading_orders=self.reading_orders, ): if self.configuration.rescale_frames: data_frames = self.rescale_frames(data_frames) if self.configuration.normalization_by_sample.is_active(): data_frames = self.normalize_frame_by_sample(data_frames) sf = SingleAxisStitcher.stitch_frames( frames=data_frames, axis=self.axis, x_relative_shifts=self._axis_1_rel_final_shifts, y_relative_shifts=self._axis_0_rel_final_shifts, overlap_kernels=self._overlap_kernels, i_frame=i_proj, output_dtype=output_dtype, dumper=self.dumper, return_composition_cls=store_composition if i_proj == 0 else False, stitching_axis=self.axis, pad_mode=self.configuration.pad_mode, alignment=self.configuration.alignment_axis_2, new_width=self._stitching_constant_length, check_inputs=i_proj == 0, # on process check on the first iteration ) if i_proj == 0 and store_composition: _, self._frame_composition = sf if self.progress is not None: self.progress.update() i_proj += 1 # create link to this dataset that can be missing # "data/data" link if "data" in h5f[self.configuration.output_data_path]: data_group = h5f[self.configuration.output_data_path]["data"] if not stitched_frame_path.startswith("/"): stitched_frame_path = "/" + stitched_frame_path data_group["data"] = h5py.SoftLink(stitched_frame_path) if "default" not in h5f[self.configuration.output_data_path].attrs: h5f[self.configuration.output_data_path].attrs["default"] = "data" for attr_name, attr_value in zip( ("NX_class", "SILX_style/axis_scale_types", "signal"), ("NXdata", ["linear", "linear"], "data"), ): if attr_name not in data_group.attrs: data_group.attrs[attr_name] = attr_value return nx_tomo def _create_stitching(self, store_composition): self._create_nx_tomo(store_composition=store_composition) @staticmethod def get_bunch_of_data( bunch_start: int, bunch_end: int, step: int, scans: tuple, scans_projections_indexes: tuple, reading_orders: tuple, flip_lr_arr: tuple, flip_ud_arr: tuple, ): """ goal is to load contiguous projections as much as possible... :param int bunch_start: begining of the bunch :param int bunch_end: end of the bunch :param int scans: ordered scan for which we want to get data :param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded :param tuple flip_lr_arr: extra information from the user to left-right flip frames :param tuple flip_ud_arr: extra information from the user to up-down flip frames :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together """ assert len(scans) == len(scans_projections_indexes) assert isinstance(flip_lr_arr, tuple) assert isinstance(flip_ud_arr, tuple) assert isinstance(step, int) scans_proj_urls = [] # for each scan store the real indices and the data url for scan, scan_projection_indexes in zip(scans, scans_projections_indexes): scan_proj_urls = {} # for each scan get the list of url to be loaded for i_proj in range(bunch_start, bunch_end): if i_proj % step != 0: continue proj_index_in_full_scan = scan_projection_indexes[i_proj] scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan] scans_proj_urls.append(scan_proj_urls) # then load data all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object) from nabu.preproc.flatfield import FlatFieldArrays for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate( zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders) ): i_frame = 0 _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True) for _, url in set_of_compacted_slices.items(): scan = scans[i_scan] url = DataUrl( file_path=url.file_path(), data_path=url.data_path(), scheme="silx", data_slice=url.data_slice(), ) raw_radios = get_data(url)[::reading_order] radio_indices = url.data_slice() if isinstance(radio_indices, slice): step = radio_indices.step if radio_indices is not None else 1 radio_indices = numpy.arange( start=radio_indices.start, stop=radio_indices.stop, step=step, dtype=numpy.int16, ) missing = [] if len(scan.reduced_flats) == 0: missing = "flats" if len(scan.reduced_darks) == 0: missing = "darks" if len(missing) > 0: _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction") ff_arrays = None data = raw_radios else: has_reduced_metadata = ( scan.reduced_flats_infos is not None and len(scan.reduced_flats_infos.machine_electric_current) > 0 and scan.reduced_darks_infos is not None and len(scan.reduced_darks_infos.machine_electric_current) > 0 ) if not has_reduced_metadata: _logger.warning("no metadata about current found. Won't normalize according to machine current") ff_arrays = FlatFieldArrays( radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1), flats=scan.reduced_flats, darks=scan.reduced_darks, radios_indices=radio_indices, radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None, flats_srcurrent=( scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None ), ) # note: we need to cast radios to float 32. Darks and flats are cast to anyway data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32)) transformations = list(scans[i_scan].get_detector_transformations(tuple())) if scan_flip_lr: transformations.append(DetZFlipTransformation(flip=True)) if scan_flip_ud: transformations.append(DetYFlipTransformation(flip=True)) transformation_matrix_det_space = build_matrix(transformations) if transformation_matrix_det_space is None or numpy.allclose( transformation_matrix_det_space, numpy.identity(3) ): flip_ud = False flip_lr = False elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()): flip_ud = True flip_lr = False elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()): flip_ud = False flip_lr = True elif numpy.allclose( transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix() ): flip_ud = True flip_lr = True else: raise ValueError("case not handled... For now only handle up-down flip as left-right flip") for frame in data: if flip_ud: frame = numpy.flipud(frame) if flip_lr: frame = numpy.fliplr(frame) all_scan_final_data[i_frame, i_scan] = frame i_frame += 1 return all_scan_final_data def compute_reduced_flats_and_darks(self): """ make sure reduced dark and flats are existing otherwise compute them """ for scan in self.series: try: reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True) except: _logger.info("no reduced dark found. Try to compute them.") if reduced_darks in (None, {}): reduced_darks, darks_infos = scan.compute_reduced_darks(return_info=True) try: # if we don't have write in the folder containing the .nx for example scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos) except Exception as e: pass scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos) try: reduced_flats, flats_infos = scan.load_reduced_flats(return_info=True) except: _logger.info("no reduced flats found. Try to compute them.") if reduced_flats in (None, {}): reduced_flats, flats_infos = scan.compute_reduced_flats(return_info=True) try: # if we don't have write in the folder containing the .nx for example scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos) except Exception as e: pass scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos) @staticmethod @cache(maxsize=None) def _get_UD_flip_matrix(): return DetYFlipTransformation(flip=True).as_matrix() @staticmethod @cache(maxsize=None) def _get_LR_flip_matrix(): return DetZFlipTransformation(flip=True).as_matrix() @staticmethod @cache(maxsize=None) def _get_UD_AND_LR_flip_matrix(): return numpy.matmul( PreProcessingStitching._get_UD_flip_matrix(), PreProcessingStitching._get_LR_flip_matrix(), ) @staticmethod def _get_bunch_of_data( bunch_start: int, bunch_end: int, step: int, scans: tuple, scans_projections_indexes: tuple, reading_orders: tuple, flip_lr_arr: tuple, flip_ud_arr: tuple, ): """ goal is to load contiguous projections as much as possible... :param int bunch_start: begining of the bunch :param int bunch_end: end of the bunch :param int scans: ordered scan for which we want to get data :param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded :param tuple flip_lr_arr: extra information from the user to left-right flip frames :param tuple flip_ud_arr: extra information from the user to up-down flip frames :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together """ assert len(scans) == len(scans_projections_indexes) assert isinstance(flip_lr_arr, tuple) assert isinstance(flip_ud_arr, tuple) assert isinstance(step, int) scans_proj_urls = [] # for each scan store the real indices and the data url for scan, scan_projection_indexes in zip(scans, scans_projections_indexes): scan_proj_urls = {} # for each scan get the list of url to be loaded for i_proj in range(bunch_start, bunch_end): if i_proj % step != 0: continue proj_index_in_full_scan = scan_projection_indexes[i_proj] scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan] scans_proj_urls.append(scan_proj_urls) # then load data all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object) from nabu.preproc.flatfield import FlatFieldArrays for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate( zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders) ): i_frame = 0 _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True) for _, url in set_of_compacted_slices.items(): scan = scans[i_scan] url = DataUrl( file_path=url.file_path(), data_path=url.data_path(), scheme="silx", data_slice=url.data_slice(), ) raw_radios = get_data(url)[::reading_order] radio_indices = url.data_slice() if isinstance(radio_indices, slice): step = radio_indices.step if radio_indices is not None else 1 radio_indices = numpy.arange( start=radio_indices.start, stop=radio_indices.stop, step=step, dtype=numpy.int16, ) missing = [] if len(scan.reduced_flats) == 0: missing = "flats" if len(scan.reduced_darks) == 0: missing = "darks" if len(missing) > 0: _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction") ff_arrays = None data = raw_radios else: has_reduced_metadata = ( scan.reduced_flats_infos is not None and len(scan.reduced_flats_infos.machine_electric_current) > 0 and scan.reduced_darks_infos is not None and len(scan.reduced_darks_infos.machine_electric_current) > 0 ) if not has_reduced_metadata: _logger.warning("no metadata about current found. Won't normalize according to machine current") ff_arrays = FlatFieldArrays( radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1), flats=scan.reduced_flats, darks=scan.reduced_darks, radios_indices=radio_indices, radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None, flats_srcurrent=( scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None ), ) # note: we need to cast radios to float 32. Darks and flats are cast to anyway data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32)) transformations = list(scans[i_scan].get_detector_transformations(tuple())) if scan_flip_lr: transformations.append(DetZFlipTransformation(flip=True)) if scan_flip_ud: transformations.append(DetYFlipTransformation(flip=True)) transformation_matrix_det_space = build_matrix(transformations) if transformation_matrix_det_space is None or numpy.allclose( transformation_matrix_det_space, numpy.identity(3) ): flip_ud = False flip_lr = False elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()): flip_ud = True flip_lr = False elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()): flip_ud = False flip_lr = True elif numpy.allclose( transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix() ): flip_ud = True flip_lr = True else: raise ValueError("case not handled... For now only handle up-down flip as left-right flip") for frame in data: if flip_ud: frame = numpy.flipud(frame) if flip_lr: frame = numpy.fliplr(frame) all_scan_final_data[i_frame, i_scan] = frame i_frame += 1 return all_scan_final_data ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/stitcher/single_axis.py0000644000175000017500000004701214713343202021715 0ustar00pierrepierreimport h5py import numpy import logging from math import ceil from typing import Optional, Iterable, Union from tomoscan.series import Series from tomoscan.identifier import BaseIdentifier from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length from nabu.stitching.stitcher_2D import stitch_raw_frames from nabu.stitching.utils.utils import ShiftAlgorithm, from_slice_to_n_elements from nabu.stitching.overlap import ( check_overlaps, ImageStichOverlapKernel, ) from nabu.stitching.config import ( SingleAxisStitchingConfiguration, KEY_RESCALE_MIN_PERCENTILES, KEY_RESCALE_MAX_PERCENTILES, ) from nabu.misc.utils import rescale_data from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample from nabu.stitching.stitcher.dumper.base import DumperBase from silx.io.utils import get_data from silx.io.url import DataUrl from scipy.ndimage import shift as shift_scipy _logger = logging.getLogger(__name__) PROGRESS_BAR_STITCH_VOL_DESC = "stitch volumes" # description of the progress bar used when stitching volume. # Needed to retrieve advancement from file when stitching remotely class _SingleAxisMetaClass(type): """ Metaclass for single axis stitcher in order to aggregate dumper class and axis """ def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None): mcls = super().__new__(mcls, name, bases, attrs) mcls._axis = axis mcls._dumperCls = dumper_cls return mcls class SingleAxisStitcher(_StitcherBase, metaclass=_SingleAxisMetaClass): """ Any single-axis base class """ def __init__(self, configuration, *args, **kwargs) -> None: super().__init__(configuration, *args, **kwargs) if self._dumperCls is not None: self._dumper = self._dumperCls(configuration=configuration) else: self._dumper = None # initial shifts self._axis_0_rel_ini_shifts = [] """Shift between two juxtapose objects along axis 0 found from position metadata or given by the user""" self._axis_1_rel_ini_shifts = [] """Shift between two juxtapose objects along axis 1 found from position metadata or given by the user""" self._axis_2_rel_ini_shifts = [] """Shift between two juxtapose objects along axis 2 found from position metadata or given by the user""" # shifts to add once refine self._axis_0_rel_final_shifts = [] """Shift over axis 0 found once refined by the cross correlation algorithm""" self._axis_1_rel_final_shifts = [] """Shift over axis 1 found once refined by the cross correlation algorithm""" self._axis_2_rel_final_shifts = [] """Shift over axis 2 found once refined by the cross correlation algorithm""" self._slices_to_stitch = None # slices to be stitched. Obtained from calling Configuration.settle_slices self._stitching_constant_length = None # stitching width: larger volume width. Other volume will be pad def shifts_is_scalar(shifts): return isinstance(shifts, ShiftAlgorithm) or numpy.isscalar(shifts) # 'expend' shift algorithm if shifts_is_scalar(self.configuration.axis_0_pos_px): self.configuration.axis_0_pos_px = [ self.configuration.axis_0_pos_px, ] * (len(self.series) - 1) if shifts_is_scalar(self.configuration.axis_1_pos_px): self.configuration.axis_1_pos_px = [ self.configuration.axis_1_pos_px, ] * (len(self.series) - 1) if shifts_is_scalar(self.configuration.axis_1_pos_px): self.configuration.axis_1_pos_px = [ self.configuration.axis_1_pos_px, ] * (len(self.series) - 1) if numpy.isscalar(self.configuration.axis_0_params): self.configuration.axis_0_params = [ self.configuration.axis_0_params, ] * (len(self.series) - 1) if numpy.isscalar(self.configuration.axis_1_params): self.configuration.axis_1_params = [ self.configuration.axis_1_params, ] * (len(self.series) - 1) if numpy.isscalar(self.configuration.axis_2_params): self.configuration.axis_2_params = [ self.configuration.axis_2_params, ] * (len(self.series) - 1) @property def axis(self) -> int: return self._axis @property def dumper(self): return self._dumper @property def stitching_axis_in_frame_space(self): """ stitching is operated in 2D (frame) space. So the axis in frame space is different than the one in 3D ebs-tomo space (https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html) """ raise NotImplementedError("Base class") def stitch(self, store_composition: bool = True) -> BaseIdentifier: if self.progress is not None: self.progress.set_description("order scans") self.order_input_tomo_objects() if self.progress is not None: self.progress.set_description("check inputs") self.check_inputs() self.settle_flips() if self.progress is not None: self.progress.set_description("compute shifts") self._compute_positions_as_px() self.pre_processing_computation() self.compute_estimated_shifts() self._compute_shifts() self._createOverlapKernels() if self.progress is not None: self.progress.set_description(PROGRESS_BAR_STITCH_VOL_DESC) self._create_stitching(store_composition=store_composition) if self.progress is not None: self.progress.set_description("dump configuration") self.dumper.save_configuration() return self.dumper.output_identifier @property def serie_label(self) -> str: """return serie name for logs""" return "single axis serie" def get_n_slices_to_stitch(self): """Return the number of slice to be stitched""" if self._slices_to_stitch is None: raise RuntimeError("Slices needs to be settled first") return from_slice_to_n_elements(self._slices_to_stitch) def get_final_axis_positions_in_px(self) -> dict: """ compute the final position (**in pixel**) from the initial position of the first object and the final relative shift computed (1) (1): the final relative shift is obtained from the initial shift (from motor position of provided by the user) + the refinement shift from cross correlation algorithm :return: dict with tomo object identifier (str) as key and a tuple of position in pixel (axis_0_pos, axis_1_pos, axis_2_pos) """ pos_0_shift = numpy.concatenate( ( numpy.atleast_1d(0.0), numpy.array(self._axis_0_rel_final_shifts) - numpy.array(self._axis_0_rel_ini_shifts), ) ) pos_0_cum_shift = numpy.cumsum(pos_0_shift) final_pos_axis_0 = self.configuration.axis_0_pos_px + pos_0_cum_shift pos_1_shift = numpy.concatenate( ( numpy.atleast_1d(0.0), numpy.array(self._axis_1_rel_final_shifts) - numpy.array(self._axis_1_rel_ini_shifts), ) ) pos_1_cum_shift = numpy.cumsum(pos_1_shift) final_pos_axis_1 = self.configuration.axis_1_pos_px + pos_1_cum_shift pos_2_shift = numpy.concatenate( ( numpy.atleast_1d(0.0), numpy.array(self._axis_2_rel_final_shifts) - numpy.array(self._axis_2_rel_ini_shifts), ) ) pos_2_cum_shift = numpy.cumsum(pos_2_shift) final_pos_axis_2 = self.configuration.axis_2_pos_px + pos_2_cum_shift assert len(final_pos_axis_0) == len(final_pos_axis_1) assert len(final_pos_axis_0) == len(final_pos_axis_2) assert len(final_pos_axis_0) == len(self.series) return { tomo_obj.get_identifier().to_str(): (pos_0, pos_1, pos_2) for tomo_obj, (pos_0, pos_1, pos_2) in zip( self.series, zip(final_pos_axis_0, final_pos_axis_1, final_pos_axis_2) ) } def settle_flips(self): """ User can provide some information on existing flips at frame level. The goal of this step is to get one flip_lr and on flip_ud value per scan or volume """ if numpy.isscalar(self.configuration.flip_lr): self.configuration.flip_lr = tuple([self.configuration.flip_lr] * len(self.series)) else: if not len(self.configuration.flip_lr) == len(self.series): raise ValueError("flip_lr expects a scalar value or one value per element to stitch") self.configuration.flip_lr = tuple(self.configuration.flip_lr) for elmt in self.configuration.flip_lr: if not isinstance(elmt, bool): raise TypeError if numpy.isscalar(self.configuration.flip_ud): self.configuration.flip_ud = tuple([self.configuration.flip_ud] * len(self.series)) else: if not len(self.configuration.flip_ud) == len(self.series): raise ValueError("flip_ud expects a scalar value or one value per element to stitch") self.configuration.flip_ud = tuple(self.configuration.flip_ud) for elmt in self.configuration.flip_ud: if not isinstance(elmt, bool): raise TypeError def _createOverlapKernels(self): """ after this stage the overlap kernels must be created and with the final overlap size """ if self.axis == 0: stitched_axis_rel_shifts = self._axis_0_rel_final_shifts stitched_axis_params = self.configuration.axis_0_params elif self.axis == 1: stitched_axis_rel_shifts = self._axis_1_rel_final_shifts stitched_axis_params = self.configuration.axis_1_params elif self.axis == 2: stitched_axis_rel_shifts = self._axis_2_rel_final_shifts stitched_axis_params = self.configuration.axis_2_params else: raise NotImplementedError if stitched_axis_rel_shifts is None or len(stitched_axis_rel_shifts) == 0: raise RuntimeError( f"axis {self.axis} shifts have not been defined yet. Please define them before calling this function" ) overlap_size = stitched_axis_params.get("overlap_size", None) if overlap_size in (None, "None", ""): overlap_size = -1 else: overlap_size = int(overlap_size) self._stitching_constant_length = max( [get_obj_constant_side_length(obj, axis=self.axis) for obj in self.series] ) for stitched_axis_shift in stitched_axis_rel_shifts: if overlap_size == -1: height = abs(stitched_axis_shift) else: height = overlap_size self._overlap_kernels.append( ImageStichOverlapKernel( stitching_axis=self.stitching_axis_in_frame_space, frame_unstitched_axis_size=self._stitching_constant_length, stitching_strategy=self.configuration.stitching_strategy, overlap_size=height, extra_params=self.configuration.stitching_kernels_extra_params, ) ) @property def series(self) -> Series: return self._series @property def configuration(self) -> SingleAxisStitchingConfiguration: return self._configuration @property def progress(self): return self._progress @staticmethod def _data_bunch_iterator(slices, bunch_size): """util to get indices by bunch until we reach n_frames""" if isinstance(slices, slice): # note: slice step is handled at a different level start = end = slices.start while True: start, end = end, min((end + bunch_size), slices.stop) yield (start, end) if end >= slices.stop: break # in the case of non-contiguous frames elif isinstance(slices, Iterable): for s in slices: yield (s, s + 1) else: raise TypeError(f"slices is provided as {type(slices)}. When Iterable or slice is expected") def rescale_frames(self, frames: tuple): """ rescale_frames if requested by the configuration """ _logger.info("apply rescale frames") def cast_percentile(percentile) -> int: if isinstance(percentile, str): percentile.replace(" ", "").rstrip("%") return int(percentile) rescale_min_percentile = cast_percentile(self.configuration.rescale_params.get(KEY_RESCALE_MIN_PERCENTILES, 0)) rescale_max_percentile = cast_percentile( self.configuration.rescale_params.get(KEY_RESCALE_MAX_PERCENTILES, 100) ) new_min = numpy.percentile(frames[0], rescale_min_percentile) new_max = numpy.percentile(frames[0], rescale_max_percentile) def rescale(data): # FIXME: takes time because browse several time the dataset, twice for percentiles and twices to get min and max when calling rescale_data... data_min = numpy.percentile(data, rescale_min_percentile) data_max = numpy.percentile(data, rescale_max_percentile) return rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max) return tuple([rescale(data) for data in frames]) def normalize_frame_by_sample(self, frames: tuple): """ normalize frame from a sample picked on the left or the right """ _logger.info("apply normalization by a sample") return tuple( [ normalize_frame_by_sample( frame=frame, side=self.configuration.normalization_by_sample.side, method=self.configuration.normalization_by_sample.method, margin_before_sample=self.configuration.normalization_by_sample.margin, sample_width=self.configuration.normalization_by_sample.width, ) for frame in frames ] ) @staticmethod def stitch_frames( frames: Union[tuple, numpy.ndarray], axis, x_relative_shifts: tuple, y_relative_shifts: tuple, output_dtype: numpy.ndarray, stitching_axis: int, overlap_kernels: tuple, dumper: DumperBase = None, check_inputs=True, shift_mode="nearest", i_frame=None, return_composition_cls=False, alignment="center", pad_mode="constant", new_width: Optional[int] = None, ) -> numpy.ndarray: """ shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and save them to output_dataset. :param tuple frames: element must be a DataUrl or a 2D numpy array :param stitching_regions_hdf5_dataset: """ if check_inputs: if len(frames) < 2: raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)") if len(frames) != len(x_relative_shifts) + 1: raise ValueError( f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected" ) if len(x_relative_shifts) != len(overlap_kernels): raise ValueError( f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(overlap_kernels)})" ) if len(y_relative_shifts) != len(overlap_kernels): raise ValueError( f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})" ) relative_positions = [(0, 0, 0)] for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts): relative_positions.append( ( y_rel_pos + relative_positions[-1][0], 0, # position over axis 1 (aka y) is not handled yet x_rel_pos + relative_positions[-1][2], ) ) check_overlaps( frames=tuple(frames), positions=tuple(relative_positions), axis=axis, raise_error=False, ) def check_frame_is_2d(frame): if frame.ndim != 2: raise ValueError(f"2D frame expected when {frame.ndim}D provided") # step_0 load data if from url data = [] for frame in frames: if isinstance(frame, DataUrl): data_frame = get_data(frame) if check_inputs: check_frame_is_2d(data_frame) data.append(data_frame) elif isinstance(frame, numpy.ndarray): if check_inputs: check_frame_is_2d(frame) data.append(frame) else: raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}") # step 1: shift each frames (except the first one) if stitching_axis == 0: relative_shift_along_stitched_axis = y_relative_shifts relative_shift_along_unstitched_axis = x_relative_shifts elif stitching_axis == 1: relative_shift_along_stitched_axis = x_relative_shifts relative_shift_along_unstitched_axis = y_relative_shifts else: raise NotImplementedError("") shifted_data = [data[0]] for frame, relative_shift in zip(data[1:], relative_shift_along_unstitched_axis): # note: for now we only shift data in x. the y shift is handled in the FrameComposition relative_shift = numpy.asarray(relative_shift).astype(numpy.int8) if relative_shift == 0: shifted_frame = frame else: # TO speed up: should use the Fourier transform shifted_frame = shift_scipy( frame, mode=shift_mode, shift=[0, -relative_shift] if stitching_axis == 0 else [-relative_shift, 0], order=1, ) shifted_data.append(shifted_frame) # step 2: create stitched frame stitched_frame, composition_cls = stitch_raw_frames( frames=shifted_data, key_lines=( [ (int(frame.shape[stitching_axis] - abs(relative_shift / 2)), int(abs(relative_shift / 2))) for relative_shift, frame in zip(relative_shift_along_stitched_axis, frames) ] ), overlap_kernels=overlap_kernels, check_inputs=check_inputs, output_dtype=output_dtype, return_composition_cls=True, alignment=alignment, pad_mode=pad_mode, new_unstitched_axis_size=new_width, ) dumper.save_stitched_frame( stitched_frame=stitched_frame, composition_cls=composition_cls, i_frame=i_frame, axis=1, ) if return_composition_cls: return stitched_frame, composition_cls else: return stitched_frame ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/stitcher.py0000644000175000017500000000000014654107202021221 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/stitcher/y_stitcher.py0000644000175000017500000000047514654107202021571 0ustar00pierrepierrefrom nabu.stitching.stitcher.pre_processing import PreProcessingStitching from .dumper import PreProcessingStitchingDumper class PreProcessingYStitcher( PreProcessingStitching, dumper_cls=PreProcessingStitchingDumper, axis=1, ): @property def serie_label(self) -> str: return "y-serie" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/stitcher/z_stitcher.py0000644000175000017500000000252014713343202021561 0ustar00pierrepierrefrom nabu.stitching.stitcher.pre_processing import PreProcessingStitching from nabu.stitching.stitcher.post_processing import PostProcessingStitching from .dumper import PreProcessingStitchingDumper, PostProcessingStitchingDumperNoDD, PostProcessingStitchingDumper from nabu.stitching.stitcher.single_axis import _SingleAxisMetaClass class PreProcessingZStitcher( PreProcessingStitching, dumper_cls=PreProcessingStitchingDumper, axis=0, ): def check_inputs(self): """ insure input data is coherent """ super().check_inputs() for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]): if scan_0.dim_1 != scan_1.dim_1: raise ValueError( f"projections width are expected to be the same. Not the case for {scan_0} ({scan_0.dim_1} and {scan_1} ({scan_1.dim_1}))" ) class PostProcessingZStitcher( PostProcessingStitching, metaclass=_SingleAxisMetaClass, dumper_cls=PostProcessingStitchingDumper, axis=0, ): @property def serie_label(self) -> str: return "z-serie" class PostProcessingZStitcherNoDD( PostProcessingStitching, metaclass=_SingleAxisMetaClass, dumper_cls=PostProcessingStitchingDumperNoDD, axis=0, ): @property def serie_label(self) -> str: return "z-serie" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/stitcher_2D.py0000644000175000017500000003054714713343202017742 0ustar00pierrepierreimport numpy from math import ceil from typing import Union, Optional from nabu.stitching.overlap import ImageStichOverlapKernel from nabu.stitching.frame_composition import FrameComposition from nabu.stitching.alignment import align_frame, _Alignment def stitch_raw_frames( frames: tuple, key_lines: tuple, overlap_kernels: Union[ImageStichOverlapKernel, tuple], output_dtype: numpy.dtype = numpy.float32, check_inputs=True, raw_frames_compositions: Optional[FrameComposition] = None, overlap_frames_compositions: Optional[FrameComposition] = None, return_composition_cls=False, alignment: _Alignment = "center", pad_mode="constant", new_unstitched_axis_size: Optional[int] = None, ) -> numpy.ndarray: """ stitches raw frames (already shifted and flat fielded !!!) together using raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int). Sttiching depends on the kernel used. It can be done: * vertically: X ------------------------------------------------------------------> | -------------- | | | | | Frame 1 | -------------- | | | | Frame 1 | | -------------- | | Y | --> stitching --> |~ stitching ~| | -------------- | | | | | | Frame 2 | | | Frame 2 | -------------- | | | | -------------- \/ * horizontally: ------------------------------------------------------------------> | -------------- -------------- ----------------------- | | | | | | ~ ~ | Y | | Frame 1 | | Frame 2 | --> stitching --> | Frame 1 ~ ~ Frame 2 | | | | | | | ~ ~ | | -------------- -------------- ----------------------- | \/ returns stitched_projection, raw_img_1, raw_img_2, computed_overlap proj_0 and pro_1 are already expected to be in a row. Having stitching_height_in_px in common. At top of proj_0 and at bottom of proj_1 :param tuple frames: tuple of 2D numpy array. Expected to be Z up oriented at this stage :param tuple key_lines: for each jonction define the two lines to overlaid (from the upper and the lower frames). In the reference where 0 is the bottom line of the image. :param overlap_kernels: ZStichOverlapKernel overlap kernel to be used or a list of kernel (one per overlap). Define startegy and overlap heights :param numpy.dtype output_dtype: dataset dtype. For now must be provided because flat field corrcetion change data type (numpy.float32 for now) :param bool check_inputs: if True will do more test on inputs parameters like checking frame shapes, coherence of the request.. As it can be time consuming it is optional :param raw_frames_compositions: pre computed raw frame composition. If not provided will compute them. allow providing it to speed up calculation :param overlap_frames_compositions: pre computed stitched frame composition. If not provided will compute them. allow providing it to speed up calculation :param bool return_frame_compositions: if False return simply the stitched frames. Else return a tuple with stitching frame and the dictionnary with the composition frames... :param alignment: how to align frame if two frames have different size along the unstitched axis :param pad_mode: how to pad data for alignment (provided to numpy.pad function) :param new_unstitched_axis_size: size of the image along the axis not stitched. So it will be the frame width if the stitching axis is 0 and the frame height if the stitching axis is 1 """ if overlap_kernels is None or len(overlap_kernels) == 0: raise ValueError("overlap kernels must be provided") stitched_axis = overlap_kernels[0].stitched_axis unstitched_axis = overlap_kernels[0].unstitched_axis if check_inputs: # check frames are 2D numpy arrays def check_frame(proj): if not isinstance(proj, numpy.ndarray) and proj.ndim == 2: raise ValueError(f"frames are expected to be 2D numpy array") [check_frame(frame) for frame in frames] for frame_0, frame_1 in zip(frames[:-1], frames[1:]): if not (frame_0.ndim == frame_1.ndim == 2): raise ValueError("Frames are expected to be 2D") # check there is coherence between overlap kernels and frames for frame_0, frame_1, kernel in zip(frames[:-1], frames[1:], overlap_kernels): if frame_0.shape[stitched_axis] < kernel.overlap_size: raise ValueError( f"frame_0 height ({frame_0.shape[stitched_axis]}) is less than kernel overlap ({kernel.overlap_size})" ) if frame_1.shape[stitched_axis] < kernel.overlap_size: raise ValueError( f"frame_1 height ({frame_1.shape[stitched_axis]}) is less than kernel overlap ({kernel.overlap_size})" ) # check key lines are coherent with overlp kernels if not len(key_lines) == len(overlap_kernels): raise ValueError("we expect to have the same number of key_lines then the number of kernel") else: for key_line in key_lines: for value in key_line: if not isinstance(value, (int, numpy.integer)): raise TypeError(f"key_line is expected to be an integer. {type(key_line)} provided") elif value < 0: raise ValueError(f"key lines are expected to be positive values. Get {value} as key line value") # check overlap kernel stitching axis are coherent (for now make sure they are all along the same axis) if len(overlap_kernels) > 1: for previous_kernel, next_kernel in zip(overlap_kernels[:-1], overlap_kernels[1:]): if not isinstance(previous_kernel, ImageStichOverlapKernel): raise TypeError( f"overlap kernels must be instances of {ImageStichOverlapKernel}. Get {type(previous_kernel)}" ) if not isinstance(next_kernel, ImageStichOverlapKernel): raise TypeError( f"overlap kernels must be instances of {ImageStichOverlapKernel}. Get {type(next_kernel)}" ) if previous_kernel.stitched_axis != next_kernel.stitched_axis: raise ValueError( "kernels with different stitching axis provided. For now all kernels must have the same stitchign axis" ) if new_unstitched_axis_size is None: new_unstitched_axis_size = max([frame.shape[unstitched_axis] for frame in frames]) frames = tuple( [ align_frame( data=frame, alignment=alignment, new_aligned_axis_size=new_unstitched_axis_size, pad_mode=pad_mode, alignment_axis=unstitched_axis, ) for frame in frames ] ) # step 1: create numpy array that will contain stitching # if raw composition doesn't exists create it if raw_frames_compositions is None: raw_frames_compositions = FrameComposition.compute_raw_frame_compositions( frames=frames, overlap_kernels=overlap_kernels, key_lines=key_lines, stitching_axis=stitched_axis, ) new_stitched_axis_size = raw_frames_compositions.global_end[-1] - raw_frames_compositions.global_start[0] if stitched_axis == 0: stitched_projection_shape = ( int(new_stitched_axis_size), new_unstitched_axis_size, ) else: stitched_projection_shape = ( new_unstitched_axis_size, int(new_stitched_axis_size), ) stitch_array = numpy.empty(stitched_projection_shape, dtype=output_dtype) # step 2: set raw data # fill stitch array with raw data raw data raw_frames_compositions.compose( output_frame=stitch_array, input_frames=frames, ) # step 3 set stitched data # 3.1 create stitched overlaps stitched_overlap = [] for frame_0, frame_1, kernel, key_line in zip(frames[:-1], frames[1:], overlap_kernels, key_lines): assert kernel.overlap_size >= 0 frame_0_overlap, frame_1_overlap = get_overlap_areas( upper_frame=frame_0, lower_frame=frame_1, upper_frame_key_line=key_line[0], lower_frame_key_line=key_line[1], overlap_size=kernel.overlap_size, stitching_axis=stitched_axis, ) assert ( frame_0_overlap.shape[stitched_axis] == frame_1_overlap.shape[stitched_axis] == kernel.overlap_size ), f"{frame_0_overlap.shape[stitched_axis]} == {frame_1_overlap.shape[stitched_axis]} == {kernel.overlap_size}" stitched_overlap.append( kernel.stitch( frame_0_overlap, frame_1_overlap, )[0] ) # 3.2 fill stitched overlap on output array if overlap_frames_compositions is None: overlap_frames_compositions = FrameComposition.compute_stitch_frame_composition( frames=frames, overlap_kernels=overlap_kernels, key_lines=key_lines, stitching_axis=stitched_axis, ) overlap_frames_compositions.compose( output_frame=stitch_array, input_frames=stitched_overlap, ) if return_composition_cls: return ( stitch_array, { "raw_composition": raw_frames_compositions, "overlap_composition": overlap_frames_compositions, }, ) return stitch_array def get_overlap_areas( upper_frame: numpy.ndarray, lower_frame: numpy.ndarray, upper_frame_key_line: int, lower_frame_key_line: int, overlap_size: int, stitching_axis: int, ): """ return the requested area from lower_frame and upper_frame. Lower_frame contains at the end of it the 'real overlap' with the upper_frame. Upper_frame contains the 'real overlap' at the end of it. For some reason the user can ask the stitching height to be smaller than the `real overlap`. Here are some drawing to have a better of view of those regions: .. image:: images/stitching/z_stitch_real_overlap.png :width: 600 .. image:: z_stitch_stitch_height.png :width: 600 """ assert stitching_axis in (0, 1) for pf, pn in zip((lower_frame_key_line, upper_frame_key_line), ("lower_frame", "upper_frame")): if not isinstance(pf, (int, numpy.number)): raise TypeError(f"{pn} is expected to be a number. {type(pf)} provided") assert overlap_size >= 0 lf_start = ceil(lower_frame_key_line - overlap_size / 2) lf_end = ceil(lower_frame_key_line + overlap_size / 2) uf_start = ceil(upper_frame_key_line - overlap_size / 2) uf_end = ceil(upper_frame_key_line + overlap_size / 2) lf_start, lf_end = min(lf_start, lf_end), max(lf_start, lf_end) uf_start, uf_end = min(uf_start, uf_end), max(uf_start, uf_end) if lf_start < 0 or uf_start < 0: raise ValueError( f"requested overlap ({overlap_size}) is incoherent with key line positions ({lower_frame_key_line}, {upper_frame_key_line}) - expected to be smaller." ) if stitching_axis == 0: overlap_upper = upper_frame[uf_start:uf_end] overlap_lower = lower_frame[lf_start:lf_end] elif stitching_axis == 1: overlap_upper = upper_frame[:, uf_start:uf_end] overlap_lower = lower_frame[:, lf_start:lf_end] else: raise NotImplementedError if not overlap_upper.shape == overlap_lower.shape: # maybe in the future: try to reduce one according to the other ???? raise RuntimeError( f"lower and upper frame have different overlap size ({overlap_upper.shape} vs {overlap_lower.shape})" ) return overlap_upper, overlap_lower ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/tests/0000755000175000017500000000000014730277752016366 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/stitching/tests/__init__.py0000644000175000017500000000000014550227307020454 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/stitching/tests/test_alignment.py0000644000175000017500000000525714550227307021755 0ustar00pierrepierreimport numpy import pytest from nabu.stitching.alignment import align_horizontally, PaddedRawData from nabu.testutils import get_data def test_alignment_axis_2(): """ test 'align_horizontally' function """ dataset = get_data("chelsea.npz")["data"] # shape is (300, 451) # test if new_width < current_width: should raise an error with pytest.raises(ValueError): align_horizontally(dataset, alignment="center", new_width=10) # test some use cases res = align_horizontally( dataset, alignment="center", new_width=600, pad_mode="mean", ) assert res.shape == (300, 600) numpy.testing.assert_array_almost_equal(res[:, 74:-75], dataset) res = align_horizontally( dataset, alignment="left", new_width=600, pad_mode="median", ) assert res.shape == (300, 600) numpy.testing.assert_array_almost_equal(res[:, :451], dataset) res = align_horizontally( dataset, alignment="right", new_width=600, pad_mode="reflect", ) assert res.shape == (300, 600) numpy.testing.assert_array_almost_equal(res[:, -451:], dataset) def test_PaddedRawData(): """ test PaddedVolume class """ data = numpy.linspace( start=0, stop=20 * 6 * 3, dtype=numpy.int64, num=20 * 6 * 3, ) data = data.reshape((3, 6, 20)) padded_volume = PaddedRawData(data=data, axis_1_pad_width=(4, 1)) assert padded_volume.shape == (3, 6 + 4 + 1, 20) numpy.testing.assert_array_equal( padded_volume[:, 0, :], numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64), ) numpy.testing.assert_array_equal( padded_volume[:, 3, :], numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64), ) numpy.testing.assert_array_equal( padded_volume[:, 10, :], numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64), ) assert padded_volume[:, 3, :].shape == (3, 1, 20) numpy.testing.assert_array_equal( padded_volume[:, 4, :], data[:, 0:1, :], # TODO: have a look, return a 3D array when a 2D expected... ) with pytest.raises(ValueError): padded_volume[:, 40, :] with pytest.raises(ValueError): padded_volume[:, 5:1, :] arrays = ( numpy.zeros(shape=(3, 4, 20), dtype=numpy.int64), data, numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64), ) expected_volume = numpy.hstack( arrays, ) assert padded_volume[:, :, :].shape == padded_volume.shape assert expected_volume.shape == padded_volume.shape numpy.testing.assert_array_equal( padded_volume[:, :, :], expected_volume, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/tests/test_config.py0000644000175000017500000001637514713343202021241 0ustar00pierrepierreimport os from tempfile import TemporaryDirectory import pytest from nabu.pipeline.config import ( generate_nabu_configfile, _options_levels, parse_nabu_config_file, ) from nabu.stitching.overlap import OverlapStitchingStrategy from nabu.stitching import config as stiching_config _stitching_types = list(stiching_config.StitchingType.values()) _stitching_types.append(None) def nabu_config_to_dict(nabu_config): res = {} for section, section_content in nabu_config.items(): res[section] = {} for key, values in section_content.items(): res[section][key] = values["default"] return res @pytest.mark.parametrize("stitching_type", _stitching_types) @pytest.mark.parametrize("option_level", _options_levels.keys()) def test_stitching_config(stitching_type, option_level): """ insure get_default_stitching_config is returning a dict and is coherent with the configuration classes """ with TemporaryDirectory() as output_dir: nabu_dict = stiching_config.get_default_stitching_config(stitching_type) config = nabu_config_to_dict(nabu_dict) assert isinstance(config, dict) assert "stitching" in config assert "type" in config["stitching"] stitching_type = stiching_config.StitchingType.from_value(config["stitching"]["type"]) if stitching_type is stiching_config.StitchingType.Z_POSTPROC: assert isinstance( stiching_config.dict_to_config_obj(config), stiching_config.PostProcessedSingleAxisStitchingConfiguration, ) elif stitching_type is stiching_config.StitchingType.Z_PREPROC: assert isinstance( stiching_config.dict_to_config_obj(config), stiching_config.PreProcessedSingleAxisStitchingConfiguration, ) elif stitching_type is stiching_config.StitchingType.Y_PREPROC: assert isinstance( stiching_config.dict_to_config_obj(config), stiching_config.PreProcessedSingleAxisStitchingConfiguration, ) else: raise ValueError("not handled") # dump configuration to file output_file = os.path.join(output_dir, "config.conf") generate_nabu_configfile( fname=output_file, default_config=nabu_dict, comments=True, sections_comments=stiching_config.SECTIONS_COMMENTS, options_level=option_level, prefilled_values={}, ) # load configuration from file loaded_config = parse_nabu_config_file(output_file) config_class_instance = stiching_config.dict_to_config_obj(loaded_config) if stitching_type is stiching_config.StitchingType.Z_POSTPROC: assert isinstance( config_class_instance, stiching_config.PostProcessedSingleAxisStitchingConfiguration, ) elif stitching_type is stiching_config.StitchingType.Z_PREPROC: assert isinstance( config_class_instance, stiching_config.PreProcessedSingleAxisStitchingConfiguration, ) assert isinstance(config_class_instance.to_dict(), dict) @pytest.mark.parametrize("stitching_strategy", OverlapStitchingStrategy.values()) @pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1")) @pytest.mark.parametrize( "axis_shifts", ( "", None, "None", "", "skimage", "nabu-fft", ), ) @pytest.mark.parametrize("axis_shifts_params", ("", {}, "window_size=200")) @pytest.mark.parametrize( "slice_for_correlation", ( "middle", "3", ), ) @pytest.mark.parametrize("slices", ("middle", "0:26:2")) @pytest.mark.parametrize( "input_scans", ( "", "hdf5:scan:/data/scan.hdf5?path=entry; hdf5:scan:/data/scan.hdf5?path=entry1", ), ) @pytest.mark.parametrize( "slurm_config", ( { stiching_config.SLURM_MODULES_TO_LOADS: "tomotools", stiching_config.SLURM_PREPROCESSING_COMMAND: "", stiching_config.SLURM_CLEAN_SCRIPTS: True, stiching_config.SLURM_MEM: 56, stiching_config.SLURM_N_JOBS: 5, stiching_config.SLURM_PARTITION: "my_partition", }, ), ) def test_PreProcessedZStitchingConfiguration( stitching_strategy, overwrite_results, axis_shifts, axis_shifts_params, input_scans, slice_for_correlation, slices, slurm_config, ): """ make sure configuration works well for PreProcessedZStitchingConfiguration """ pre_process_config = stiching_config.PreProcessedZStitchingConfiguration.from_dict( { stiching_config.STITCHING_SECTION: { stiching_config.CROSS_CORRELATION_SLICE_FIELD: slice_for_correlation, stiching_config.AXIS_0_POS_PX: axis_shifts, stiching_config.AXIS_1_POS_PX: axis_shifts, stiching_config.AXIS_2_POS_PX: axis_shifts, stiching_config.AXIS_0_PARAMS: axis_shifts_params, stiching_config.AXIS_1_PARAMS: axis_shifts_params, stiching_config.AXIS_2_PARAMS: axis_shifts_params, stiching_config.STITCHING_STRATEGY_FIELD: stitching_strategy, }, stiching_config.INPUTS_SECTION: { stiching_config.INPUT_DATASETS_FIELD: input_scans, stiching_config.STITCHING_SLICES: slices, }, stiching_config.OUTPUT_SECTION: { stiching_config.OVERWRITE_RESULTS_FIELD: overwrite_results, }, stiching_config.PRE_PROC_SECTION: { stiching_config.DATA_FILE_FIELD: "my_file.nx", stiching_config.DATA_PATH_FIELD: "entry", stiching_config.NEXUS_VERSION_FIELD: None, }, stiching_config.SLURM_SECTION: slurm_config, stiching_config.NORMALIZATION_BY_SAMPLE_SECTION: { stiching_config.NORMALIZATION_BY_SAMPLE_MARGIN: 1, stiching_config.NORMALIZATION_BY_SAMPLE_SIDE: "right", stiching_config.NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: True, stiching_config.NORMALIZATION_BY_SAMPLE_METHOD: "mean", stiching_config.NORMALIZATION_BY_SAMPLE_WIDTH: 31, }, }, ) from_dict = stiching_config.PreProcessedZStitchingConfiguration.from_dict(pre_process_config.to_dict()) # workaround for scans because a new object is created each time pre_process_config.settle_inputs assert len(from_dict.input_scans) == len(pre_process_config.input_scans) from_dict.input_scans = None pre_process_config.input_scans = None assert pre_process_config == from_dict def test_PostProcessedZStitchingConfiguration(): """ make sure configuration works well for PostProcessedZStitchingConfiguration """ pass def test_description_dict(): """ make sure the description dict (used for generating the file) is working and generates a dict """ assert isinstance(stiching_config.PreProcessedSingleAxisStitchingConfiguration.get_description_dict(), dict) assert isinstance( stiching_config.PostProcessedSingleAxisStitchingConfiguration.get_description_dict(), dict, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/tests/test_frame_composition.py0000644000175000017500000001134714654107202023505 0ustar00pierrepierreimport pytest from nabu.stitching.frame_composition import FrameComposition import numpy from nabu.stitching.overlap import OverlapStitchingStrategy, ImageStichOverlapKernel def test_frame_composition(): """ Test FrameComposition """ frame_0 = numpy.zeros((100, 1)) frame_1 = numpy.ones((98, 1)) frame_2 = numpy.ones((205, 1)) * 2.0 frames = (frame_0, frame_1, frame_2) y_shifts = -20, -10 kernels = [ ImageStichOverlapKernel( stitching_axis=0, frame_unstitched_axis_size=1, stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overlap_size=4, ), ImageStichOverlapKernel( stitching_axis=0, frame_unstitched_axis_size=1, stitching_strategy=OverlapStitchingStrategy.MEAN, overlap_size=8, ), ] # check raw composition raw_composition = FrameComposition.compute_raw_frame_compositions( frames=frames, key_lines=( (90, 10), (98 - 5, 5), ), overlap_kernels=kernels, stitching_axis=0, ) assert isinstance(raw_composition, FrameComposition) assert raw_composition.local_start == (0, 12, 9) assert raw_composition.local_end == (88, 89, 205) assert raw_composition.global_start == (0, 92, 177) assert raw_composition.global_end == (88, 169, 373) stitched_data = numpy.empty((100 + 98 + 205 - 30, 1)) raw_composition.compose(output_frame=stitched_data, input_frames=frames) assert stitched_data[0, 0] == 0 assert stitched_data[150, 0] == 1.0 assert stitched_data[-1, 0] == 2.0 # check stitch composition stitch_composition = FrameComposition.compute_stitch_frame_composition( frames=frames, key_lines=( (90, 10), (98 - 5, 5), ), overlap_kernels=kernels, stitching_axis=0, ) FrameComposition.pprint_composition(raw_composition, stitch_composition) assert stitch_composition.local_start == (0, 0) assert stitch_composition.local_end == (4, 8) assert stitch_composition.global_start == (88, 169) assert stitch_composition.global_end == (92, 177) stitched_frames = [] for frame_0, frame_1, kernel, y_shift in zip(frames[:-1], frames[1:], kernels, y_shifts): # take frames once shifted frame_0_overlap = frame_0[y_shift:] frame_1_overlap = frame_1[:-y_shift] # select the overlap area frame_0_overlap = frame_0[-kernel.overlap_size :] frame_1_overlap = frame_1[: kernel.overlap_size] stitched_frames.append(kernel.stitch(frame_0_overlap, frame_1_overlap)[0]) stitch_composition.compose( output_frame=stitched_data, input_frames=stitched_frames, ) assert 0.0 < stitched_data[90, 0] < 1.0 assert 1.0 < stitched_data[172, 0] < 2.0 _raw_comp_config = ( { "key_lines": ( (17, 2), (36, 3), ), "raw_global_start": (0, 19, 53), "raw_global_end": (16, 49, 68), "raw_local_start": (0, 4, 5), "raw_local_end": (16, 34, 20), "kernels": ( ImageStichOverlapKernel( stitching_axis=0, frame_unstitched_axis_size=1, stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overlap_size=3, ), ImageStichOverlapKernel( stitching_axis=0, frame_unstitched_axis_size=1, stitching_strategy=OverlapStitchingStrategy.MEAN, overlap_size=4, ), ), }, ) @pytest.mark.parametrize("configuration", _raw_comp_config) def test_raw_frame_composition_exotic_config(configuration): """ Test some """ frame_0 = numpy.zeros((20, 1)) frame_1 = numpy.ones((40, 1)) frame_2 = numpy.ones((20, 1)) * 2.0 frames = (frame_0, frame_1, frame_2) key_lines = configuration.get("key_lines") kernels = configuration.get("kernels") # check raw composition raw_composition = FrameComposition.compute_raw_frame_compositions( frames=frames, overlap_kernels=kernels, key_lines=key_lines, stitching_axis=0, ) assert raw_composition.global_start == configuration.get("raw_global_start") assert raw_composition.global_end == configuration.get("raw_global_end") assert raw_composition.local_start == configuration.get("raw_local_start") assert raw_composition.local_end == configuration.get("raw_local_end") stitched_data = numpy.empty( ( (raw_composition.global_end[-1] - raw_composition.global_start[0]), 1, ) ) raw_composition.compose(output_frame=stitched_data, input_frames=frames) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/tests/test_overlap.py0000644000175000017500000001721414654107202021437 0ustar00pierrepierreimport numpy import pytest from nabu.stitching.overlap import compute_image_minimum_divergence, compute_image_higher_signal, check_overlaps from nabu.testutils import get_data from nabu.stitching.overlap import ImageStichOverlapKernel, OverlapStitchingStrategy from nabu.stitching.stitcher_2D import stitch_raw_frames from silx.image.phantomgenerator import PhantomGenerator strategies_to_test_weights = ( OverlapStitchingStrategy.CLOSEST, OverlapStitchingStrategy.COSINUS_WEIGHTS, OverlapStitchingStrategy.LINEAR_WEIGHTS, OverlapStitchingStrategy.MEAN, ) @pytest.mark.parametrize("strategy", strategies_to_test_weights) @pytest.mark.parametrize("stitching_axis", (0, 1)) def test_overlap_stitcher(strategy, stitching_axis): frame_width = 128 frame_height = frame_width frame_1 = PhantomGenerator.get2DPhantomSheppLogan(n=frame_width) stitcher = ImageStichOverlapKernel( stitching_strategy=strategy, overlap_size=frame_height, frame_unstitched_axis_size=128, stitching_axis=stitching_axis, ) stitched_frame = stitcher.stitch(frame_1, frame_1)[0] assert stitched_frame.shape == (frame_height, frame_width) # check result is close to the expected one numpy.testing.assert_allclose(frame_1, stitched_frame, atol=10e-10) # check sum of weights ~ 1.0 numpy.testing.assert_allclose( stitcher.weights_img_1 + stitcher.weights_img_2, numpy.ones_like(stitcher.weights_img_1), ) @pytest.mark.parametrize("stitching_axis", (0, 1)) def test_compute_image_minimum_divergence(stitching_axis): """make sure the compute_image_minimum_divergence function is processing""" raw_data_1 = get_data("brain_phantom.npz")["data"] raw_data_2 = numpy.random.rand(*raw_data_1.shape) * 255.0 stitching = compute_image_minimum_divergence( raw_data_1, raw_data_2, high_frequency_threshold=2, stitching_axis=stitching_axis ) assert stitching.shape == raw_data_1.shape def test_compute_image_higher_signal(): """ make sure compute_image_higher_signal is processing """ raw_data = get_data("brain_phantom.npz")["data"] raw_data_1 = raw_data.copy() raw_data_1[40:75] = 0.0 raw_data_1[:, 210:245] = 0.0 raw_data_2 = raw_data.copy() raw_data_2[:, 100:120] = 0.0 stitching = compute_image_higher_signal(raw_data_1, raw_data_2) numpy.testing.assert_array_equal( stitching, raw_data, ) def test_check_overlaps(): """test 'check_overlaps' function""" # two frames, ordered and with an overlap check_overlaps( frames=( numpy.ones(10), numpy.ones(20), ), positions=((10, 0, 0), (0, 0, 0)), axis=0, raise_error=True, ) # two frames, ordered and without an overlap with pytest.raises(ValueError): check_overlaps( frames=( numpy.ones(10), numpy.ones(20), ), positions=((0, 0, 0), (100, 0, 0)), axis=0, raise_error=True, ) # two frames, frame 0 fully overlap frame 1 with pytest.raises(ValueError): check_overlaps( frames=( numpy.ones(20), numpy.ones(10), ), positions=((8, 0, 0), (5, 0, 0)), axis=0, raise_error=True, ) # three frames 'overlaping' as expected check_overlaps( frames=( numpy.ones(10), numpy.ones(20), numpy.ones(10), ), positions=((20, 0, 0), (10, 0, 0), (0, 0, 0)), axis=0, raise_error=True, ) # three frames: frame 0 overlap frame 1 but also frame 2 with pytest.raises(ValueError): check_overlaps( frames=( numpy.ones(20), numpy.ones(10), numpy.ones(10), ), positions=((20, 0, 0), (15, 0, 0), (11, 0, 0)), axis=0, raise_error=True, ) @pytest.mark.parametrize("dtype", (numpy.float16, numpy.float32)) def test_stitch_vertically_raw_frames(dtype): """ ensure a stitching with 3 frames and different overlap can be done """ ref_frame_width = 256 frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) # split the frame into several part frame_1 = frame_ref[0:100] frame_2 = frame_ref[80:164] frame_3 = frame_ref[154:] kernel_1 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=20, stitching_axis=0) kernel_2 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=10, stitching_axis=0) stitched = stitch_raw_frames( frames=(frame_1, frame_2, frame_3), output_dtype=dtype, overlap_kernels=(kernel_1, kernel_2), raw_frames_compositions=None, overlap_frames_compositions=None, key_lines=( ( 90, # frame_1 height - kernel_1 height / 2.0 10, # kernel_1 height / 2.0 ), ( 79, # frame_2 height - kernel_2 height / 2.0 ou 102-20 ? 5, # kernel_2 height / 2.0 ), ), ) assert stitched.shape == frame_ref.shape numpy.testing.assert_array_almost_equal(frame_ref, stitched) def test_stitch_vertically_raw_frames_2(): """ ensure a stitching with 3 frames and different overlap can be done """ ref_frame_width = 256 frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) # split the frame into several part frame_1 = frame_ref.copy() frame_2 = frame_ref.copy() frame_3 = frame_ref.copy() kernel_1 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=10, stitching_axis=0) kernel_2 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=10, stitching_axis=0) stitched = stitch_raw_frames( frames=(frame_1, frame_2, frame_3), output_dtype=numpy.float32, overlap_kernels=(kernel_1, kernel_2), raw_frames_compositions=None, overlap_frames_compositions=None, key_lines=((20, 20), (105, 105)), ) assert stitched.shape == frame_ref.shape numpy.testing.assert_array_almost_equal(frame_ref, stitched) @pytest.mark.parametrize("dtype", (numpy.float16, numpy.float32)) def test_stitch_horizontally_raw_frames(dtype): """ ensure a stitching with 3 frames and different overlap can be done along axis 1 """ ref_frame_width = 256 frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) # split the frame into several part frame_1 = frame_ref[:, 0:100] frame_2 = frame_ref[:, 80:164] frame_3 = frame_ref[:, 154:] kernel_1 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=20, stitching_axis=1) kernel_2 = ImageStichOverlapKernel(frame_unstitched_axis_size=ref_frame_width, overlap_size=10, stitching_axis=1) stitched = stitch_raw_frames( frames=(frame_1, frame_2, frame_3), output_dtype=dtype, overlap_kernels=(kernel_1, kernel_2), raw_frames_compositions=None, overlap_frames_compositions=None, key_lines=( ( 90, # frame_1 height - kernel_1 height / 2.0 10, # kernel_1 height / 2.0 ), ( 79, # frame_2 height - kernel_2 height / 2.0 ou 102-20 ? 5, # kernel_2 height / 2.0 ), ), ) assert stitched.shape == frame_ref.shape numpy.testing.assert_array_almost_equal(frame_ref, stitched) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/stitching/tests/test_sample_normalization.py0000644000175000017500000000250414550227307024216 0ustar00pierrepierreimport numpy import pytest from nabu.stitching.sample_normalization import normalize_frame, SampleSide, Method def test_normalize_frame(): """ test normalize_frame function """ with pytest.raises(TypeError): normalize_frame("toto", "left", "median") with pytest.raises(TypeError): normalize_frame(numpy.linspace(0, 100), "left", "median") frame = numpy.ones((10, 40)) frame[:, 15:25] = numpy.arange(1, 101, step=1).reshape((10, 10)) numpy.testing.assert_array_equal( normalize_frame( frame=frame, side="left", method="mean", sample_width=10, margin_before_sample=2, )[:, 15:25], numpy.arange(0, 100, step=1).reshape((10, 10)), ) numpy.testing.assert_array_equal( normalize_frame( frame=frame, side="right", method="median", sample_width=10, margin_before_sample=2, )[:, 15:25], numpy.arange(0, 100, step=1).reshape((10, 10)), ) assert not numpy.array_equal( normalize_frame( frame=frame, side="right", method="mean", sample_width=10, margin_before_sample=20, )[:, 15:25], numpy.arange(0, 100, step=1).reshape((10, 10)), ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/tests/test_slurm_utils.py0000644000175000017500000001152614713343202022347 0ustar00pierrepierreimport os import numpy import pytest from tomoscan.esrf import NXtomoScan from tomoscan.esrf.volume import HDF5Volume from tomoscan.esrf.scan.utils import cwd_context from nabu.stitching.config import PreProcessedZStitchingConfiguration, SlurmConfig from nabu.stitching.overlap import OverlapStitchingStrategy from nabu.stitching.slurm_utils import ( split_slices, get_working_directory, split_stitching_configuration_to_slurm_job, ) from tomoscan.esrf.mock import MockNXtomo try: import sluurp except ImportError: has_sluurp = False else: has_sluurp = True def test_split_slices(): """test split_slices function""" assert tuple(split_slices(slice(0, 100, 1), n_parts=4)) == ( slice(0, 25, 1), slice(25, 50, 1), slice(50, 75, 1), slice(75, 100, 1), ) assert tuple(split_slices(slice(0, 50, 2), n_parts=3)) == ( slice(0, 17, 2), slice(17, 34, 2), slice(34, 50, 2), ) assert tuple(split_slices(slice(0, 100, 1), n_parts=1)) == (slice(0, 100, 1),) assert tuple(split_slices(("first", "middle", "last"), 2)) == ( ("first", "middle"), ("last",), ) assert tuple( split_slices( ( 10, 12, 13, ), 4, ) ) == ((10,), (12,), (13,)) with pytest.raises(TypeError): next(split_slices("dsad", 12)) def test_get_working_directory(): """test get_working_directory function""" assert get_working_directory(NXtomoScan("/this/is/my/hdf5file.hdf5", "entry")) == "/this/is/my" assert get_working_directory(HDF5Volume("/this/is/my/volume.hdf5", "entry")) == "/this/is/my" @pytest.mark.skipif(not has_sluurp, reason="sluurp not installed") def test_split_stitching_configuration_to_slurm_job(tmp_path): """ test split_stitching_configuration_to_slurm_job behavior This test is stitching two existing NXtomo (scan1 and scan2 contained in inputs_dir) and create a final_nx_tomo.nx to the output_dir The stitching will be split in two slurm jobs. One will create output_dir/final_nx_tomo/final_nx_tomo_part_0.nx and the second output_dir/final_nx_tomo/final_nx_tomo_part_1.nx then the concatenation (not tested here) will create a output_dir/final_nx_tomo.nx redirecting to the sub parts This test only focus on checking each sub configuration is as expected """ inputs_dir = tmp_path / "inputs" inputs_dir.mkdir() output_dir = tmp_path / "outputs" output_dir.mkdir() with cwd_context(inputs_dir): # the current working directory context help to check file path are moved to absolute. # which is important because those jobs will be launched on slurm scan_1 = MockNXtomo( os.path.join("scan_1"), n_proj=10, n_ini_proj=10, dim=100, ).scan scan_2 = MockNXtomo( os.path.join("scan_2"), n_proj=10, n_ini_proj=10, dim=100, ).scan n_jobs = 2 raw_config = PreProcessedZStitchingConfiguration( axis_0_pos_px=None, axis_0_pos_mm=None, axis_0_params={}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={}, stitching_strategy=OverlapStitchingStrategy.MEAN, overwrite_results=True, slurm_config=SlurmConfig( partition="par-test", mem="45G", n_jobs=n_jobs, other_options="", preprocessing_command="source /my/venv", clean_script=True, ), slices=slice(0, 120, 1), input_scans=(scan_1, scan_2), output_file_path=os.path.join("../outputs/", "final_nx_tomo.nx"), output_data_path="stitched_entry", output_nexus_version=None, slice_for_cross_correlation="middle", pixel_size=None, ) sbatch_script_jobs = [] stitching_configurations = [] for job, configuration in split_stitching_configuration_to_slurm_job(raw_config, yield_configuration=True): sbatch_script_jobs.append(job) stitching_configurations.append(configuration) assert len(stitching_configurations) == n_jobs == len(sbatch_script_jobs) for i_sub_config, sub_config in enumerate(stitching_configurations): assert isinstance(sub_config, type(raw_config)) assert sub_config.slurm_config is None assert sub_config.output_file_path == os.path.join( output_dir, "final_nx_tomo", f"final_nx_tomo_part_{i_sub_config}.nx" ) assert raw_config.output_file_path == os.path.join(output_dir, "final_nx_tomo.nx") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/tests/test_utils.py0000644000175000017500000000120714654107202021122 0ustar00pierrepierrefrom nabu.stitching.utils.utils import has_itk, find_shift_with_itk from scipy.ndimage import shift as shift_scipy import numpy import pytest from nabu.testutils import get_data @pytest.mark.parametrize("data_type", (numpy.float32, numpy.uint16)) @pytest.mark.skipif(not has_itk, reason="itk not installed") def test_find_shift_with_itk(data_type): shift = (5, 2) img1 = get_data("chelsea.npz")["data"].astype(data_type) img2 = shift_scipy( img1.copy(), shift=shift, order=1, ) img1 = img1[10:-10, 10:-10] img2 = img2[10:-10, 10:-10] assert find_shift_with_itk(img1=img1, img2=img2) == shift ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/tests/test_y_preprocessing_stitching.py0000644000175000017500000001163514713343202025255 0ustar00pierrepierreimport os import pytest import numpy from tqdm import tqdm from nabu.stitching.y_stitching import y_stitching from nabu.stitching.config import PreProcessedYStitchingConfiguration from nxtomo.application.nxtomo import NXtomo from nxtomo.nxobject.nxdetector import ImageKey from tomoscan.esrf.scan.nxtomoscan import NXtomoScan def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple: r""" build two nxtomos in output_dir and return the list of NXtomos ready to be stitched /\ | ______________ ______________ | |~ ~~| |~ | | |~ nxtomo 1 ~~| |~ nxtomo 0 | Z | |~ frame ~~| |~ frame | |______________| |______________| <----------------------------------------------- 90 40 0 y (in acquisition space) * ~: represent the overlap area """ dark_data = numpy.array([0] * 64 * 120, dtype=numpy.float32).reshape((64, 120)) flat_data = numpy.array([1] * 64 * 120, dtype=numpy.float32).reshape((64, 120)) normalized_data = numpy.linspace(128, 1024, num=64 * 120, dtype=numpy.float32).reshape((64, 120)) if flip_lr: dark_data = numpy.fliplr(dark_data) flat_data = numpy.fliplr(flat_data) normalized_data = numpy.fliplr(normalized_data) if flip_ud: dark_data = numpy.flipud(dark_data) flat_data = numpy.flipud(flat_data) normalized_data = numpy.flipud(normalized_data) raw_data = (normalized_data + dark_data) * (flat_data + dark_data) # create raw data scans = [] slices = (slice(0, 80), slice(60, -1)) frame_y_positions = (40, 90) for i_nxtomo, (my_slice, frame_y_position) in enumerate(zip(slices, frame_y_positions)): my_raw_data = raw_data[:, my_slice] assert my_raw_data.ndim == 2 my_dark_data = dark_data[:, my_slice] assert my_dark_data.ndim == 2 my_flat_data = flat_data[:, my_slice] assert my_flat_data.ndim == 2 n_projs = 3 nx_tomo = NXtomo() nx_tomo.sample.x_translation = [0] * (n_projs + 2) nx_tomo.sample.y_translation = [frame_y_position] * (n_projs + 2) nx_tomo.sample.z_translation = [0] * (n_projs + 2) nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False) nx_tomo.instrument.detector.image_key_control = ( ImageKey.DARK_FIELD, ImageKey.FLAT_FIELD, ImageKey.PROJECTION, ImageKey.PROJECTION, ImageKey.PROJECTION, ) nx_tomo.instrument.detector.x_pixel_size = 1.0 nx_tomo.instrument.detector.y_pixel_size = 1.0 nx_tomo.instrument.detector.distance = 2.3 nx_tomo.energy = 19.2 nx_tomo.instrument.detector.data = numpy.stack( ( my_dark_data, my_flat_data, my_raw_data, my_raw_data, my_raw_data, ) ) file_path = os.path.join(output_dir, f"nxtomo_{i_nxtomo}.nx") entry = f"entry000{i_nxtomo}" nx_tomo.save(file_path=file_path, data_path=entry) scans.append(NXtomoScan(scan=file_path, entry=entry)) return scans, frame_y_positions, normalized_data @pytest.mark.parametrize("flip_lr", (True, False)) @pytest.mark.parametrize("flip_ud", (True, False)) @pytest.mark.parametrize("progress", (None, "with_tqdm")) def test_preprocessing_stitching(tmp_path, flip_lr, flip_ud, progress): if progress == "with_tqdm": progress = tqdm(total=100) nxtomo_dir = tmp_path / "nxtomos" nxtomo_dir.mkdir() output_dir = tmp_path / "output" output_dir.mkdir() output_file_path = os.path.join(output_dir, "nxtomo.nxs") nxtomos, _, normalized_data = build_nxtomos( output_dir=nxtomo_dir, flip_lr=flip_lr, flip_ud=flip_ud, ) configuration = PreProcessedYStitchingConfiguration( input_scans=nxtomos, axis_0_pos_px=None, axis_1_pos_px=None, axis_2_pos_px=None, output_file_path=output_file_path, output_data_path="stitched_volume", ) output_identifier = y_stitching( configuration=configuration, progress=progress, ) created_nx_tomo = NXtomo().load( file_path=output_identifier.file_path, data_path=output_identifier.data_path, detector_data_as="as_numpy_array", ) assert created_nx_tomo.instrument.detector.data.shape == ( 3, 64, 120, ) # 3 == number of projections, dark and flat will not be exported when doing the stitching # TODO: improve me: the relative tolerance is pretty high. This doesn't comes from the algorithm on itself # but more on the numerical calculation and the flat field normalization numpy.testing.assert_allclose(normalized_data, created_nx_tomo.instrument.detector.data[0], rtol=0.06) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/tests/test_z_postprocessing_stitching.py0000644000175000017500000006444314713343202025462 0ustar00pierrepierreimport os import h5py import numpy import pytest from tqdm import tqdm from silx.image.phantomgenerator import PhantomGenerator from tomoscan.esrf.volume import EDFVolume, HDF5Volume from tomoscan.esrf.volume.tiffvolume import TIFFVolume, has_tifffile from tomoscan.factory import Factory as TomoscanFactory from tomoscan.utils.volume import concatenate as concatenate_volumes from nabu.stitching.alignment import AlignmentAxis1, AlignmentAxis2 from nabu.stitching.config import NormalizationBySample, PostProcessedZStitchingConfiguration from nabu.stitching.overlap import OverlapStitchingStrategy from nabu.stitching.utils import ShiftAlgorithm from nabu.stitching.z_stitching import PostProcessZStitcher, PostProcessZStitcherNoDD strategies_to_test_weights = ( OverlapStitchingStrategy.CLOSEST, OverlapStitchingStrategy.COSINUS_WEIGHTS, OverlapStitchingStrategy.LINEAR_WEIGHTS, OverlapStitchingStrategy.MEAN, ) def build_raw_volume(): """util to create some raw volume""" raw_volume = numpy.stack( [ PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float32) * 256.0, PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float32) * 128.0, PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float32) * 32.0, PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float32) * 16.0, ] ) assert raw_volume.shape == (4, 120, 120) raw_volume = numpy.rollaxis(raw_volume, axis=1, start=0) assert raw_volume.shape == (120, 4, 120) return raw_volume _VOL_CLASSES_TO_TEST_FOR_POSTPROC_STITCHING = [HDF5Volume, EDFVolume] # avoid testing glymur because doesn't handle float # if has_minimal_openjpeg: # _VOL_CLASSES_TO_TEST_FOR_POSTPROC_STITCHING.append(JP2KVolume) if has_tifffile: _VOL_CLASSES_TO_TEST_FOR_POSTPROC_STITCHING.append(TIFFVolume) def build_volumes(output_dir: str, volume_class): # create some random data. raw_volume = build_raw_volume() # create a simple case where the volume have 10 voxel of overlap and a height (z) of 30 Voxels, 40 and 30 Voxels vol_1_constructor_params = { "data": raw_volume[0:30, :, :], "metadata": { "processing_options": { "reconstruction": { "position": (-15.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_2_constructor_params = { "data": raw_volume[20:80, :, :], "metadata": { "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_3_constructor_params = { "data": raw_volume[60:, :, :], "metadata": { "processing_options": { "reconstruction": { "position": (-90.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } volumes = [] axis_0_positions = [] for i_vol, vol_params in enumerate([vol_1_constructor_params, vol_2_constructor_params, vol_3_constructor_params]): if volume_class == HDF5Volume: vol_params.update( { "file_path": os.path.join(output_dir, f"raw_volume_{i_vol}.hdf5"), "data_path": "volume", } ) else: vol_params.update( { "folder": os.path.join(output_dir, f"raw_volume_{i_vol}"), } ) axis_0_positions.append(vol_params["metadata"]["processing_options"]["reconstruction"]["position"][0]) volume = volume_class(**vol_params) volume.save() volumes.append(volume) return volumes, axis_0_positions, raw_volume @pytest.mark.parametrize("progress", (None, "with_tqdm")) @pytest.mark.parametrize("volume_class", (_VOL_CLASSES_TO_TEST_FOR_POSTPROC_STITCHING)) def test_PostProcessZStitcher( tmp_path, volume_class, progress, ): """ test PreProcessZStitcher class and insure a full stitching can be done automatically. :param bool clear_input_volumes_data: if True save the volume then clear volume.data (used to check internal management of loading volumes - used to check behavior with HDF5) :param volume_class: class to be used (same class for input and output for now) :param axis_0_pos: position of the different TomoObj along axis 0 (Also know as z axis) """ if progress == "with_tqdm": progress = tqdm(total=100) # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() volumes, axis_0_positions, raw_volume = build_volumes(output_dir=raw_data_dir, volume_class=volume_class) volume_1, volume_2, volume_3 = volumes output_volume = HDF5Volume( file_path=os.path.join(output_dir, "stitched_volume.hdf5"), data_path="stitched_volume", ) z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overwrite_results=True, input_volumes=(volume_1, volume_2, volume_3), output_volume=output_volume, slices=None, slurm_config=None, axis_0_pos_px=axis_0_positions, axis_0_pos_mm=None, axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, slice_for_cross_correlation="middle", voxel_size=None, ) stitcher = PostProcessZStitcher(z_stich_config, progress=progress) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_volume.file_path assert output_identifier.data_path == output_volume.data_path output_volume.data = None output_volume.metadata = None output_volume.load_data(store=True) output_volume.load_metadata(store=True) assert raw_volume.shape == output_volume.data.shape numpy.testing.assert_array_almost_equal(raw_volume, output_volume.data) metadata = output_volume.metadata assert "about" in metadata assert "configuration" in metadata assert output_volume.position[0] == -60.0 assert output_volume.pixel_size == (1.0, 1.0, 1.0) slices_to_test_post = ( { "slices": (None,), "complete": True, }, { "slices": (("first",), ("middle",), ("last",)), "complete": False, }, { "slices": ((0, 1, 2), slice(3, -1, 1)), "complete": True, }, ) @pytest.mark.parametrize("flip_ud", (True, False)) @pytest.mark.parametrize("configuration_dist", slices_to_test_post) def test_DistributePostProcessZStitcher(tmp_path, configuration_dist, flip_ud): # create some random data. slices = configuration_dist["slices"] complete = configuration_dist["complete"] raw_volume = numpy.ones((80, 40, 120), dtype=numpy.float16) raw_volume[:, 0, :] = ( PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float16)[30:110, :] * 80 * 40 * 120 ) raw_volume[:, 8, :] = ( PhantomGenerator.get2DPhantomSheppLogan(n=120).astype(numpy.float16)[30:110, :] * 80 * 40 * 120 + 2 ) raw_volume[12] = 1.0 raw_volume[:, 23] = 1.2 # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() def flip_input_data(data): if flip_ud: data = numpy.flipud(data) return data volume_1 = HDF5Volume( file_path=os.path.join(raw_data_dir, "volume_1.hdf5"), data_path="volume", data=flip_input_data(raw_volume[-60:, :, :]), metadata={ "processing_options": { "reconstruction": { "position": (-30.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, ) volume_1.save() volume_2 = HDF5Volume( file_path=os.path.join(raw_data_dir, "volume_2.hdf5"), data_path="volume", data=flip_input_data(raw_volume[:60, :, :]), metadata={ "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, ) volume_2.save() reconstructed_sub_volumes = [] for i_slice, s in enumerate(slices): output_volume = HDF5Volume( file_path=os.path.join(output_dir, f"stitched_subvolume_{i_slice}.hdf5"), data_path="stitched_volume", ) volumes = (volume_2, volume_1) z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, axis_0_pos_px=tuple( volume.metadata["processing_options"]["reconstruction"]["position"][0] for volume in volumes ), axis_0_pos_mm=None, axis_0_params={}, axis_1_pos_px=(0, 0), axis_1_pos_mm=None, axis_1_params={}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={}, overwrite_results=True, input_volumes=volumes, output_volume=output_volume, slices=s, slurm_config=None, slice_for_cross_correlation="middle", voxel_size=None, flip_ud=flip_ud, ) stitcher = PostProcessZStitcher(z_stich_config) vol_id = stitcher.stitch() reconstructed_sub_volumes.append(TomoscanFactory.create_tomo_object_from_identifier(identifier=vol_id)) final_vol = HDF5Volume( file_path=os.path.join(output_dir, "final_volume"), data_path="volume", ) if complete: concatenate_volumes(output_volume=final_vol, volumes=tuple(reconstructed_sub_volumes), axis=1) final_vol.load_data(store=True) numpy.testing.assert_almost_equal( raw_volume, final_vol.data, ) @pytest.mark.parametrize("alignment_axis_2", ("left", "right", "center")) def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2): """ test z volume stitching with different width (and so that requires image alignment over axis 2) """ # create some random data. raw_volume = build_raw_volume() # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() # create a simple case where the volume have 10 voxel of overlap and a height (z) of 30 Voxels, 40 and 30 Voxels vol_1_constructor_params = { "data": raw_volume[0:30, :, 4:-4], "metadata": { "processing_options": { "reconstruction": { "position": (-15.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_2_constructor_params = { "data": raw_volume[20:80, :, :], "metadata": { "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_3_constructor_params = { "data": raw_volume[60:, :, 10:-10], "metadata": { "processing_options": { "reconstruction": { "position": (-90.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } raw_volumes = [] axis_0_positions = [] for i_vol, vol_params in enumerate([vol_1_constructor_params, vol_2_constructor_params, vol_3_constructor_params]): vol_params.update( { "file_path": os.path.join(raw_data_dir, f"raw_volume_{i_vol}.hdf5"), "data_path": "volume", } ) axis_0_positions.append(vol_params["metadata"]["processing_options"]["reconstruction"]["position"][0]) volume = HDF5Volume(**vol_params) volume.save() raw_volumes.append(volume) volume_1, volume_2, volume_3 = raw_volumes output_volume = HDF5Volume( file_path=os.path.join(output_dir, "stitched_volume.hdf5"), data_path="stitched_volume", ) z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overwrite_results=True, input_volumes=(volume_1, volume_2, volume_3), output_volume=output_volume, slices=None, slurm_config=None, axis_0_pos_px=axis_0_positions, axis_0_pos_mm=None, axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, slice_for_cross_correlation="middle", voxel_size=None, alignment_axis_2=AlignmentAxis2.from_value(alignment_axis_2), ) stitcher = PostProcessZStitcher(z_stich_config, progress=None) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_volume.file_path assert output_identifier.data_path == output_volume.data_path output_volume.load_data(store=True) output_volume.load_metadata(store=True) assert output_volume.data.shape == (120, 4, 120) if alignment_axis_2 == "center": numpy.testing.assert_array_almost_equal(raw_volume[:, :, 10:-10], output_volume.data[:, :, 10:-10]) elif alignment_axis_2 == "left": numpy.testing.assert_array_almost_equal(raw_volume[:, :, :-20], output_volume.data[:, :, :-20]) elif alignment_axis_2 == "right": numpy.testing.assert_array_almost_equal(raw_volume[:, :, 20:], output_volume.data[:, :, 20:]) @pytest.mark.parametrize("alignment_axis_1", ("front", "center", "back")) def test_vol_z_stitching_with_alignment_axis_1(tmp_path, alignment_axis_1): """ test z volume stitching with different number of frames (and so that requires image alignment over axis 0) """ # create some random data. raw_volume = build_raw_volume() # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() # create a simple case where the volume have 10 voxel of overlap and a height (z) of 30 Voxels, 40 and 30 Voxels vol_1_constructor_params = { "data": raw_volume[ 0:30, 1:3, ], "metadata": { "processing_options": { "reconstruction": { "position": (-15.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_2_constructor_params = { "data": raw_volume[20:80, :, :], "metadata": { "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_3_constructor_params = { "data": raw_volume[ 60:, 1:3, ], "metadata": { "processing_options": { "reconstruction": { "position": (-90.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } raw_volumes = [] axis_0_positions = [] for i_vol, vol_params in enumerate([vol_1_constructor_params, vol_2_constructor_params, vol_3_constructor_params]): vol_params.update( { "file_path": os.path.join(raw_data_dir, f"raw_volume_{i_vol}.hdf5"), "data_path": "volume", } ) axis_0_positions.append(vol_params["metadata"]["processing_options"]["reconstruction"]["position"][0]) volume = HDF5Volume(**vol_params) volume.save() raw_volumes.append(volume) volume_1, volume_2, volume_3 = raw_volumes output_volume = HDF5Volume( file_path=os.path.join(output_dir, "stitched_volume.hdf5"), data_path="stitched_volume", ) z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overwrite_results=True, input_volumes=(volume_1, volume_2, volume_3), output_volume=output_volume, slices=None, slurm_config=None, axis_0_pos_px=axis_0_positions, axis_0_pos_mm=None, axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, slice_for_cross_correlation="middle", voxel_size=None, alignment_axis_1=AlignmentAxis1.from_value(alignment_axis_1), ) stitcher = PostProcessZStitcher(z_stich_config, progress=None) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_volume.file_path assert output_identifier.data_path == output_volume.data_path output_volume.load_data(store=True) output_volume.load_metadata(store=True) assert output_volume.data.shape == (120, 4, 120) if alignment_axis_1 == "middle": numpy.testing.assert_array_almost_equal(raw_volume[:, 10:-10, :], output_volume.data[:, 10:-10, :]) elif alignment_axis_1 == "front": numpy.testing.assert_array_almost_equal(raw_volume[:, :-20, :], output_volume.data[:, :-20, :]) elif alignment_axis_1 == "middle": numpy.testing.assert_array_almost_equal(raw_volume[:, 20:, :], output_volume.data[:, 20:, :]) def test_normalization_by_sample(tmp_path): """ simple test of a volume stitching. Raw volumes have 'extra' values (+2, +5, +9) that must be removed at the end thanks to the normalization """ raw_volume = build_raw_volume() # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() # create a simple case where the volume have 10 voxel of overlap and a height (z) of 30 Voxels, 40 and 30 Voxels vol_1_constructor_params = { "data": raw_volume[0:30, :, :] + 3, "metadata": { "processing_options": { "reconstruction": { "position": (-15.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_2_constructor_params = { "data": raw_volume[20:80, :, :] + 5, "metadata": { "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } vol_3_constructor_params = { "data": raw_volume[60:, :, :] + 12, "metadata": { "processing_options": { "reconstruction": { "position": (-90.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, } raw_volumes = [] axis_0_positions = [] for i_vol, vol_params in enumerate([vol_1_constructor_params, vol_2_constructor_params, vol_3_constructor_params]): vol_params.update( { "file_path": os.path.join(raw_data_dir, f"raw_volume_{i_vol}.hdf5"), "data_path": "volume", } ) axis_0_positions.append(vol_params["metadata"]["processing_options"]["reconstruction"]["position"][0]) volume = HDF5Volume(**vol_params) volume.save() raw_volumes.append(volume) volume_1, volume_2, volume_3 = raw_volumes output_volume = HDF5Volume( file_path=os.path.join(output_dir, "stitched_volume.hdf5"), data_path="stitched_volume", ) normalization_by_sample = NormalizationBySample() normalization_by_sample.set_is_active(True) normalization_by_sample.width = 1 normalization_by_sample.margin = 0 normalization_by_sample.side = "left" normalization_by_sample.method = "median" z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.CLOSEST, overwrite_results=True, input_volumes=(volume_1, volume_2, volume_3), output_volume=output_volume, slices=None, slurm_config=None, axis_0_pos_px=axis_0_positions, axis_0_pos_mm=None, axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, slice_for_cross_correlation="middle", voxel_size=None, normalization_by_sample=normalization_by_sample, ) stitcher = PostProcessZStitcher(z_stich_config, progress=None) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_volume.file_path assert output_identifier.data_path == output_volume.data_path output_volume.data = None output_volume.metadata = None output_volume.load_data(store=True) output_volume.load_metadata(store=True) assert raw_volume.shape == output_volume.data.shape numpy.testing.assert_array_almost_equal(raw_volume, output_volume.data) metadata = output_volume.metadata assert "configuration" in metadata assert "about" in metadata assert metadata["about"]["program"] == "nabu-stitching" assert output_volume.position[0] == -60.0 assert output_volume.pixel_size == (1.0, 1.0, 1.0) @pytest.mark.parametrize("data_duplication", (True, False)) def test_data_duplication(tmp_path, data_duplication): """ Test that the post-processing stitching can be done without duplicating data. And also making sure avoid data duplication can handle frame flips """ raw_volume = build_raw_volume() # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() volume_1 = HDF5Volume( data=raw_volume[0:30], metadata={ "processing_options": { "reconstruction": { "position": (-15.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, file_path=os.path.join(raw_data_dir, f"raw_volume_1.hdf5"), data_path="volume", ) volume_2 = HDF5Volume( data=raw_volume[20:80], metadata={ "processing_options": { "reconstruction": { "position": (-50.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, file_path=os.path.join(raw_data_dir, f"raw_volume_2.hdf5"), data_path="volume", ) volume_3 = HDF5Volume( data=raw_volume[60:], metadata={ "processing_options": { "reconstruction": { "position": (-90.0, 0.0, 0.0), "voxel_size_cm": (100.0, 100.0, 100.0), } }, }, file_path=os.path.join(raw_data_dir, f"raw_volume_3.hdf5"), data_path="volume", ) for volume in (volume_1, volume_2, volume_3): volume.save() volume.clear_cache() output_volume = HDF5Volume( file_path=os.path.join(output_dir, "stitched_volume.hdf5"), data_path="stitched_volume", ) z_stich_config = PostProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.CLOSEST, overwrite_results=True, input_volumes=(volume_1, volume_2, volume_3), output_volume=output_volume, slices=None, slurm_config=None, axis_0_pos_px=None, axis_0_pos_mm=None, axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_1_pos_px=None, axis_1_pos_mm=None, axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, axis_2_pos_px=None, axis_2_pos_mm=None, axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, slice_for_cross_correlation="middle", voxel_size=None, duplicate_data=data_duplication, ) if data_duplication: stitcher = PostProcessZStitcher(z_stich_config, progress=None) else: stitcher = PostProcessZStitcherNoDD(z_stich_config, progress=None) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_volume.file_path assert output_identifier.data_path == output_volume.data_path output_volume.data = None output_volume.metadata = None output_volume.load_data(store=True) output_volume.load_metadata(store=True) assert raw_volume.shape == output_volume.data.shape numpy.testing.assert_almost_equal(raw_volume.data, output_volume.data) with h5py.File(output_volume.file_path, mode="r") as h5f: if data_duplication: assert f"{output_volume.data_path}/stitching_regions" not in h5f assert not h5f[f"{output_volume.data_path}/results/data"].is_virtual else: assert f"{output_volume.data_path}/stitching_regions" in h5f assert h5f[f"{output_volume.data_path}/results/data"].is_virtual if not data_duplication: # make sure an error is raised if we try to ask for no data duplication and if we get some flips z_stich_config.flip_ud = (False, True, False) with pytest.raises(ValueError): stitcher = PostProcessZStitcherNoDD(z_stich_config, progress=None) stitcher.stitch() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/tests/test_z_preprocessing_stitching.py0000644000175000017500000003743514713343202025264 0ustar00pierrepierreimport os from silx.image.phantomgenerator import PhantomGenerator from scipy.ndimage import shift as scipy_shift import numpy import pytest from nabu.stitching.config import PreProcessedZStitchingConfiguration from nabu.stitching.config import KEY_IMG_REG_METHOD from nabu.stitching.overlap import ImageStichOverlapKernel, OverlapStitchingStrategy from nabu.stitching.z_stitching import ( PreProcessZStitcher, ) from nabu.stitching.stitcher_2D import stitch_raw_frames, get_overlap_areas from nxtomo.nxobject.nxdetector import ImageKey from nxtomo.utils.transformation import DetYFlipTransformation, DetZFlipTransformation from nxtomo.application.nxtomo import NXtomo from tomoscan.esrf.scan.nxtomoscan import NXtomoScan from nabu.stitching.utils import ShiftAlgorithm import h5py _stitching_configurations = ( # simple case where shifts are provided { "n_proj": 4, "raw_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to "input_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), }, # simple case where shift is found from z position { "n_proj": 4, "raw_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), "input_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), "check_bb": ((40, 140), (-50, 50), (-140, -40)), "axis_0_params": { KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, }, "axis_2_params": { KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, }, "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), }, ) @pytest.mark.parametrize("configuration", _stitching_configurations) @pytest.mark.parametrize("dtype", (numpy.float32, numpy.int16)) def test_PreProcessZStitcher(tmp_path, dtype, configuration): """ test PreProcessZStitcher class and insure a full stitching can be done automatically. """ n_proj = configuration["n_proj"] ref_frame_width = 280 raw_frame_height = 100 ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) * 256.0 # add some mark for image registration ref_frame[:, 96] = -3.2 ref_frame[:, 125] = 9.1 ref_frame[:, 165] = 4.4 ref_frame[:, 200] = -2.5 # create raw data frame_0_shift, frame_1_shift, frame_2_shift = configuration["raw_shifts"] frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_height] frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_height] frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_height] frames = frame_0, frame_1, frame_2 frame_0_input_pos, frame_1_input_pos, frame_2_input_pos = configuration["input_pos"] frame_0_raw_pos, frame_1_raw_pos, frame_2_raw_pos = configuration["raw_pos"] # create a Nxtomo for each of those raw data raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() z_position = ( frame_0_raw_pos[0], frame_1_raw_pos[0], frame_2_raw_pos[0], ) scans = [] for (i_frame, frame), z_pos in zip(enumerate(frames), z_position): nx_tomo = NXtomo() nx_tomo.sample.z_translation = [z_pos] * n_proj nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj nx_tomo.instrument.detector.x_pixel_size = 1.0 nx_tomo.instrument.detector.y_pixel_size = 1.0 nx_tomo.instrument.detector.distance = 2.3 nx_tomo.energy = 19.2 nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") entry = f"entry000{i_frame}" nx_tomo.save(file_path=file_path, data_path=entry) scans.append(NXtomoScan(scan=file_path, entry=entry)) # if requested: check bounding box check_bb = configuration.get("check_bb", None) if check_bb is not None: for scan, expected_bb in zip(scans, check_bb): assert scan.get_bounding_box(axis="z") == expected_bb output_file_path = os.path.join(output_dir, "stitched.nx") output_data_path = "stitched" z_stich_config = PreProcessedZStitchingConfiguration( stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overwrite_results=True, axis_0_pos_px=( frame_0_input_pos[0], frame_1_input_pos[0], frame_2_input_pos[0], ), axis_1_pos_px=( frame_0_input_pos[1], frame_1_input_pos[1], frame_2_input_pos[1], ), axis_2_pos_px=( frame_0_input_pos[2], frame_1_input_pos[2], frame_2_input_pos[2], ), axis_0_pos_mm=None, axis_1_pos_mm=None, axis_2_pos_mm=None, input_scans=scans, output_file_path=output_file_path, output_data_path=output_data_path, axis_0_params=configuration.get("axis_0_params", {}), axis_1_params=configuration.get("axis_1_params", {}), axis_2_params=configuration.get("axis_2_params", {}), output_nexus_version=None, slices=None, slurm_config=None, slice_for_cross_correlation="middle", pixel_size=None, ) stitcher = PreProcessZStitcher(z_stich_config) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_file_path assert output_identifier.data_path == output_data_path created_nx_tomo = NXtomo().load( file_path=output_identifier.file_path, data_path=output_identifier.data_path, detector_data_as="as_numpy_array", ) assert created_nx_tomo.instrument.detector.data.ndim == 3 mean_abs_error = configuration.get("mean_abs_error", None) if mean_abs_error is not None: assert ( numpy.mean(numpy.abs(ref_frame - created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :])) < mean_abs_error ) else: numpy.testing.assert_array_almost_equal( ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :] ) # check also other metadata are here assert created_nx_tomo.instrument.detector.distance.value == 2.3 assert created_nx_tomo.energy.value == 19.2 numpy.testing.assert_array_equal( created_nx_tomo.instrument.detector.image_key_control, numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj), ) # check configuration has been saved with h5py.File(output_identifier.file_path, mode="r") as h5f: assert "stitching_configuration" in h5f[output_identifier.data_path] slices_to_test_pre = ( { "slices": (None,), "complete": True, }, { "slices": (("first",), ("middle",), ("last",)), "complete": False, }, { "slices": ((0, 1, 2), slice(3, -1, 1)), "complete": True, }, ) def build_nxtomos(output_dir) -> tuple: r""" build two nxtomos in output_dir and return the list of NXtomos ready to be stitched /\ | ______________ | | nxtomo 1 | Z | | frame | | |~~~~~~~~~~~~~~| | |~~~~~~~~~~~~~~| | |______________| | ______________ | |~~~~~~~~~~~~~~| | |~~~~~~~~~~~~~~| | | nxtomo 2 | | | frame | | |______________| | <----------------------------------------------- y (in acquisition space) * ~: represent the overlap area """ n_projs = 100 raw_data = numpy.arange(100 * 128 * 128).reshape((100, 128, 128)) # create raw data frame_0 = raw_data[:, 60:] assert frame_0.ndim == 3 frame_0_pos = 40 frame_1 = raw_data[:, 0:80] assert frame_1.ndim == 3 frame_1_pos = 94 frames = (frame_0, frame_1) z_positions = (frame_0_pos, frame_1_pos) # create a Nxtomo for each of those raw data scans = [] for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions): nx_tomo = NXtomo() nx_tomo.sample.z_translation = [z_pos] * n_projs nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs nx_tomo.instrument.detector.x_pixel_size = 1.0 nx_tomo.instrument.detector.y_pixel_size = 1.0 nx_tomo.instrument.detector.distance = 2.3 nx_tomo.energy = 19.2 nx_tomo.instrument.detector.data = frame file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx") entry = f"entry000{i_frame}" nx_tomo.save(file_path=file_path, data_path=entry) scans.append(NXtomoScan(scan=file_path, entry=entry)) return scans, z_positions, raw_data @pytest.mark.parametrize("configuration_dist", slices_to_test_pre) def test_DistributePreProcessZStitcher(tmp_path, configuration_dist): slices = configuration_dist["slices"] complete = configuration_dist["complete"] raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() scans, z_positions, raw_data = build_nxtomos(output_dir=raw_data_dir) stitched_nx_tomo = [] for s in slices: output_file_path = os.path.join(output_dir, "stitched_section.nx") output_data_path = f"stitched_{s}" z_stich_config = PreProcessedZStitchingConfiguration( axis_0_pos_px=z_positions, axis_1_pos_px=(0, 0), axis_2_pos_px=None, axis_0_pos_mm=None, axis_1_pos_mm=None, axis_2_pos_mm=None, axis_0_params={}, axis_1_params={}, axis_2_params={}, stitching_strategy=OverlapStitchingStrategy.CLOSEST, overwrite_results=True, input_scans=scans, output_file_path=output_file_path, output_data_path=output_data_path, output_nexus_version=None, slices=s, slurm_config=None, slice_for_cross_correlation="middle", pixel_size=None, ) stitcher = PreProcessZStitcher(z_stich_config) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_file_path assert output_identifier.data_path == output_data_path created_nx_tomo = NXtomo().load( file_path=output_identifier.file_path, data_path=output_identifier.data_path, detector_data_as="as_numpy_array", ) stitched_nx_tomo.append(created_nx_tomo) assert len(stitched_nx_tomo) == len(slices) final_nx_tomo = NXtomo.concatenate(stitched_nx_tomo) assert isinstance(final_nx_tomo.instrument.detector.data, numpy.ndarray) final_nx_tomo.save( file_path=os.path.join(output_dir, "final_stitched.nx"), data_path="entry0000", ) if complete: len(final_nx_tomo.instrument.detector.data) == 128 # test middle numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :]) else: len(final_nx_tomo.instrument.detector.data) == 3 # test middle numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :]) # in the case of first, middle and last frames # test first numpy.testing.assert_array_almost_equal(raw_data[0], final_nx_tomo.instrument.detector.data[0, :, :]) # test last numpy.testing.assert_array_almost_equal(raw_data[-1], final_nx_tomo.instrument.detector.data[-1, :, :]) def test_get_overlap_areas(): """test get_overlap_areas function""" f_upper = numpy.linspace(7, 15, num=9, endpoint=True) f_lower = numpy.linspace(0, 12, num=13, endpoint=True) o_1, o_2 = get_overlap_areas( upper_frame=f_upper, lower_frame=f_lower, upper_frame_key_line=3, lower_frame_key_line=10, overlap_size=4, stitching_axis=0, ) numpy.testing.assert_array_equal(o_1, o_2) numpy.testing.assert_array_equal(o_1, numpy.linspace(8, 11, num=4, endpoint=True)) def test_frame_flip(tmp_path): """check it with some NXtomo flipped""" ref_frame_width = 280 n_proj = 10 raw_frame_width = 100 ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) * 256.0 # create raw data frame_0_shift = (0, 0) frame_1_shift = (-90, 0) frame_2_shift = (-180, 0) frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_width] frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_width] frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_width] frames = frame_0, frame_1, frame_2 x_flips = [False, True, True] y_flips = [False, False, True] def apply_flip(args): frame, flip_x, flip_y = args if flip_x: frame = numpy.fliplr(frame) if flip_y: frame = numpy.flipud(frame) return frame frames = map(apply_flip, zip(frames, x_flips, y_flips)) # create a Nxtomo for each of those raw data raw_data_dir = tmp_path / "raw_data" raw_data_dir.mkdir() output_dir = tmp_path / "output_dir" output_dir.mkdir() z_position = (90, 0, -90) scans = [] for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips): nx_tomo = NXtomo() nx_tomo.sample.z_translation = [z_pos] * n_proj nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj nx_tomo.instrument.detector.x_pixel_size = 1.0 nx_tomo.instrument.detector.y_pixel_size = 1.0 nx_tomo.instrument.detector.distance = 2.3 nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip)) nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip)) nx_tomo.energy = 19.2 nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") entry = f"entry000{i_frame}" nx_tomo.save(file_path=file_path, data_path=entry) scans.append(NXtomoScan(scan=file_path, entry=entry)) output_file_path = os.path.join(output_dir, "stitched.nx") output_data_path = "stitched" assert len(scans) == 3 z_stich_config = PreProcessedZStitchingConfiguration( axis_0_pos_px=(0, -90, -180), axis_1_pos_px=(0, 0, 0), axis_2_pos_px=None, axis_0_pos_mm=None, axis_1_pos_mm=None, axis_2_pos_mm=None, axis_0_params={}, axis_1_params={}, axis_2_params={}, stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, overwrite_results=True, input_scans=scans, output_file_path=output_file_path, output_data_path=output_data_path, output_nexus_version=None, slices=None, slurm_config=None, slice_for_cross_correlation="middle", pixel_size=None, ) stitcher = PreProcessZStitcher(z_stich_config) output_identifier = stitcher.stitch() assert output_identifier.file_path == output_file_path assert output_identifier.data_path == output_data_path created_nx_tomo = NXtomo().load( file_path=output_identifier.file_path, data_path=output_identifier.data_path, detector_data_as="as_numpy_array", ) assert created_nx_tomo.instrument.detector.data.ndim == 3 # insure flipping has been taking into account numpy.testing.assert_array_almost_equal(ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :]) assert len(created_nx_tomo.instrument.detector.transformations) == 0 ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/utils/0000755000175000017500000000000014730277752016364 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1722846850.0 nabu-2024.2.1/nabu/stitching/utils/__init__.py0000644000175000017500000000002514654107202020455 0ustar00pierrepierrefrom .utils import * ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/utils/post_processing.py0000644000175000017500000002604714713343202022151 0ustar00pierrepierreimport os import logging from typing import Optional, Union from nabu import version as nabu_version from nabu.stitching.config import ( PreProcessedSingleAxisStitchingConfiguration, PostProcessedSingleAxisStitchingConfiguration, SingleAxisStitchingConfiguration, ) from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC from nabu.io.writer import get_datetime from tomoscan.factory import Factory as TomoscanFactory from silx.io.dictdump import dicttonx from nxtomo.application.nxtomo import NXtomo from tomoscan.utils.volume import concatenate as concatenate_volumes from tomoscan.esrf.volume import HDF5Volume from contextlib import AbstractContextManager from threading import Thread from time import sleep _logger = logging.getLogger(__name__) class StitchingPostProcAggregation: """ for remote stitching each process will stitch a part of the volume or projections. Then once all are finished we want to aggregate them all to a final volume or NXtomo. This is the goal of this class. Please be careful with API. This is already inheriting from a tomwer class :param stitching_config: configuration of the stitching configuration :param futures: futures that just run :param existing_objs: futures that just run :param progress_bars: tqdm progress bars for each jobs """ def __init__( self, stitching_config: SingleAxisStitchingConfiguration, futures: Optional[tuple] = None, existing_objs_ids: Optional[tuple] = None, progress_bars: Optional[dict] = None, ) -> None: if not isinstance(stitching_config, (SingleAxisStitchingConfiguration)): raise TypeError(f"stitching_config should be an instance of {SingleAxisStitchingConfiguration}") if not ((existing_objs_ids is None) ^ (futures is None)): raise ValueError("Either existing_objs or futures should be provided (can't provide both)") if progress_bars is not None and not isinstance(progress_bars, dict): raise TypeError(f"'progress_bars' should be None or an instance of a dict. Got {type(progress_bars)}") self._futures = futures self._stitching_config = stitching_config self._existing_objs_ids = existing_objs_ids self._progress_bars = progress_bars or {} @property def futures(self): return self._futures @property def progress_bars(self) -> dict: return self._progress_bars def retrieve_tomo_objects(self) -> tuple: """ Return tomo objects to be stitched together. Either from future or from existing_objs """ if self._existing_objs_ids is not None: scan_ids = self._existing_objs_ids else: results = {} _logger.info( f"wait for slurm job to be completed. Advancement will be created once slurm job output file will be available" ) for obj_id, future in self.futures.items(): results[obj_id] = future.result() failed = tuple( filter( lambda x: x.exception() is not None, self.futures.values(), ) ) if len(failed) > 0: # if some job failed: useless to do the concatenation exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed]) raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}") canceled = tuple( filter( lambda x: x.cancelled(), self.futures.values(), ) ) if len(canceled) > 0: # if some job canceled: useless to do the concatenation raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}") scan_ids = results.keys() return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids] def dump_stitching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str): dict_to_dump = { process_name: { "config": self._stitching_config.to_dict(), "program": "nabu-stitching", "version": nabu_version, "date": get_datetime(), }, f"{process_name}@NX_class": "NXprocess", } dicttonx( dict_to_dump, h5file=file_path, h5path=data_path, update_mode="replace" if overwrite else "add", mode="a", ) @property def stitching_config(self) -> SingleAxisStitchingConfiguration: return self._stitching_config def process(self) -> None: """ main function """ # concatenate result _logger.info("all job succeeded. Concatenate results") if isinstance(self._stitching_config, PreProcessedSingleAxisStitchingConfiguration): # 1: case of a pre-processing stitching with self.follow_progress(): scans = self.retrieve_tomo_objects() nx_tomos = [] for scan in scans: if not os.path.exists(scan.master_file): raise RuntimeError( f"output file not created ({scan.master_file}). Stitching failed. " "Please check slurm .out files to have more information. Most likely the slurm configuration is invalid. " "(partition name not existing...)" ) nx_tomos.append( NXtomo().load( file_path=scan.master_file, data_path=scan.entry, ) ) final_nx_tomo = NXtomo.concatenate(nx_tomos) final_nx_tomo.save( file_path=self.stitching_config.output_file_path, data_path=self.stitching_config.output_data_path, overwrite=self.stitching_config.overwrite_results, ) # dump NXprocess if possible parts = self.stitching_config.output_data_path.split("/") process_name = parts[-1] + "_stitching" if len(parts) < 2: data_path = "/" else: data_path = "/".join(parts[:-1]) self.dump_stitching_config_as_nx_process( file_path=self.stitching_config.output_file_path, data_path=data_path, process_name=process_name, overwrite=self.stitching_config.overwrite_results, ) elif isinstance(self.stitching_config, PostProcessedSingleAxisStitchingConfiguration): # 2: case of a post-processing stitching with self.follow_progress(): outputs_sub_volumes = self.retrieve_tomo_objects() concatenate_volumes( output_volume=self.stitching_config.output_volume, volumes=tuple(outputs_sub_volumes), axis=1, ) if isinstance(self.stitching_config.output_volume, HDF5Volume): parts = self.stitching_config.output_volume.metadata_url.data_path().split("/") process_name = parts[-1] + "_stitching" if len(parts) < 2: data_path = "/" else: data_path = "/".join(parts[:-1]) self.dump_stitching_config_as_nx_process( file_path=self.stitching_config.output_volume.metadata_url.file_path(), data_path=data_path, process_name=process_name, overwrite=self.stitching_config.overwrite_results, ) else: raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled") def follow_progress(self) -> AbstractContextManager: return SlurmStitchingFollowerContext( output_files_to_progress_bars={ job._get_output_file_path(): progress_bar for (job, progress_bar) in self.progress_bars.items() } ) class SlurmStitchingFollowerContext(AbstractContextManager): """Util class to provide user feedback from stitching done on slurm""" def __init__(self, output_files_to_progress_bars: dict): self._update_thread = SlurmStitchingFollowerThread(file_to_progress_bar=output_files_to_progress_bars) def __enter__(self) -> None: self._update_thread.start() def __exit__(self, *args, **kwargs): self._update_thread.join(timeout=1.5) for progress_bar in self._update_thread.file_to_progress_bar.values(): progress_bar.close() # close to clean display as leave == False class SlurmStitchingFollowerThread(Thread): """ Thread to check progression of stitching slurm job(s) Read slurm jobs .out file each 'delay time' and look for a tqdm line at the end. If it exists then deduce progress from it. file_to_progress_bar provide for each slurm .out file the progress bar to update """ def __init__(self, file_to_progress_bar: dict, delay_time: float = 0.5) -> None: super().__init__() self._stop_run = False self._wait_time = delay_time self._file_to_progress_bar = file_to_progress_bar self._first_run = True @property def file_to_progress_bar(self) -> dict: return self._file_to_progress_bar def run(self) -> None: while not self._stop_run: for file_path, progress_bar in self._file_to_progress_bar.items(): if self._first_run: # make sure each progress bar have been refreshed at least one progress_bar.refresh() if not os.path.exists(file_path): continue with open(file_path, "r") as f: try: last_line = f.readlines()[-1] except IndexError: continue advancement = self.cast_progress_line_from_log(line=last_line) if advancement is not None: progress_bar.n = advancement progress_bar.refresh() self._first_run = False sleep(self._wait_time) def join(self, timeout: Union[float, None] = None) -> None: self._stop_run = True return super().join(timeout) @staticmethod def cast_progress_line_from_log(line: str) -> Optional[float]: """Try to retrieve from a line from log the advancement (in percentage)""" if PROGRESS_BAR_STITCH_VOL_DESC not in line or "%" not in line: return None str_before_percentage = line.split("%")[0].split(" ")[-1] try: advancement = float(str_before_percentage) except ValueError: _logger.debug(f"Failed to retrieve advancement from log file. Value got is {str_before_percentage}") return None else: return advancement ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/stitching/utils/tests/0000755000175000017500000000000014730277752017526 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/stitching/utils/tests/test_post-processing.py0000644000175000017500000000136114712705065024270 0ustar00pierrepierreimport pytest from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC from nabu.stitching.utils.post_processing import SlurmStitchingFollowerThread @pytest.mark.parametrize( "test_case", { "dump configuration: 100%|": None, f"stitching : 100%|": None, f"{PROGRESS_BAR_STITCH_VOL_DESC}: 42%": 42.0, f"{PROGRESS_BAR_STITCH_VOL_DESC}: 56% toto: 23%": 56.0, "": None, "my%": None, }.items(), ) def test_SlurmStitchingFollowerContext(test_case): """Test that the conversion from log lines created by tqdm can be read back""" str_to_test, expected_result = test_case assert SlurmStitchingFollowerThread.cast_progress_line_from_log(str_to_test) == expected_result ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/utils/utils.py0000644000175000017500000006011214713343202020057 0ustar00pierrepierrefrom distutils.version import StrictVersion from typing import Optional, Union import logging import functools import numpy from tomoscan.scanbase import TomoScanBase from tomoscan.volumebase import VolumeBase from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation from silx.utils.enum import Enum as _Enum from scipy.fft import rfftn as local_fftn from scipy.fft import irfftn as local_ifftn from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData from ...misc import fourier_filters from ...estimation.alignment import AlignmentBase from ...resources.dataset_analyzer import HDF5DatasetAnalyzer from ...resources.nxflatfield import update_dataset_info_flats_darks try: import itk except ImportError: has_itk = False else: has_itk = True _logger = logging.getLogger(__name__) try: from skimage.registration import phase_cross_correlation except ImportError: _logger.warning( "Unable to load skimage. Please install it if you want to use it for finding shifts from `find_relative_shifts`" ) __has_sk_phase_correlation__ = False else: __has_sk_phase_correlation__ = True class ShiftAlgorithm(_Enum): """All generic shift search algorithm""" NABU_FFT = "nabu-fft" SKIMAGE = "skimage" ITK_IMG_REG_V4 = "itk-img-reg-v4" NONE = "None" # In the case of shift search on radio along axis 2 (or axis x in image space) we can benefit from the existing # nabu algorithm such as growing-window or sliding-window CENTERED = "centered" GLOBAL = "global" SLIDING_WINDOW = "sliding-window" GROWING_WINDOW = "growing-window" SINO_COARSE_TO_FINE = "sino-coarse-to-fine" COMPOSITE_COARSE_TO_FINE = "composite-coarse-to-fine" @classmethod def from_value(cls, value): if value in ("", None): return ShiftAlgorithm.NONE else: return super().from_value(value=value) def find_frame_relative_shifts( overlap_upper_frame: numpy.ndarray, overlap_lower_frame: numpy.ndarray, estimated_shifts: tuple, overlap_axis: int, x_cross_correlation_function=None, y_cross_correlation_function=None, x_shifts_params: Optional[dict] = None, y_shifts_params: Optional[dict] = None, ): """ :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x """ if not overlap_axis in (0, 1): raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}") from nabu.stitching.config import ( KEY_LOW_PASS_FILTER, KEY_HIGH_PASS_FILTER, ) # avoid cyclic import x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function) y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function) if x_shifts_params is None: x_shifts_params = {} if y_shifts_params is None: y_shifts_params = {} # apply filtering if any def _str_to_int(value): if isinstance(value, str): value = value.lstrip("'").lstrip('"') value = value.rstrip("'").rstrip('"') value = int(value) return value low_pass = _str_to_int(x_shifts_params.get(KEY_LOW_PASS_FILTER, y_shifts_params.get(KEY_LOW_PASS_FILTER, None))) high_pass = _str_to_int(x_shifts_params.get(KEY_HIGH_PASS_FILTER, y_shifts_params.get(KEY_HIGH_PASS_FILTER, None))) if high_pass is None and low_pass is None: pass else: if low_pass is None: low_pass = 1 if high_pass is None: high_pass = 20 _logger.info(f"filter image for shift search (low_pass={low_pass}, high_pass={high_pass})") img_filter = fourier_filters.get_bandpass_filter( overlap_upper_frame.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=overlap_upper_frame.dtype, ) overlap_upper_frame = local_ifftn( local_fftn(overlap_upper_frame, axes=(-2, -1)) * img_filter, axes=(-2, -1) ).real overlap_lower_frame = local_ifftn( local_fftn(overlap_lower_frame, axes=(-2, -1)) * img_filter, axes=(-2, -1) ).real # compute shifts initial_shifts = numpy.array(estimated_shifts).copy() extra_shifts = numpy.array([0.0, 0.0]) def skimage_proxy(img1, img2): if not __has_sk_phase_correlation__: raise ValueError("scikit-image not installed. Cannot do phase correlation from it") else: found_shift, _, _ = phase_cross_correlation(reference_image=img1, moving_image=img2, space="real") return -found_shift shift_methods = { ShiftAlgorithm.NABU_FFT: functools.partial( find_shift_correlate, img1=overlap_upper_frame, img2=overlap_lower_frame ), ShiftAlgorithm.SKIMAGE: functools.partial(skimage_proxy, img1=overlap_upper_frame, img2=overlap_lower_frame), ShiftAlgorithm.ITK_IMG_REG_V4: functools.partial( find_shift_with_itk, img1=overlap_upper_frame, img2=overlap_lower_frame ), ShiftAlgorithm.NONE: functools.partial(lambda: (0.0, 0.0)), } res_algo = {} for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)): if shift_alg not in shift_methods: raise ValueError(f"requested image alignment function not handled ({shift_alg})") try: res_algo[shift_alg] = shift_methods[shift_alg]() except Exception as e: _logger.error(f"Failed to find shift from {shift_alg.value}. Error is {e}") res_algo[shift_alg] = (0, 0) extra_shifts = ( res_algo[y_cross_correlation_function][0], res_algo[x_cross_correlation_function][1], ) final_rel_shifts = numpy.array(extra_shifts) + initial_shifts return tuple([int(shift) for shift in final_rel_shifts]) def find_volumes_relative_shifts( upper_volume: VolumeBase, lower_volume: VolumeBase, overlap_axis: int, estimated_shifts, dim_axis_1: int, dtype, flip_ud_upper_frame: bool = False, flip_ud_lower_frame: bool = False, slice_for_shift: Union[int, str] = "middle", x_cross_correlation_function=None, y_cross_correlation_function=None, x_shifts_params: Optional[dict] = None, y_shifts_params: Optional[dict] = None, alignment_axis_2="center", alignment_axis_1="center", ): """ :param int dim_axis_1: axis 1 dimension (to handle axis 1 alignment) """ if y_shifts_params is None: y_shifts_params = {} if x_shifts_params is None: x_shifts_params = {} # convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space. if overlap_axis == 1: raise NotImplementedError("finding projection shift along axis 1 is not handled for projections") elif overlap_axis == 0: overlap_axis_proj_space = 0 elif overlap_axis == 2: overlap_axis_proj_space = 1 else: raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}") alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2) alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1) assert dim_axis_1 > 0, "dim_axis_1 <= 0" if isinstance(slice_for_shift, str): if slice_for_shift == "first": slice_for_shift = 0 elif slice_for_shift == "last": slice_for_shift = dim_axis_1 elif slice_for_shift == "middle": slice_for_shift = dim_axis_1 // 2 else: raise ValueError("invalid slice provided to search shift", slice_for_shift) def get_slice_along_axis_1(volume: VolumeBase, index: int): assert isinstance(index, int), f"index should be an int, {type(index)} provided" volume_shape = volume.get_volume_shape() if alignment_axis_1 is AlignmentAxis1.BACK: front_empty_width = dim_axis_1 - volume_shape[1] if index < front_empty_width: return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype) else: return volume.get_slice(index=index - front_empty_width, axis=1) elif alignment_axis_1 is AlignmentAxis1.FRONT: if index >= volume_shape[1]: return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype) else: return volume.get_slice(index=index, axis=1) elif alignment_axis_1 is AlignmentAxis1.CENTER: front_empty_width = (dim_axis_1 - volume_shape[1]) // 2 back_empty_width = dim_axis_1 - front_empty_width if index < front_empty_width or index > back_empty_width: return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype) else: return volume.get_slice(index=index - front_empty_width, axis=1) else: raise TypeError(f"unmanaged alignment mode {alignment_axis_1.value}") upper_frame = get_slice_along_axis_1(upper_volume, index=slice_for_shift) lower_frame = get_slice_along_axis_1(lower_volume, index=slice_for_shift) if flip_ud_upper_frame: upper_frame = numpy.flipud(upper_frame.copy()) if flip_ud_lower_frame: lower_frame = numpy.flipud(lower_frame.copy()) from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400)) start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0) end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0])) if start_overlap == 0: overlap_upper_frame = upper_frame[-end_overlap:] else: overlap_upper_frame = upper_frame[-end_overlap:-start_overlap] overlap_lower_frame = lower_frame[start_overlap:end_overlap] # align if necessary if overlap_upper_frame.shape[1] != overlap_lower_frame.shape[1]: overlap_frame_width = min(overlap_upper_frame.shape[1], overlap_lower_frame.shape[1]) if alignment_axis_2 is AlignmentAxis2.CENTER: upper_frame_left_pos = overlap_upper_frame.shape[1] // 2 - overlap_frame_width // 2 upper_frame_right_pos = upper_frame_left_pos + overlap_frame_width overlap_upper_frame = overlap_upper_frame[:, upper_frame_left_pos:upper_frame_right_pos] lower_frame_left_pos = overlap_lower_frame.shape[1] // 2 - overlap_frame_width // 2 lower_frame_right_pos = lower_frame_left_pos + overlap_frame_width overlap_lower_frame = overlap_lower_frame[:, lower_frame_left_pos:lower_frame_right_pos] elif alignment_axis_2 is AlignmentAxis2.LEFT: overlap_upper_frame = overlap_upper_frame[:, :overlap_frame_width] overlap_lower_frame = overlap_lower_frame[:, :overlap_frame_width] elif alignment_axis_2 is AlignmentAxis2.RIGTH: overlap_upper_frame = overlap_upper_frame[:, -overlap_frame_width:] overlap_lower_frame = overlap_lower_frame[:, -overlap_frame_width:] else: raise ValueError(f"Alignement {alignment_axis_2.value} is not handled") if not overlap_upper_frame.shape == overlap_lower_frame.shape: raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})") return find_frame_relative_shifts( overlap_upper_frame=overlap_upper_frame, overlap_lower_frame=overlap_lower_frame, estimated_shifts=estimated_shifts, x_cross_correlation_function=x_cross_correlation_function, y_cross_correlation_function=y_cross_correlation_function, x_shifts_params=x_shifts_params, y_shifts_params=y_shifts_params, overlap_axis=overlap_axis_proj_space, ) from nabu.pipeline.estimators import estimate_cor def find_projections_relative_shifts( upper_scan: TomoScanBase, lower_scan: TomoScanBase, estimated_shifts: tuple, axis: int, flip_ud_upper_frame: bool = False, flip_ud_lower_frame: bool = False, projection_for_shift: Union[int, str] = "middle", invert_order: bool = False, x_cross_correlation_function=None, y_cross_correlation_function=None, x_shifts_params: Optional[dict] = None, y_shifts_params: Optional[dict] = None, ) -> tuple: """ deduce the relative shift between the two scans. Expected behavior: * compute expected overlap area from z_translations and (sample) pixel size * call an (optional) cross correlation function from the overlap area to compute the x shift and polish the y shift from `projection_for_shift` :param TomoScanBase scan_0: :param TomoScanBase scan_1: :param tuple estimated_shifts: 'a priori' shift estimation :param int axis: axis on which the overlap / stitching is happening. In the 3D space (sample, detector referential) :param bool flip_ud_upper_frame: is the upper frame flipped :param bool flip_ud_lower_frame: is the lower frame flipped :param Union[int,str] projection_for_shift: index fo the projection to use (in projection space or in scan space ?. For now in projection) or str. If str must be in (`middle`, `first`, `last`) :param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted) :param str x_cross_correlation_function: optional method to refine x shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft") :param str y_cross_correlation_function: optional method to refine y shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft") :param x_shifts_params: parameters to find the shift over x :param y_shifts_params: parameters to find the shift over y :return: relative shift of scan_1 with scan_0 as reference: (y_shift, x_shift) :rtype: tuple :warning: this function will flip left-right and up-down frames by default. So it will return shift according to this information """ if x_shifts_params is None: x_shifts_params = {} if y_shifts_params is None: y_shifts_params = {} # convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space. if axis == 1: axis_proj_space = 1 elif axis == 0: axis_proj_space = 0 elif axis == 2: raise NotImplementedError( "finding projection shift along axis 1 (x-ray direction) is not handled for projections" ) else: raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {axis}") x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function) y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function) # { handle specific use case (finding shift on scan) - when using nabu COR algorithms (for axis 2) if x_cross_correlation_function in ( ShiftAlgorithm.SINO_COARSE_TO_FINE, ShiftAlgorithm.COMPOSITE_COARSE_TO_FINE, ShiftAlgorithm.CENTERED, ShiftAlgorithm.GLOBAL, ShiftAlgorithm.GROWING_WINDOW, ShiftAlgorithm.SLIDING_WINDOW, ): cor_options = x_shifts_params.copy() cor_options.pop("img_reg_method", None) # remove all none numeric options because estimate_cor will call 'literal_eval' on them upper_scan_dataset_info = HDF5DatasetAnalyzer( location=upper_scan.master_file, extra_options={"hdf5_entry": upper_scan.entry} ) update_dataset_info_flats_darks(upper_scan_dataset_info, flatfield_mode=1) upper_scan_pos = estimate_cor( method=x_cross_correlation_function.value, dataset_info=upper_scan_dataset_info, cor_options=cor_options, ) lower_scan_dataset_info = HDF5DatasetAnalyzer( location=lower_scan.master_file, extra_options={"hdf5_entry": lower_scan.entry} ) update_dataset_info_flats_darks(lower_scan_dataset_info, flatfield_mode=1) lower_scan_pos = estimate_cor( method=x_cross_correlation_function.value, dataset_info=lower_scan_dataset_info, cor_options=cor_options, ) estimated_shifts = tuple( [ estimated_shifts[0], (lower_scan_pos - upper_scan_pos), ] ) x_cross_correlation_function = ShiftAlgorithm.NONE # } else we will compute shift from the flat projections def get_flat_fielded_proj( scan: TomoScanBase, proj_index: int, reverse: bool, transformation_matrix: Optional[numpy.ndarray] ): first_proj_idx = sorted(lower_scan.projections.keys(), reverse=reverse)[proj_index] ff = scan.flat_field_correction( (scan.projections[first_proj_idx],), (first_proj_idx,), )[0] assert ff.ndim == 2, f"expects a single 2D frame. Get something with {ff.ndim} dimensions" if transformation_matrix is not None: assert ( transformation_matrix.ndim == 2 ), f"expects a 2D transformation matrix. Get a {transformation_matrix.ndim} D" if numpy.isclose(transformation_matrix[2, 2], -1): transformation_matrix[2, :] = 0 transformation_matrix[0, 2] = 0 transformation_matrix[2, 2] = 1 ff = numpy.flipud(ff) return ff if isinstance(projection_for_shift, str): if projection_for_shift.lower() == "first": projection_for_shift = 0 elif projection_for_shift.lower() == "last": projection_for_shift = -1 elif projection_for_shift.lower() == "middle": projection_for_shift = len(upper_scan.projections) // 2 else: try: projection_for_shift = int(projection_for_shift) except ValueError: raise ValueError( f"{projection_for_shift} cannot be cast to an int and is not one of the possible ('first', 'last', 'middle')" ) elif not isinstance(projection_for_shift, (int, numpy.number)): raise TypeError( f"projection_for_shift is expected to be an int. Not {type(projection_for_shift)} - {projection_for_shift}" ) upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple())) if flip_ud_upper_frame: upper_scan_transformations.append(DetYFlipTransformation(flip=True)) upper_scan_trans_matrix = build_matrix(upper_scan_transformations) lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple())) if flip_ud_lower_frame: lower_scan_transformations.append(DetYFlipTransformation(flip=True)) lower_scan_trans_matrix = build_matrix(lower_scan_transformations) upper_proj = get_flat_fielded_proj( upper_scan, projection_for_shift, reverse=False, transformation_matrix=upper_scan_trans_matrix, ) lower_proj = get_flat_fielded_proj( lower_scan, projection_for_shift, reverse=invert_order, transformation_matrix=lower_scan_trans_matrix, ) from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import if axis_proj_space == 0: w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400)) else: w_window_size = int(x_shifts_params.get(KEY_WINDOW_SIZE, 400)) start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0) end_overlap = min( estimated_shifts[axis_proj_space] // 2 + w_window_size // 2, min(upper_proj.shape[axis_proj_space], lower_proj.shape[axis_proj_space]), ) o_upper_sel = numpy.array(range(-end_overlap, -start_overlap)) overlap_upper_frame = numpy.take_along_axis( upper_proj, o_upper_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :], axis=axis_proj_space, ) o_lower_sel = numpy.array(range(start_overlap, end_overlap)) overlap_lower_frame = numpy.take_along_axis( lower_proj, o_lower_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :], axis=axis_proj_space, ) if not overlap_upper_frame.shape == overlap_lower_frame.shape: raise ValueError(f"Fail to get consistent overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})") return find_frame_relative_shifts( overlap_upper_frame=overlap_upper_frame, overlap_lower_frame=overlap_lower_frame, estimated_shifts=estimated_shifts, x_cross_correlation_function=x_cross_correlation_function, y_cross_correlation_function=y_cross_correlation_function, x_shifts_params=x_shifts_params, y_shifts_params=y_shifts_params, overlap_axis=axis_proj_space, ) def find_shift_correlate(img1, img2, padding_mode="reflect"): alignment = AlignmentBase() cc = alignment._compute_correlation_fft( img1, img2, padding_mode, ) img_shape = img1.shape[-2:] cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2]) cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) (f_vals, fv, fh) = alignment.extract_peak_region_2d(cc, cc_vs=cc_vs, cc_hs=cc_hs) shifts_vh = alignment.refine_max_position_2d(f_vals, fv, fh) return -shifts_vh def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple: # created from https://examples.itk.org/src/registration/common/perform2dtranslationregistrationwithmeansquares/documentation # return (y_shift, x_shift). For now shift are integers as only integer shift are handled. if not img1.dtype == img2.dtype: raise ValueError("the two images are expected to have the same type") if not img1.ndim == img2.ndim == 2: raise ValueError("the two images are expected to 2D numpy arrays") if not has_itk: _logger.warning("itk is not installed. Please install it to find shift with it") return (0, 0) if StrictVersion(itk.Version.GetITKVersion()) < StrictVersion("4.9.0"): _logger.error("ITK 4.9.0 is required to find shift with it.") return (0, 0) pixel_type = itk.ctype("float") img1 = numpy.ascontiguousarray(img1, dtype=numpy.float32) img2 = numpy.ascontiguousarray(img2, dtype=numpy.float32) dimension = 2 image_type = itk.Image[pixel_type, dimension] fixed_image = itk.PyBuffer[image_type].GetImageFromArray(img1) moving_image = itk.PyBuffer[image_type].GetImageFromArray(img2) transform_type = itk.TranslationTransform[itk.D, dimension] initial_transform = transform_type.New() optimizer = itk.RegularStepGradientDescentOptimizerv4.New( LearningRate=4, MinimumStepLength=0.001, RelaxationFactor=0.5, NumberOfIterations=200, ) metric = itk.MeanSquaresImageToImageMetricv4[image_type, image_type].New() registration = itk.ImageRegistrationMethodv4.New( FixedImage=fixed_image, MovingImage=moving_image, Metric=metric, Optimizer=optimizer, InitialTransform=initial_transform, ) moving_initial_transform = transform_type.New() initial_parameters = moving_initial_transform.GetParameters() initial_parameters[0] = 0 initial_parameters[1] = 0 moving_initial_transform.SetParameters(initial_parameters) registration.SetMovingInitialTransform(moving_initial_transform) identity_transform = transform_type.New() identity_transform.SetIdentity() registration.SetFixedInitialTransform(identity_transform) registration.SetNumberOfLevels(1) registration.SetSmoothingSigmasPerLevel([0]) registration.SetShrinkFactorsPerLevel([1]) registration.Update() transform = registration.GetTransform() final_parameters = transform.GetParameters() translation_along_x = final_parameters.GetElement(0) translation_along_y = final_parameters.GetElement(1) return numpy.round(translation_along_y), numpy.round(translation_along_x) def from_slice_to_n_elements(slice_: Union[slice, tuple]): """Return the number of element in a slice or in a tuple""" if isinstance(slice_, slice): return (slice_.stop - slice_.start) / (slice_.step or 1) else: return len(slice_) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/y_stitching.py0000644000175000017500000000177414713343202020114 0ustar00pierrepierrefrom tomoscan.identifier import BaseIdentifier from nabu.stitching.stitcher.y_stitcher import PreProcessingYStitcher as PreProcessYStitcher from nabu.stitching.config import PreProcessedYStitchingConfiguration def y_stitching(configuration: PreProcessedYStitchingConfiguration, progress=None) -> BaseIdentifier: """ Apply stitching from provided configuration. Stitching will be applied along the first axis - 1 (aka y). like: axis 0 ^ | x-ray | --------> ------> axis 2 / / axis 1 """ if isinstance(configuration, PreProcessedYStitchingConfiguration): stitcher = PreProcessYStitcher(configuration=configuration, progress=progress) else: raise TypeError( f"configuration is expected to be in {(PreProcessedYStitchingConfiguration, )}. {type(configuration)} provided" ) return stitcher.stitch() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731053186.0 nabu-2024.2.1/nabu/stitching/z_stitching.py0000644000175000017500000000345314713343202020111 0ustar00pierrepierrefrom typing import Union from tomoscan.identifier import BaseIdentifier from nabu.stitching.stitcher.z_stitcher import PreProcessingZStitcher as PreProcessZStitcher from nabu.stitching.stitcher.z_stitcher import ( PostProcessingZStitcher as PostProcessZStitcher, PostProcessingZStitcherNoDD as PostProcessZStitcherNoDD, ) from nabu.stitching.config import ( PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration, ) def z_stitching( configuration: Union[PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration], progress=None ) -> BaseIdentifier: """ Apply stitching from provided configuration. Along axis 0 (aka z) Return a DataUrl with the created NXtomo or Volume like: axis 0 ^ | x-ray | --------> ------> axis 2 / / axis 1 """ stitcher = None assert configuration.axis is not None if isinstance(configuration, PreProcessedZStitchingConfiguration): if configuration.axis == 0: stitcher = PreProcessZStitcher(configuration=configuration, progress=progress) elif isinstance(configuration, PostProcessedZStitchingConfiguration): assert configuration.axis == 0 if configuration.duplicate_data: stitcher = PostProcessZStitcher(configuration=configuration, progress=progress) else: stitcher = PostProcessZStitcherNoDD(configuration=configuration, progress=progress) if stitcher is None: raise TypeError( f"configuration is expected to be in {(PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration)}. {type(configuration)} provided" ) return stitcher.stitch() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/tests.py0000644000175000017500000000261114550227307014733 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import sys import os import pytest from nabu.utils import get_folder_path from nabu import __nabu_modules__ as nabu_modules def get_modules_to_test(mods): sep = os.sep modules = [] extra_args = [] for mod in mods: if mod.startswith("-"): extra_args.append(mod) continue # Test a whole module if mod.lower() in nabu_modules: mod_abspath = os.path.join(get_folder_path(mod), "tests") # Test an individual file else: mod_path = mod.replace(".", sep) + ".py" mod_abspath = get_folder_path(mod_path) # test only one file mod_split = mod_abspath.split(sep) mod_split.insert(-1, "tests") mod_abspath = sep.join(mod_split) if not (os.path.exists(mod_abspath)): print("Error: no such file or directory: %s" % mod_abspath) exit(1) modules.append(mod_abspath) return modules, extra_args def nabu_test(): nabu_folder = get_folder_path() args = sys.argv[1:] modules_to_test, extra_args = get_modules_to_test(args) if len(modules_to_test) == 0: modules_to_test = [ nabu_folder, ] pytest_args = extra_args + modules_to_test return pytest.main(pytest_args) if __name__ == "__main__": ret = nabu_test() exit(ret) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731065675.0 nabu-2024.2.1/nabu/testutils.py0000644000175000017500000001672014713373513015641 0ustar00pierrepierrefrom itertools import product import tarfile import os import numpy as np from scipy.signal.windows import gaussian from silx.resources import ExternalResources from silx.io.dictdump import nxtodict, dicttonx utilstest = ExternalResources( project="nabu", url_base="http://www.silx.org/pub/nabu/data/", env_key="NABU_DATA", timeout=60 ) __big_testdata_dir__ = os.environ.get("NABU_BIGDATA_DIR") if __big_testdata_dir__ is None or not (os.path.isdir(__big_testdata_dir__)): __big_testdata_dir__ = None __do_long_tests__ = os.environ.get("NABU_LONG_TESTS", False) if __do_long_tests__: try: __do_long_tests__ = bool(int(__do_long_tests__)) except: __do_long_tests__ = False __do_large_mem_tests__ = os.environ.get("NABU_LARGE_MEM_TESTS", False) if __do_large_mem_tests__: try: __do_large_mem_tests__ = bool(int(__do_large_mem_tests__)) except: __do_large_mem_tests__ = False def generate_tests_scenarios(configurations): """ Generate "scenarios" of tests. The parameter is a dictionary where: - the key is the name of a parameter - the value is a list of possible parameters This function returns a list of dictionary where: - the key is the name of a parameter - the value is one value of this parameter """ scenarios = [{key: val for key, val in zip(configurations.keys(), p_)} for p_ in product(*configurations.values())] return scenarios def get_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) return np.load(dataset_downloaded_path) def get_array_of_given_shape(img, shape, dtype): """ From a given image, returns an array of the wanted shape and dtype. """ # Tile image until it's big enough. # "fun" fact: using any(blabla) crashes but using any([blabla]) does not, because of variables re-evaluation while any([i_dim <= s_dim for i_dim, s_dim in zip(img.shape, shape)]): img = np.tile(img, (2, 2)) if len(shape) == 1: arr = img[: shape[0], 0] elif len(shape) == 2: arr = img[: shape[0], : shape[1]] else: arr = np.tile(img, (shape[0], 1, 1))[: shape[0], : shape[1], : shape[2]] return np.ascontiguousarray(np.squeeze(arr), dtype=dtype) def get_big_data(filename): if __big_testdata_dir__ is None: return None return np.load(os.path.join(__big_testdata_dir__, filename)) def uncompress_file(compressed_file_path, target_directory): with tarfile.open(compressed_file_path) as f: f.extractall(path=target_directory) def get_file(fname): downloaded_file = dataset_downloaded_path = utilstest.getfile(fname) if ".tar" in fname: uncompress_file(downloaded_file, os.path.dirname(downloaded_file)) downloaded_file = downloaded_file.split(".tar")[0] return downloaded_file def compare_arrays(arr1, arr2, tol, diff=None, absolute_value=True, percent=None, method="max", return_residual=False): """ Utility to compare two arrays. Parameters ---------- arr1: numpy.ndarray First array to compare arr2: numpy.ndarray Second array to compare tol: float Tolerance indicating whether arrays are close to eachother. diff: numpy.ndarray, optional Difference `arr1 - arr2`. If provided, this array is taken instead of `arr1` and `arr2`. absolute_value: bool, optional Whether to take absolute value of the difference. percent: float If set, a "relative" comparison is performed instead of a subtraction: `red(|arr1 - arr2|) / (red(|arr1|) * percent) < tol` where "red" is the reduction method (mean, max or median). method: Reduction method. Can be "max", "mean", or "median". Returns -------- (is_close, residual) if return_residual is set to True is_close otherwise Examples -------- When using method="mean" and absolute_value=True, this function computes the Mean Absolute Difference (MAD) metric. When also using percent=1.0, this computes the Relative Mean Absolute Difference (RMD) metric. """ reductions = { "max": np.max, "mean": np.mean, "median": np.median, } if method not in reductions: raise ValueError("reduction method should be in %s" % str(list(reductions.keys()))) if diff is None: diff = arr1 - arr2 if absolute_value is not None: diff = np.abs(diff) residual = reductions[method](diff) if percent is not None: a1 = np.abs(arr1) if absolute_value else arr1 residual /= reductions[method](a1) res = residual < tol if return_residual: res = res, residual return res def gaussian_apodization_window(shape, fwhm_ratio=0.7): fwhm = fwhm_ratio * np.array(shape) sigma = fwhm / 2.355 return np.outer(*[gaussian(n, s) for n, s in zip(shape, sigma)]) def compare_shifted_images(img1, img2, fwhm_ratio=0.7, return_upper_bound=False): """ Compare two images that are slightly shifted from one another. Typically, tomography reconstruction wight slightly different CoR. Each image is Fourier-transformed, and the modulus is taken to get rid of the shift between the images. An apodization is done to filter the high frequencies that are usually less relevant. Parameters ---------- img1: numpy.ndarray First image img2: numpy.ndarray Second image fwhm_ratio: float, optional Ratio defining the apodization in the frequency domain. A small value (eg. 0.2) means that essentually only the low frequencies will be compared. A value of 1.0 means no apodization return_upper_bound: bool, optional Whether to return a (coarse) upper bound of the comparison metric Notes ----- This function roughly computes |phi(F(img1)) - phi(F(img2))| where F is the absolute value of the Fourier transform, and phi some shrinking function (here arcsinh). """ def _fourier_transform(img): return np.arcsinh(np.abs(np.fft.fftshift(np.fft.fft2(img)))) diff = _fourier_transform(img1) - _fourier_transform(img2) diff *= gaussian_apodization_window(img1.shape, fwhm_ratio=fwhm_ratio) res = np.abs(diff).max() if return_upper_bound: data_range = np.max(np.abs(img1)) return res, np.arcsinh(np.prod(img1.shape) * data_range) else: return res # To be improved def generate_nx_dataset(out_fname, image_key, data_volume=None, rotation_angle=None): nx_template_file = get_file("dummy.nx.tar.gz") nx_dict = nxtodict(nx_template_file) nx_entry = nx_dict["entry"] def _get_field(dict_, path): if path.startswith("/"): path = path[1:] if path.endswith("/"): path = path[:-1] split_path = path.split("/") if len(split_path) == 1: return dict_[split_path[0]] return _get_field(dict_[split_path[0]], "/".join(split_path[1:])) for name in ["image_key", "image_key_control"]: nx_entry["data"][name] = image_key nx_entry["instrument"]["detector"][name] = image_key if rotation_angle is not None: nx_entry["data"]["rotation_angle"] = rotation_angle nx_entry["sample"]["rotation_angle"] = rotation_angle dicttonx(nx_dict, out_fname) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/nabu/thirdparty/0000755000175000017500000000000014730277752015422 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/thirdparty/__init__.py0000644000175000017500000000000014315516747017520 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1705062087.0 nabu-2024.2.1/nabu/thirdparty/algotom_convert_sino.py0000644000175000017500000003306014550227307022217 0ustar00pierrepierre""" Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2021] [Nghia T. Vo, Diamond Light Source Ltd] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy as np from scipy.ndimage import shift """ This function was taken from "algotom" (3b9a75a5f7b4d407ce142964ca683c067be8c246), with the following minor differences: - That the center of rotation is kept where it is, so that the sinogram is not shifted to have the CoR in the middle. - When CoR is on the left, use 2*cor instead of 2*cor+2 for overlapping region """ def extend_sinogram(sino_360, cor, apply_log=True, shift_sinogram=True): """ Extend a 360-degree sinogram (with offset center-of-rotation) for later reconstruction (Ref. [1]). Parameters ---------- sino_360 : array_like 2D array. 360-degree sinogram. cor : float or tuple of float Center-of-rotation or (Overlap_area, overlap_side). apply_log : bool, optional Apply the logarithm function if True. Returns ------- sino_pad : array_like Extended sinogram. cor : float Updated center-of-rotation referred to the converted sinogram. References ---------- .. [1] https://doi.org/10.1364/OE.418448 """ if apply_log is True: sino_360 = -np.log(sino_360) else: sino_360 = np.copy(sino_360) (nrow, ncol) = sino_360.shape xcenter = (ncol - 1.0) * 0.5 if isinstance(cor, tuple): (overlap, side) = cor else: if cor <= xcenter: overlap = 2 * (cor + 0) # was + 1 side = 0 else: overlap = 2 * (ncol - cor - 1) side = 1 overlap_int = int(np.floor(overlap)) sub_pixel = overlap - overlap_int if side == 1: if sub_pixel > 0.0 and shift_sinogram: sino_360 = shift(sino_360, (0, sub_pixel), mode='nearest') wei_list = np.linspace(1.0, 0.0, overlap_int) wei_mat = np.tile(wei_list, (nrow, 1)) sino_360[:, -overlap_int:] = sino_360[:, -overlap_int:] * wei_mat pad_wid = ncol - overlap_int sino_pad = np.pad(sino_360, ((0, 0), (0, pad_wid)), mode="edge") else: if sub_pixel > 0.0 and shift_sinogram: sino_360 = shift(sino_360, (0, -sub_pixel), mode='nearest') wei_list = np.linspace(0.0, 1.0, overlap_int) wei_mat = np.tile(wei_list, (nrow, 1)) sino_360[:, :overlap_int] = sino_360[:, :overlap_int] * wei_mat pad_wid = ncol - overlap_int sino_pad = np.pad(sino_360, ((0, 0), (pad_wid, 0)), mode="edge") cor = (sino_pad.shape[1] - 1.0) / 2.0 return sino_pad, cor ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1699887648.0 nabu-2024.2.1/nabu/thirdparty/pore3d_deringer_munch.py0000644000175000017500000001062414524435040022227 0ustar00pierrepierre""" The following de-striping method is adapted from the pore3d software. :Organization: Elettra - Sincrotrone Trieste S.C.p.A. :Version: 2013.05.01 References ---------- [1] F. Brun, A. Accardo, G. Kourousias, D. Dreossi, R. Pugliese. Effective implementation of ring artifacts removal filters for synchrotron radiation microtomographic images. Proc. of the 8th International Symposium on Image and Signal Processing (ISPA), pp. 672-676, Sept. 4-6, Trieste (Italy), 2013. The license follows. """ # Copyright (c) 2013, Elettra - Sincrotrone Trieste S.C.p.A. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of the copyright holders nor the names of any # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. import numpy as np try: import pywt __has_pywt__ = True except ImportError: __has_pywt__ = False def munchetal_filter(im, wlevel, sigma, wname='db15'): """Process a sinogram image with the Munch et al. de-striping algorithm. Parameters ---------- im : array_like Image data as numpy array. wname : {'haar', 'db1'-'db20', 'sym2'-'sym20', 'coif1'-'coif5', 'dmey'} The wavelet transform to use. wlevel : int Levels of the wavelet decomposition. sigma : float Cutoff frequency of the Butterworth low-pass filtering. Example (using tiffile.py) -------------------------- >>> im = imread('original.tif') >>> im = munchetal_filter(im, 'db15', 4, 1.0) >>> imsave('filtered.tif', im) References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ # Wavelet decomposition: coeffs = pywt.wavedec2(im.astype(np.float32), wname, level=wlevel, mode="periodization") coeffsFlt = [coeffs[0]] # FFT transform of horizontal frequency bands: for i in range(1, wlevel + 1): # FFT: fcV = np.fft.fftshift(np.fft.fft(coeffs[i][1], axis=0)) my, mx = fcV.shape # Damping of vertical stripes: damp = 1 - np.exp(-(np.arange(-np.floor(my / 2.), -np.floor(my / 2.) + my) ** 2) / (2 * (sigma ** 2))) dampprime = np.kron(np.ones((1, mx)), damp.reshape((damp.shape[0], 1))) # np.tile(damp[:, np.newaxis], (1, mx)) fcV = fcV * dampprime # Inverse FFT: fcVflt = np.real(np.fft.ifft(np.fft.ifftshift(fcV), axis=0)) cVHDtup = (coeffs[i][0], fcVflt, coeffs[i][2]) coeffsFlt.append(cVHDtup) # Get wavelet reconstruction: im_f = np.real(pywt.waverec2(coeffsFlt, wname, mode="periodization")) # Return image according to input type: if (im.dtype == 'uint16'): # Check extrema for uint16 images: im_f[im_f < np.iinfo(np.uint16).min] = np.iinfo(np.uint16).min im_f[im_f > np.iinfo(np.uint16).max] = np.iinfo(np.uint16).max # Return filtered image (an additional row and/or column might be present): return im_f[0:im.shape[0], 0:im.shape[1]].astype(np.uint16) else: return im_f[0:im.shape[0], 0:im.shape[1]] if not(__has_pywt__): munchetal_filter = None ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1730906677.0 nabu-2024.2.1/nabu/thirdparty/tomocupy_remove_stripe.py0000644000175000017500000005745314712705065022625 0ustar00pierrepierre# pylint: skip-file """ This file is a "GPU" (through cupy) implementation of "remove_all_stripe". The original method is implemented by Nghia Vo in the algotom project: https://github.com/algotom/algotom/blob/master/algotom/prep/removal.py The implementation using cupy is done by Viktor Nikitin in the tomocupy project: https://github.com/tomography/tomocupy/blame/main/src/tomocupy/remove_stripe.py then moved to https://github.com/tomography/tomocupy/blob/main/src/tomocupy/processing/remove_stripe.py For now we can't rely on off-the-shelf tomocupy as it's not packaged in pypi, and compilation is quite tedious. License follows. """ # *************************************************************************** # # Copyright © 2022, UChicago Argonne, LLC # # All Rights Reserved # # Software Name: Tomocupy # # By: Argonne National Laboratory # # # # OPEN SOURCE LICENSE # # # # Redistribution and use in source and binary forms, with or without # # modification, are permitted provided that the following conditions are met: # # # # 1. Redistributions of source code must retain the above copyright notice, # # this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright # # notice, this list of conditions and the following disclaimer in the # # documentation and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # # contributors may be used to endorse or promote products derived # # from this software without specific prior written permission. # # # # # # *************************************************************************** # # DISCLAIMER # # # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS # # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED # # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # *************************************************************************** # try: import pycuda.gpuarray as garray import cupy as cp import pywt from cupyx.scipy.ndimage import median_filter from cupyx.scipy import signal from cupyx.scipy.ndimage import binary_dilation from cupyx.scipy.ndimage import uniform_filter1d __have_tomocupy_deringer__ = True except ImportError as err: __have_tomocupy_deringer__ = False __tomocupy_deringer_import_error__ = err ###### Ring removal with wavelet filtering (adapted for cupy from pytroch_wavelet package https://pytorch-wavelets.readthedocs.io/)################################################################################ def _reflect(x, minx, maxx): """Reflect the values in matrix *x* about the scalar values *minx* and *maxx*. Hence a vector *x* containing a long linearly increasing series is converted into a waveform which ramps linearly up and down between *minx* and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + 0.5), the ramps will have repeated max and min samples. .. codeauthor:: Rich Wareham , Aug 2013 .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. """ x = cp.asanyarray(x) rng = maxx - minx rng_by_2 = 2 * rng mod = cp.fmod(x - minx, rng_by_2) normed_mod = cp.where(mod < 0, mod + rng_by_2, mod) out = cp.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx return cp.array(out, dtype=x.dtype) def _mypad(x, pad, value=0): """ Function to do numpy like padding on Arrays. Only works for 2-D padding. Inputs: x (array): Array to pad pad (tuple): tuple of (left, right, top, bottom) pad sizes """ # Vertical only if pad[0] == 0 and pad[1] == 0: m1, m2 = pad[2], pad[3] l = x.shape[-2] xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) return x[:, :, xe] # horizontal only elif pad[2] == 0 and pad[3] == 0: m1, m2 = pad[0], pad[1] l = x.shape[-1] xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) return x[:, :, :, xe] def _conv2d(x, w, stride, pad, groups=1): """ Convolution (equivalent pytorch.conv2d) """ if pad != 0: x = cp.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant') b, ci, hi, wi = x.shape co, _, hk, wk = w.shape ho = int(cp.floor(1 + (hi - hk) / stride[0])) wo = int(cp.floor(1 + (wi - wk) / stride[1])) out = cp.zeros([b, co, ho, wo], dtype='float32') x = cp.expand_dims(x, axis=1) w = cp.expand_dims(w, axis=0) chunk = ci//groups chunko = co//groups for g in range(groups): for ii in range(hk): for jj in range(wk): x_windows = x[:, :, g*chunk:(g+1)*chunk, ii:ho * stride[0]+ii:stride[0], jj:wo*stride[1]+jj:stride[1]] out[:, g*chunko:(g+1)*chunko] += cp.sum(x_windows * w[:, g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1], axis=2) return out def _conv_transpose2d(x, w, stride, pad, bias=None, groups=1): """ Transposed convolution (equivalent pytorch.conv_transpose2d) """ b, co, ho, wo = x.shape co, ci, hk, wk = w.shape hi = (ho-1)*stride[0]+hk wi = (wo-1)*stride[1]+wk out = cp.zeros([b, ci, hi, wi], dtype='float32') chunk = ci//groups chunko = co//groups for g in range(groups): for ii in range(hk): for jj in range(wk): x_windows = x[:, g*chunko:(g+1)*chunko] out[:, g*chunk:(g+1)*chunk, ii:ho*stride[0]+ii:stride[0], jj:wo*stride[1] + jj:stride[1]] += x_windows * w[g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1] if pad != 0: out = out[:, :, pad[0]:out.shape[2]-pad[0], pad[1]:out.shape[3]-pad[1]] return out def afb1d(x, h0, h1='zero', dim=-1): """ 1D analysis filter bank (along one dimension only) of an image Parameters ---------- x (array): 4D input with the last two dimensions the spatial input h0 (array): 4D input for the lowpass filter. Should have shape (1, 1, h, 1) or (1, 1, 1, w) h1 (array): 4D input for the highpass filter. Should have shape (1, 1, h, 1) or (1, 1, 1, w) dim (int) - dimension of filtering. d=2 is for a vertical filter (called column filtering but filters across the rows). d=3 is for a horizontal filter, (called row filtering but filters across the columns). Returns ------- lohi: lowpass and highpass subbands concatenated along the channel dimension """ C = x.shape[1] # Convert the dim to positive d = dim % 4 s = (2, 1) if d == 2 else (1, 2) N = x.shape[d] L = h0.size L2 = L // 2 shape = [1, 1, 1, 1] shape[d] = L h = cp.concatenate([h0.reshape(*shape), h1.reshape(*shape)]*C, axis=0) # Calculate the pad size outsize = pywt.dwt_coeff_len(N, L, mode='symmetric') p = 2 * (outsize - 1) - N + L pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) x = _mypad(x, pad=pad) lohi = _conv2d(x, h, stride=s, pad=0, groups=C) return lohi def sfb1d(lo, hi, g0, g1='zero', dim=-1): """ 1D synthesis filter bank of an image Array """ C = lo.shape[1] d = dim % 4 L = g0.size shape = [1, 1, 1, 1] shape[d] = L N = 2*lo.shape[d] s = (2, 1) if d == 2 else (1, 2) g0 = cp.concatenate([g0.reshape(*shape)]*C, axis=0) g1 = cp.concatenate([g1.reshape(*shape)]*C, axis=0) pad = (L-2, 0) if d == 2 else (0, L-2) y = _conv_transpose2d(cp.asarray(lo), cp.asarray(g0), stride=s, pad=pad, groups=C) + \ _conv_transpose2d(cp.asarray(hi), cp.asarray(g1), stride=s, pad=pad, groups=C) return y class DWTForward(): """ Performs a 2d DWT Forward decomposition of an image Args: wave (str): Which wavelet to use. """ def __init__(self, wave='db1'): super().__init__() wave = pywt.Wavelet(wave) h0_col, h1_col = wave.dec_lo, wave.dec_hi h0_row, h1_row = h0_col, h1_col self.h0_col = cp.array(h0_col).astype('float32')[ ::-1].reshape((1, 1, -1, 1)) self.h1_col = cp.array(h1_col).astype('float32')[ ::-1].reshape((1, 1, -1, 1)) self.h0_row = cp.array(h0_row).astype('float32')[ ::-1].reshape((1, 1, 1, -1)) self.h1_row = cp.array(h1_row).astype('float32')[ ::-1].reshape((1, 1, 1, -1)) def apply(self, x): """ Forward pass of the DWT. Args: x (array): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of scale coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH coefficients. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ # Do a multilevel transform # Do 1 level of the transform lohi = afb1d(x, self.h0_row, self.h1_row, dim=3) y = afb1d(lohi, self.h0_col, self.h1_col, dim=2) s = y.shape y = y.reshape(s[0], -1, 4, s[-2], s[-1]) # pylint: disable=E1121 # this might blow up in the future x = cp.ascontiguousarray(y[:, :, 0]) yh = cp.ascontiguousarray(y[:, :, 1:]) return x, yh class DWTInverse(): """ Performs a 2d DWT Inverse reconstruction of an image Args: wave (str): Which wavelet to use. """ def __init__(self, wave='db1'): super().__init__() wave = pywt.Wavelet(wave) g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col # Prepare the filters self.g0_col = cp.array(g0_col).astype('float32').reshape((1, 1, -1, 1)) self.g1_col = cp.array(g1_col).astype('float32').reshape((1, 1, -1, 1)) self.g0_row = cp.array(g0_row).astype('float32').reshape((1, 1, 1, -1)) self.g1_row = cp.array(g1_row).astype('float32').reshape((1, 1, 1, -1)) def apply(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass array of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass arrays of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ yl, yh = coeffs lh = yh[:, :, 0] hl = yh[:, :, 1] hh = yh[:, :, 2] lo = sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2) hi = sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2) yl = sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3) return yl def remove_stripe_fw(data, sigma, wname, level): """Remove stripes with wavelet filtering""" [nproj, nz, ni] = data.shape nproj_pad = nproj + nproj // 8 xshift = int((nproj_pad - nproj) // 2) # Accepts all wave types available to PyWavelets xfm = DWTForward(wave=wname) ifm = DWTInverse(wave=wname) # Wavelet decomposition. cc = [] sli = cp.zeros([nz, 1, nproj_pad, ni], dtype='float32') sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) // 2] = data.astype('float32').swapaxes(0, 1) for k in range(level): sli, c = xfm.apply(sli) cc.append(c) # FFT fcV = cp.fft.fft(cc[k][:, 0, 1], axis=1) _, my, mx = fcV.shape # Damping of ring artifact information. y_hat = cp.fft.ifftshift((cp.arange(-my, my, 2) + 1) / 2) damp = -cp.expm1(-y_hat**2 / (2 * sigma**2)) fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1) # Inverse FFT. cc[k][:, 0, 1] = cp.fft.ifft(fcV, my, axis=1).real # Wavelet reconstruction. for k in range(level)[::-1]: shape0 = cc[k][0, 0, 1].shape sli = sli[:, :, :shape0[0], :shape0[1]] sli = ifm.apply((sli, cc[k])) data = sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) // 2, :ni].astype(data.dtype) # modified data = data.swapaxes(0, 1) return data ######## Titarenko ring removal ############################################################################################################################################################################ def remove_stripe_ti(data, beta, mask_size): """Remove stripes with a new method by V. Titareno """ gamma = beta*((1-beta)/(1+beta) )**cp.abs(cp.fft.fftfreq(data.shape[-1])*data.shape[-1]) gamma[0] -= 1 v = cp.mean(data, axis=0) v = v-v[:, 0:1] v = cp.fft.irfft(cp.fft.rfft(v)*cp.fft.rfft(gamma)) mask = cp.zeros(v.shape, dtype=v.dtype) mask_size = mask_size*mask.shape[1] mask[:, mask.shape[1]//2-mask_size//2:mask.shape[1]//2+mask_size//2] = 1 data[:] += v*mask return data ######## Optimized version for Vo-all ring removal in tomopy################################################################################################################################################################ def _rs_sort(sinogram, size, matindex, dim): """ Remove stripes using the sorting technique. """ sinogram = cp.transpose(sinogram) matcomb = cp.asarray(cp.dstack((matindex, sinogram))) # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcomb]) ids = cp.argsort(matcomb[:,:,1],axis=1) matsort = matcomb.copy() matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1) matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1) if dim == 1: matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, 1)) else: matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, size)) # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort]) ids = cp.argsort(matsort[:,:,0],axis=1) matsortback = matsort.copy() matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1) matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1) sino_corrected = matsortback[:, :, 1] return cp.transpose(sino_corrected) def _mpolyfit(x,y): n= len(x) x_mean = cp.mean(x) y_mean = cp.mean(y) Sxy = cp.sum(x*y) - n*x_mean*y_mean Sxx = cp.sum(x*x) - n*x_mean*x_mean slope = Sxy / Sxx intercept = y_mean - slope*x_mean return slope,intercept def _detect_stripe(listdata, snr): """ Algorithm 4 in :cite:`Vo:18`. Used to locate stripes. """ numdata = len(listdata) listsorted = cp.sort(listdata)[::-1] xlist = cp.arange(0, numdata, 1.0) ndrop = cp.int16(0.25 * numdata) # (_slope, _intercept) = cp.polyfit(xlist[ndrop:-ndrop - 1], # listsorted[ndrop:-ndrop - 1], 1) (_slope, _intercept) = _mpolyfit(xlist[ndrop:-ndrop - 1], listsorted[ndrop:-ndrop - 1]) numt1 = _intercept + _slope * xlist[-1] noiselevel = cp.abs(numt1 - _intercept) noiselevel = cp.clip(noiselevel, 1e-6, None) val1 = cp.abs(listsorted[0] - _intercept) / noiselevel val2 = cp.abs(listsorted[-1] - numt1) / noiselevel listmask = cp.zeros_like(listdata) if (val1 >= snr): upper_thresh = _intercept + noiselevel * snr * 0.5 listmask[listdata > upper_thresh] = 1.0 if (val2 >= snr): lower_thresh = numt1 - noiselevel * snr * 0.5 listmask[listdata <= lower_thresh] = 1.0 return listmask def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True): """ Remove large stripes. """ drop_ratio = max(min(drop_ratio,0.8),0)# = cp.clip(drop_ratio, 0.0, 0.8) (nrow, ncol) = sinogram.shape ndrop = int(0.5 * drop_ratio * nrow) sinosort = cp.sort(sinogram, axis=0) sinosmooth = median_filter(sinosort, (1, size)) list1 = cp.mean(sinosort[ndrop:nrow - ndrop], axis=0) list2 = cp.mean(sinosmooth[ndrop:nrow - ndrop], axis=0) # listfact = cp.divide(list1, # list2, # out=cp.ones_like(list1), # where=list2 != 0) listfact = list1/list2 # Locate stripes listmask = _detect_stripe(listfact, snr) listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype) matfact = cp.tile(listfact, (nrow, 1)) # Normalize if norm is True: sinogram = sinogram / matfact sinogram1 = cp.transpose(sinogram) matcombine = cp.asarray(cp.dstack((matindex, sinogram1))) # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcombine]) ids = cp.argsort(matcombine[:,:,1],axis=1) matsort = matcombine.copy() matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1) matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1) matsort[:, :, 1] = cp.transpose(sinosmooth) # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort]) ids = cp.argsort(matsort[:,:,0],axis=1) matsortback = matsort.copy() matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1) matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1) sino_corrected = cp.transpose(matsortback[:, :, 1]) listxmiss = cp.where(listmask > 0.0)[0] sinogram[:, listxmiss] = sino_corrected[:, listxmiss] return sinogram def _rs_dead(sinogram, snr, size, matindex, norm=True): """ Remove unresponsive and fluctuating stripes. """ sinogram = cp.copy(sinogram) # Make it mutable (nrow, _) = sinogram.shape # sinosmooth = cp.apply_along_axis(uniform_filter1d, 0, sinogram, 10) sinosmooth = uniform_filter1d(sinogram, 10, axis=0) listdiff = cp.sum(cp.abs(sinogram - sinosmooth), axis=0) listdiffbck = median_filter(listdiff, size) listfact = listdiff/listdiffbck listmask = _detect_stripe(listfact, snr) listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype) listmask[0:2] = 0.0 listmask[-2:] = 0.0 listx = cp.where(listmask < 1.0)[0] listy = cp.arange(nrow) matz = sinogram[:, listx] listxmiss = cp.where(listmask > 0.0)[0] # finter = interpolate.interp2d(listx.get(), listy.get(), matz.get(), kind='linear') if len(listxmiss) > 0: # sinogram_c[:, listxmiss.get()] = finter(listxmiss.get(), listy.get()) ids = cp.searchsorted(listx, listxmiss) sinogram[:,listxmiss] = matz[:,ids-1]+(listxmiss-listx[ids-1])*(matz[:,ids]-matz[:,ids-1])/(listx[ids]-listx[ids-1]) # Remove residual stripes if norm is True: sinogram = _rs_large(sinogram, snr, size, matindex) return sinogram def _create_matindex(nrow, ncol): """ Create a 2D array of indexes used for the sorting technique. """ listindex = cp.arange(0.0, ncol, 1.0) matindex = cp.tile(listindex, (nrow, 1)) return matindex def remove_all_stripe(tomo, snr=3, la_size=61, sm_size=21, dim=1): """ Remove all types of stripe artifacts from sinogram using Nghia Vo's approach :cite:`Vo:18` (combination of algorithm 3,4,5, and 6). Parameters ---------- tomo : ndarray 3D tomographic data. snr : float Ratio used to locate large stripes. Greater is less sensitive. la_size : int Window size of the median filter to remove large stripes. sm_size : int Window size of the median filter to remove small-to-medium stripes. dim : {1, 2}, optional Dimension of the window. Returns ------- ndarray Corrected 3D tomographic data. """ matindex = _create_matindex(tomo.shape[2], tomo.shape[0]) for m in range(tomo.shape[1]): sino = tomo[:, m, :] sino = _rs_dead(sino, snr, la_size, matindex) sino = _rs_sort(sino, sm_size, matindex, dim) tomo[:, m, :] = sino return tomo def remove_all_stripe_sinos(sinos, snr=3, la_size=61, sm_size=21, dim=1): """ Same as remove_all_stripe(), but acting on sinograms """ n_sinos, n_a, n_x = sinos.shape matindex = _create_matindex(n_x, n_a) for m in range(n_sinos): sino = sinos[m] sino = _rs_dead(sino, snr, la_size, matindex) sino = _rs_sort(sino, sm_size, matindex, dim) sinos[m] = sino return sinos from ..cuda.utils import pycuda_to_cupy def remove_all_stripe_pycuda(array, layout="radios", device_id=0, **kwargs): """ Nabu interface to tomocupy "remove_all_stripe". Processing is done in-place to save memory, meaning that the content of "array" will be overwritten. Parameters ---------- array: pycuda.GPUArray Stack of radios in the shape (n_angles, n_y, n_x), if layout == "radios" Stack of sinos in the shape (n_y, n_angles, n_x), if layout == "sinos". Other Parameters ---------------- See parameters of 'remove_all_stripe """ # Init cupy. Nabu currently does not use cupy, with exception of this module, # so the initialization has to be done here. if getattr(remove_all_stripe, "_cupy_init", False) is False: from cupy import cuda cuda.Device(device_id).use() setattr(remove_all_stripe, "_cupy_init", True) # remove_all_stripe() in tomocupy expects a 3D array to build the "matindex" data structure. # The usage of this "matindex" array is not clear since this method is supposed to act on individual sinograms. # To avoid memory duplication, we use fake 3D array, i.e, we pass a series of (1, n_a, n_x) sinograms if layout == "radios": sinos = array.transpose(axes=(1, 0, 2)) # no copy else: sinos = array # is_contiguous = sinos.flags.c_contigious n_sinos, n_a, n_x = sinos.shape sinos_tmp = garray.zeros((1, n_a, n_x), dtype="f") for i in range(n_sinos): sinos_tmp[0] = sinos[i] cupy_sinos = pycuda_to_cupy(sinos_tmp) # no memory copy, the internal pointer is passed to pycuda remove_all_stripe_sinos(cupy_sinos, **kwargs) sinos[i] = sinos_tmp[0] return array ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/thirdparty/tomopy_phase.py0000644000175000017500000002067214315516747020511 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- # ######################################################################### # Copyright (c) 2015-2019, UChicago Argonne, LLC. All rights reserved. # # # # Copyright 2015-2019. UChicago Argonne, LLC. This software was produced # # under U.S. Government contract DE-AC02-06CH11357 for Argonne National # # Laboratory (ANL), which is operated by UChicago Argonne, LLC for the # # U.S. Department of Energy. The U.S. Government has rights to use, # # reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR # # UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR # # ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is # # modified to produce derivative works, such modified software should # # be clearly marked, so as not to confuse it with the version available # # from ANL. # # # # Additionally, redistribution and use in source and binary forms, with # # or without modification, are permitted provided that the following # # conditions are met: # # # # * Redistributions of source code must retain the above copyright # # notice, this list of conditions and the following disclaimer. # # # # * Redistributions in binary form must reproduce the above copyright # # notice, this list of conditions and the following disclaimer in # # the documentation and/or other materials provided with the # # distribution. # # # # * Neither the name of UChicago Argonne, LLC, Argonne National # # Laboratory, ANL, the U.S. Government, nor the names of its # # contributors may be used to endorse or promote products derived # # from this software without specific prior written permission. # # # # THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS # # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS # # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago # # Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, # # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN # # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # # POSSIBILITY OF SUCH DAMAGE. # # ######################################################################### """ Module for phase retrieval. This code is part of tomopy: https://github.com/tomopy/tomopy It was adapted for being stand-alone, without job distribution. See the license above for more information. """ import numpy as np fft2 = np.fft.fft2 ifft2 = np.fft.ifft2 __author__ = "Doga Gursoy" __credits__ = "Mark Rivers, Xianghui Xiao" __copyright__ = "Copyright (c) 2015, UChicago Argonne, LLC." __docformat__ = 'restructuredtext en' __all__ = ['retrieve_phase'] BOLTZMANN_CONSTANT = 1.3806488e-16 # [erg/k] SPEED_OF_LIGHT = 299792458e+2 # [cm/s] PI = 3.14159265359 PLANCK_CONSTANT = 6.58211928e-19 # [keV*s] def _wavelength(energy): return 2 * PI * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy def retrieve_phase( tomo, pixel_size=1e-4, dist=50, energy=20, alpha=1e-3, pad=True, ncore=None, nchunk=None): """ Perform single-step phase retrieval from phase-contrast measurements. Parameters ---------- tomo : ndarray 3D tomographic data. pixel_size : float, optional Detector pixel size in cm. dist : float, optional Propagation distance of the wavefront in cm. energy : float, optional Energy of incident wave in keV. alpha : float, optional Regularization parameter. pad : bool, optional If True, extend the size of the projections by padding with zeros. ncore : int, optional Number of cores that will be assigned to jobs. nchunk : int, optional Chunk size for each core. Returns ------- ndarray Approximated 3D tomographic phase data. """ # New dimensions and pad value after padding. py, pz, val = _calc_pad(tomo, pixel_size, dist, energy, pad) # Compute the reciprocal grid. dx, dy, dz = tomo.shape w2 = _reciprocal_grid(pixel_size, dy + 2 * py, dz + 2 * pz) # Filter in Fourier space. phase_filter = np.fft.fftshift( _paganin_filter_factor(energy, dist, alpha, w2)) prj = np.full((dy + 2 * py, dz + 2 * pz), val, dtype='float32') arr = _retrieve_phase(tomo, phase_filter, py, pz, prj, pad) return arr def _retrieve_phase(tomo, phase_filter, px, py, prj, pad): dx, dy, dz = tomo.shape num_jobs = tomo.shape[0] normalized_phase_filter = phase_filter / phase_filter.max() for m in range(num_jobs): # Padding "constant" with border value # prj is initially filled with "val" prj[px:dy + px, py:dz + py] = tomo[m] prj[:px] = prj[px] prj[-px:] = prj[-px-1] prj[:, :py] = prj[:, py][:, np.newaxis] prj[:, -py:] = prj[:, -py-1][:, np.newaxis] fproj = fft2(prj) fproj *= normalized_phase_filter proj = np.real(ifft2(fproj)) if pad: proj = proj[px:dy + px, py:dz + py] tomo[m] = proj return tomo def _calc_pad(tomo, pixel_size, dist, energy, pad): """ Calculate new dimensions and pad value after padding. Parameters ---------- tomo : ndarray 3D tomographic data. pixel_size : float Detector pixel size in cm. dist : float Propagation distance of the wavefront in cm. energy : float Energy of incident wave in keV. pad : bool If True, extend the size of the projections by padding with zeros. Returns ------- int Pad amount in projection axis. int Pad amount in sinogram axis. float Pad value. """ dx, dy, dz = tomo.shape wavelength = _wavelength(energy) py, pz, val = 0, 0, 0 if pad: val = _calc_pad_val(tomo) py = _calc_pad_width(dy, pixel_size, wavelength, dist) pz = _calc_pad_width(dz, pixel_size, wavelength, dist) return py, pz, val def _paganin_filter_factor(energy, dist, alpha, w2): return 1 / (_wavelength(energy) * dist * w2 / (4 * PI) + alpha) def _calc_pad_width(dim, pixel_size, wavelength, dist): pad_pix = np.ceil(PI * wavelength * dist / pixel_size ** 2) return int((pow(2, np.ceil(np.log2(dim + pad_pix))) - dim) * 0.5) def _calc_pad_val(tomo): # mean of [(column 0 of radio) + (column -1 of radio)]/2. return np.mean((tomo[..., 0] + tomo[..., -1]) * 0.5) def _reciprocal_grid(pixel_size, nx, ny): """ Calculate reciprocal grid. Parameters ---------- pixel_size : float Detector pixel size in cm. nx, ny : int Size of the reciprocal grid along x and y axes. Returns ------- ndarray Grid coordinates. """ # Sampling in reciprocal space. indx = _reciprocal_coord(pixel_size, nx) indy = _reciprocal_coord(pixel_size, ny) np.square(indx, out=indx) np.square(indy, out=indy) return np.add.outer(indx, indy) def _reciprocal_coord(pixel_size, num_grid): """ Calculate reciprocal grid coordinates for a given pixel size and discretization. Parameters ---------- pixel_size : float Detector pixel size in cm. num_grid : int Size of the reciprocal grid. Returns ------- ndarray Grid coordinates. """ n = num_grid - 1 rc = np.arange(-n, num_grid, 2, dtype = np.float32) rc *= 0.5 / (n * pixel_size) return rc ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1664523751.0 nabu-2024.2.1/nabu/thirdparty/tomwer_load_flats_darks.py0000644000175000017500000001275314315516747022674 0ustar00pierrepierre""" script embedding all function needed to read flats and darks. For information calculation method is stored in the results/configuration dictionary """ """ IMPORTANT: this script is used as long as flat-fielding with "raw" flats/daks is not implemented in nabu. For now we load results from tomwer. """ import typing import h5py from silx.io.url import DataUrl from tomoscan.io import HDF5File from silx.io.utils import h5py_read_dataset import logging MAX_DEPTH = 2 logger = logging.getLogger(__name__) def get_process_entries(root_node: h5py.Group, depth: int) -> tuple: """ return the list of 'Nxtomo' entries at the root level :param str file_path: :return: list of valid Nxtomo node (ordered alphabetically) :rtype: tuple ..note: entries are sorted to insure consistency """ def _get_entries(node, depth_): if isinstance(node, h5py.Dataset): return {} res_buf = {} if is_process_node(node) is True: res_buf[node.name] = int(node['sequence_index'][()]) assert isinstance(node, h5py.Group) if depth_ >= 1: for sub_node in node.values(): res_buf[node.name] = _get_entries(node=sub_node, depth_=depth_-1) return res_buf res = {} for node in root_node.values(): res.update(_get_entries(node=node, depth_=depth-1)) return res def is_process_node(node): return (node.name.split('/')[-1].startswith('tomwer_process_') and 'NX_class' in node.attrs and node.attrs['NX_class'] == "NXprocess" and 'program' in node and h5py_read_dataset(node['program']) == 'tomwer_dark_refs' and 'version' in node and 'sequence_index' in node) def get_darks_frm_process_file(process_file, entry) -> typing.Union[None, dict]: """ :param process_file: :return: """ if entry is None: with HDF5File(process_file, 'r', swmr=True) as h5f: entries = get_process_entries(root_node=h5f, depth=MAX_DEPTH) if len(entries) == 0: logger.info( 'unable to find a DarkRef process in %s' % process_file) return None elif len(entries) > 0: raise ValueError('several entry found, entry should be ' 'specify') else: entry = list(entries.keys())[0] logger.info('take %s as default entry' % entry) with HDF5File(process_file, 'r', swmr=True) as h5f: dark_nodes = get_process_entries(root_node=h5f[entry], depth=MAX_DEPTH-1) index_to_path = {} for key, value in dark_nodes.items(): index_to_path[key] = key if len(dark_nodes) == 0: return {} # take the last processed dark ref last_process_index = sorted(dark_nodes.keys())[-1] last_process_dark = index_to_path[last_process_index] if(len(index_to_path)) > 1: logger.warning('several processing found for dark-ref,' 'take the last one: %s' % last_process_dark) res = {} if 'results' in h5f[last_process_dark].keys(): results_node = h5f[last_process_dark]['results'] if 'darks' in results_node.keys(): darks = results_node['darks'] for index in darks: res[int(index)] = DataUrl( file_path=process_file, data_path=darks[index].name, scheme="silx" ) return res def get_flats_frm_process_file(process_file, entry) -> typing.Union[None, dict]: """ :param process_file: :return: """ if entry is None: with HDF5File(process_file, 'r', swmr=True) as h5f: entries = get_process_entries(root_node=h5f, depth=MAX_DEPTH) if len(entries) == 0: logger.info( 'unable to find a DarkRef process in %s' % process_file) return None elif len(entries) > 0: raise ValueError('several entry found, entry should be ' 'specify') else: entry = list(entries.keys())[0] logger.info('take %s as default entry' % entry) with HDF5File(process_file, 'r', swmr=True) as h5f: dkref_nodes = get_process_entries(root_node=h5f[entry], depth=MAX_DEPTH-1) if len(dkref_nodes) == 0: return {} index_to_path = {} for key, value in dkref_nodes.items(): index_to_path[key] = key # take the last processed dark ref last_process_index = sorted(dkref_nodes.keys())[-1] last_process_dkrf = index_to_path[last_process_index] if(len(index_to_path)) > 1: logger.warning('several processing found for dark-ref,' 'take the last one: %s' % last_process_dkrf) res = {} if 'results' in h5f[last_process_dkrf].keys(): results_node = h5f[last_process_dkrf]['results'] if 'flats' in results_node.keys(): flats = results_node['flats'] for index in flats: res[int(index)] = DataUrl( file_path=process_file, data_path=flats[index].name, scheme="silx" ) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1731681010.0 nabu-2024.2.1/nabu/utils.py0000644000175000017500000006336314715655362014755 0ustar00pierrepierrefrom fnmatch import fnmatch from functools import partial import os from functools import partial, lru_cache from itertools import product import warnings from time import time import posixpath import numpy as np def nextpow2(N, dtype=np.int32): return 2 ** np.ceil(np.log2(N)).astype(dtype) def previouspow2(N, dtype=np.int32): return 2 ** np.floor(np.log2(N)).astype(dtype) def updiv(a, b): return (a + (b - 1)) // b def convert_index(idx, idx_max, default_val): """ Convert an index (possibly negative or None) to a non-negative integer. Parameters ---------- idx: int or None Index idx_max: int Maximum value (upper bound) for the index. default_val: int Default value if idx is None Examples --------- Given an integer `M`, `J = convert_index(i, M, XX)` returns an integer in the mathematical range [0, M] (or Python `range(0, M)`). `J` can then be used to define an upper bound of a range. """ if idx is None: return default_val if idx > idx_max: return idx_max if idx < 0: return idx % idx_max return idx def get_folder_path(foldername=""): _file_dir = os.path.dirname(os.path.abspath(__file__)) package_dir = _file_dir return os.path.join(package_dir, foldername) def get_cuda_srcfile(filename): src_relpath = os.path.join("cuda", "src") cuda_src_folder = get_folder_path(foldername=src_relpath) return os.path.join(cuda_src_folder, filename) def get_opencl_srcfile(filename): src_relpath = os.path.join("opencl", "src") src_folder = get_folder_path(foldername=src_relpath) return os.path.join(src_folder, filename) def get_resource_file(filename, subfolder=None): subfolder = subfolder or [] relpath = os.path.join("resources", *subfolder) abspath = get_folder_path(foldername=relpath) return os.path.join(abspath, filename) def indices_to_slices(indices): """ From a series of integer indices, return corresponding slice() objects. Parameters ---------- indices: collection of sorted unique integers Arrays indices Examples -------- slices_from_indices([0, 1, 2, 3]) returns [slice(0, 4)] slices_from_indices([8, 9, 10, 14, 15, 16]) returns [slice(8, 11), slice(15, 17)] """ jumps = np.where(np.diff(indices) > 1)[0] if len(jumps) == 0: return [slice(indices[0], indices[-1] + 1)] jumps = np.hstack([-1, jumps, len(indices) - 1]) slices = [] for i in range(len(jumps) - 1): slices.append(slice(indices[jumps[i] + 1], indices[jumps[i + 1]] + 1)) return slices def merge_slices(slice1, slice2): """ Merge two slicing operations in one. Examples -------- array = numpy.arange(200) array[slice(133, 412, 2)][slice(31, 35, 2)] # gives [195 199] array[merge_slices(slice(133, 412, 2), slice(31, 35, 2))] # gives [195 199] """ step1 = slice1.step or 1 step2 = slice2.step or 1 step = step1 * step2 if step == 1: step = None start = slice1.start + step1 * (slice2.start or 0) if slice2.stop is None: stop = slice1.stop else: stop = min(slice1.stop, slice1.start + step1 * slice2.stop) return slice(start, stop, step) def compacted_views(slices_): """ From a list of slice objects, returns the slice objects corresponding to a compact view. If "array" is obtained with array = np.hstack([big_array[slice1], big_array[slice2]]) Then, compacted_views([slice1, slice2]) returns [slice3, slice4] where - slice3 are the indices, in 'array', corresponding to indices of slice1 in 'big_array' - slice4 are the indices, in 'array', corresponding to indices of slice2 in 'big_array' Example ------- compacted_views([slice(1, 26), slice(526, 551)]) gives [slice(0, 25), slice(25, 50)] """ prev_start = 0 r = [] for s in slices_: start = prev_start stop = start + (s.stop - s.start) r.append(slice(start, stop)) prev_start = stop return r def get_size_from_sliced_dimension(length, slice_): """ From a given array size, returns the size of the array once it is accessed using a slice. Examples -------- If data.shape = (3500, 2160, 2560) get_size_from_sliced_dimension(data.shape[0], None) returns 3500 get_size_from_sliced_dimension(data.shape[0], slice(100, 200)) returns 100 """ return np.arange(length)[slice_].size def get_shape_from_sliced_dims(shape, slices_): """ Same as get_size_from_sliced_dimension() but in 3D """ return tuple(get_size_from_sliced_dimension(length, slice_) for length, slice_ in zip(shape, slices_)) def get_available_threads(): try: n_threads = len(os.sched_getaffinity(0)) except AttributeError: n_threads = int(os.environ.get("SLURM_CPUS_PER_TASK", os.cpu_count())) return n_threads def list_match_queries(available, queries): """ Given a list of strings, return all items matching any of one elements of "queries" """ matches = [] for a in available: for q in queries: if fnmatch(a, q): matches.append(a) return matches def is_writeable(location): """ Return True if a file/location is writeable. """ return os.access(location, os.W_OK) def is_int(num, eps=1e-7): return abs(num - int(num)) < eps def is_scalar(stuff): if isinstance(stuff, str): return False return np.isscalar(stuff) def _sizeof(Type): """ return the size (in bytes) of a scalar type, like the C behavior """ return np.dtype(Type).itemsize class _Default_format(dict): """ https://docs.python.org/3/library/stdtypes.html """ def __missing__(self, key): return key def safe_format(str_, **kwargs): """ Alternative to str.format(), but does not throw a KeyError when fields are missing. """ return str_.format_map(_Default_format(**kwargs)) def get_ftype(url): """ return supposed filetype of an url """ if hasattr(url, "file_path"): return os.path.splitext(url.file_path())[-1].replace(".", "") else: return os.path.splitext(url)[-1].replace(".", "") def get_2D_3D_shape(shape): if len(shape) == 2: return (1,) + shape return shape def get_subregion(sub_region): ymin, ymax = None, None if sub_region is None: xmin, xmax, ymin, ymax = None, None, None, None elif len(sub_region) == 2: first_part, second_part = sub_region if np.iterable(first_part) and np.iterable(second_part): xmin, xmax = first_part ymin, ymax = second_part else: xmin, xmax = sub_region elif len(sub_region) == 4: xmin, xmax, ymin, ymax = sub_region else: raise ValueError("Expected parameter in the form (a, b, c, d) or ((a, b), (c, d))") return xmin, xmax, ymin, ymax def get_3D_subregion(sub_region): if sub_region is None: xmin, xmax, ymin, ymax, zmin, zmax = None, None, None, None, None, None elif len(sub_region) == 3: first_part, second_part, third_part = sub_region xmin, xmax = first_part ymin, ymax = second_part zmin, zmax = third_part elif len(sub_region) == 6: xmin, xmax, ymin, ymax, zmin, zmax = sub_region else: raise ValueError( "Expected parameter in the form (xmin, xmax, ymin, ymax, zmin, zmax) or ((xmin, xmax), (ymin, ymax), (zmin, zmax))" ) return xmin, xmax, ymin, ymax, zmin, zmax def to_3D_array(arr): """ Turn an array to a (C-Contiguous) 3D array with the layout (n_arrays, n_y, n_x). """ if arr.ndim == 3: return arr return np.tile(arr, (1, 1, 1)) def view_as_images_stack(img): """ View an image (2D array) as a stack of one image (3D array). No data is duplicated. """ return img.reshape((1,) + img.shape) def rescale_integers(items, new_tot): """ " From a given sequence of integers, create a new sequence where the sum of all items must be equal to "new_tot". The relative contribution of each item to the new total is approximately kept. Parameters ---------- items: array-like Sequence of integers new_tot: int Integer indicating that the sum of the new array must be equal to this value """ cur_items = np.array(items) new_items = np.ceil(cur_items / cur_items.sum() * new_tot).astype(np.int64) excess = new_items.sum() - new_tot i = 0 while excess > 0: ind = i % new_items.size if cur_items[ind] > 0: new_items[ind] -= 1 excess -= 1 i += 1 return new_items.tolist() def merged_shape(shapes, axis=0): n_img = sum(shape[axis] for shape in shapes) if axis == 0: return (n_img,) + shapes[0][1:] elif axis == 1: return (shapes[0][0], n_img, shapes[0][2]) elif axis == 2: return shapes[0][:2] + (n_img,) def is_device_backend(backend): return backend.lower() in ["cuda", "opencl"] def get_decay(curve, cutoff=1e3, vmax=None): """ Assuming a decreasing curve, get the first point below a certain threshold. Parameters ---------- curve: numpy.ndarray A 1D array cutoff: float, optional Threshold. Default is 1000. vmax: float, optional Curve maximum value """ if vmax is None: vmax = curve.max() return np.argmax(np.abs(curve) < vmax / cutoff) @lru_cache(maxsize=1) def generate_powers(): """ Generate a list of powers of [2, 3, 5, 7], up to (2**15)*(3**9)*(5**6)*(7**5). """ primes = [2, 3, 5, 7] maxpow = {2: 15, 3: 9, 5: 6, 7: 5} valuations = [] for prime in primes: # disallow any odd number (for R2C transform), and any number # not multiple of 4 (Ram-Lak filter behaves strangely when # dwidth_padded/2 is not even) minval = 2 if prime == 2 else 0 valuations.append(range(minval, maxpow[prime] + 1)) powers = product(*valuations) res = [] for pw in powers: res.append(np.prod(list(map(lambda x: x[0] ** x[1], zip(primes, pw))))) return np.unique(res) def calc_padding_lengths1D(length, length_padded): """ Compute the padding lengths at both side along one dimension. Parameters ---------- length: int Number of elements along one dimension of the original array length_padded: tuple Number of elements along one dimension of the padded array Returns ------- pad_lengths: tuple A tuple under the form (padding_left, padding_right). These are the lengths needed to pad the original array. """ pad_left = (length_padded - length) // 2 pad_right = length_padded - length - pad_left return (pad_left, pad_right) def calc_padding_lengths(shape, shape_padded): """ Multi-dimensional version of calc_padding_lengths1D. Please refer to the documentation of calc_padding_lengths1D. """ assert len(shape) == len(shape_padded) padding_lengths = [] for dim_len, dim_len_padded in zip(shape, shape_padded): pad0, pad1 = calc_padding_lengths1D(dim_len, dim_len_padded) padding_lengths.append((pad0, pad1)) return tuple(padding_lengths) def partition_dict(dict_, n_partitions): keys = np.sort(list(dict_.keys())) res = [] for keys_arr in np.array_split(keys, n_partitions): d = {} for key in keys_arr: d[key] = dict_[key] res.append(d) return res def first_dict_item(dict_): keys = sorted(list(dict_.keys())) return dict_[keys[0]] def subsample_dict(dic, subsampling_factor): """ Subsample a dict where keys are integers. """ res = {} indices = sorted(dic.keys()) for i in indices[::subsampling_factor]: res[i] = dic[i] return res def compare_dicts(dic1, dic2): """ Compare two dictionaries. Return None if and only iff the dictionaries are the same. Parameters ---------- dic1: dict First dictionary dic2: dict Second dictionary Returns ------- res: result which can be the following: - None: it means that dictionaries are the same - empty string (""): the dictionaries do not have the same keys - nonempty string: path to the first differing items """ if set(dic1.keys()) != set(dic2.keys()): return "" for key, val1 in dic1.items(): val2 = dic2[key] if isinstance(val1, dict): res = compare_dicts(val1, val2) if res is not None: return posixpath.join(key, res) # str elif isinstance(val1, str): if val1 != val2: return key # Scalars elif np.isscalar(val1): if not np.isclose(val1, val2): return key # NoneType elif val1 is None: if val2 is not None: return key # Array-like elif np.iterable(val1): arr1 = np.array(val1) arr2 = np.array(val2) if arr1.ndim != arr2.ndim or arr1.dtype != arr2.dtype or not np.allclose(arr1, arr2): return key else: raise ValueError("Don't know what to do with type %s" % str(type(val1))) return None def remove_items_from_list(list_, items_to_remove): """ Remove items from a list and return the removed elements. Parameters ---------- list_: list List containing items to remove items_to_remove: list List of items to remove Returns -------- reduced_list: list List with removed items removed_items: dict Dictionary where the keys are the indices of removed items, and values are the items """ removed_items = {} res = [] for i, val in enumerate(list_): if val in items_to_remove: removed_items[i] = val else: res.append(val) return res, removed_items def restore_items_in_list(list_, removed_items): """ Undo the effect of the function `remove_items_from_list` Parameters ---------- list_: list List where items were removed removed_items: dict Dictionary where the keys are the indices of removed items, and values are the items """ for idx, val in removed_items.items(): list_.insert(idx, val) def check_supported(param_value, available, param_desc): if param_value not in available: raise ValueError("Unsupported %s '%s'. Available are: %s" % (param_desc, param_value, str(available))) def check_supported_enum(param_value, enum_cls, param_desc): available = enum_cls.values() return check_supported(param_value, available, param_desc) def check_shape(shape, expected_shape, name): if shape != expected_shape: raise ValueError("Expected %s shape %s but got %s" % (name, str(expected_shape), str(shape))) def copy_dict_items(dict_, keys): """ Perform a shallow copy of a subset of a dictionary. The subset if done by a list of keys. """ res = {key: dict_[key] for key in keys} return res def recursive_copy_dict(dict_): """ Perform a shallow copy of a dictionary of dictionaries. This is NOT a deep copy ! Only reference to objects are kept. """ if not (isinstance(dict_, dict)): res = dict_ else: res = {k: recursive_copy_dict(v) for k, v in dict_.items()} return res def subdivide_into_overlapping_segment(N, window_width, half_overlap): """ Divide a segment into a number of possibly-overlapping sub-segments. Parameters ---------- N: int Total segment length window_width: int Length of each segment half_overlap: int Half-length of the overlap between each sub-segment. Returns ------- segments: list A list where each item is in the form (left_margin_start, inner_segment_start, inner_segment_end, right_margin_end) """ if half_overlap > 0 and half_overlap >= window_width // 2: raise ValueError("overlap must be smaller than window_width") w_in = window_width - 2 * half_overlap n_segments = N // w_in inner_start = w_in * np.arange(n_segments) inner_end = w_in * (np.arange(n_segments) + 1) margin_left_start = np.maximum(inner_start - half_overlap, 0) margin_right_end = np.minimum(inner_end + half_overlap, N) segments = [ (left_start, i_start, i_end, right_end) for left_start, i_start, i_end, right_end in zip( margin_left_start.tolist(), inner_start.tolist(), inner_end.tolist(), margin_right_end.tolist() ) ] if N % w_in: # additional sub-segment new_margin_left_start = inner_end[-1] - half_overlap new_inner_start = inner_end[-1] new_inner_end = N segments.append((new_margin_left_start, new_inner_start, new_inner_end, new_inner_end)) return segments def get_num_threads(n=None): """ Get a number of threads (ex. to be used by fftw). If the argument is None, returns the total number of CPU threads. If the argument is negative, the total number of available threads plus this number is returned. Parameters ----------- n: int, optional - If an positive integer `n` is provided, then `n` threads are used - If a negative integer `n` is provided, then `n_avail + n` threads are used (so -1 to use all available threads minus one) """ n_avail = get_available_threads() if n is None or n == 0: return n_avail if n < 0: return max(1, n_avail + n) else: return min(n_avail, n) class DictToObj(object): """utility class to transform a dictionary into an object with dictionary items as members. Example: >>> a=DictToObj( dict(i=1,j=1)) ... a.j+a.i """ def __init__(self, dictio): self.__dict__ = dictio def remove_parenthesis_or_brackets(input_str): """ clear string from left and or roght parenthesis / braquets """ if input_str.startswith("(") and input_str.endswith(")") or input_str.startswith("[") and input_str.endswith("]"): input_str = input_str[1:-1] return input_str def filter_str_def(elmt): """clean elemt if is a string defined from a text file. Remove some character that could have be put on left or right and some empty spaces""" if elmt is None: return None assert isinstance(elmt, str) elmt = elmt.lstrip(" ").rstrip(" ") elmt = elmt.lstrip("'").lstrip('"') elmt = elmt.rstrip("'").rstrip('"') elmt = elmt.lstrip(" ").rstrip(" ") for character in ("'", '"'): if elmt.startswith(character) and elmt.endswith(character): elmt = elmt[1:-1] return elmt def convert_str_to_tuple(input_str: str, none_if_empty: bool = False): """ :param str input_str: string to convert :param bool none_if_empty: if true and the conversion is an empty tuple return None instead of an empty tuple """ if isinstance(input_str, tuple): return input_str if not isinstance(input_str, str): raise TypeError("input_str should be a string not {}, {}".format(type(input_str), input_str)) input_str = input_str.lstrip(" ").lstrip("(").lstrip("[").lstrip(" ").rstrip(" ") input_str = remove_parenthesis_or_brackets(input_str) input_str = input_str.replace("\n", ",") elmts = input_str.split(",") elmts = [filter_str_def(elmt) for elmt in elmts] rm_empty_str = lambda a: a != "" elmts = list(filter(rm_empty_str, elmts)) if none_if_empty and len(elmts) == 0: return None else: return tuple(elmts) def concatenate_dict(dict_1, dict_2) -> dict: """update dict which has dict as values. And we want concatenate those values to""" res = dict_1.copy() for key in dict_2: if key in dict_1: res[key].update(dict_2[key]) else: res[key] = dict_2[key] return res class BaseClassError: def __init__(self, *args, **kwargs): raise ValueError("Base class") def MissingComponentError(msg): class MissingComponentCls: def __init__(self, *args, **kwargs): raise RuntimeError(msg) return MissingComponentCls # ------------------------------------------------------------------------------ # ------------------------ Image (move elsewhere ?) ---------------------------- # ------------------------------------------------------------------------------ def generate_coords(img_shp, center=None): l_r, l_c = float(img_shp[0]), float(img_shp[1]) R, C = np.mgrid[:l_r, :l_c] # np.indices is faster if center is None: center0, center1 = l_r / 2.0, l_c / 2.0 else: center0, center1 = center R += 0.5 - center0 C += 0.5 - center1 return R, C def clip_circle(img, center=None, radius=None, out_value=0): R, C = generate_coords(img.shape, center) if radius is None: radius = R.shape[-1] // 2 M = R**2 + C**2 res = np.zeros_like(img) if out_value != 0: res.fill(out_value) res[M < radius**2] = img[M < radius**2] return res def extend_image_onepixel(img): # extend of one pixel img2 = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=img.dtype) img2[0, 1:-1] = img[0] img2[-1, 1:-1] = img[-1] img2[1:-1, 0] = img[:, 0] img2[1:-1, -1] = img[:, -1] # middle img2[1:-1, 1:-1] = img # corners img2[0, 0] = img[0, 0] img2[-1, 0] = img[-1, 0] img2[0, -1] = img[0, -1] img2[-1, -1] = img[-1, -1] return img2 def median2(img): """ 3x3 median filter for 2D arrays, with "reflect" boundary mode. Roughly same speed as scipy median filter, but more memory demanding. """ img2 = extend_image_onepixel(img) I = np.array( [ img2[0:-2, 0:-2], img2[0:-2, 1:-1], img2[0:-2, 2:], img2[1:-1, 0:-2], img2[1:-1, 1:-1], img2[1:-1, 2:], img2[2:, 0:-2], img2[2:, 1:-1], img2[2:, 2:], ] ) return np.median(I, axis=0) # ------------------------------------------------------------------------------ # ---------------------------- Decorators -------------------------------------- # ------------------------------------------------------------------------------ _warnings = {} def measure_time(func): def wrapper(*args, **kwargs): t0 = time() res = func(*args, **kwargs) el = time() - t0 return el, res return wrapper def wip(func): def wrapper(*args, **kwargs): func_name = func.__name__ if func_name not in _warnings: _warnings[func_name] = 1 print("Warning: function %s is a work in progress, it is likely to change in the future") return func(*args, **kwargs) return wrapper def warning(msg): def decorator(func): def wrapper(*args, **kwargs): func_name = func.__name__ if func_name not in _warnings: _warnings[func_name] = 1 print(msg) res = func(*args, **kwargs) return res return wrapper return decorator def deprecated(msg, do_print=False): def decorator(func): def wrapper(*args, **kwargs): deprecation_warning(msg, do_print=do_print, func_name=func.__name__) res = func(*args, **kwargs) return res return wrapper return decorator def deprecated_class(msg, do_print=False): def decorator(cls): class wrapper: def __init__(self, *args, **kwargs): deprecation_warning(msg, do_print=do_print, func_name=cls.__name__) self.wrapped = cls(*args, **kwargs) # This is so ugly :-( def __getattr__(self, name): return getattr(self.wrapped, name) return wrapper return decorator def deprecation_warning(message, do_print=True, func_name=None): func_name_msg = str("%s: " % func_name) if func_name is not None else "" func_name = func_name or "None" if _warnings.get(func_name, False): return warnings.warn(message, DeprecationWarning) if do_print: print("Deprecation warning: %s%s" % (func_name_msg, message)) _warnings[func_name] = 1 def _docstring(dest, origin): """Implementation of docstring decorator. It patches dest.__doc__. """ if not isinstance(dest, type) and isinstance(origin, type): # func is not a class, but origin is, get the method with the same name try: origin = getattr(origin, dest.__name__) except AttributeError: raise ValueError("origin class has no %s method" % dest.__name__) dest.__doc__ = origin.__doc__ return dest def docstring(origin): """Decorator to initialize the docstring from another source. This is useful to duplicate a docstring for inheritance and composition. If origin is a method or a function, it copies its docstring. If origin is a class, the docstring is copied from the method of that class which has the same name as the method/function being decorated. :param origin: The method, function or class from which to get the docstring :raises ValueError: If the origin class has not method n case the """ return partial(_docstring, origin=origin) from warnings import catch_warnings # FIX for python < 3.11 # catch_warnings() does not have "action=XX" kwarg for python < 3.11 from sys import version_info if version_info.major == 3 and version_info.minor < 11: def dummy(*args, **kwargs): pass catch_warnings_old = catch_warnings def catch_warnings(*args, **kwargs): # pylint: disable=E0102 action = kwargs.pop("action", None) return catch_warnings_old(record=(dummy if action == "ignore" else False)) # --- ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1734442985.4967566 nabu-2024.2.1/nabu.egg-info/0000755000175000017500000000000014730277751014721 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/PKG-INFO0000644000175000017500000001066714730277751016030 0ustar00pierrepierreMetadata-Version: 2.1 Name: nabu Version: 2024.2.1 Summary: Nabu - Tomography software Author-email: Pierre Paleo , Henri Payno , Alessandro Mirone , Jérôme Lesaint Maintainer-email: Pierre Paleo License: MIT License Copyright (c) 2020-2024 ESRF Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Project-URL: Homepage, https://gitlab.esrf.fr/tomotools/nabu Project-URL: Documentation, http://www.silx.org/pub/nabu/doc Project-URL: Repository, https://gitlab.esrf.fr/tomotools/nabu/-/releases Project-URL: Changelog, https://gitlab.esrf.fr/tomotools/nabu/-/blob/master/CHANGELOG.md Keywords: tomography,reconstruction,X-ray imaging,synchrotron radiation,High Performance Computing,Parallel geometry,Conebeam geometry,Helical geometry,Ring artefact correction,Geometric calibration Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Environment :: Console Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: Unix Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: POSIX Classifier: Topic :: Scientific/Engineering :: Physics Classifier: Topic :: Scientific/Engineering :: Medical Science Apps. Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: full Provides-Extra: full_nocuda Provides-Extra: doc License-File: LICENSE # Nabu ESRF tomography processing software. ## Installation To install the development version: ```bash pip install [--user] git+https://gitlab.esrf.fr/tomotools/nabu.git ``` To install the stable version: ```bash pip install [--user] nabu ``` ## Usage Nabu can be used in several ways: - As a Python library, by features like `Backprojector`, `FlatField`, etc - As a standalone application with the command line interface - From Tomwer ([https://gitlab.esrf.fr/tomotools/tomwer/](https://gitlab.esrf.fr/tomotools/tomwer/)) To get quickly started, launch: ```bash nabu-config ``` Edit the generated configuration file (`nabu.conf`). Then: ```bash nabu nabu.conf --slice 500-600 ``` will reconstruct the slices 500 to 600, with processing steps depending on `nabu.conf` contents. ## Documentation The documentation can be found on the silx.org page ([https://www.silx.org/pub/nabu/doc](http://www.silx.org/pub/nabu/doc)). The latest documentation built by continuous integration can be found here: [https://tomotools.gitlab-pages.esrf.fr/nabu/](https://tomotools.gitlab-pages.esrf.fr/nabu/) ## Running the tests Once nabu is installed, running ```bash nabu-test ``` will execute all the tests. You can also specify specific module(s) to test, for example: ```bash nabu-test preproc misc ``` You can also provide more `pytest` options, for example increase verbosity with `-v`, exit at the first fail with `-x`, etc. Use `nabu-test --help` for displaying the complete options list. ## Nabu - what's in a name ? Nabu was the Mesopotamian god of literacy, rational arts, scribes and wisdom. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/SOURCES.txt0000644000175000017500000002351014730277751016606 0ustar00pierrepierreLICENSE README.md pyproject.toml doc/conf.py doc/create_conf_doc.py doc/get_mathjax.py nabu/__init__.py nabu/tests.py nabu/testutils.py nabu/utils.py nabu.egg-info/PKG-INFO nabu.egg-info/SOURCES.txt nabu.egg-info/dependency_links.txt nabu.egg-info/entry_points.txt nabu.egg-info/requires.txt nabu.egg-info/top_level.txt nabu/app/__init__.py nabu/app/bootstrap.py nabu/app/bootstrap_stitching.py nabu/app/cast_volume.py nabu/app/cli_configs.py nabu/app/compare_volumes.py nabu/app/composite_cor.py nabu/app/correct_rot.py nabu/app/create_distortion_map_from_poly.py nabu/app/diag_to_pix.py nabu/app/diag_to_rot.py nabu/app/double_flatfield.py nabu/app/generate_header.py nabu/app/histogram.py nabu/app/multicor.py nabu/app/nx_z_splitter.py nabu/app/parse_reconstruction_log.py nabu/app/prepare_weights_double.py nabu/app/reconstruct.py nabu/app/reconstruct_helical.py nabu/app/reduce_dark_flat.py nabu/app/rotate.py nabu/app/shrink_dataset.py nabu/app/stitching.py nabu/app/utils.py nabu/app/validator.py nabu/app/tests/__init__.py nabu/app/tests/test_reduce_dark_flat.py nabu/cuda/__init__.py nabu/cuda/convolution.py nabu/cuda/fft.py nabu/cuda/kernel.py nabu/cuda/medfilt.py nabu/cuda/padding.py nabu/cuda/processing.py nabu/cuda/utils.py nabu/cuda/src/ElementOp.cu nabu/cuda/src/backproj.cu nabu/cuda/src/backproj_polar.cu nabu/cuda/src/boundary.h nabu/cuda/src/cone.cu nabu/cuda/src/convolution.cu nabu/cuda/src/dfi_fftshift.cu nabu/cuda/src/flatfield.cu nabu/cuda/src/fourier_wavelets.cu nabu/cuda/src/halftomo.cu nabu/cuda/src/helical_padding.cu nabu/cuda/src/hierarchical_backproj.cu nabu/cuda/src/histogram.cu nabu/cuda/src/interpolation.cu nabu/cuda/src/medfilt.cu nabu/cuda/src/normalization.cu nabu/cuda/src/padding.cu nabu/cuda/src/proj.cu nabu/cuda/src/rotation.cu nabu/cuda/src/transpose.cu nabu/cuda/tests/__init__.py nabu/estimation/__init__.py nabu/estimation/alignment.py nabu/estimation/cor.py nabu/estimation/cor_sino.py nabu/estimation/distortion.py nabu/estimation/focus.py nabu/estimation/tilt.py nabu/estimation/translation.py nabu/estimation/utils.py nabu/estimation/tests/__init__.py nabu/estimation/tests/test_alignment.py nabu/estimation/tests/test_cor.py nabu/estimation/tests/test_focus.py nabu/estimation/tests/test_tilt.py nabu/estimation/tests/test_translation.py nabu/io/__init__.py nabu/io/cast_volume.py nabu/io/detector_distortion.py nabu/io/reader.py nabu/io/reader_helical.py nabu/io/utils.py nabu/io/writer.py nabu/io/tests/__init__.py nabu/io/tests/test_cast_volume.py nabu/io/tests/test_detector_distortion.py nabu/io/tests/test_readers.py nabu/io/tests/test_writers.py nabu/misc/__init__.py nabu/misc/binning.py nabu/misc/fftshift.py nabu/misc/filters.py nabu/misc/fourier_filters.py nabu/misc/histogram.py nabu/misc/histogram_cuda.py nabu/misc/kernel_base.py nabu/misc/padding.py nabu/misc/padding_base.py nabu/misc/processing_base.py nabu/misc/rotation.py nabu/misc/rotation_cuda.py nabu/misc/transpose.py nabu/misc/unsharp.py nabu/misc/unsharp_cuda.py nabu/misc/unsharp_opencl.py nabu/misc/utils.py nabu/misc/tests/__init__.py nabu/misc/tests/test_binning.py nabu/misc/tests/test_interpolation.py nabu/opencl/__init__.py nabu/opencl/fft.py nabu/opencl/kernel.py nabu/opencl/memcpy.py nabu/opencl/padding.py nabu/opencl/processing.py nabu/opencl/utils.py nabu/opencl/src/ElementOp.cl nabu/opencl/src/backproj.cl nabu/opencl/src/fftshift.cl nabu/opencl/src/halftomo.cl nabu/opencl/src/padding.cl nabu/opencl/src/roll.cl nabu/opencl/src/transpose.cl nabu/opencl/tests/__init__.py nabu/pipeline/__init__.py nabu/pipeline/config.py nabu/pipeline/config_validators.py nabu/pipeline/datadump.py nabu/pipeline/dataset_validator.py nabu/pipeline/detector_distortion_provider.py nabu/pipeline/estimators.py nabu/pipeline/params.py nabu/pipeline/processconfig.py nabu/pipeline/reader.py nabu/pipeline/utils.py nabu/pipeline/writer.py nabu/pipeline/fullfield/__init__.py nabu/pipeline/fullfield/chunked.py nabu/pipeline/fullfield/chunked_cuda.py nabu/pipeline/fullfield/computations.py nabu/pipeline/fullfield/dataset_validator.py nabu/pipeline/fullfield/nabu_config.py nabu/pipeline/fullfield/processconfig.py nabu/pipeline/fullfield/reconstruction.py nabu/pipeline/helical/__init__.py nabu/pipeline/helical/dataset_validator.py nabu/pipeline/helical/fbp.py nabu/pipeline/helical/filtering.py nabu/pipeline/helical/gridded_accumulator.py nabu/pipeline/helical/helical_chunked_regridded.py nabu/pipeline/helical/helical_chunked_regridded_cuda.py nabu/pipeline/helical/helical_reconstruction.py nabu/pipeline/helical/helical_utils.py nabu/pipeline/helical/nabu_config.py nabu/pipeline/helical/processconfig.py nabu/pipeline/helical/span_strategy.py nabu/pipeline/helical/weight_balancer.py nabu/pipeline/helical/tests/__init__.py nabu/pipeline/tests/test_estimators.py nabu/pipeline/xrdct/__init__.py nabu/preproc/__init__.py nabu/preproc/alignment.py nabu/preproc/ccd.py nabu/preproc/ccd_cuda.py nabu/preproc/ctf.py nabu/preproc/ctf_cuda.py nabu/preproc/distortion.py nabu/preproc/double_flatfield.py nabu/preproc/double_flatfield_cuda.py nabu/preproc/double_flatfield_variable_region.py nabu/preproc/flatfield.py nabu/preproc/flatfield_cuda.py nabu/preproc/flatfield_variable_region.py nabu/preproc/phase.py nabu/preproc/phase_cuda.py nabu/preproc/shift.py nabu/preproc/shift_cuda.py nabu/preproc/tests/__init__.py nabu/preproc/tests/test_ccd_corr.py nabu/preproc/tests/test_ctf.py nabu/preproc/tests/test_double_flatfield.py nabu/preproc/tests/test_flatfield.py nabu/preproc/tests/test_paganin.py nabu/preproc/tests/test_vshift.py nabu/processing/__init__.py nabu/processing/azim.py nabu/processing/convolution_cuda.py nabu/processing/fft_base.py nabu/processing/fft_cuda.py nabu/processing/fft_opencl.py nabu/processing/fftshift.py nabu/processing/histogram.py nabu/processing/histogram_cuda.py nabu/processing/kernel_base.py nabu/processing/medfilt_cuda.py nabu/processing/muladd.py nabu/processing/muladd_cuda.py nabu/processing/padding_base.py nabu/processing/padding_cuda.py nabu/processing/padding_opencl.py nabu/processing/processing_base.py nabu/processing/roll_opencl.py nabu/processing/rotation.py nabu/processing/rotation_cuda.py nabu/processing/transpose.py nabu/processing/unsharp.py nabu/processing/unsharp_cuda.py nabu/processing/unsharp_opencl.py nabu/processing/tests/__init__.py nabu/processing/tests/test_fft.py nabu/processing/tests/test_fftshift.py nabu/processing/tests/test_histogram.py nabu/processing/tests/test_medfilt.py nabu/processing/tests/test_muladd.py nabu/processing/tests/test_padding.py nabu/processing/tests/test_roll.py nabu/processing/tests/test_rotation.py nabu/processing/tests/test_transpose.py nabu/processing/tests/test_unsharp.py nabu/reconstruction/__init__.py nabu/reconstruction/cone.py nabu/reconstruction/fbp.py nabu/reconstruction/fbp_base.py nabu/reconstruction/fbp_opencl.py nabu/reconstruction/filtering.py nabu/reconstruction/filtering_cuda.py nabu/reconstruction/filtering_opencl.py nabu/reconstruction/hbp.py nabu/reconstruction/mlem.py nabu/reconstruction/projection.py nabu/reconstruction/reconstructor.py nabu/reconstruction/reconstructor_cuda.py nabu/reconstruction/rings.py nabu/reconstruction/rings_cuda.py nabu/reconstruction/sinogram.py nabu/reconstruction/sinogram_cuda.py nabu/reconstruction/sinogram_opencl.py nabu/reconstruction/tests/__init__.py nabu/reconstruction/tests/test_cone.py nabu/reconstruction/tests/test_deringer.py nabu/reconstruction/tests/test_fbp.py nabu/reconstruction/tests/test_filtering.py nabu/reconstruction/tests/test_halftomo.py nabu/reconstruction/tests/test_mlem.py nabu/reconstruction/tests/test_projector.py nabu/reconstruction/tests/test_reconstructor.py nabu/reconstruction/tests/test_sino_normalization.py nabu/resources/__init__.py nabu/resources/cor.py nabu/resources/dataset_analyzer.py nabu/resources/gpu.py nabu/resources/logger.py nabu/resources/nxflatfield.py nabu/resources/utils.py nabu/resources/cli/__init__.py nabu/resources/templates/__init__.py nabu/resources/templates/bm05_pag.conf nabu/resources/templates/id16_ctf.conf nabu/resources/templates/id16_holo.conf nabu/resources/templates/id16a_fluo.conf nabu/resources/templates/id19_pag.conf nabu/resources/tests/__init__.py nabu/resources/tests/test_extract.py nabu/resources/tests/test_nxflatfield.py nabu/resources/tests/test_units.py nabu/stitching/__init__.py nabu/stitching/alignment.py nabu/stitching/config.py nabu/stitching/definitions.py nabu/stitching/frame_composition.py nabu/stitching/overlap.py nabu/stitching/sample_normalization.py nabu/stitching/single_axis_stitching.py nabu/stitching/slurm_utils.py nabu/stitching/stitcher_2D.py nabu/stitching/y_stitching.py nabu/stitching/z_stitching.py nabu/stitching/stitcher/__init__.py nabu/stitching/stitcher/base.py nabu/stitching/stitcher/post_processing.py nabu/stitching/stitcher/pre_processing.py nabu/stitching/stitcher/single_axis.py nabu/stitching/stitcher/stitcher.py nabu/stitching/stitcher/y_stitcher.py nabu/stitching/stitcher/z_stitcher.py nabu/stitching/stitcher/dumper/__init__.py nabu/stitching/stitcher/dumper/base.py nabu/stitching/stitcher/dumper/postprocessing.py nabu/stitching/stitcher/dumper/preprocessing.py nabu/stitching/tests/__init__.py nabu/stitching/tests/test_alignment.py nabu/stitching/tests/test_config.py nabu/stitching/tests/test_frame_composition.py nabu/stitching/tests/test_overlap.py nabu/stitching/tests/test_sample_normalization.py nabu/stitching/tests/test_slurm_utils.py nabu/stitching/tests/test_utils.py nabu/stitching/tests/test_y_preprocessing_stitching.py nabu/stitching/tests/test_z_postprocessing_stitching.py nabu/stitching/tests/test_z_preprocessing_stitching.py nabu/stitching/utils/__init__.py nabu/stitching/utils/post_processing.py nabu/stitching/utils/utils.py nabu/stitching/utils/tests/test_post-processing.py nabu/thirdparty/__init__.py nabu/thirdparty/algotom_convert_sino.py nabu/thirdparty/pore3d_deringer_munch.py nabu/thirdparty/tomocupy_remove_stripe.py nabu/thirdparty/tomopy_phase.py nabu/thirdparty/tomwer_load_flats_darks.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/dependency_links.txt0000644000175000017500000000000114730277751020767 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/entry_points.txt0000644000175000017500000000241014730277751020214 0ustar00pierrepierre[console_scripts] nabu = nabu.app.reconstruct:main nabu-cast = nabu.app.cast_volume:main nabu-compare-volumes = nabu.app.compare_volumes:compare_volumes_cli nabu-composite-cor = nabu.app.composite_cor:main nabu-config = nabu.app.bootstrap:bootstrap nabu-diag2pix = nabu.app.diag_to_pix:main nabu-diag2rot = nabu.app.diag_to_rot:main nabu-display-timings = nabu.app.parse_reconstruction_log:parse_reclog_cli nabu-double-flatfield = nabu.app.double_flatfield:dff_cli nabu-generate-info = nabu.app.generate_header:generate_merged_info_file nabu-helical = nabu.app.reconstruct_helical:main_helical nabu-helical-correct-rot = nabu.app.correct_rot:main nabu-helical-prepare-weights-double = nabu.app.prepare_weights_double:main nabu-histogram = nabu.app.histogram:histogram_cli nabu-multicor = nabu.app.multicor:main nabu-poly2map = nabu.app.create_distortion_map_from_poly:create_distortion_maps_entry_point nabu-reduce-dark-flat = nabu.app.reduce_dark_flat:main nabu-rotate = nabu.app.rotate:rotate_cli nabu-shrink-dataset = nabu.app.shrink_dataset:shrink_cli nabu-stitching = nabu.app.stitching:main nabu-stitching-config = nabu.app.bootstrap_stitching:bootstrap_stitching nabu-test = nabu.tests:nabu_test nabu-validator = nabu.app.validator:main nabu-zsplit = nabu.app.nx_z_splitter:zsplit ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/requires.txt0000644000175000017500000000045314730277751017323 0ustar00pierrepierrenumpy<2,>1.9.0 scipy h5py>=3.0 silx>=0.15.0 tomoscan>=2.1.0 psutil pytest tifffile tqdm [doc] sphinx cloud_sptheme myst-parser nbsphinx [full] scikit-image PyWavelets glymur pycuda!=2024.1.1 scikit-cuda pycudwt sluurp>=0.3 pyvkfft [full_nocuda] scikit-image PyWavelets glymur sluurp>=0.3 pyvkfft ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734442985.0 nabu-2024.2.1/nabu.egg-info/top_level.txt0000644000175000017500000000002614730277751017451 0ustar00pierrepierredist doc nabu scripts ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1734019212.0 nabu-2024.2.1/pyproject.toml0000644000175000017500000001037314726604214015213 0ustar00pierrepierre[build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "nabu" authors = [ {name = "Pierre Paleo", email = "pierre.paleo@esrf.fr"}, {name = "Henri Payno", email = "henri.payno@esrf.fr"}, {name = "Alessandro Mirone", email = "mirone@esrf.fr"}, {name = "Jérôme Lesaint", email = "jerome.lesaint@esrf.fr"}, ] maintainers = [ {name = "Pierre Paleo", email = "pierre.paleo@esrf.fr"} ] dynamic = ["version"] description = "Nabu - Tomography software" readme = "README.md" requires-python = ">=3.7" keywords = [ "tomography", "reconstruction", "X-ray imaging", "synchrotron radiation", "High Performance Computing", "Parallel geometry", "Conebeam geometry", "Helical geometry", "Ring artefact correction", "Geometric calibration", ] license = {file = "LICENSE"} classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Environment :: Console", "License :: OSI Approved :: MIT License", "Operating System :: Unix", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX", "Topic :: Scientific/Engineering :: Physics", "Topic :: Scientific/Engineering :: Medical Science Apps.", ] dependencies = [ "numpy > 1.9.0, < 2", "scipy", "h5py>=3.0", "silx >= 0.15.0", "tomoscan >= 2.1.0", "psutil", "pytest", "tifffile", "tqdm", ] #packages = find [project.urls] Homepage = "https://gitlab.esrf.fr/tomotools/nabu" Documentation = "http://www.silx.org/pub/nabu/doc" Repository = "https://gitlab.esrf.fr/tomotools/nabu/-/releases" Changelog = "https://gitlab.esrf.fr/tomotools/nabu/-/blob/master/CHANGELOG.md" [project.optional-dependencies] full = [ "scikit-image", "PyWavelets", "glymur", "pycuda!=2024.1.1", "scikit-cuda", "pycudwt", "sluurp >=0.3", "pyvkfft", ] full_nocuda = [ "scikit-image", "PyWavelets", "glymur", "sluurp >=0.3", "pyvkfft", ] doc = [ "sphinx", "cloud_sptheme", "myst-parser", "nbsphinx", ] [project.scripts] nabu = "nabu.app.reconstruct:main" nabu-config = "nabu.app.bootstrap:bootstrap" nabu-test = "nabu.tests:nabu_test" nabu-histogram = "nabu.app.histogram:histogram_cli" nabu-zsplit = "nabu.app.nx_z_splitter:zsplit" nabu-rotate = "nabu.app.rotate:rotate_cli" nabu-double-flatfield = "nabu.app.double_flatfield:dff_cli" nabu-generate-info = "nabu.app.generate_header:generate_merged_info_file" nabu-validator = "nabu.app.validator:main" nabu-helical = "nabu.app.reconstruct_helical:main_helical" nabu-helical-prepare-weights-double = "nabu.app.prepare_weights_double:main" nabu-stitching-config = "nabu.app.bootstrap_stitching:bootstrap_stitching" nabu-stitching = "nabu.app.stitching:main" nabu-cast = "nabu.app.cast_volume:main" nabu-compare-volumes = "nabu.app.compare_volumes:compare_volumes_cli" nabu-shrink-dataset = "nabu.app.shrink_dataset:shrink_cli" nabu-composite-cor = "nabu.app.composite_cor:main" nabu-poly2map = "nabu.app.create_distortion_map_from_poly:create_distortion_maps_entry_point" nabu-diag2pix = "nabu.app.diag_to_pix:main" nabu-diag2rot = "nabu.app.diag_to_rot:main" nabu-helical-correct-rot = "nabu.app.correct_rot:main" nabu-multicor = "nabu.app.multicor:main" nabu-reduce-dark-flat = "nabu.app.reduce_dark_flat:main" nabu-display-timings = "nabu.app.parse_reconstruction_log:parse_reclog_cli" [tool.setuptools.dynamic] version = {attr = "nabu.version"} #readme = {file = ["README.md"]} [tool.setuptools.packages.find] where = ["."] # list of folders that contain the packages (["."] by default) # include = ["nabu*"] # package names should match these glob patterns (["*"] by default) exclude = ["sandbox","build*"] # exclude packages matching these glob patterns (empty by default) # namespaces = false # to disable scanning PEP 420 namespaces (true by default) [tool.setuptools.package-data] "nabu.cuda" = ["src/*.cu", "src/*.h"] "nabu.opencl" = ["src/*.cl", "src/*.h"] "nabu.resources" = ["templates/*.conf"] ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1734442985.524757 nabu-2024.2.1/setup.cfg0000644000175000017500000000004614730277752014124 0ustar00pierrepierre[egg_info] tag_build = tag_date = 0