././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4727333 nabu-2023.1.1/0000755000175000017500000000000000000000000012224 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1586329326.0 nabu-2023.1.1/LICENSE0000644000175000017500000000204500000000000013232 0ustar00pierrepierreMIT License Copyright (c) 2020 ESRF Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4727333 nabu-2023.1.1/PKG-INFO0000644000175000017500000000041600000000000013322 0ustar00pierrepierreMetadata-Version: 2.1 Name: nabu Version: 2023.1.1 Summary: Nabu - Tomography software Author: Pierre Paleo Author-email: pierre.paleo@esrf.fr Maintainer: Pierre Paleo Maintainer-email: pierre.paleo@esrf.fr Provides-Extra: full Provides-Extra: doc License-File: LICENSE ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1621502133.0 nabu-2023.1.1/README.md0000644000175000017500000000311600000000000013504 0ustar00pierrepierre# Nabu ESRF tomography processing software. ## Installation To install the development version: ```bash pip install [--user] git+https://gitlab.esrf.fr/tomotools/nabu.git ``` To install the stable version: ```bash pip install [--user] nabu ``` ## Usage Nabu can be used in several ways: - As a Python library, by features like `Backprojector`, `FlatField`, etc - As a standalone application with the command line interface - From Tomwer ([https://gitlab.esrf.fr/tomotools/tomwer/](https://gitlab.esrf.fr/tomotools/tomwer/)) To get quickly started, launch: ```bash nabu-config ``` Edit the generated configuration file (`nabu.conf`). Then: ```bash nabu nabu.conf --slice 500-600 ``` will reconstruct the slices 500 to 600, with processing steps depending on `nabu.conf` contents. ## Documentation The documentation can be found on the silx.org page ([https://www.silx.org/pub/nabu/doc](http://www.silx.org/pub/nabu/doc)). The latest documentation built by continuous integration can be found here: [https://tomotools.gitlab-pages.esrf.fr/nabu/](https://tomotools.gitlab-pages.esrf.fr/nabu/) ## Running the tests Once nabu is installed, running ```bash nabu-test ``` will execute all the tests. You can also specify specific module(s) to test, for example: ```bash nabu-test preproc misc ``` You can also provide more `pytest` options, for example increase verbosity with `-v`, exit at the first fail with `-x`, etc. Use `nabu-test --help` for displaying the complete options list. ## Nabu - what's in a name ? Nabu was the Mesopotamian god of literacy, rational arts, scribes and wisdom. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4527333 nabu-2023.1.1/nabu/0000755000175000017500000000000000000000000013151 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589945.0 nabu-2023.1.1/nabu/__init__.py0000644000175000017500000000035300000000000015263 0ustar00pierrepierre__version__ = "2023.1.1" __nabu_modules__ = [ "app", "cuda", "estimation", "io", "misc", "opencl", "pipeline", "preproc", "reconstruction", "resources", "thirdparty", ] version = __version__ ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4527333 nabu-2023.1.1/nabu/app/0000755000175000017500000000000000000000000013731 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/app/__init__.py0000644000175000017500000000000000000000000016030 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/bootstrap.py0000644000175000017500000000624100000000000016323 0ustar00pierrepierrefrom os import path, environ from glob import glob from ..utils import get_folder_path from ..pipeline.config import generate_nabu_configfile, parse_nabu_config_file from ..pipeline.fullfield.nabu_config import nabu_config as default_fullfield_config from ..pipeline.helical.nabu_config import nabu_config as helical_fullfield_config from .utils import parse_params_values from .cli_configs import BootstrapConfig def bootstrap(): args = parse_params_values(BootstrapConfig, parser_description="Initialize a nabu configuration file") do_bootstrap = bool(args["bootstrap"]) do_convert = args["convert"] != "" no_comments = bool(args["nocomments"]) if do_bootstrap: print( "The --bootstrap option is now the default behavior of the nabu-config command. This option is therefore not needed anymore." ) if path.isfile(args["output"]): rep = input("File %s already exists. Overwrite ? [y/N]" % args["output"]) if rep.lower() != "y": print("Stopping") exit(0) opts_level = args["level"] prefilled_values = {} template_name = args["template"] if template_name != "": prefilled_values = get_config_template(template_name, if_not_found="print") if prefilled_values is None: exit(0) opts_level = "advanced" if args["dataset"] != "": prefilled_values["dataset"] = {} user_dataset = args["dataset"] if not path.isabs(user_dataset): user_dataset = path.abspath(user_dataset) print("Warning: using absolute dataset path %s" % user_dataset) if not path.exists(user_dataset): print("Error: cannot find the file or directory %s" % user_dataset) exit(1) prefilled_values["dataset"]["location"] = user_dataset if args["helical"]: my_config = helical_fullfield_config else: my_config = default_fullfield_config generate_nabu_configfile( args["output"], my_config, comments=not (no_comments), options_level=opts_level, prefilled_values=prefilled_values, ) def get_config_template(template_name, if_not_found="raise"): def handle_not_found(msg): if if_not_found == "raise": raise FileNotFoundError(msg) elif if_not_found == "print": print(msg) templates_path = get_folder_path(path.join("resources", "templates")) custom_templates_path = environ.get("NABU_TEMPLATES_PATH", None) templates = glob(path.join(templates_path, "*.conf")) if custom_templates_path is not None: templates_custom = glob(path.join(custom_templates_path, "*.conf")) templates_custom += glob(path.join(custom_templates_path, "*.cfg")) templates = templates_custom + templates available_templates_names = [path.splitext(path.basename(fname))[0] for fname in templates] if template_name not in available_templates_names: handle_not_found("Unable to find template '%s'. Available are: %s" % (template_name, available_templates_names)) return fname = templates[available_templates_names.index(template_name)] return parse_nabu_config_file(fname) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/bootstrap_stitching.py0000644000175000017500000000133500000000000020376 0ustar00pierrepierrefrom .cli_configs import BootstrapStitchingConfig from ..pipeline.config import generate_nabu_configfile from ..stitching.config import get_default_stitching_config, SECTIONS_COMMENTS as _SECTIONS_COMMENTS from .utils import parse_params_values def bootstrap_stitching(): args = parse_params_values( BootstrapStitchingConfig, parser_description="Initialize a nabu configuration file", ) prefilled_values = {} generate_nabu_configfile( fname=args["output"], default_config=get_default_stitching_config(args["stitching_type"]), comments=True, sections_comments=_SECTIONS_COMMENTS, options_level=args["level"], prefilled_values=prefilled_values, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/app/cast_volume.py0000644000175000017500000002133100000000000016624 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import os import sys import logging from argparse import RawTextHelpFormatter import numpy from tomoscan.esrf.volume.utils import guess_volumes from tomoscan.factory import Factory from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from nabu.io.cast_volume import ( RESCALE_MAX_PERCENTILE, RESCALE_MIN_PERCENTILE, cast_volume, get_default_output_volume, ) from nabu.pipeline.params import files_formats _logger = logging.getLogger(__name__) def main(argv=None): if argv is None: argv = sys.argv _volume_url_helps = "\n".join( [ f"- {(volume.__name__).ljust(15)}: {volume.example_defined_from_str_identifier()}" for volume in ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) ] ) volume_help = f"""To define a volume you can either provide: \n * an url (recommanded way) - see details lower \n * a path. For hdf5 and multitiff we expect a file path. For edf, tif and jp2k we expect a folder path. In this case we will try to deduce the Volume from it. \n url must be defined like: \n{_volume_url_helps} """ parser = argparse.ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) parser.add_argument( "input_volume", help=f"input volume. {volume_help}", ) parser.add_argument( "--output-data-type", help="output data type. Valid value are numpy default types name like (uint8, uint16, int8, int16, int32, float32, float64)", default="uint16", ) parser.add_argument( "--output_volume", help=f"output volume. Must be provided if 'output_type' isn't. Must looks like: \n{volume_help}", default=None, ) parser.add_argument( "--output_type", help=f"output type. Must be provided if 'output_volume' isn't. Valid values are {tuple(files_formats.keys())}", default=None, ) parser.add_argument( "--data_min", help=f"value to clamp to volume cast new min. Any lower value will also be clamp to this value.", default=None, ) parser.add_argument( "--data_max", help=f"value to clamp to volume cast new max. Any higher value will also be clamp to this value.", default=None, ) parser.add_argument( "--rescale_min_percentile", help=f"used to determine data_min if not provided. Expected as percentage. Default is {RESCALE_MIN_PERCENTILE}%%", default=RESCALE_MIN_PERCENTILE, ) parser.add_argument( "--rescale_max_percentile", help=f"used to determine data_max if not provided. Expected as percentage. Default is {RESCALE_MAX_PERCENTILE}%%", default=RESCALE_MAX_PERCENTILE, ) parser.add_argument( "--overwrite", dest="overwrite", action="store_true", default=False, help="Overwrite file or dataset if exists", ) options = parser.parse_args(argv[1:]) if os.path.exists(options.input_volume): volumes = guess_volumes(options.input_volume) def is_not_histogram(vol_identifier): return not (hasattr(vol_identifier, "data_path") and vol_identifier.data_path.endswith("histogram")) volumes = tuple(filter(is_not_histogram, volumes)) if len(volumes) == 0: _logger.error(f"no valid volume found in {options.input_volume}") exit(1) elif len(volumes) > 1: _logger.error( f"found several volume from {options.input_volume}. Please provide one full url from {[volume.get_identifier() for volume in volumes]}" ) else: input_volume = volumes[0] else: try: input_volume = Factory.create_tomo_object_from_identifier(options.input_volume) except Exception as e: raise ValueError(f"Fail to build input volume from url {options.input_volume}") from e output_format = files_formats.get(options.output_type, None) if options.output_volume is not None: # if an url is provided if ":" in options.output_volume: try: output_volume = Factory.create_tomo_object_from_identifier(options.output_volume) except Exception as e: raise ValueError(f"Fail to build output volume from {options.output_volume}") from e if output_format is not None: if not ( isinstance(output_volume, EDFVolume) and output_format == "edf" or isinstance(output_format, HDF5Volume) and output_format == "hdf5" or isinstance(output_format, JP2KVolume) and output_format == "jp2" or isinstance(output_format, (TIFFVolume, MultiTIFFVolume)) and output_format == "tiff" ): raise ValueError( "Requested 'output_type' and output volume url are incoherent. 'output_type' is optional when url provided" ) else: path_extension = os.path.splitext(options.output_volume)[-1] if path_extension == "": # if a folder ha sbeen provided we try to create a volume from this path and the output format if output_format == "tiff": output_volume = TIFFVolume( folder=options.output_volume, ) elif output_format == "edf": output_volume = EDFVolume( folder=options.output_volume, ) elif output_format == "jp2": output_volume = JP2KVolume( folder=options.output_volume, ) else: raise ValueError( f"Unable to deduce an output volume from {options.output_volume} and output format {output_format}. Please provide an output_volume as an url" ) else: # if a file path_has been provided if path_extension.lower() in ("tif", "tiff") and output_format in ( None, "tiff", ): output_volume = MultiTIFFVolume( file_path=options.output_volume, ) elif path_extension.lower() in ( "h5", "nx", "nexus", "hdf", "hdf5", ) and output_format in (None, "hdf5"): output_volume = HDF5Volume( file_path=options.output_volume, data_path="volume", ) else: raise ValueError( f"Unable to deduce an output volume from {options.output_volume} and output format {output_format}. Please provide an output_volume as an url" ) elif options.output_type is None: raise ValueError("'output_type' or 'output_volume' is expected") else: output_volume = get_default_output_volume(input_volume=input_volume, output_type=output_format) try: output_data_type = numpy.dtype(getattr(numpy, options.output_data_type)) except Exception as e: raise ValueError(f"Unable to get output data type from {options.output_data_type}") from e # get data_min and data_max data_min = options.data_min if data_min is not None: data_min = float(data_min) data_max = options.data_max if data_max is not None: data_max = float(data_max) # get rescale_min_percentile and rescale_min_percentile rescale_min_percentile = options.rescale_min_percentile if rescale_min_percentile is not None and isinstance(rescale_min_percentile, str): rescale_min_percentile = float(rescale_min_percentile.rstrip("%")) rescale_max_percentile = options.rescale_max_percentile if rescale_max_percentile is not None and isinstance(rescale_min_percentile, str): rescale_max_percentile = float(rescale_max_percentile.rstrip("%")) output_volume.overwrite = options.overwrite cast_volume( input_volume=input_volume, output_volume=output_volume, output_data_type=output_data_type, data_min=data_min, data_max=data_max, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) exit(0) if __name__ == "__main__": main(sys.argv) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/app/cli_configs.py0000644000175000017500000003346000000000000016570 0ustar00pierrepierre# # Default configuration for CLI tools # # Default configuration for "bootstrap" command from nabu.stitching.config import StitchingType BootstrapConfig = { "bootstrap": { "help": "DEPRECATED, this is the default behavior. Bootstrap a configuration file from scratch.", "action": "store_const", "const": 1, }, "convert": { "help": "UNSUPPORTED. This option has no effect and will disappear. Convert a PyHST configuration file to a nabu configuration file.", "default": "", }, "output": { "help": "Output filename", "default": "nabu.conf", }, "nocomments": { "help": "Remove the comments in the configuration file (default: False)", "action": "store_const", "const": 1, }, "level": { "help": "Level of options to embed in the configuration file. Can be 'required', 'optional', 'advanced'.", "default": "optional", }, "dataset": { "help": "Pre-fill the configuration file with the dataset path.", "default": "", }, "template": { "help": "Use a template configuration file. Available are: id19_pag, id16_holo, id16_ctf. You can also define your own templates via the NABU_TEMPLATES_PATH environment variable.", "default": "", }, "helical": {"help": "Prepare configuration file for helical", "default": 0, "required": False, "type": int}, } # Default configuration for "zsplit" command ZSplitConfig = { "input_file": { "help": "Input HDF5-Nexus file", "mandatory": True, }, "output_directory": { "help": "Output directory to write split files.", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "entry": { "help": "HDF5 entry to take in the input file. By default, the first entry is taken.", "default": "", }, "n_stages": { "help": "Number of expected stages (i.e different 'Z' values). By default it is inferred from the dataset.", "default": -1, "type": int, }, "use_virtual_dataset": { "help": "Whether to use virtual datasets for output file. Not using a virtual dataset duplicates data and thus results in big files ! However virtual datasets currently have performance issues. Default is False", "default": 0, "type": int, }, } # Default configuration for "histogram" command HistogramConfig = { "h5_file": { "help": "HDF5 file(s). It can be one or several paths to HDF5 files. You can specify entry for each file with /path/to/file.h5?entry0000", "mandatory": True, "nargs": "+", }, "output_file": { "help": "Output file (HDF5)", "mandatory": True, }, "bins": { "help": "Number of bins for histogram if they have to be computed. Default is one million.", "default": 1000000, "type": int, }, "chunk_size_slices": { "help": "If histogram are computed, specify the maximum subvolume size (in number of slices) for computing histogram.", "default": 100, "type": int, }, "chunk_size_GB": { "help": "If histogram are computed, specify the maximum subvolume size (in GibaBytes) for computing histogram.", "default": -1, "type": float, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, } # Default configuration for "reconstruct" command ReconstructConfig = { "input_file": { "help": "Nabu input file", "default": "", "mandatory": True, }, "logfile": { "help": "Log file. Default is dataset_prefix_nabu.log", "default": "", }, "log_file": { "help": "Same as logfile. Deprecated, use --logfile instead.", "default": "", }, "slice": { "help": "Slice(s) indice(s) to reconstruct, in the format z1-z2. Default (empty) is the whole volume. This overwrites the configuration file start_z and end_z. You can also use --slice first, --slice last, --slice middle, and --slice all", "default": "", }, "gpu_mem_fraction": { "help": "Which fraction of GPU memory to use. Default is 0.9.", "default": 0.9, "type": float, }, "cpu_mem_fraction": { "help": "Which fraction of memory to use. Default is 0.9.", "default": 0.9, "type": float, }, "max_chunk_size": { "help": "Maximum chunk size to use.", "default": -1, "type": int, }, "phase_margin": { "help": "Specify an explicit phase margin to use when performing phase retrieval.", "default": -1, "type": int, }, "force_use_grouped_pipeline": { "help": "Force nabu to use the 'grouped' reconstruction pipeline - slower but should work for all big datasets.", "default": 0, "type": int, }, } GenerateInfoConfig = { "hist_file": { "help": "HDF5 file containing the histogram, either the reconstruction file or a dedicated histogram file.", "default": "", }, "hist_entry": { "help": "Histogram HDF5 entry. Defaults to the first available entry.", "default": "", }, "output": { "help": "Output file name", "default": "", "mandatory": True, }, "bliss_file": { "help": "HDF5 master file produced by BLISS", "default": "", }, "bliss_entry": { "help": "Entry in the HDF5 master file produced by BLISS. By default, take the first entry.", "default": "", }, "info_file": { "help": "Path to the .info file, in the case of a EDF dataset", "default": "", }, "edf_proj": { "help": "Path to a projection, in the case of a EDF dataset", "default": "", }, } RotateRadiosConfig = { "dataset": { "help": "Path to the dataset. Only HDF5 format is supported for now.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry. By default, the first entry is taken.", "default": "", }, "angle": { "help": "Rotation angle in degrees", "default": 0.0, "mandatory": True, "type": float, }, "center": { "help": "Rotation center, in the form (x, y) where x (resp. y) is the horizontal (resp. vertical) dimension, i.e along the columns (resp. lines). Default is (Nx/2 - 0.5, Ny/2 - 0.5).", "default": "", }, "output": { "help": "Path to the output file. Only HDF5 output is supported. In the case of HDF5 input, the output file will have the same structure.", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "batchsize": { "help": "Size of the batch of images to process. Default is 100", "default": 100, "type": int, }, "use_cuda": { "help": "Whether to use Cuda if available", "default": "1", }, "use_multiprocessing": { "help": "Whether to use multiprocessing if available", "default": "1", }, } DFFConfig = { "dataset": { "help": "Path to the dataset.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry (for HDF5 datasets). By default, the first entry is taken.", "default": "", }, "flatfield": { "help": "Whether to perform flat-field normalization. Default is True.", "default": "1", "type": int, }, "sigma": { "default": 0.0, "help": "Enable high-pass filtering on double flatfield with this value of 'sigma'", "type": float, }, "output": { "help": "Path to the output file (HDF5).", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, "chunk_size": { "help": "Maximum number of lines to read in each projection in a single pass. Default is 100", "default": 100, "type": int, }, } CompareVolumesConfig = { "volume1": { "help": "Path to the first volume.", "default": "", "mandatory": True, }, "volume2": { "help": "Path to the first volume.", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry. By default, the first entry is taken.", "default": "", }, "hdf5_path": { "help": "Full HDF5 path to the data. Default is /reconstruction/results/data", "default": "", }, "chunk_size": { "help": "Maximum number of images to read in each step. Default is 100.", "default": 100, "type": int, }, "stop_at": { "help": "Stop the comparison immediately when the difference exceeds this threshold. Default is to compare the full volumes.", "default": "1e-4", }, "statistics": { "help": "Compute statistics on the compared (sub-)volumes. Mind that in this case the command output will not be empty!", "default": 0, "type": int, }, } # Default configuration for "stitching" command StitchingConfig = { "input_file": { "help": "Nabu configuraiton file for stitching (can be obtain from nabu-stitching-boostrap command)", "default": "", "mandatory": True, }, "loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, } # Default configuration for "stitching-bootstrap" command BootstrapStitchingConfig = { "stitching_type": { "help": f"User can provide stitching type to filter some parameters. Must be in {StitchingType.values()}.", "default": None, }, "level": { "help": "Level of options to embed in the configuration file. Can be 'required', 'optional', 'advanced'.", "default": "optional", }, "output": { "help": "output file to store the configuration", "default": "stitching.conf", }, } ShrinkConfig = { "input_file": { "help": "Path to the NX file", "default": "", "mandatory": True, }, "output_file": { "help": "Path to the output NX file", "default": "", "mandatory": True, }, "entry": { "help": "HDF5 entry in the file. Default is to take the first entry.", "default": "", }, "binning": { "help": "Binning factor, in the form (bin_z, bin_x). Each image (projection, dark, flat) will be binned by this factor", "default": "", }, "subsampling": {"help": "Subsampling factor for projections (and metadata)", "default": ""}, "threads": { "help": "Number of threads to use for binning. Default is 1.", "default": 1, "type": int, }, } CompositeCorConfig = { "--filename_template": { "required": True, "help": """The filename template. It can optionally contain a segment equal to "X"*ndigits which will be replaced by the stage number if several stages are requested by the user""", }, "--entry_name": { "required": False, "help": "Optional. The entry_name. It defaults to entry0000", "default": "entry0000", }, "--num_of_stages": { "type": int, "required": False, "help": "Optional. How many stages. Example: from 0 to 43 -> --num_of_stages 44. It is optional. ", }, "--oversampling": { "type": int, "default": 4, "required": False, "help": "Oversampling in the research of the axis position. Defaults to 4 ", }, "--n_subsampling_y": { "type": int, "default": 10, "required": False, "help": "How many lines we are going to take from each radio. Defaults to 10.", }, "--theta_interval": { "type": float, "default": 5, "required": False, "help": "Angular step for composing the image. Default to 5", }, "--first_stage": {"type": int, "default": None, "required": False, "help": "Optional. The first stage. "}, "--output_file": { "type": str, "required": False, "help": "Optional. Where the list of cors will be written. Default is the filename postixed with cors.txt", }, "--cor_options": { "type": str, "help": """the cor_options string used by Nabu. Example --cor_options "side='near'; near_pos = 300.0; near_width = 20.0" """, "required": True, }, } CreateDistortionMapHorizontallyMatchedFromPolyConfig = { "--nz": {"type": int, "help": "vertical dimension of the detector", "required": True}, "--nx": {"type": int, "help": "horizontal dimension of the detector", "required": True}, "--center_z": {"type": float, "help": "vertical position of the optical center", "required": True}, "--center_x": {"type": float, "help": "horizontal position of the optical center", "required": True}, "--c4": {"type": float, "help": "order 4 coefficient", "required": True}, "--c2": {"type": float, "help": "order 2 coefficient", "required": True}, "--target_file": {"type": str, "help": "The map output filename", "required": True}, "--axis_pos": { "type": float, "default": None, "help": "Optional argument. If given it will be corrected for use with the produced map. The value is printed, or given as return argument if the utility is used from a script", "required": False, }, "--loglevel": { "help": "Logging level. Can be 'debug', 'info', 'warning', 'error'. Default is 'info'.", "default": "info", }, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/compare_volumes.py0000644000175000017500000000637700000000000017520 0ustar00pierrepierrefrom math import ceil from posixpath import join import numpy as np from tomoscan.io import HDF5File from .cli_configs import CompareVolumesConfig from ..utils import clip_circle from .utils import parse_params_values from ..io.utils import get_first_hdf5_entry, hdf5_entry_exists from ..io.reader import get_hdf5_dataset_shape def idx_1d_to_3d(idx, shape): nz, ny, nx = shape x = idx % nx idx2 = (idx - x) // nx y = idx2 % ny z = (idx2 - y) // ny return (z, y, x) def compare_volumes(fname1, fname2, h5_path, chunk_size, do_stats, stop_at_thresh, clip_outer_circle=False): result = None f1 = HDF5File(fname1, "r") f2 = HDF5File(fname2, "r") try: # Check that data is in the provided hdf5 path for fname in [fname1, fname2]: if not hdf5_entry_exists(fname, h5_path): result = "File %s do not has data in %s" % (fname, h5_path) return # Check shapes shp1 = get_hdf5_dataset_shape(fname1, h5_path) shp2 = get_hdf5_dataset_shape(fname2, h5_path) if shp1 != shp2: result = "Volumes do not have the same shape: %s vs %s" % (shp1, shp2) return # Compare volumes n_steps = ceil(shp1[0] / chunk_size) for i in range(n_steps): start = i * chunk_size end = min((i + 1) * chunk_size, shp1[0]) data1 = f1[h5_path][start:end, :, :] data2 = f2[h5_path][start:end, :, :] abs_diff = np.abs(data1 - data2) if clip_outer_circle: for j in range(abs_diff.shape[0]): abs_diff[j] = clip_circle(abs_diff[j], radius=0.9 * min(abs_diff.shape[1:])) coord_argmax = idx_1d_to_3d(np.argmax(abs_diff), data1.shape) if do_stats: mean = np.mean(abs_diff) std = np.std(abs_diff) maxabs = np.max(abs_diff) print("Chunk %d: mean = %e std = %e max = %e" % (i, mean, std, maxabs)) if stop_at_thresh is not None and abs_diff[coord_argmax] > stop_at_thresh: coord_argmax_absolute = (start + coord_argmax[0],) + coord_argmax[1:] result = "abs_diff[%s] = %e" % (coord_argmax_absolute, abs_diff[coord_argmax]) return except Exception as exc: result = "Error: %s" % (str(exc)) raise finally: f1.close() f2.close() return result def compare_volumes_cli(): args = parse_params_values( CompareVolumesConfig, parser_description="A command-line utility for comparing two volumes." ) fname1 = args["volume1"] fname2 = args["volume2"] h5_path = args["hdf5_path"] if h5_path == "": entry = args["entry"].strip() or None if entry is None: entry = get_first_hdf5_entry(fname1) h5_path = join(entry, "reconstruction/results/data") do_stats = bool(args["statistics"]) chunk_size = args["chunk_size"] stop_at_thresh = args["stop_at"] or None if stop_at_thresh is not None: stop_at_thresh = float(stop_at_thresh) res = compare_volumes(fname1, fname2, h5_path, chunk_size, do_stats, stop_at_thresh) if res is not None: print(res) if __name__ == "__main__": compare_volumes_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/composite_cor.py0000644000175000017500000000652400000000000017157 0ustar00pierrepierreimport logging import os import sys import numpy as np import re from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer from nabu.pipeline.estimators import CompositeCOREstimator from nabu.resources.nxflatfield import update_dataset_info_flats_darks from .. import version from .cli_configs import CompositeCorConfig from .utils import parse_params_values from ..utils import DictToObj def main(user_args=None): "Application to extract with the composite cor finder the center of rotation for a scan or a series of scans" if user_args is None: user_args = sys.argv[1:] args = DictToObj( parse_params_values( CompositeCorConfig, parser_description=main.__doc__, program_version="nabu " + version, user_args=user_args, ) ) if len(os.path.dirname(args.filename_template)) == 0: # To make sure that other utility routines can succesfully deal with path within the current directory args.filename_template = os.path.join(".", args.filename_template) if args.first_stage is not None: if args.num_of_stages is None: args.num_of_stages = 1 # if the first_stage parameter has been given then # we are using numbers to form the names of the files. # The filename must containe a XX..X substring which will be replaced pattern = re.compile("[X]+") ps = pattern.findall(args.filename_template) if len(ps) == 0: message = f""" You have specified the "first_stage" parameter, with an integer. Therefore the "filename_template" parameter is expected to containe a XX..X subsection but none was found in the passed parameter which is {args.filename_template} """ raise ValueError(message) ls = list(map(len, ps)) idx = np.argmax(ls) args.filename_template = args.filename_template.replace(ps[idx], "{i_stage:" + "0" + str(ls[idx]) + "d}") if args.num_of_stages is None: # this way it works also in the simple case where # only the filename is provided together with the cor options num_of_stages = 1 first_stage = 0 else: num_of_stages = args.num_of_stages first_stage = args.first_stage cor_list = [] for iz in range(first_stage, first_stage + num_of_stages): if args.num_of_stages is not None: nexus_name = args.filename_template.format(i_stage=iz) else: nexus_name = args.filename_template dataset_info = HDF5DatasetAnalyzer(nexus_name, extra_options={"h5_entry": args.entry_name}) update_dataset_info_flats_darks(dataset_info, flatfield_mode=1) cor_finder = CompositeCOREstimator( dataset_info, oversampling=args.oversampling, theta_interval=args.theta_interval, n_subsampling_y=args.n_subsampling_y, take_log=True, spike_threshold=0.04, cor_options=args.cor_options, ) cor_position = cor_finder.find_cor() cor_list.append(cor_position) cor_list = np.array(cor_list).T if args.output_file is not None: output_name = args.output_file else: output_name = os.path.splitext(args.filename_template)[0] + "_cors.txt" np.savetxt( output_name, cor_list, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/create_distortion_map_from_poly.py0000644000175000017500000001107100000000000022747 0ustar00pierrepierrefrom .. import version import numpy as np import h5py import argparse import sys from .cli_configs import CreateDistortionMapHorizontallyMatchedFromPolyConfig from .utils import parse_params_values from ..utils import DictToObj from ..resources.logger import Logger, LoggerOrPrint def horizontal_match(user_args=None): """This application builds two arrays. Let us call them map_x and map_z. Both are 2D arrays with shape given by (nz, nx). These maps are meant to be used to generate a corrected detector image, using them to obtain the pixel (i,j) of the corrected image by interpolating the raw data at position ( map_z(i,j), map_x(i,j) ). This map is determined by a user given polynomial P(rs) in the radial variable rs = sqrt( (z-center_z)**2 + (x-center_x)**2 ) / (nx/2) where center_z and center_x give the center around which the deformation is centered. The perfect position (zp,xp) , that would be observed on a perfect detector, of a photon observed at pixel (z,x) of the distorted detector is: (zp, xp) = (center_z, center_x) + P(rs) * ( z - center_z , x - center_x ) The polynomial is given by P(rs) = rs *(1 + c2 * rs**2 + c4 * rs**4) The map is rescaled and reshifted so that a perfect match is realised at the borders of a horizontal line passing by the center. This ensures coerence with the procedure of pixel size calibration which is performed moving a needle horizontally and reading the motor positions at the extreme positions. The maps are written in the target file, creating it as hdf5 file, in the datasets "/coords_source_x" "/coords_source_z" The URLs of these two maps can be used for the detector correction of type "map_xz" in the nabu configuration file as in this example [dataset] ... detector_distortion_correction = map_xz detector_distortion_correction_options = map_x="silx:./map_coordinates.h5?path=/coords_source_x" ; map_z="silx:./map_coordinates.h5?path=/coords_source_z" """ if user_args is None: user_args = sys.argv[1:] args = DictToObj( parse_params_values( CreateDistortionMapHorizontallyMatchedFromPolyConfig, parser_description=horizontal_match.__doc__, program_version="nabu " + version, user_args=user_args, ) ) logger = Logger("horizontal_match", level=args.loglevel, logfile="horizontal_match.log") nz, nx = args.nz, args.nx center_x, center_z = (args.center_x, args.center_z) c4 = args.c4 c2 = args.c2 polynomial = np.poly1d([c4, 0, c2, 0, 1, 0.0]) # change of variable cofv = np.poly1d([1.0 / (nx / 2), 0]) polynomial = nx / 2 * polynomial(cofv) left_border = 0 - center_x right_border = nx - 1 - center_x def get_rescaling_shift(left_border, right_border, polynomial): dl = polynomial(left_border) dr = polynomial(right_border) rescaling = (dr - dl) / (right_border - left_border) shift = -left_border * rescaling + dl return rescaling, shift final_grid_rescaling, final_grid_shift = get_rescaling_shift(left_border, right_border, polynomial) coords_z, coords_x = np.indices([nz, nx]) coords_z = ((coords_z - center_z) * final_grid_rescaling).astype("d") coords_x = ((coords_x - center_x) * final_grid_rescaling + final_grid_shift).astype("d") distances_goal = np.sqrt(coords_z * coords_z + coords_x * coords_x) distances_unknown = distances_goal pp_deriv = polynomial.deriv() # iteratively finding the positions to interpolated at by newton method for i in range(10): errors = polynomial(distances_unknown) - distances_goal derivative = pp_deriv(distances_unknown) distances_unknown = distances_unknown - errors / derivative distances_data_sources = distances_unknown # avoid 0/0 distances_data_sources[distances_goal < 1] = 1 distances_goal[distances_goal < 1] = 1 coords_source_z = coords_z * distances_data_sources / distances_goal + center_z coords_source_x = coords_x * distances_data_sources / distances_goal + center_x with h5py.File(args.target_file, "w") as f: f["coords_source_x"] = coords_source_x f["coords_source_z"] = coords_source_z if args.axis_pos is not None: coord_axis = args.axis_pos - center_x new_pos = (polynomial(coord_axis) - final_grid_shift) / final_grid_rescaling + center_x logger.info("New axis position at %e it was previously %e " % (new_pos, args.axis_pos)) return new_pos else: return None ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/double_flatfield.py0000644000175000017500000001247500000000000017600 0ustar00pierrepierreimport numpy as np from ..preproc.double_flatfield import DoubleFlatField from ..preproc.flatfield import FlatFieldDataUrls from ..io.reader import ChunkReader from ..io.writer import NXProcessWriter from ..resources.dataset_analyzer import analyze_dataset from ..resources.nxflatfield import update_dataset_info_flats_darks from ..resources.logger import Logger, LoggerOrPrint from .cli_configs import DFFConfig from .utils import parse_params_values class DoubleFlatFieldChunks: def __init__( self, dataset_path, output_file, chunk_size=100, sigma=None, do_flatfield=True, h5_entry=None, logger=None ): self.logger = LoggerOrPrint(logger) self.dataset_info = analyze_dataset(dataset_path, extra_options={"hdf5_entry": h5_entry}, logger=logger) self.do_flatfield = bool(do_flatfield) if self.do_flatfield: update_dataset_info_flats_darks(self.dataset_info, flatfield_mode="force-compute") self.output_file = output_file self.sigma = sigma if sigma is not None and abs(sigma) > 1e-5 else None self._init_reader(chunk_size) self._init_flatfield((None, None, 0, self.chunk_size)) self._init_dff() def _init_reader(self, chunk_size, start_idx=0): self.chunk_size = min(chunk_size, self.dataset_info.radio_dims[-1]) self.reader = ChunkReader( self.dataset_info.projections, sub_region=(None, None, start_idx, start_idx + self.chunk_size), convert_float=True, ) self.projections = self.reader.files_data def _init_flatfield(self, subregion): if not self.do_flatfield: return self.flatfield = FlatFieldDataUrls( (self.dataset_info.n_angles, self.chunk_size, self.dataset_info.radio_dims[0]), self.dataset_info.flats, self.dataset_info.darks, sorted(self.dataset_info.projections.keys()), sub_region=subregion, ) def _apply_flatfield(self): if self.do_flatfield: self.flatfield.normalize_radios(self.projections) def _set_reader_subregion(self, subregion): self.reader._set_subregion(subregion) self.reader._init_reader() self.reader._loaded = False def _init_dff(self): self.double_flatfield = DoubleFlatField( self.projections.shape, input_is_mlog=False, output_is_mlog=False, average_is_on_log=self.sigma is not None, sigma_filter=self.sigma, ) def _get_config(self): conf = { "dataset": self.dataset_info.location, "entry": self.dataset_info.hdf5_entry or None, "dff_sigma": self.sigma, "do_flatfield": self.do_flatfield, } return conf def compute_double_flatfield(self): """ Compute the double flatfield for the current dataset. """ n_z = self.dataset_info.radio_dims[-1] chunk_size = self.chunk_size n_steps = n_z // chunk_size extra_step = bool(n_z % chunk_size) res = np.zeros(self.dataset_info.radio_dims[::-1]) for i in range(n_steps): self.logger.debug("Computing DFF batch %d/%d" % (i + 1, n_steps + int(extra_step))) subregion = (None, None, i * chunk_size, (i + 1) * chunk_size) self._set_reader_subregion(subregion) self._init_flatfield(subregion) self.reader.load_files() self._apply_flatfield() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[subregion[-2] : subregion[-1]] = dff[:] # Need to initialize objects with a different shape if extra_step: curr_idx = (i + 1) * self.chunk_size self.logger.debug("Computing DFF batch %d/%d" % (i + 2, n_steps + int(extra_step))) self._init_reader(n_z - curr_idx, start_idx=curr_idx) self._init_flatfield(self.reader.sub_region) self._init_dff() self.reader.load_files() self._apply_flatfield() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[curr_idx:] = dff[:] return res def write_double_flatfield(self, arr): """ Write the double flatfield image to a file """ writer = NXProcessWriter( self.output_file, entry=self.dataset_info.hdf5_entry or "entry", filemode="a", overwrite=True, ) writer.write(arr, "double_flatfield", config=self._get_config()) self.logger.info("Wrote %s" % writer.fname) def dff_cli(): args = parse_params_values( DFFConfig, parser_description="A command-line utility for computing the double flatfield of a dataset." ) logger = Logger("nabu_double_flatfield", level=args["loglevel"], logfile="nabu_double_flatfield.log") output_file = args["output"] dff = DoubleFlatFieldChunks( args["dataset"], output_file, chunk_size=args["chunk_size"], sigma=args["sigma"], do_flatfield=bool(args["flatfield"]), h5_entry=args["entry"] or None, logger=logger, ) dff_image = dff.compute_double_flatfield() dff.write_double_flatfield(dff_image) if __name__ == "__main__": dff_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/generate_header.py0000644000175000017500000002125700000000000017414 0ustar00pierrepierreimport os import numpy as np from tomoscan.io import HDF5File from silx.third_party.EdfFile import EdfFile from ..io.utils import get_first_hdf5_entry from .utils import parse_params_values from .cli_configs import GenerateInfoConfig edf_header_hdf5_path = { # EDF Header "count_time": "/technique/scan/exposure_time", # ms in HDF5 "date": "/start_time", "energy": "/technique/scan/energy", "flip": "/technique/detector/flipping", # [x, y] in HDF5 "motors": "/instrument/positioners", "optic_used": "/technique/optic/magnification", } info_hdf5_path = { "Energy": "/technique/scan/energy", "Distance": "/technique/scan/sample_detector_distance", # mm in HDF5 and EDF # ~ "Prefix": TODO "Directory": "/technique/saving/path", "ScanRange": "/technique/scan/scan_range", "TOMO_N": "/technique/scan/tomo_n", "REF_ON": "/technique/scan/ref_on", "REF_N": "/technique/reference/ref_n", "DARK_N": "/technique/dark/dark_n", "Dim_1": "/technique/detector/size", # array "Dim_2": "/technique/detector/size", # array "Count_time": "/technique/scan/exposure_time", # ms in HDF5, s in EDF ? "Shutter_time": "/technique/scan/shutter_time", # ms in HDF5, s in EDF ? "Optic_used": "/technique/optic/magnification", "PixelSize": "/technique/detector/pixel_size", # array (microns) "Date": "/start_time", "Scan_Type": "/technique/scan/scan_type", "SrCurrent": "/instrument/machine/current", "Comment": "/technique/scan/comment", # empty ? } def decode_bytes(content): if isinstance(content, bytes): return content.decode() else: return content def simulate_edf_header(fname, entry, return_dict=False): edf_header = { "ByteOrder": "LowByteFirst", } with HDF5File(fname, "r") as fid: for edf_name, h5_path in edf_header_hdf5_path.items(): if edf_name == "motors": continue h5_path = entry + h5_path edf_header[edf_name] = decode_bytes(fid[h5_path][()]) h5_motors = decode_bytes(fid[entry + edf_header_hdf5_path["motors"]]) edf_header["motor_mne"] = list(h5_motors.keys()) edf_header["motor_pos"] = [v[()] for v in h5_motors.values()] # remove "scan_numbers" from motors try: idx = edf_header["motor_mne"].index("scan_numbers") edf_header["motor_mne"].pop(idx) edf_header["motor_pos"].pop(idx) except ValueError: pass # remove invalid values from motor_pos invalid_values = ["*DIS*"] try: for invalid_val in invalid_values: idx = edf_header["motor_pos"].index(invalid_val) edf_header["motor_mne"].pop(idx) edf_header["motor_pos"].pop(idx) except ValueError: pass # Format edf_header["flip"] = "" % (edf_header["flip"][0], edf_header["flip"][1]) edf_header["motor_mne"] = " ".join(edf_header["motor_mne"]) edf_header["motor_pos"] = " ".join(list(map(str, edf_header["motor_pos"]))) edf_header["count_time"] = edf_header["count_time"] * 1e-3 # HDF5: ms -> EDF: s if return_dict: return edf_header res = "" for k, v in edf_header.items(): res = res + "%s = %s ;\n" % (k, v) return res def simulate_info_file(fname, entry, return_dict=False): info_file_content = {} with HDF5File(fname, "r") as fid: for info_name, h5_path in info_hdf5_path.items(): h5_path = entry + h5_path info_file_content[info_name] = decode_bytes(fid[h5_path][()]) info_file_content["Dim_1"] = info_file_content["Dim_1"][0] info_file_content["Dim_2"] = info_file_content["Dim_2"][1] info_file_content["PixelSize"] = info_file_content["PixelSize"][0] info_file_content["Prefix"] = os.path.basename(fname) info_file_content["Col_end"] = info_file_content["Dim_1"] - 1 info_file_content["Col_beg"] = 0 info_file_content["Row_end"] = info_file_content["Dim_2"] - 1 info_file_content["Row_beg"] = 0 for what in ["Count_time", "Shutter_time"]: info_file_content[what] = info_file_content[what] * 1e-3 if return_dict: return info_file_content # Format res = "" for k, v in info_file_content.items(): k_s = str("%s=" % k) sep = " " * (24 - len(k_s)) res = res + k_s + sep + str("%s\n" % v) return res def get_hst_saturations(hist, bins, numels): aMin = bins[0] aMax = bins[-1] hist_sum = numels * 1.0 hist_cum = np.cumsum(hist) hist_cum_rev = np.cumsum(hist[::-1]) i_s1 = np.where(hist_cum > 0.00001 * hist_sum)[0][0] sat1 = aMin + i_s1 * (aMax - aMin) / (hist.size - 1) i_S1 = np.where(hist_cum > 0.002 * hist_sum)[0][0] Sat1 = aMin + i_S1 * (aMax - aMin) / (hist.size - 1) i_s2 = np.argwhere(hist_cum_rev > 0.00001 * hist_sum)[0][0] sat2 = aMin + (hist.size - 1 - i_s2) * (aMax - aMin) / (hist.size - 1) i_S2 = np.argwhere(hist_cum_rev > 0.002 * hist_sum)[0][0] Sat2 = aMin + (hist.size - 1 - i_S2) * (aMax - aMin) / (hist.size - 1) return sat1, sat2, Sat1, Sat2 def simulate_hst_vol_header(fname, entry=None, return_dict=False): with HDF5File(fname, "r") as fid: try: histogram_path = entry + "/histogram/results/data" histogram_config_path = entry + "/histogram/configuration" Nz, Ny, Nx = fid[histogram_config_path]["volume_shape"][()] vol_header_content = { "NUM_X": Nx, "NUM_Y": Ny, "NUM_Z": Nz, "BYTEORDER": "LOWBYTEFIRST", } hist = fid[histogram_path][()] except KeyError as err: print("Could not load histogram from %s: %s" % (fname, err)) return {} if return_dict else "" bins = hist[1] hist = hist[0] vmin = bins[0] vmax = bins[-1] + (bins[-1] - bins[-2]) s1, s2, S1, S2 = get_hst_saturations(hist, bins, Nx * Ny * Nz) vmin = bins[0] vmax = bins[-1] + (bins[-1] - bins[-2]) vol_header_content.update( { "ValMin": vmin, "ValMax": vmax, "s1": s1, "s2": s2, "S1": S1, "S2": S2, } ) if return_dict: return vol_header_content res = "" for k, v in vol_header_content.items(): res = res + "%s = %s\n" % (k, v) return res def format_as_info(d): res = "" for k, v in d.items(): k_s = str("%s=" % k) sep = " " * (24 - len(k_s)) res = res + k_s + sep + str("%s\n" % v) return res def generate_merged_info_file_content( hist_fname=None, hist_entry=None, bliss_fname=None, bliss_entry=None, first_edf_proj=None, info_file=None ): # EDF Header if first_edf_proj is None: edf_header = simulate_edf_header(bliss_fname, bliss_entry, return_dict=True) else: edf_header = EdfFile(first_edf_proj).GetHeader(0) # .info File if info_file is None: info_file_content = simulate_info_file(bliss_fname, bliss_entry, return_dict=True) info_file_content = format_as_info(info_file_content) else: with open(info_file, "r") as f: info_file_content = f.read() # .vol File vol_file_content = simulate_hst_vol_header(hist_fname, entry=hist_entry, return_dict=True) # res = format_as_info(edf_header) res += info_file_content res += format_as_info(vol_file_content) return res def str_or_none(s): if len(s.strip()) == 0: return None return s def generate_merged_info_file(): args = parse_params_values(GenerateInfoConfig, parser_description="Generate a .info file") hist_fname = str_or_none(args["hist_file"]) hist_entry = str_or_none(args["hist_entry"]) bliss_fname = str_or_none(args["bliss_file"]) bliss_entry = str_or_none(args["bliss_entry"]) first_edf_proj = str_or_none(args["edf_proj"]) info_file = str_or_none(args["info_file"]) if hist_fname is not None and hist_entry is None: hist_entry = get_first_hdf5_entry(hist_fname) if bliss_fname is not None and bliss_entry is None: bliss_entry = get_first_hdf5_entry(bliss_fname) if not ((bliss_fname is None) ^ (info_file is None)): print("Error: please provide either --bliss_file or --info_file") exit(1) if info_file is not None and first_edf_proj is None: print("Error: please provide also --edf_proj when using the EDF format") exit(1) content = generate_merged_info_file_content( hist_fname=hist_fname, hist_entry=hist_entry, bliss_fname=bliss_fname, bliss_entry=bliss_entry, first_edf_proj=first_edf_proj, info_file=info_file, ) with open(args["output"], "w") as f: f.write(content) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/histogram.py0000644000175000017500000001734100000000000016306 0ustar00pierrepierrefrom os import path import posixpath from silx.io.url import DataUrl from silx.io.dictdump import h5todict from ..utils import check_supported from ..io.utils import get_first_hdf5_entry, get_h5_value from ..io.writer import NXProcessWriter from ..misc.histogram import PartialHistogram, VolumeHistogram, hist_as_2Darray from ..misc.histogram_cuda import CudaVolumeHistogram from ..resources.logger import Logger, LoggerOrPrint from .utils import parse_params_values from .cli_configs import HistogramConfig class VolumesHistogram: """ A class for extracting or computing histograms of one or several volumes. """ available_backends = { "numpy": VolumeHistogram, "cuda": CudaVolumeHistogram, } def __init__( self, fnames, output_file, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None, backend="cuda" ): """ Initialize a VolumesHistogram object. Parameters ----------- fnames: list of str List of paths to HDF5 files. To specify an entry for each file name, use the "?" separator: /path/to/file.h5?entry0001 output_file: str Path to the output file write_histogram_if_computed: bool, optional Whether to write histograms that are computed to a file. Some volumes might be missing their histogram. In this case, the histogram is computed, and the result is written to a dedicated file in the same directory as 'output_file'. Default is True. """ self._get_files_and_entries(fnames) self.chunk_size_slices = chunk_size_slices self.chunk_size_GB = chunk_size_GB self.nbins = nbins self.logger = LoggerOrPrint(logger) self.output_file = output_file self._get_histogrammer_backend(backend) def _get_files_and_entries(self, fnames): res_fnames = [] res_entries = [] for fname in fnames: if "?" not in fname: entry = None else: fname, entry = fname.split("?") if entry == "": entry = None res_fnames.append(fname) res_entries.append(entry) self.fnames = res_fnames self.entries = res_entries def _get_histogrammer_backend(self, backend): check_supported(backend, self.available_backends.keys(), "histogram backend") self.VolumeHistogramClass = self.available_backends[backend] def _get_config_onevolume(self, fname, entry, data_shape): return { "chunk_size_slices": self.chunk_size_slices, "chunk_size_GB": self.chunk_size_GB, "bins": self.nbins, "filename": fname, "entry": entry, "volume_shape": data_shape, } def _get_config(self): conf = self._get_config_onevolume("", "", None) conf.pop("filename") conf.pop("entry") conf["filenames"] = self.fnames conf["entries"] = [entry if entry is not None else "None" for entry in self.entries] return conf def _write_histogram_onevolume(self, fname, entry, histogram, data_shape): output_file = ( path.join(path.dirname(self.output_file), path.splitext(path.basename(fname))[0]) + "_histogram" + path.splitext(fname)[1] ) self.logger.info("Writing histogram of %s into %s" % (fname, output_file)) writer = NXProcessWriter(output_file, entry, filemode="w", overwrite=True) writer.write( hist_as_2Darray(histogram), "histogram", config=self._get_config_onevolume(fname, entry, data_shape) ) def get_histogram_single_volume(self, fname, entry, write_histogram_if_computed=True, return_config=False): entry = entry or get_first_hdf5_entry(fname) hist_path = posixpath.join(entry, "histogram", "results", "data") hist_cfg_path = posixpath.join(entry, "histogram", "configuration") rec_path = posixpath.join(entry, "reconstruction", "results", "data") rec_url = DataUrl(file_path=fname, data_path=rec_path) hist = get_h5_value(fname, hist_path) config = None if hist is None: self.logger.info("No histogram found in %s, computing it" % fname) vol_histogrammer = self.VolumeHistogramClass( rec_url, chunk_size_slices=self.chunk_size_slices, chunk_size_GB=self.chunk_size_GB, nbins=self.nbins, logger=self.logger, ) hist = vol_histogrammer.compute_volume_histogram() if write_histogram_if_computed: self._write_histogram_onevolume(fname, entry, hist, vol_histogrammer.data_shape) else: if return_config: raise ValueError( "return_config must be set to True to get configuration for non-existing histograms" ) hist = hist_as_2Darray(hist) config = h5todict(path.splitext(fname)[0] + "_histogram" + path.splitext(fname)[1], path=hist_cfg_path) if return_config: return hist, config else: return hist def get_histogram(self, return_config=False): histograms = [] configs = [] for fname, entry in zip(self.fnames, self.entries): self.logger.info("Getting histogram for %s" % fname) hist, conf = self.get_histogram_single_volume(fname, entry, return_config=True) histograms.append(hist) configs.append(conf) self.logger.info("Merging histograms") histogrammer = PartialHistogram(method="fixed_bins_number", num_bins=self.nbins) hist = histogrammer.merge_histograms(histograms, dont_truncate_bins=True) if return_config: return hist, configs else: return hist def merge_histograms_configurations(self, configs): if configs is None or len(configs) == 0: return res_config = {"volume_shape": list(configs[0]["volume_shape"])} res_config["volume_shape"][0] = 0 for conf in configs: nz, ny, nx = conf["volume_shape"] res_config["volume_shape"][0] += nz res_config["volume_shape"] = tuple(res_config["volume_shape"]) return res_config def write_histogram(self, hist, config=None): self.logger.info("Writing final histogram to %s" % (self.output_file)) config = config or {} base_config = self._get_config() base_config.pop("volume_shape") config.update(base_config) writer = NXProcessWriter(self.output_file, "entry0000", filemode="w", overwrite=True) writer.write(hist_as_2Darray(hist), "histogram", config=config) def histogram_cli(): args = parse_params_values(HistogramConfig, parser_description="Extract/compute histogram of volume(s).") logger = Logger("nabu_histogram", level=args["loglevel"], logfile="nabu_histogram.log") output = args["output_file"].split("?")[0] if path.exists(output): logger.fatal("Output file %s already exists, not overwriting it" % output) exit(1) chunk_size_gb = float(args["chunk_size_GB"]) if chunk_size_gb <= 0: chunk_size_gb = None histogramer = VolumesHistogram( args["h5_file"], output, chunk_size_slices=int(args["chunk_size_slices"]), chunk_size_GB=chunk_size_gb, nbins=int(args["bins"]), logger=logger, ) hist, configs = histogramer.get_histogram(return_config=True) config = histogramer.merge_histograms_configurations(configs) histogramer.write_histogram(hist, config=config) if __name__ == "__main__": histogram_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/nx_z_splitter.py0000644000175000017500000001154300000000000017213 0ustar00pierrepierreimport warnings from shutil import copy as copy_file from os import path from h5py import VirtualSource, VirtualLayout from tomoscan.io import HDF5File from ..resources.logger import Logger, LoggerOrPrint from ..io.utils import get_first_hdf5_entry from .cli_configs import ZSplitConfig from .utils import parse_params_values warnings.warn( "This command-line utility is intended as a temporary solution. Please do not rely too much on it.", Warning ) def _get_z_translations(fname, entry): z_path = path.join(entry, "sample", "z_translation") with HDF5File(fname, "r") as fid: z_transl = fid[z_path][:] return z_transl class NXZSplitter: def __init__(self, fname, output_dir, n_stages=None, entry=None, logger=None, use_virtual_dataset=False): self.fname = fname self._ext = path.splitext(fname)[-1] self.output_dir = output_dir self.n_stages = n_stages if entry is None: entry = get_first_hdf5_entry(fname) self.entry = entry self.logger = LoggerOrPrint(logger) self.use_virtual_dataset = use_virtual_dataset def _patch_nx_file(self, fname, mask): orig_fname = self.fname detector_path = path.join(self.entry, "instrument", "detector") sample_path = path.join(self.entry, "sample") with HDF5File(fname, "a") as fid: def patch_nx_entry(name): newval = fid[name][mask] del fid[name] fid[name] = newval detector_entries = [ path.join(detector_path, what) for what in ["count_time", "image_key", "image_key_control"] ] sample_entries = [ path.join(sample_path, what) for what in ["rotation_angle", "x_translation", "y_translation", "z_translation"] ] for what in detector_entries + sample_entries: self.logger.debug("Patching %s" % what) patch_nx_entry(what) # Patch "data" using a virtual dataset self.logger.debug("Patching data") data_path = path.join(detector_path, "data") if self.use_virtual_dataset: data_shape = fid[data_path].shape data_dtype = fid[data_path].dtype new_data_shape = (int(mask.sum()),) + data_shape[1:] vlayout = VirtualLayout(shape=new_data_shape, dtype=data_dtype) vsource = VirtualSource(orig_fname, name=data_path, shape=data_shape, dtype=data_dtype) vlayout[:] = vsource[mask, :, :] del fid[data_path] fid[detector_path].create_virtual_dataset("data", vlayout) if not (self.use_virtual_dataset): data_path = path.join(self.entry, "instrument", "detector", "data") with HDF5File(orig_fname, "r") as fid: data_arr = fid[data_path][mask, :, :] # Actually load data. Heavy ! with HDF5File(fname, "a") as fid: del fid[data_path] fid[data_path] = data_arr def z_split(self): """ Split a HDF5-NX file according to different z_translation. """ z_transl = _get_z_translations(self.fname, self.entry) different_z = set(z_transl) n_z = len(different_z) self.logger.info("Detected %d different z values: %s" % (n_z, str(different_z))) if n_z <= 1: raise ValueError("Detected only %d z-value. Stopping." % n_z) if self.n_stages is not None and self.n_stages != n_z: raise ValueError("Expected %d different stages, but I detected %d" % (self.n_stages, n_z)) masks = [(z_transl == z) for z in different_z] for i_z, mask in enumerate(masks): fname_curr_z = path.join( self.output_dir, path.splitext(path.basename(self.fname))[0] + str("_%06d" % i_z) + self._ext ) self.logger.info("Creating %s" % fname_curr_z) copy_file(self.fname, fname_curr_z) self._patch_nx_file(fname_curr_z, mask) def zsplit(): # Parse arguments args = parse_params_values( ZSplitConfig, parser_description="Split a HDF5-Nexus file according to z translation (z-series)" ) # Sanitize arguments fname = args["input_file"] output_dir = args["output_directory"] loglevel = args["loglevel"].upper() entry = args["entry"] if len(entry) == 0: entry = None n_stages = args["n_stages"] if n_stages < 0: n_stages = None use_virtual_dataset = bool(args["use_virtual_dataset"]) # Instantiate and execute logger = Logger("NX_z-splitter", level=loglevel, logfile="nxzsplit.log") nx_splitter = NXZSplitter( fname, output_dir, n_stages=n_stages, entry=entry, logger=logger, use_virtual_dataset=use_virtual_dataset ) nx_splitter.z_split() if __name__ == "__main__": zsplit() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/prepare_weights_double.py0000644000175000017500000000730300000000000021030 0ustar00pierrepierreimport h5py import numpy as np from scipy.special import erf # pylint: disable=all from scipy.special import erf import sys import os from scipy.ndimage import gaussian_filter from tomoscan.esrf.scan.hdf5scan import ImageKey, HDF5TomoScan from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer from ..io.reader import load_images_from_dataurl_dict def main(): """auxiliary program that can be called to create default input detector profiles, for nabu helical, concerning the weights of the pixels and the "double flat" renormalisation denominator. The result is an hdf5 file that can be used as a "processes_file" in the nabu configuration and is used by nabu-helical. In particulars cases the user may have fancy masks and correction map and will provide its own processes file, and will not need this. This code, and in particular the auxiliary function below (that by the way tomwer can use) provide a default construction of such maps. The double-flat is set to one and the weight is build on the basis of the flat fields from the dataset with an apodisation on the borders which allows to eliminate discontinuities in the contributions from the borders, above and below for the z-translations, and on the left or roght border for half-tomo. The usage is :: nabu-helical-prepare-weights-double nexus_file_name entry_name Then the resulting file can be used as processes file in the configuration file of nabu-helical """ if len(sys.argv) not in [3, 4, 5]: message = f""" Usage: {os.path.basename(sys.argv[0])} nexus_file_name entry_name [target_file name [transition_width]] """ print(message) sys.exit(1) file_name = sys.argv[1] if len(os.path.dirname(file_name)) == 0: # To make sure that other utility routines can succesfully deal with path within the current directory file_name = os.path.join(".", file_name) entry_name = sys.argv[2] process_file_name = "double.h5" transition_width = 50.0 if len(sys.argv) in [4, 5]: process_file_name = sys.argv[3] if len(sys.argv) in [5]: transition_width = float(sys.argv[4]) dataset_info = HDF5DatasetAnalyzer(file_name, extra_options={"h5_entry": entry_name}) update_dataset_info_flats_darks(dataset_info, flatfield_mode=1) mappe = 0 my_flats = load_images_from_dataurl_dict(dataset_info.flats) for key, flat in my_flats.items(): mappe += flat mappe = mappe / len(list(dataset_info.flats.keys())) create_heli_maps(mappe, process_file_name, entry_name, transition_width) def create_heli_maps(profile, process_file_name, entry_name, transition_width): profile = profile / profile.max() profile = profile.astype("f") profile = gaussian_filter(profile, 10) if os.path.exists(process_file_name): fd = h5py.File(process_file_name, "r+") else: fd = h5py.File(process_file_name, "w") def f(L, m, w): x = np.arange(L) d = (x - L + m).astype("f") res_r = (1 - erf(d / w)) / 2 d = (x - m).astype("f") res_l = (1 + erf(d / w)) / 2 return res_r * res_l border = f(profile.shape[1], 20, 13.33) border_v = f(profile.shape[0], int(round(transition_width / 2)), transition_width / 4) path_weights = entry_name + "/weights_field/results/data" path_double = entry_name + "/double_flatfield/results/data" if path_weights in fd: del fd[path_weights] if path_double in fd: del fd[path_double] fd[path_weights] = (profile * border) * border_v[:, None] fd[path_double] = np.ones_like(profile) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/reconstruct.py0000644000175000017500000000727500000000000016671 0ustar00pierrepierrefrom .. import version from ..resources.utils import is_hdf5_extension from ..pipeline.config import parse_nabu_config_file from ..pipeline.config_validators import convert_to_int from .cli_configs import ReconstructConfig from .utils import parse_params_values def update_reconstruction_start_end(conf_dict, user_indices): if len(user_indices) == 0: return rec_cfg = conf_dict["reconstruction"] err = None val_int, conv_err = convert_to_int(user_indices) if conv_err is None: start_z = user_indices end_z = user_indices else: if user_indices in ["first", "middle", "last"]: start_z = user_indices end_z = user_indices elif user_indices == "all": start_z = 0 end_z = -1 elif "-" in user_indices: try: start_z, end_z = user_indices.split("-") start_z = int(start_z) end_z = int(end_z) except Exception as exc: err = "Could not interpret slice indices '%s': %s" % (user_indices, str(exc)) else: err = "Could not interpret slice indices: %s" % user_indices if err is not None: print(err) exit(1) rec_cfg["start_z"] = start_z rec_cfg["end_z"] = end_z def get_log_file(arg_logfile, legacy_arg_logfile, forbidden=None): default_arg_val = "" # Compat. log_file --> logfile if legacy_arg_logfile != default_arg_val: logfile = legacy_arg_logfile else: logfile = arg_logfile # if forbidden is None: forbidden = [] for forbidden_val in forbidden: if logfile == forbidden_val: print("Error: --logfile argument cannot have the value %s" % forbidden_val) exit(1) if logfile == "": logfile = True return logfile def main(): args = parse_params_values( ReconstructConfig, parser_description=f"Perform a tomographic reconstruction.", program_version="nabu " + version, ) # Imports are done here, otherwise "nabu --version" takes forever from ..pipeline.fullfield.processconfig import ProcessConfig from ..pipeline.fullfield.reconstruction import FullFieldReconstructor # # A crash with scikit-cuda happens only on PPC64 platform if and nvidia-persistenced is running. # On such machines, a warm-up has to be done. import platform if platform.machine() == "ppc64le": try: from silx.math.fft.cufft import CUFFT except: # can't catch narrower - cublasNotInitialized requires cublas ! CUFFT = None # logfile = get_log_file(args["logfile"], args["log_file"], forbidden=[args["input_file"]]) conf_dict = parse_nabu_config_file(args["input_file"]) update_reconstruction_start_end(conf_dict, args["slice"].strip()) proc = ProcessConfig(conf_dict=conf_dict, create_logger=logfile) logger = proc.logger logger.info("Going to reconstruct slices (%d, %d)" % (proc.rec_region["start_z"], proc.rec_region["end_z"])) # Get extra options extra_options = { "gpu_mem_fraction": args["gpu_mem_fraction"], "cpu_mem_fraction": args["cpu_mem_fraction"], "chunk_size": args["max_chunk_size"] if args["max_chunk_size"] > 0 else None, "margin": args["phase_margin"], "force_grouped_mode": bool(args["force_use_grouped_pipeline"]), } R = FullFieldReconstructor(proc, logger=logger, extra_options=extra_options) R.reconstruct() R.merge_data_dumps() if is_hdf5_extension(proc.nabu_config["output"]["file_format"]): R.merge_hdf5_reconstructions() else: R.merge_histograms() if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/reconstruct_helical.py0000644000175000017500000000742300000000000020345 0ustar00pierrepierrefrom .. import version from ..resources.utils import is_hdf5_extension from ..pipeline.config import parse_nabu_config_file from ..pipeline.config_validators import convert_to_int from .cli_configs import ReconstructConfig from .utils import parse_params_values from .reconstruct import update_reconstruction_start_end, get_log_file def main_helical(): ReconstructConfig["dry_run"] = { "help": "Stops after printing some information on the reconstruction layout.", "default": 0, "type": int, } args = parse_params_values( ReconstructConfig, parser_description=f"Perform a helical tomographic reconstruction", program_version="nabu " + version, ) # Imports are done here, otherwise "nabu --version" takes forever from ..pipeline.helical.processconfig import ProcessConfig from ..pipeline.helical.helical_reconstruction import HelicalReconstructorRegridded # # A crash with scikit-cuda happens only on PPC64 platform if and nvidia-persistenced is running. # On such machines, a warm-up has to be done. import platform if platform.machine() == "ppc64le": try: from silx.math.fft.cufft import CUFFT except: # can't catch narrower - cublasNotInitialized requires cublas ! CUFFT = None # logfile = get_log_file(args["logfile"], args["log_file"], forbidden=[args["input_file"]]) conf_dict = parse_nabu_config_file(args["input_file"]) update_reconstruction_start_end(conf_dict, args["slice"].strip()) proc = ProcessConfig(conf_dict=conf_dict, create_logger=logfile) logger = proc.logger if "rotate_projections" in proc.processing_steps: message = """ The rotate_projections step is activated. The Helical pipelines are not yet suited for projection rotation it will soon be implemented. For the moment you should desactivate the rotation options in nabu.conf """ raise ValueError(message) # Determine which reconstructor to use reconstructor_cls = None phase_method = None if "phase" in proc.processing_steps: phase_method = proc.processing_options["phase"]["method"] # fix the reconstruction roi if not given if "reconstruction" in proc.processing_steps: rec_config = proc.processing_options["reconstruction"] rot_center = rec_config["rotation_axis_position"] Nx, Ny = proc.dataset_info.radio_dims if proc.nabu_config["reconstruction"]["auto_size"]: if 2 * rot_center > Nx: w = int(round(2 * rot_center)) else: w = int(round(2 * Nx - 2 * rot_center)) rec_config["start_x"] = int(round(rot_center - w / 2)) rec_config["end_x"] = int(round(rot_center + w / 2)) rec_config["start_y"] = rec_config["start_x"] rec_config["end_y"] = rec_config["end_x"] reconstructor_cls = HelicalReconstructorRegridded logger.debug("Using pipeline: %s" % reconstructor_cls.__name__) # Get extra options extra_options = { "gpu_mem_fraction": args["gpu_mem_fraction"], "cpu_mem_fraction": args["cpu_mem_fraction"], } extra_options.update( { ##### ??? "use_phase_margin": args["use_phase_margin"], "max_chunk_size": args["max_chunk_size"] if args["max_chunk_size"] > 0 else None, "phase_margin": args["phase_margin"], "dry_run": args["dry_run"], } ) R = reconstructor_cls(proc, logger=logger, extra_options=extra_options) R.reconstruct() if not R.dry_run: R.merge_data_dumps() if is_hdf5_extension(proc.nabu_config["output"]["file_format"]): R.merge_hdf5_reconstructions() R.merge_histograms() if __name__ == "__main__": main_helical() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/rotate.py0000644000175000017500000001442600000000000015610 0ustar00pierrepierreimport posixpath from os import path from math import ceil from shutil import copy from multiprocessing import cpu_count from multiprocessing.pool import ThreadPool import numpy as np from tomoscan.io import HDF5File from tomoscan.esrf.scan.hdf5scan import HDF5TomoScan from ..io.utils import get_first_hdf5_entry from ..misc.rotation import Rotation from ..resources.logger import Logger, LoggerOrPrint from ..pipeline.config_validators import optional_tuple_of_floats_validator, boolean_validator from ..misc.rotation_cuda import CudaRotation, __has_pycuda__ from .utils import parse_params_values from .cli_configs import RotateRadiosConfig class HDF5ImagesStackRotation: def __init__( self, input_file, output_file, angle, center=None, entry=None, logger=None, batch_size=100, use_cuda=True, use_multiprocessing=True, ): self.logger = LoggerOrPrint(logger) self.use_cuda = use_cuda & __has_pycuda__ self.batch_size = batch_size self.use_multiprocessing = use_multiprocessing self._browse_dataset(input_file, entry) self._get_rotation(angle, center) self._init_output_dataset(output_file) def _browse_dataset(self, input_file, entry): self.input_file = input_file if entry is None or entry == "": entry = get_first_hdf5_entry(input_file) self.entry = entry self.dataset_info = HDF5TomoScan(input_file, entry=entry) def _get_rotation(self, angle, center): if self.use_cuda: self.logger.info("Using Cuda rotation") rot_cls = CudaRotation else: self.logger.info("Using skimage rotation") rot_cls = Rotation if self.use_multiprocessing: self.thread_pool = ThreadPool(processes=cpu_count() - 2) self.logger.info("Using multiprocessing with %d cores" % self.thread_pool._processes) self.rotation = rot_cls((self.dataset_info.dim_2, self.dataset_info.dim_1), angle, center=center, mode="edge") def _init_output_dataset(self, output_file): self.output_file = output_file copy(self.input_file, output_file) first_proj_url = self.dataset_info.projections[list(self.dataset_info.projections.keys())[0]] self.data_path = first_proj_url.data_path() dirname, basename = posixpath.split(self.data_path) self._data_path_dirname = dirname self._data_path_basename = basename def _rotate_stack_cuda(self, images, output): # pylint: disable=E1136 self.rotation.cuda_processing.allocate_array("tmp_images_stack", images.shape) self.rotation.cuda_processing.allocate_array("tmp_images_stack_rot", images.shape) d_in = self.rotation.cuda_processing.get_array("tmp_images_stack") d_out = self.rotation.cuda_processing.get_array("tmp_images_stack_rot") n_imgs = images.shape[0] d_in[:n_imgs].set(images) for j in range(n_imgs): self.rotation.rotate(d_in[j], output=d_out[j]) d_out[:n_imgs].get(ary=output[:n_imgs]) def _rotate_stack(self, images, output): if self.use_cuda: self._rotate_stack_cuda(images, output) elif self.use_multiprocessing: out_tmp = self.thread_pool.map(self.rotation.rotate, images) print(out_tmp[0]) output[:] = np.array(out_tmp, dtype="f") # list -> np array... consumes twice as much memory else: for j in range(images.shape[0]): output[j] = self.rotation.rotate(images[j]) def rotate_images(self, suffix="_rot"): data_path = self.data_path fid = HDF5File(self.input_file, "r") fid_out = HDF5File(self.output_file, "a") try: data_ptr = fid[data_path] n_images = data_ptr.shape[0] data_out_ptr = fid_out[data_path] # Delete virtual dataset in output file, create "data_rot" dataset del fid_out[data_path] fid_out[self._data_path_dirname].create_dataset( self._data_path_basename + suffix, shape=data_ptr.shape, dtype=data_ptr.dtype ) data_out_ptr = fid_out[data_path + suffix] # read by group of images to hide latency group_size = self.batch_size images_rot = np.zeros((group_size, data_ptr.shape[1], data_ptr.shape[2]), dtype="f") n_groups = ceil(n_images / group_size) for i in range(n_groups): self.logger.info("Processing radios group %d/%d" % (i + 1, n_groups)) i_min = i * group_size i_max = min((i + 1) * group_size, n_images) images = data_ptr[i_min:i_max, :, :].astype("f") self._rotate_stack(images, images_rot) data_out_ptr[i_min:i_max, :, :] = images_rot[: i_max - i_min, :, :].astype(data_ptr.dtype) finally: fid_out[self._data_path_dirname].move(posixpath.basename(data_path) + suffix, self._data_path_basename) fid_out[data_path].attrs["interpretation"] = "image" fid.close() fid_out.close() def rotate_cli(): args = parse_params_values( RotateRadiosConfig, parser_description="A command-line utility for performing a rotation on all the radios of a dataset.", ) logger = Logger("nabu_rotate", level=args["loglevel"], logfile="nabu_rotate.log") dataset_path = args["dataset"] h5_entry = args["entry"] output_file = args["output"] center = optional_tuple_of_floats_validator("", "", args["center"]) # pylint: disable=E1121 use_cuda = boolean_validator("", "", args["use_cuda"]) # pylint: disable=E1121 use_multiprocessing = boolean_validator("", "", args["use_multiprocessing"]) # pylint: disable=E1121 if path.exists(output_file): logger.fatal("Output file %s already exists, not overwriting it" % output_file) exit(1) h5rot = HDF5ImagesStackRotation( dataset_path, output_file, args["angle"], center=center, entry=h5_entry, logger=logger, batch_size=args["batchsize"], use_cuda=use_cuda, use_multiprocessing=use_multiprocessing, ) h5rot.rotate_images() if __name__ == "__main__": rotate_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/shrink_dataset.py0000644000175000017500000000664600000000000017322 0ustar00pierrepierreimport os import posixpath from multiprocessing.pool import ThreadPool import numpy as np from silx.io.dictdump import dicttonx, nxtodict from ..misc.binning import binning as image_binning from ..io.utils import get_first_hdf5_entry from ..pipeline.config_validators import optional_tuple_of_floats_validator, optional_positive_integer_validator from .cli_configs import ShrinkConfig from .utils import parse_params_values def access_nested_dict(dict_, path, default=None): items = [s for s in path.split(posixpath.sep) if len(s) > 0] if len(items) == 1: return dict_.get(items[0], default) if items[0] not in dict_: return default return access_nested_dict(dict_[items[0]], posixpath.sep.join(items[1:])) def set_nested_dict_value(dict_, path, val): dirname, basename = posixpath.split(path) sub_dict = access_nested_dict(dict_, dirname) sub_dict[basename] = val def shrink_dataset(input_file, output_file, binning=None, subsampling=None, entry=None, n_threads=1): entry = entry or get_first_hdf5_entry(input_file) data_dict = nxtodict(input_file, path=entry, dereference_links=False) to_subsample = [ "control/data", "instrument/detector/count_time", "instrument/detector/data", "instrument/detector/image_key", "instrument/detector/image_key_control", "sample/rotation_angle", "sample/x_translation", "sample/y_translation", "sample/z_translation", ] detector_data = access_nested_dict(data_dict, "instrument/detector/data") if detector_data is None: raise ValueError("No data found in %s entry %s" % (input_file, entry)) if binning is not None: def _apply_binning(img_res_tuple): img, res = img_res_tuple res[:] = image_binning(img, binning) data_binned = np.zeros( (detector_data.shape[0], detector_data.shape[1] // binning[0], detector_data.shape[2] // binning[1]), detector_data.dtype, ) with ThreadPool(n_threads) as tp: tp.map(_apply_binning, zip(detector_data, data_binned)) detector_data = data_binned set_nested_dict_value(data_dict, "instrument/detector/data", data_binned) if subsampling is not None: for item_path in to_subsample: item_val = access_nested_dict(data_dict, item_path) if item_val is not None: set_nested_dict_value(data_dict, item_path, item_val[::subsampling]) dicttonx(data_dict, output_file, h5path=entry) def shrink_cli(): args = parse_params_values(ShrinkConfig, parser_description="Shrink a NX dataset") if not (os.path.isfile(args["input_file"])): print("No such file: %s" % args["input_file"]) exit(1) if os.path.isfile(args["output_file"]): print("Output file %s already exists, not overwriting it" % args["output_file"]) exit(1) binning = optional_tuple_of_floats_validator("", "binning", args["binning"]) # pylint: disable=E1121 if binning is not None: binning = tuple(map(int, binning)) subsampling = optional_positive_integer_validator("", "subsampling", args["subsampling"]) # pylint: disable=E1121 shrink_dataset( args["input_file"], args["output_file"], binning=binning, subsampling=subsampling, entry=args["entry"], n_threads=args["threads"], ) if __name__ == "__main__": shrink_cli() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/app/stitching.py0000644000175000017500000000142500000000000016301 0ustar00pierrepierrefrom nabu.utils import Progress from .cli_configs import StitchingConfig from ..pipeline.config import parse_nabu_config_file from nabu.stitching.z_stitching import z_stitching from nabu.stitching.config import dict_to_config_obj from .utils import parse_params_values def main(): args = parse_params_values( StitchingConfig, parser_description="Run stitching from a configuration file. Configuration can be obtain from `stitching-config`", ) conf_dict = parse_nabu_config_file(args["input_file"], allow_no_value=True) stitching_config = dict_to_config_obj(conf_dict) progress = Progress("z-stitching") progress.set_name("initialize z-stitching") progress.set_advancement(0) z_stitching(stitching_config, progress=progress) exit(0) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/app/utils.py0000644000175000017500000000173000000000000015444 0ustar00pierrepierrefrom argparse import ArgumentParser def parse_params_values(Params, parser_description=None, program_version=None, user_args=None): parser = ArgumentParser(description=parser_description) for param_name, vals in Params.items(): if param_name[0] != "-": # It would be better to use "required" and not to pop it. # required is an accepted keyword for argparse optional = not (vals.pop("mandatory", False)) if optional: param_name = "--" + param_name parser.add_argument(param_name, **vals) if program_version is not None: parser.add_argument("--version", "-V", action="version", version=program_version) args = parser.parse_args(args=user_args) args_dict = args.__dict__ return args_dict def parse_sections(sections): sections = sections.lower() if sections == "all": return None sections = sections.replace(" ", "").split(",") return sections ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/app/validator.py0000644000175000017500000000644000000000000016274 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import sys import os import h5py import tomoscan.validator from tomoscan.esrf.scan.hdf5scan import HDF5TomoScan from tomoscan.esrf.scan.edfscan import EDFTomoScan def get_scans(path, entries: str): path = os.path.abspath(path) res = [] if EDFTomoScan.is_tomoscan_dir(path): res.append(EDFTomoScan(scan=path)) elif HDF5TomoScan.is_tomoscan_dir(path): if entries == "__all__": entries = HDF5TomoScan.get_valid_entries(path) for entry in entries: res.append(HDF5TomoScan(path, entry)) else: raise TypeError(f"{path} does not looks like a folder containing .EDF or a valid nexus file ") return res def main(): argv = sys.argv parser = argparse.ArgumentParser(description="Check if provided scan(s) seems valid to be reconstructed.") parser.add_argument("path", help="Data to validate (h5 file, edf folder)") parser.add_argument("entries", help="Entries to be validated (in the case of a h5 file)", nargs="*") parser.add_argument( "--ignore-dark", help="Do not check for dark", default=True, action="store_false", dest="check_dark", ) parser.add_argument( "--ignore-flat", help="Do not check for flat", default=True, action="store_false", dest="check_flat", ) parser.add_argument( "--no-phase-retrieval", help="Check scan energy, distance and pixel size", dest="check_phase_retrieval", default=True, action="store_false", ) parser.add_argument( "--check-nan", help="Check frames if contains any nan.", dest="check_nan", default=False, action="store_true", ) parser.add_argument( "--skip-links-check", "--no-link-check", help="Check frames dataset if have some broken links.", dest="check_vds", default=True, action="store_false", ) parser.add_argument( "--all-entries", help="Check all entries of the files (for HDF5 only for now)", default=False, action="store_true", ) parser.add_argument( "--extend", help="By default it only display items with issues. Extend will display them all", dest="only_issues", default=True, action="store_false", ) options = parser.parse_args(argv[1:]) if options.all_entries is True: entries = "__all__" else: if len(options.entries) == 0 and h5py.is_hdf5(options.path): entries = "__all__" else: entries = options.entries scans = get_scans(path=options.path, entries=entries) if len(scans) == 0: raise ValueError(f"No scan found from file:{options.path}, entries:{options.entries}") for scan in scans: validator = tomoscan.validator.ReconstructionValidator( scan=scan, check_phase_retrieval=options.check_phase_retrieval, check_values=options.check_nan, check_vds=options.check_vds, check_dark=options.check_dark, check_flat=options.check_flat, ) sys.stdout.write(validator.checkup(only_issues=options.only_issues)) if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4567332 nabu-2023.1.1/nabu/cuda/0000755000175000017500000000000000000000000014065 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1622022789.0 nabu-2023.1.1/nabu/cuda/__init__.py0000644000175000017500000000000000000000000016164 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/cuda/convolution.py0000644000175000017500000003573600000000000017034 0ustar00pierrepierreimport numpy as np from os.path import dirname from silx.image.utils import gaussian_kernel from ..utils import updiv, get_cuda_srcfile from ..cuda.utils import __has_pycuda__ from ..misc.utils import ConvolutionInfos from .processing import CudaProcessing if __has_pycuda__: import pycuda.gpuarray as garray from pycuda.compiler import SourceModule class Convolution: """ A class for performing convolution on GPU with CUDA, but without using textures (unlike for example in ``silx.opencl.convolution``) """ def __init__(self, shape, kernel, axes=None, mode=None, extra_options=None, cuda_options=None): """ Constructor of Cuda Convolution. Parameters ----------- shape: tuple Shape of the array. kernel: array-like Convolution kernel (1D, 2D or 3D). axes: tuple, optional Axes along which the convolution is performed, for batched convolutions. mode: str, optional Boundary handling mode. Available modes are: - "reflect": cba|abcd|dcb - "nearest": aaa|abcd|ddd - "wrap": bcd|abcd|abc - "constant": 000|abcd|000 Default is "reflect". extra_options: dict, optional Advanced options (dict). Current options are: - "allocate_input_array": True - "allocate_output_array": True - "allocate_tmp_array": True - "sourcemodule_kwargs": {} - "batch_along_flat_dims": True """ self.cuda = CudaProcessing(**(cuda_options or {})) self._configure_extra_options(extra_options) self._determine_use_case(shape, kernel, axes) self._allocate_memory(mode) self._init_kernels() def _configure_extra_options(self, extra_options): self.extra_options = { "allocate_input_array": True, "allocate_output_array": True, "allocate_tmp_array": True, "sourcemodule_kwargs": {}, "batch_along_flat_dims": True, } extra_opts = extra_options or {} self.extra_options.update(extra_opts) self.sourcemodule_kwargs = self.extra_options["sourcemodule_kwargs"] def _get_dimensions(self, shape, kernel): self.shape = shape self.data_ndim = self._check_dimensions(shape=shape, name="Data") self.kernel_ndim = self._check_dimensions(arr=kernel, name="Kernel") Nx = shape[-1] if self.data_ndim >= 2: Ny = shape[-2] else: Ny = 1 if self.data_ndim >= 3: Nz = shape[-3] else: Nz = 1 self.Nx = np.int32(Nx) self.Ny = np.int32(Ny) self.Nz = np.int32(Nz) def _determine_use_case(self, shape, kernel, axes): """ Determine the convolution use case from the input/kernel shape, and axes. """ self._get_dimensions(shape, kernel) if self.kernel_ndim > self.data_ndim: raise ValueError("Kernel dimensions cannot exceed data dimensions") data_ndim = self.data_ndim kernel_ndim = self.kernel_ndim self.kernel = kernel.astype("f") convol_infos = ConvolutionInfos() k = (data_ndim, kernel_ndim) if k not in convol_infos.use_cases: raise ValueError( "Cannot find a use case for data ndim = %d and kernel ndim = %d" % (data_ndim, kernel_ndim) ) possible_use_cases = convol_infos.use_cases[k] # If some dimensions are "flat", make a batched convolution along them # Ex. data_dim = (1, Nx) -> batched 1D convolution if self.extra_options["batch_along_flat_dims"] and (1 in self.shape): axes = tuple([curr_dim for numels, curr_dim in zip(self.shape, range(len(self.shape))) if numels != 1]) # self.use_case_name = None for uc_name, uc_params in possible_use_cases.items(): if axes in convol_infos.allowed_axes[uc_name]: self.use_case_name = uc_name self.use_case_desc = uc_params["name"] self.use_case_kernels = uc_params["kernels"].copy() if self.use_case_name is None: raise ValueError( "Cannot find a use case for data ndim = %d, kernel ndim = %d and axes=%s" % (data_ndim, kernel_ndim, str(axes)) ) # TODO implement this use case if self.use_case_name == "batched_separable_2D_1D_3D": raise NotImplementedError("The use case %s is not implemented" % self.use_case_name) # self.axes = axes # Replace "axes=None" with an actual value (except for ND-ND) allowed_axes = convol_infos.allowed_axes[self.use_case_name] if len(allowed_axes) > 1: # The default choice might impact perfs self.axes = allowed_axes[0] or allowed_axes[1] self.separable = self.use_case_name.startswith("separable") self.batched = self.use_case_name.startswith("batched") def _allocate_memory(self, mode): self.mode = mode or "reflect" # The current implementation does not support kernel size bigger than data size, # except for mode="nearest" for i, dim_size in enumerate(self.shape): if min(self.kernel.shape) > dim_size and i in self.axes: print( "Warning: kernel support is too large for data dimension %d (%d). Forcing convolution mode to 'nearest'" % (i, dim_size) ) self.mode = "nearest" # option_array_names = { "allocate_input_array": "data_in", "allocate_output_array": "data_out", "allocate_tmp_array": "data_tmp", } # Nonseparable transforms do not need tmp array if not (self.separable): self.extra_options["allocate_tmp_array"] = False # Allocate arrays for option_name, array_name in option_array_names.items(): if self.extra_options[option_name]: value = garray.zeros(self.shape, np.float32) else: value = None setattr(self, array_name, value) if isinstance(self.kernel, np.ndarray): self.d_kernel = garray.to_gpu(self.kernel) else: if not (isinstance(self.kernel, garray.GPUArray)): raise ValueError("kernel must be either numpy array or pycuda array") self.d_kernel = self.kernel self._old_input_ref = None self._old_output_ref = None self._c_modes_mapping = { "periodic": 2, "wrap": 2, "nearest": 1, "replicate": 1, "reflect": 0, "constant": 3, } mp = self._c_modes_mapping if self.mode.lower() not in mp: raise ValueError( """ Mode %s is not available. Available modes are: %s """ % (self.mode, str(mp.keys())) ) if self.mode.lower() == "constant": raise NotImplementedError("mode='constant' is not implemented yet") self._c_conv_mode = mp[self.mode] def _init_kernels(self): if self.kernel_ndim > 1: if np.abs(np.diff(self.kernel.shape)).max() > 0: raise NotImplementedError("Non-separable convolution with non-square kernels is not implemented yet") # Compile source module compile_options = [str("-DUSED_CONV_MODE=%d" % self._c_conv_mode)] fname = get_cuda_srcfile("convolution.cu") nabu_cuda_dir = dirname(fname) include_dirs = [nabu_cuda_dir] self.sourcemodule_kwargs["options"] = compile_options self.sourcemodule_kwargs["include_dirs"] = include_dirs with open(fname) as fid: cuda_src = fid.read() self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs) # Blocks, grid self._block_size = {1: (32, 1, 1), 2: (32, 32, 1), 3: (16, 8, 8)}[self.data_ndim] # TODO tune self._n_blocks = tuple([int(updiv(a, b)) for a, b in zip(self.shape[::-1], self._block_size)]) # Prepare cuda kernel calls self._cudakernel_signature = { 1: "PPPiiii", 2: "PPPiiiii", 3: "PPPiiiiii", }[self.kernel_ndim] self.cuda_kernels = {} for axis, kern_name in enumerate(self.use_case_kernels): self.cuda_kernels[axis] = self._module.get_function(kern_name) self.cuda_kernels[axis].prepare(self._cudakernel_signature) # Cuda kernel arguments kernel_args = [ self._n_blocks, self._block_size, None, None, self.d_kernel.gpudata, np.int32(self.kernel.shape[0]), self.Nx, self.Ny, self.Nz, ] if self.kernel_ndim == 2: kernel_args.insert(5, np.int32(self.kernel.shape[1])) if self.kernel_ndim == 3: kernel_args.insert(5, np.int32(self.kernel.shape[2])) kernel_args.insert(6, np.int32(self.kernel.shape[1])) self.kernel_args = tuple(kernel_args) # If self.data_tmp is allocated, separable transforms can be performed # by a series of batched transforms, without any copy, by swapping refs. self.swap_pattern = None if self.separable: if self.data_tmp is not None: self.swap_pattern = { 2: [("data_in", "data_tmp"), ("data_tmp", "data_out")], 3: [ ("data_in", "data_out"), ("data_out", "data_tmp"), ("data_tmp", "data_out"), ], } else: raise NotImplementedError("For now, data_tmp has to be allocated") def _get_swapped_arrays(self, i): """ Get the input and output arrays to use when using a "swap pattern". Swapping refs enables to avoid copies between temp. array and output. For example, a separable 2D->1D convolution on 2D data reads: data_tmp = convol(data_input, kernel, axis=1) # step i=0 data_out = convol(data_tmp, kernel, axis=0) # step i=1 :param i: current step number of the separable convolution """ n_batchs = len(self.axes) in_ref, out_ref = self.swap_pattern[n_batchs][i] d_in = getattr(self, in_ref) d_out = getattr(self, out_ref) return d_in, d_out def _configure_kernel_args(self, cuda_kernel_args, input_ref, output_ref): # TODO more elegant if isinstance(input_ref, garray.GPUArray): input_ref = input_ref.gpudata if isinstance(output_ref, garray.GPUArray): output_ref = output_ref.gpudata if input_ref is not None or output_ref is not None: cuda_kernel_args = list(cuda_kernel_args) if input_ref is not None: cuda_kernel_args[2] = input_ref if output_ref is not None: cuda_kernel_args[3] = output_ref cuda_kernel_args = tuple(cuda_kernel_args) return cuda_kernel_args @staticmethod def _check_dimensions(arr=None, shape=None, name="", dim_min=1, dim_max=3): if shape is not None: ndim = len(shape) elif arr is not None: ndim = arr.ndim else: raise ValueError("Please provide either arr= or shape=") if ndim < dim_min or ndim > dim_max: raise ValueError("%s dimensions should be between %d and %d" % (name, dim_min, dim_max)) return ndim def _check_array(self, arr): if not (isinstance(arr, garray.GPUArray) or isinstance(arr, np.ndarray)): raise TypeError("Expected either pycuda.gpuarray or numpy.ndarray") if arr.dtype != np.float32: raise TypeError("Data must be float32") if arr.shape != self.shape: raise ValueError("Expected data shape = %s" % str(self.shape)) def _set_arrays(self, array, output=None): # Either copy H->D or update references. if isinstance(array, np.ndarray): self.data_in[:] = array[:] else: self._old_input_ref = self.data_in self.data_in = array data_in_ref = self.data_in if output is not None: if not (isinstance(output, np.ndarray)): self._old_output_ref = self.data_out self.data_out = output # Update Cuda kernel arguments with new array references self.kernel_args = self._configure_kernel_args(self.kernel_args, data_in_ref, self.data_out) def _separable_convolution(self): assert len(self.axes) == len(self.use_case_kernels) # Separable: one kernel call per data dimension for i, axis in enumerate(self.axes): in_ref, out_ref = self._get_swapped_arrays(i) self._batched_convolution(axis, input_ref=in_ref, output_ref=out_ref) def _batched_convolution(self, axis, input_ref=None, output_ref=None): # Batched: one kernel call in total cuda_kernel = self.cuda_kernels[axis] cuda_kernel_args = self._configure_kernel_args(self.kernel_args, input_ref, output_ref) ev = cuda_kernel.prepared_call(*cuda_kernel_args) def _nd_convolution(self): assert len(self.use_case_kernels) == 1 cuda_kernel = self._module.get_function(self.use_case_kernels[0]) ev = cuda_kernel.prepared_call(*self.kernel_args) def _recover_arrays_references(self): if self._old_input_ref is not None: self.data_in = self._old_input_ref self._old_input_ref = None if self._old_output_ref is not None: self.data_out = self._old_output_ref self._old_output_ref = None self.kernel_args = self._configure_kernel_args(self.kernel_args, self.data_in, self.data_out) def _get_output(self, output): if output is None: res = self.data_out.get() else: res = output if isinstance(output, np.ndarray): output[:] = self.data_out[:] self._recover_arrays_references() return res def convolve(self, array, output=None): """ Convolve an array with the class kernel. :param array: Input array. Can be numpy.ndarray or pycuda.gpuarray.GPUArray. :param output: Output array. Can be numpy.ndarray or pycuda.gpuarray.GPUArray. """ self._check_array(array) self._set_arrays(array, output=output) if self.axes is not None: if self.separable: self._separable_convolution() elif self.batched: assert len(self.axes) == 1 self._batched_convolution(self.axes[0]) # else: ND-ND convol else: # ND-ND convol self._nd_convolution() res = self._get_output(output) return res __call__ = convolve ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/kernel.py0000644000175000017500000001463000000000000015723 0ustar00pierrepierrefrom ..utils import updiv import pycuda.gpuarray as garray from pycuda.compiler import SourceModule class CudaKernel: """ Helper class that wraps CUDA kernel through pycuda SourceModule. Parameters ----------- kernel_name: str Name of the CUDA kernel. filename: str, optional Path to the file name containing kernels definitions src: str, optional Source code of kernels definitions signature: str, optional Signature of kernel function. If provided, pycuda will not guess the types of kernel arguments, making the calls slightly faster. For example, a function acting on two pointers, an integer and a float32 has the signature "PPif". texrefs: list, optional List of texture references, if any automation_params: dict, optional Automation parameters, see below sourcemodule_kwargs: optional Extra arguments to provide to pycuda.compiler.SourceModule(), Automation parameters ---------------------- automation_params is a dictionary with the following keys and default values. guess_block: bool (True) If block is not specified during calls, choose a block size based on the size/dimensions of the first array. Mind that it is unlikely to be the optimal choice. guess_grid: bool (True): If the grid size is not specified during calls, choose a grid size based on the size of the first array. follow_gpuarr_ptr: bool (True) specify gpuarray.gpudata for all GPUArrays. Otherwise, raise an error. """ def __init__( self, kernel_name, filename=None, src=None, signature=None, texrefs=[], automation_params=None, **sourcemodule_kwargs, ): self.check_filename_src(filename, src) self.set_automation_params(automation_params) self.compile_kernel_source(kernel_name, sourcemodule_kwargs) self.prepare(signature, texrefs) def check_filename_src(self, filename, src): err_msg = "Please provide either filename or src" if filename is None and src is None: raise ValueError(err_msg) if filename is not None and src is not None: raise ValueError(err_msg) if filename is not None: with open(filename) as fid: src = fid.read() self.filename = filename self.src = src def set_automation_params(self, automation_params): self.automation_params = { "guess_block": True, "guess_grid": True, "follow_gpuarr_ptr": True, } automation_params = automation_params or {} self.automation_params.update(automation_params) def compile_kernel_source(self, kernel_name, sourcemodule_kwargs): self.sourcemodule_kwargs = sourcemodule_kwargs self.kernel_name = kernel_name self.module = SourceModule(self.src, **self.sourcemodule_kwargs) self.func = self.module.get_function(kernel_name) def prepare(self, kernel_signature, texrefs): self.prepared = False self.kernel_signature = kernel_signature self.texrefs = texrefs if kernel_signature is not None: self.func.prepare(self.kernel_signature, texrefs=texrefs) self.prepared = True @staticmethod def guess_grid_size(shape, block_size): # python: (z, y, x) -> cuda: (x, y, z) res = tuple(map(lambda x: updiv(x[0], x[1]), zip(shape[::-1], block_size))) if len(res) == 2: res += (1,) return res @staticmethod def guess_block_size(shape): """ Guess a block size based on the shape of an array. """ ndim = len(shape) if ndim == 1: return (128, 1, 1) if ndim == 2: return (32, 32, 1) else: return (16, 8, 8) def get_block_grid(self, *args, **kwargs): block = None grid = None if ("block" not in kwargs) or (kwargs["block"] is None): if self.automation_params["guess_block"]: block = self.guess_block_size(args[0].shape) else: raise ValueError("Please provide block size") else: block = kwargs["block"] if ("grid" not in kwargs) or (kwargs["grid"] is None): if self.automation_params["guess_grid"]: grid = self.guess_grid_size(args[0].shape, block) else: raise ValueError("Please provide block grid") else: grid = kwargs["grid"] self.last_block_size = block self.last_grid_size = grid return block, grid def follow_gpu_arr(self, args): args = list(args) # Replace GPUArray with GPUArray.gpudata for i, arg in enumerate(args): if isinstance(arg, garray.GPUArray): args[i] = arg.gpudata return tuple(args) def get_last_kernel_time(self): """ Return the execution time (in seconds) of the last called kernel. The last called kernel should have been called with time_kernel=True. """ if self.last_kernel_time is not None: return self.last_kernel_time() else: return None def call(self, *args, **kwargs): block, grid = self.get_block_grid(*args, **kwargs) # pycuda crashes when any element of block/grid is not a python int (ex. numpy.int64). # A weird behavior once observed is "data.shape" returning (np.int64, int, int) (!). # Ensure that everything is a python integer. block = tuple(int(x) for x in block) grid = tuple(int(x) for x in grid) # args = self.follow_gpu_arr(args) if self.prepared: func_call = self.func.prepared_call if "time_kernel" in kwargs: func_call = self.func.prepared_timed_call kwargs.pop("time_kernel") if "block" in kwargs: kwargs.pop("block") if "grid" in kwargs: kwargs.pop("grid") t = func_call(grid, block, *args, **kwargs) else: kwargs["block"] = block kwargs["grid"] = grid t = self.func(*args, **kwargs) # ~ return t # TODO return event like in OpenCL ? self.last_kernel_time = t # list ? __call__ = call ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/cuda/medfilt.py0000644000175000017500000001323300000000000016065 0ustar00pierrepierreimport numpy as np from os.path import dirname from silx.image.utils import gaussian_kernel from ..utils import updiv, get_cuda_srcfile from .processing import CudaProcessing import pycuda.gpuarray as garray from pycuda.compiler import SourceModule class MedianFilter: """ A class for performing median filter on GPU with CUDA """ def __init__( self, shape, footprint=(3, 3), mode="reflect", threshold=None, cuda_options=None, abs_diff=False, ): """Constructor of Cuda Median Filter. Parameters ---------- shape: tuple Shape of the array, in the format (n_rows, n_columns) footprint: tuple Size of the median filter, in the format (y, x). mode: str Boundary handling mode. Available modes are: - "reflect": cba|abcd|dcb - "nearest": aaa|abcd|ddd - "wrap": bcd|abcd|abc - "constant": 000|abcd|000 Default is "reflect". threshold: float, optional Threshold for the "thresholded median filter". A thresholded median filter only replaces a pixel value by the median if this pixel value is greater or equal than median + threshold. abs_diff: bool, optional Whether to perform conditional threshold as abs(value - median) Notes ------ Please refer to the documentation of the CudaProcessing class for the other parameters. """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._set_params(shape, footprint, mode, threshold, abs_diff) self.cuda_processing.init_arrays_to_none(["d_input", "d_output"]) self._init_kernels() def _set_params(self, shape, footprint, mode, threshold, abs_diff): self.data_ndim = len(shape) if self.data_ndim == 2: ny, nx = shape nz = 1 elif self.data_ndim == 3: nz, ny, nx = shape else: raise ValueError("Expected 2D or 3D data") self.shape = shape self.Nx = np.int32(nx) self.Ny = np.int32(ny) self.Nz = np.int32(nz) if len(footprint) != 2: raise ValueError("3D median filter is not implemented yet") if not ((footprint[0] & 1) and (footprint[1] & 1)): raise ValueError("Must have odd-sized footprint") self.footprint = footprint self._set_boundary_mode(mode) self.do_threshold = False self.abs_diff = abs_diff if threshold is not None: self.threshold = np.float32(threshold) self.do_threshold = True else: self.threshold = np.float32(0) def _set_boundary_mode(self, mode): self.mode = mode # Some code duplication from convolution self._c_modes_mapping = { "periodic": 2, "wrap": 2, "nearest": 1, "replicate": 1, "reflect": 0, "constant": 3, } mp = self._c_modes_mapping if self.mode.lower() not in mp: raise ValueError( """ Mode %s is not available. Available modes are: %s """ % (self.mode, str(mp.keys())) ) if self.mode.lower() == "constant": raise NotImplementedError("mode='constant' is not implemented yet") self._c_conv_mode = mp[self.mode] def _init_kernels(self): # Compile source module compile_options = [ "-DUSED_CONV_MODE=%d" % self._c_conv_mode, "-DMEDFILT_X=%d" % self.footprint[1], "-DMEDFILT_Y=%d" % self.footprint[0], "-DDO_THRESHOLD=%d" % (int(self.do_threshold) + int(self.abs_diff)), ] fname = get_cuda_srcfile("medfilt.cu") nabu_cuda_dir = dirname(fname) include_dirs = [nabu_cuda_dir] self.sourcemodule_kwargs = {} self.sourcemodule_kwargs["options"] = compile_options self.sourcemodule_kwargs["include_dirs"] = include_dirs with open(fname) as fid: cuda_src = fid.read() self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs) self.cuda_kernel_2d = self._module.get_function("medfilt2d") # Blocks, grid self._block_size = {2: (32, 32, 1), 3: (16, 8, 8)}[self.data_ndim] # TODO tune self._n_blocks = tuple([updiv(a, b) for a, b in zip(self.shape[::-1], self._block_size)]) def medfilt2(self, image, output=None): """ Perform a median filter on an image (or batch of images). Parameters ----------- images: numpy.ndarray or pycuda.gpuarray 2D image or 3D stack of 2D images output: numpy.ndarray or pycuda.gpuarray, optional Output of filtering. If provided, it must have the same shape as the input array. """ self.cuda_processing.set_array("d_input", image, self.shape) if output is not None: self.cuda_processing.set_array("d_output", output, self.shape) else: self.cuda_processing.allocate_array("d_output", self.shape) self.cuda_kernel_2d( self.cuda_processing.d_input, self.cuda_processing.d_output, self.Nx, self.Ny, self.Nz, self.threshold, grid=self._n_blocks, block=self._block_size, ) self.cuda_processing.recover_arrays_references(["d_input", "d_output"]) if output is None: return self.cuda_processing.d_output.get() else: return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/padding.py0000644000175000017500000001126000000000000016045 0ustar00pierrepierreimport numpy as np from ..utils import get_cuda_srcfile, updiv, check_supported from .kernel import CudaKernel from .processing import CudaProcessing import pycuda.gpuarray as garray class CudaPadding: """ A class for performing padding on GPU """ supported_modes = ["constant", "edge", "reflect", "symmetric", "wrap"] def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs): """ Initialize a CudaPadding object. Parameters ---------- shape: tuple Image shape pad_width: tuple Padding width for each axis. Please see the documentation of numpy.pad(). It can also be a tuple of two numpy arrays for generic coordinate transform. mode: str Padding mode Other parameters ---------------- constant_values: tuple Tuple containing the values to fill when mode="constant". """ if len(shape) != 2: raise ValueError("This class only works on images") self.shape = shape self._set_mode(mode, **kwargs) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._get_padding_arrays(pad_width) self._init_cuda_coordinate_transform() def _set_mode(self, mode, **kwargs): if mode == "edges": mode = "edge" check_supported(mode, self.supported_modes, "padding mode") self.mode = mode self._kwargs = kwargs def _get_padding_arrays(self, pad_width): self.pad_width = pad_width if isinstance(pad_width, tuple) and isinstance(pad_width[0], np.ndarray): # user-defined coordinate transform if len(pad_width) != 2: raise ValueError( "pad_width must be either a scalar, a tuple in the form ((a, b), (c, d)), or a tuple of two numpy arrays" ) if self.mode == "constant": raise ValueError("Custom coordinate transform does not work with mode='constant'") self.mode = "custom" self.coords_rows, self.coords_cols = pad_width else: if self.mode == "constant": # no need for coordinate transform here constant_values = self._kwargs.get("constant_values", 0) self.padded_array_constant = np.pad( np.zeros(self.shape, dtype="f"), self.pad_width, mode="constant", constant_values=constant_values ) self.padded_shape = self.padded_array_constant.shape return R, C = np.indices(self.shape, dtype=np.int32) self.coords_rows = np.pad(R, self.pad_width, mode=self.mode) self.coords_cols = np.pad(C, self.pad_width, mode=self.mode) self.padded_shape = self.coords_rows.shape def _init_cuda_coordinate_transform(self): if self.mode == "constant": self.d_padded_array_constant = garray.to_gpu(self.padded_array_constant) return self._coords_transform_kernel = CudaKernel( "coordinate_transform", filename=get_cuda_srcfile("padding.cu"), signature="PPPPiii", ) self._coords_transform_block = (32, 32, 1) self._coords_transform_grid = [ updiv(a, b) for a, b in zip(self.padded_shape[::-1], self._coords_transform_block) ] self.d_coords_rows = garray.to_gpu(self.coords_rows) self.d_coords_cols = garray.to_gpu(self.coords_cols) def _pad_constant(self, image, output): pad_y, pad_x = self.pad_width self.d_padded_array_constant[pad_y[0] : pad_y[0] + self.shape[0], pad_x[0] : pad_x[0] + self.shape[1]] = image[ : ] output[:] = self.d_padded_array_constant[:] return output def pad(self, image, output=None): """ Pad an array. Parameters ---------- image: pycuda.gpuarray.GPUArray Image to pad output: pycuda.gpuarray.GPUArray, optional Output image. If provided, must be in the expected shape. """ if output is None: output = self.cuda_processing.allocate_array("d_output", self.padded_shape) if self.mode == "constant": return self._pad_constant(image, output) self._coords_transform_kernel( image, output, self.d_coords_cols, self.d_coords_rows, np.int32(self.shape[1]), np.int32(self.padded_shape[1]), np.int32(self.padded_shape[0]), grid=self._coords_transform_grid, block=self._coords_transform_block, ) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/processing.py0000644000175000017500000001257400000000000016624 0ustar00pierrepierreimport numpy as np from .utils import get_cuda_context, __has_pycuda__ if __has_pycuda__: import pycuda.driver as cuda import pycuda.gpuarray as garray dev_attrs = cuda.device_attribute # NB: we must detach from a context before creating another context class CudaProcessing: def __init__(self, device_id=None, ctx=None, stream=None, cleanup_at_exit=True): """ Initialie a CudaProcessing instance. CudaProcessing is a base class for all CUDA-based processings. This class provides utilities for context/device management, and arrays allocation. Parameters ---------- device_id: int, optional ID of the cuda device to use (those of the `nvidia-smi` command). Ignored if ctx is not None. ctx: pycuda.driver.Context, optional Existing context to use. If provided, do not create a new context. stream: pycudacuda.driver.Stream, optional Cuda stream. If not provided, will use the default stream cleanup_at_exit: bool, optional Whether to clean-up the context at exit. Ignored if ctx is not None. """ if ctx is None: self.ctx = get_cuda_context(device_id=device_id, cleanup_at_exit=cleanup_at_exit) else: self.ctx = ctx self.stream = stream self.device = self.ctx.get_device() self.device_name = self.device.name() self.device_id = self.device.get_attribute(dev_attrs.MULTI_GPU_BOARD_GROUP_ID) self._allocated = {} def push_context(self): self.ctx.push() return self.ctx def pop_context(self): self.ctx.pop() def init_arrays_to_none(self, arrays_names): """ Initialize arrays to None. After calling this method, the current instance will have self.array_name = None, and self._old_array_name = None. Parameters ---------- arrays_names: list of str List of arrays names. """ for array_name in arrays_names: setattr(self, array_name, None) setattr(self, "_old_" + array_name, None) self._allocated[array_name] = False def recover_arrays_references(self, arrays_names): """ Performs self._array_name = self._old_array_name, for each array_name in arrays_names. Parameters ---------- arrays_names: list of str List of array names """ for array_name in arrays_names: old_arr = getattr(self, "_old_" + array_name, None) if old_arr is not None: setattr(self, array_name, old_arr) def allocate_array(self, array_name, shape, dtype=np.float32): """ Allocate a GPU array on the current context/stream/device, and set 'self.array_name' to this array. Parameters ---------- array_name: str Name of the array (for book-keeping) shape: tuple of int Array shape dtype: numpy.dtype, optional Data type. Default is float32. """ if not self._allocated.get(array_name, False): new_gpuarr = garray.zeros(shape, dtype=dtype) setattr(self, array_name, new_gpuarr) self._allocated[array_name] = True return getattr(self, array_name) def set_array(self, array_name, array_ref, shape, dtype=np.float32): """ Set the content of a device array. Parameters ---------- array_name: str Array name. This method will look for self.array_name. array_ref: array (numpy or GPU array) Array containing the data to copy to 'array_name'. shape: tuple of int Array shape dtype: numpy.dtype, optional Data type. Default is float32. """ if isinstance(array_ref, garray.GPUArray): current_arr = getattr(self, array_name, None) setattr(self, "_old_" + array_name, current_arr) setattr(self, array_name, array_ref) elif isinstance(array_ref, np.ndarray): self.allocate_array(array_name, shape, dtype=dtype) getattr(self, array_name).set(array_ref) else: raise ValueError("Expected numpy array or pycuda array") def get_array(self, array_name): return getattr(self, array_name, None) # COMPAT. _init_arrays_to_none = init_arrays_to_none _recover_arrays_references = recover_arrays_references _allocate_array = allocate_array _set_array = set_array @staticmethod def check_array(arr, expected_shape, expected_dtype="f", check_contiguous=True): """ Check whether a given array is suitable for being processed (shape, dtype, contiguous) """ if arr.shape != expected_shape: raise ValueError("Expected shape %s but got %s" % (str(expected_shape), str(arr.shape))) if arr.dtype != np.dtype(expected_dtype): raise ValueError("Expected data type %s but got %s" % (str(expected_dtype), str(arr.dtype))) if check_contiguous: if isinstance(arr, np.ndarray) and not (arr.flags["C_CONTIGUOUS"]): raise ValueError("Expected C-contiguous array") if isinstance(arr, garray.GPUArray) and not arr.flags.c_contiguous: raise ValueError("Expected C-contiguous array") ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4567332 nabu-2023.1.1/nabu/cuda/src/0000755000175000017500000000000000000000000014654 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/src/ElementOp.cu0000644000175000017500000001124200000000000017075 0ustar00pierrepierre#include typedef pycuda::complex complex; // Generic operations #define OP_ADD 0 #define OP_SUB 1 #define OP_MUL 2 #define OP_DIV 3 // #ifndef GENERIC_OP #define GENERIC_OP OP_ADD #endif // arr2D *= arr1D (line by line, i.e along fast dim) __global__ void inplace_complex_mul_2Dby1D(complex* arr2D, complex* arr1D, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; // This does not seem to work // Use cuCmulf of cuComplex.h ? //~ arr2D[y*width + x] *= arr1D[x]; int i = y*width + x; complex a = arr2D[i]; complex b = arr1D[x]; arr2D[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr2D[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } __global__ void inplace_generic_op_2Dby2D(float* arr2D, float* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; int i = y*width + x; #if GENERIC_OP == OP_ADD arr2D[i] += arr2D_other[i]; #elif GENERIC_OP == OP_SUB arr2D[i] -= arr2D_other[i]; #elif GENERIC_OP == OP_MUL arr2D[i] *= arr2D_other[i]; #elif GENERIC_OP == OP_DIV arr2D[i] /= arr2D_other[i]; #endif } // arr3D *= arr1D (along fast dim) __global__ void inplace_complex_mul_3Dby1D(complex* arr3D, complex* arr1D, int width, int height, int depth) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; int z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= width) || (y >= height) || (z >= depth)) return; // This does not seem to work // Use cuCmulf of cuComplex.h ? //~ arr3D[(z*height + y)*width + x] *= arr1D[x]; int i = (z*height + y)*width + x; complex a = arr3D[i]; complex b = arr1D[x]; arr3D[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr3D[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } // arr2D *= arr2D __global__ void inplace_complex_mul_2Dby2D(complex* arr2D_out, complex* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; int i = y*width + x; complex a = arr2D_out[i]; complex b = arr2D_other[i]; arr2D_out[i]._M_re = a._M_re * b._M_re - a._M_im * b._M_im; arr2D_out[i]._M_im = a._M_im * b._M_re + a._M_re * b._M_im; } // arr2D *= arr2D __global__ void inplace_complexreal_mul_2Dby2D(complex* arr2D_out, float* arr2D_other, int width, int height) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= width) || (y >= height)) return; int i = y*width + x; complex a = arr2D_out[i]; float b = arr2D_other[i]; arr2D_out[i]._M_re *= b; arr2D_out[i]._M_im *= b; } /* Kernel used for CTF phase retrieval img_f = img_f * filter_num img_f[0, 0] -= mean_scale_factor * filter_num[0,0] img_f = img_f * filter_denom where mean_scale_factor = Nx*Ny */ __global__ void CTF_kernel( complex* image, float* filter_num, float* filter_denom, float mean_scale_factor, int Nx, int Ny ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx) || (y >= Ny)) return; uint idx = y*Nx + x; image[idx] *= filter_num[idx]; if (idx == 0) image[idx] -= mean_scale_factor; image[idx] *= filter_denom[idx]; } #ifndef DO_CLIP_MIN #define DO_CLIP_MIN 0 #endif #ifndef DO_CLIP_MAX #define DO_CLIP_MAX 0 #endif // arr = -log(arr) __global__ void nlog(float* array, int Nx, int Ny, int Nz, float clip_min, float clip_max) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t y = blockDim.y * blockIdx.y + threadIdx.y; size_t z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return; size_t pos = (z*Ny + y)*Nx + x; float val = array[pos]; #if DO_CLIP_MIN val = fmaxf(val, clip_min); #endif #if DO_CLIP_MAX val = fminf(val, clip_max); #endif array[pos] = -logf(val); } // Reverse elements of a 2D array along "x", i.e: // arr = arr[:, ::-1] // launched with grid (Nx/2, Ny) __global__ void reverse2D_x(float* array, int Nx, int Ny) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx/2) || (y >= Ny)) return; uint pos = y*Nx + x; uint pos2 = y*Nx + (Nx - 1 - x); float tmp = array[pos]; array[pos] = array[pos2]; array[pos2] = tmp; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1641475283.0 nabu-2023.1.1/nabu/cuda/src/backproj.cu0000644000175000017500000001204200000000000016777 0ustar00pierrepierre#ifndef SHARED_SIZE #define SHARED_SIZE 256 #endif texture tex_projections; #ifdef CLIP_OUTER_CIRCLE inline __device__ int is_in_circle(float x, float y, float center_x, float center_y, int radius2) { return (((x - center_x)*(x - center_x) + (y - center_y)*(y - center_y)) <= radius2); } #endif /** Implementation details ----------------------- This implementation uses two pre-computed arrays in global memory: cos(theta) -> d_cos -sin(theta) -> d_msin As the backprojection is voxel-driven, each thread will, at some point, need cos(theta) and -sin(theta) for *all* theta. Thus, we need to pre-fetch d_cos and d_msin in the fastest cached memory. Here we use the shared memory (faster than constant memory and texture). Each thread group will pre-fetch values from d_cos and d_msin to shared memory Initially, we fetched as much values as possible, ending up in a block of 1024 threads (32, 32). However, it turns out that performances are best with (16, 16) blocks. **/ // Backproject one sinogram // One thread handles up to 4 pixels in the output slice __global__ void backproj( float* d_slice, int num_projs, int num_bins, float axis_position, int n_x, int n_y, int offset_x, int offset_y, float* d_cos, float* d_msin, #ifdef DO_AXIS_CORRECTION float* d_axis_corr, #endif float scale_factor ) { int x = offset_x + blockDim.x * blockIdx.x + threadIdx.x; int y = offset_y + blockDim.y * blockIdx.y + threadIdx.y; int Gx = blockDim.x * gridDim.x; int Gy = blockDim.y * gridDim.y; // (xr, yr) (xrp, yr) // (xr, yrp) (xrp, yrp) float xr = x - axis_position, yr = y - axis_position; float xrp = xr + Gx, yrp = yr + Gy; /*volatile*/ __shared__ float s_cos[SHARED_SIZE]; /*volatile*/ __shared__ float s_msin[SHARED_SIZE]; #ifdef DO_AXIS_CORRECTION /*volatile*/ __shared__ float s_axis[SHARED_SIZE]; float axcorr; #endif int next_fetch = 0; int tid = threadIdx.y * blockDim.x + threadIdx.x; float costheta, msintheta; float h1, h2, h3, h4; float sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f, sum4 = 0.0f; for (int proj = 0; proj < num_projs; proj++) { if (proj == next_fetch) { // Fetch SHARED_SIZE values to shared memory __syncthreads(); if (next_fetch + tid < num_projs) { s_cos[tid] = d_cos[next_fetch + tid]; s_msin[tid] = d_msin[next_fetch + tid]; #ifdef DO_AXIS_CORRECTION s_axis[tid] = d_axis_corr[next_fetch + tid]; #endif } next_fetch += SHARED_SIZE; __syncthreads(); } costheta = s_cos[proj - (next_fetch - SHARED_SIZE)]; msintheta = s_msin[proj - (next_fetch - SHARED_SIZE)]; #ifdef DO_AXIS_CORRECTION axcorr = s_axis[proj - (next_fetch - SHARED_SIZE)]; #endif float c1 = fmaf(costheta, xr, axis_position); // cos(theta)*xr + axis_pos float c2 = fmaf(costheta, xrp, axis_position); // cos(theta)*(xr + Gx) + axis_pos float s1 = fmaf(msintheta, yr, 0.0f); // -sin(theta)*yr float s2 = fmaf(msintheta, yrp, 0.0f); // -sin(theta)*(yr + Gy) h1 = c1 + s1; h2 = c2 + s1; h3 = c1 + s2; h4 = c2 + s2; #ifdef DO_AXIS_CORRECTION h1 += axcorr; h2 += axcorr; h3 += axcorr; h4 += axcorr; #endif if (h1 >= 0 && h1 < num_bins) sum1 += tex2D(tex_projections, h1 + 0.5f, proj + 0.5f); if (h2 >= 0 && h2 < num_bins) sum2 += tex2D(tex_projections, h2 + 0.5f, proj + 0.5f); if (h3 >= 0 && h3 < num_bins) sum3 += tex2D(tex_projections, h3 + 0.5f, proj + 0.5f); if (h4 >= 0 && h4 < num_bins) sum4 += tex2D(tex_projections, h4 + 0.5f, proj + 0.5f); } x -= offset_x; y -= offset_y; int write_topleft = 1, write_topright = 1, write_botleft = 1, write_botright = 1; #ifdef CLIP_OUTER_CIRCLE float center_x = (n_x - 1)/2.0f, center_y = (n_y - 1)/2.0f; int radius2 = min(n_x/2, n_y/2); radius2 *= radius2; write_topleft = is_in_circle(x, y, center_x, center_y, radius2); write_topright = is_in_circle(x + Gx, y, center_x, center_y, radius2); write_botleft = is_in_circle(x, y + Gy, center_x, center_y, radius2); write_botright = is_in_circle(x + Gy, y + Gy, center_x, center_y, radius2); #endif // useful only if n_x < blocksize_x or n_y < blocksize_y if (x >= n_x) return; if (y >= n_y) return; // Pixels in top-left quadrant if (write_topleft) d_slice[y*(n_x) + x] = sum1 * scale_factor; // Pixels in top-right quadrant if ((Gx + x < n_x) && (write_topright)) { d_slice[y*(n_x) + Gx + x] = sum2 * scale_factor; } if (Gy + y < n_y) { // Pixels in bottom-left quadrant if (write_botleft) d_slice[(y+Gy)*(n_x) + x] = sum3 * scale_factor; // Pixels in bottom-right quadrant if ((Gx + x < n_x) && (write_botright)) d_slice[(y+Gy)*(n_x) + Gx + x] = sum4 * scale_factor; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1621525307.0 nabu-2023.1.1/nabu/cuda/src/backproj_polar.cu0000644000175000017500000000347300000000000020204 0ustar00pierrepierre#ifndef SHARED_SIZE #define SHARED_SIZE 256 #endif texture tex_projections; __global__ void backproj_polar( float* d_slice, int num_projs, int num_bins, float axis_position, int n_x, int n_y, int offset_x, int offset_y, float* d_cos, float* d_msin, float scale_factor ) { int i_r = offset_x + blockDim.x * blockIdx.x + threadIdx.x; int i_theta = offset_y + blockDim.y * blockIdx.y + threadIdx.y; float r = i_r - axis_position; float x = r * d_cos[i_theta]; float y = - r * d_msin[i_theta]; /*volatile*/ __shared__ float s_cos[SHARED_SIZE]; /*volatile*/ __shared__ float s_msin[SHARED_SIZE]; int next_fetch = 0; int tid = threadIdx.y * blockDim.x + threadIdx.x; float costheta, msintheta; float h1; float sum1 = 0.0f; for (int proj = 0; proj < num_projs; proj++) { if (proj == next_fetch) { // Fetch SHARED_SIZE values to shared memory __syncthreads(); if (next_fetch + tid < num_projs) { s_cos[tid] = d_cos[next_fetch + tid]; s_msin[tid] = d_msin[next_fetch + tid]; } next_fetch += SHARED_SIZE; __syncthreads(); } costheta = s_cos[proj - (next_fetch - SHARED_SIZE)]; msintheta = s_msin[proj - (next_fetch - SHARED_SIZE)]; float c1 = fmaf(costheta, x, axis_position); // cos(theta)*xr + axis_pos float s1 = fmaf(msintheta, y, 0.0f); // -sin(theta)*yr h1 = c1 + s1; if (h1 >= 0 && h1 < num_bins) sum1 += tex2D(tex_projections, h1 + 0.5f, proj + 0.5f); } // useful only if n_x < blocksize_x or n_y < blocksize_y if (i_r >= n_x) return; if (i_theta >= n_y) return; d_slice[i_theta*(n_x) + i_r] = sum1 * scale_factor; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/cuda/src/boundary.h0000644000175000017500000000552200000000000016654 0ustar00pierrepierre#ifndef BOUNDARY_H #define BOUNDARY_H // Get the center index of the filter, // and the "half-Left" and "half-Right" lengths. // In the case of an even-sized filter, the center is shifted to the left. #define GET_CENTER_HL(hlen){\ if (hlen & 1) {\ c = hlen/2;\ hL = c;\ hR = c;\ }\ else {\ c = hlen/2 - 1;\ hL = c;\ hR = c+1;\ }\ }\ // Boundary handling modes #define CONV_MODE_REFLECT 0 // cba|abcd|dcb #define CONV_MODE_NEAREST 1 // aaa|abcd|ddd #define CONV_MODE_WRAP 2 // bcd|abcd|abc #define CONV_MODE_CONSTANT 3 // 000|abcd|000 #ifndef USED_CONV_MODE #define USED_CONV_MODE CONV_MODE_NEAREST #endif #define CONV_PERIODIC_IDX_X int idx_x = gidx - c + jx; if (idx_x < 0) idx_x += Nx; if (idx_x >= Nx) idx_x -= Nx; #define CONV_PERIODIC_IDX_Y int idx_y = gidy - c + jy; if (idx_y < 0) idx_y += Ny; if (idx_y >= Ny) idx_y -= Ny; #define CONV_PERIODIC_IDX_Z int idx_z = gidz - c + jz; if (idx_z < 0) idx_z += Nz; if (idx_z >= Nz) idx_z -= Nz; // clamp not in cuda __device__ int clamp(int x, int min_, int max_) { return min(max(x, min_), max_); } #define CONV_NEAREST_IDX_X int idx_x = clamp((int) (gidx - c + jx), 0, Nx-1); #define CONV_NEAREST_IDX_Y int idx_y = clamp((int) (gidy - c + jy), 0, Ny-1); #define CONV_NEAREST_IDX_Z int idx_z = clamp((int) (gidz - c + jz), 0, Nz-1); #define CONV_REFLECT_IDX_X int idx_x = gidx - c + jx; if (idx_x < 0) idx_x = -idx_x-1; if (idx_x >= Nx) idx_x = Nx-(idx_x-(Nx-1)); #define CONV_REFLECT_IDX_Y int idx_y = gidy - c + jy; if (idx_y < 0) idx_y = -idx_y-1; if (idx_y >= Ny) idx_y = Ny-(idx_y-(Ny-1)); #define CONV_REFLECT_IDX_Z int idx_z = gidz - c + jz; if (idx_z < 0) idx_z = -idx_z-1; if (idx_z >= Nz) idx_z = Nz-(idx_z-(Nz-1)); #if USED_CONV_MODE == CONV_MODE_REFLECT #define CONV_IDX_X CONV_REFLECT_IDX_X #define CONV_IDX_Y CONV_REFLECT_IDX_Y #define CONV_IDX_Z CONV_REFLECT_IDX_Z #elif USED_CONV_MODE == CONV_MODE_NEAREST #define CONV_IDX_X CONV_NEAREST_IDX_X #define CONV_IDX_Y CONV_NEAREST_IDX_Y #define CONV_IDX_Z CONV_NEAREST_IDX_Z #elif USED_CONV_MODE == CONV_MODE_WRAP #define CONV_IDX_X CONV_PERIODIC_IDX_X #define CONV_IDX_Y CONV_PERIODIC_IDX_Y #define CONV_IDX_Z CONV_PERIODIC_IDX_Z #elif USED_CONV_MODE == CONV_MODE_CONSTANT #error "constant not implemented yet" #else #error "Unknown convolution mode" #endif // Image access patterns #define READ_IMAGE_1D_X input[(gidz*Ny + gidy)*Nx + idx_x] #define READ_IMAGE_1D_Y input[(gidz*Ny + idx_y)*Nx + gidx] #define READ_IMAGE_1D_Z input[(idx_z*Ny + gidy)*Nx + gidx] #define READ_IMAGE_2D_XY input[(gidz*Ny + idx_y)*Nx + idx_x] #define READ_IMAGE_2D_XZ input[(idx_z*Ny + gidy)*Nx + idx_x] #define READ_IMAGE_2D_YZ input[(idx_z*Ny + idx_y)*Nx + gidx] #define READ_IMAGE_3D_XYZ input[(idx_z*Ny + idx_y)*Nx + idx_x] #endif ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1596638264.0 nabu-2023.1.1/nabu/cuda/src/convolution.cu0000644000175000017500000001627100000000000017573 0ustar00pierrepierre/* * Convolution (without textures) * Adapted from OpenCL code of the the silx project * */ #include "boundary.h" typedef unsigned int uint; /******************************************************************************/ /**************************** 1D Convolution **********************************/ /******************************************************************************/ // Convolution with 1D kernel along axis "X" (fast dimension) // Works for batched 1D on 2D and batched 2D on 3D, along axis "X". __global__ void convol_1D_X( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_1D_X * filter[L-1 - jx]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 1D kernel along axis "Y" // Works for batched 1D on 2D and batched 2D on 3D, along axis "Y". __global__ void convol_1D_Y( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" sum += READ_IMAGE_1D_Y * filter[L-1 - jy]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 1D kernel along axis "Z" // Works for batched 1D on 2D and batched 2D on 3D, along axis "Z". __global__ void convol_1D_Z( float * input, float * output, float * filter, int L, // filter size int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(L); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" sum += READ_IMAGE_1D_Z * filter[L-1 - jz]; } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } /******************************************************************************/ /**************************** 2D Convolution **********************************/ /******************************************************************************/ // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_XY( float * input, float * output, float * filter, int Lx, // filter number of columns, int Ly, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_2D_XY * filter[(Ly-1-jy)*Lx + (Lx-1 - jx)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_XZ( float * input, float * output, float * filter, int Lx, // filter number of columns, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_2D_XZ * filter[(Lz-1-jz)*Lx + (Lx-1 - jx)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } // Convolution with 2D kernel // Works for batched 2D on 3D. __global__ void convol_2D_YZ( float * input, float * output, float * filter, int Ly, // filter number of columns, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Ly); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" sum += READ_IMAGE_2D_YZ * filter[(Lz-1-jz)*Ly + (Ly-1 - jy)]; } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } /******************************************************************************/ /**************************** 3D Convolution **********************************/ /******************************************************************************/ // Convolution with 3D kernel __global__ void convol_3D_XYZ( float * input, float * output, float * filter, int Lx, // filter number of columns, int Ly, // filter number of rows, int Lz, // filter number of rows, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(Lx); float sum = 0.0f; for (int jz = 0; jz <= hR+hL; jz++) { CONV_IDX_Z; // Get index "z" for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" sum += READ_IMAGE_3D_XYZ * filter[((Lz-1-jz)*Ly + (Ly-1-jy))*Lx + (Lx-1 - jx)]; } } } output[(gidz*Ny + gidy)*Nx + gidx] = sum; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/cuda/src/fftshift.cu0000644000175000017500000000425400000000000017027 0ustar00pierrepierre#include #define BLOCK_SIZE 16 __global__ void dfi_cuda_swap_quadrants_complex(cufftComplex *input, cufftComplex *output, int dim_x) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_y = gridDim.y * blockDim.y; //a half of real length output[idy * dim_x + idx] = input[(dim_y + idy) * dim_x + idx + 1]; output[(dim_y + idy) * dim_x + idx] = input[idy * dim_x + idx + 1]; } __global__ void dfi_cuda_swap_quadrants_real(cufftReal *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_x = gridDim.x * blockDim.x; int dim_x2 = dim_x/2, dim_y2 = dim_x2; long sw_idx1, sw_idx2; sw_idx1 = idy * dim_x + idx; cufftReal temp = output[sw_idx1]; if (idx < dim_x2) { sw_idx2 = (dim_y2 + idy) * dim_x + (dim_x2 + idx); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } else { sw_idx2 = (dim_y2 + idy) * dim_x + (idx - dim_x2); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } } __global__ void swap_full_quadrants_complex(cufftComplex *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; const int dim_x = gridDim.x * blockDim.x; int dim_x2 = dim_x/2, dim_y2 = dim_x2; long sw_idx1, sw_idx2; sw_idx1 = idy * dim_x + idx; cufftComplex temp = output[sw_idx1]; if (idx < dim_x2) { sw_idx2 = (dim_y2 + idy) * dim_x + (dim_x2 + idx); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } else { sw_idx2 = (dim_y2 + idy) * dim_x + (idx - dim_x2); output[sw_idx1] = output[sw_idx2]; output[sw_idx2] = temp; } } __global__ void dfi_cuda_crop_roi(cufftReal *input, int x, int y, int roi_x, int roi_y, int raster_size, float scale, cufftReal *output) { int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; int idy = blockIdx.y * BLOCK_SIZE + threadIdx.y; if (idx < roi_x && idy < roi_y) { output[idy * roi_x + idx] = input[(idy + y) * raster_size + (idx + x)] * scale; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/cuda/src/flatfield.cu0000644000175000017500000000401200000000000017134 0ustar00pierrepierre#ifndef N_FLATS #error "Please provide the N_FLATS variable" #endif #ifndef N_DARKS #error "Please provide the N_FLATS variable" #endif /** * In-place flat-field normalization with linear interpolation. * This kernel assumes that all the radios are loaded into memory * (although not necessarily the full radios images) * and in radios[x, y z], z in the radio index * * radios: 3D array * flats: 3D array * darks: 3D array * Nx: number of pixel horizontally in the radios * Nx: number of pixel vertically in the radios * Nx: number of radios * flats_indices: indices of flats to fetch for each radio * flats_weights: weights of flats for each radio * darks_indices: indices of darks, in sorted order **/ __global__ void flatfield_normalization( float* radios, float* flats, float* darks, int Nx, int Ny, int Nz, int* flats_indices, float* flats_weights ) { size_t x = blockDim.x * blockIdx.x + threadIdx.x; size_t y = blockDim.y * blockIdx.y + threadIdx.y; size_t z = blockDim.z * blockIdx.z + threadIdx.z; if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return; size_t pos = (z*Ny+y)*Nx + x; float dark_val = 0.0f, flat_val = 1.0f; #if N_FLATS == 1 flat_val = flats[y*Nx + x]; #else int prev_idx = flats_indices[z*2 + 0]; int next_idx = flats_indices[z*2 + 1]; float w1 = flats_weights[z*2 + 0]; float w2 = flats_weights[z*2 + 1]; if (next_idx == -1) { flat_val = flats[(prev_idx*Ny+y)*Nx + x]; } else { flat_val = w1 * flats[(prev_idx*Ny+y)*Nx + x] + w2 * flats[(next_idx*Ny+y)*Nx + x]; } #endif #if (N_DARKS == 1) dark_val = darks[y*Nx + x]; #else // TODO interpolate between darks // Same as above... #error "N_DARKS > 1 is not supported yet" #endif float val = (radios[pos] - dark_val) / (flat_val - dark_val); #ifdef NAN_VALUE if (flat_val == dark_val) val = NAN_VALUE; #endif radios[pos] = val; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/cuda/src/fourier_wavelets.cu0000644000175000017500000000105100000000000020567 0ustar00pierrepierre/** Damping kernel used in the Fourier-Wavelets sinogram destriping method. */ __global__ void kern_fourierwavelets(float2* sinoF, int Nx, int Ny, float wsigma) { int gidx = threadIdx.x + blockIdx.x*blockDim.x; int gidy = threadIdx.y + blockIdx.y*blockDim.y; int Nfft = Ny/2+1; if (gidx >= Nx || gidy >= Nfft) return; float m = gidy/wsigma; float factor = 1.0f - expf(-(m * m)/2); int tid = gidy*Nx + gidx; // do not forget the scale factor (here Ny) sinoF[tid].x *= factor/Ny; sinoF[tid].y *= factor/Ny; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1638967584.0 nabu-2023.1.1/nabu/cuda/src/halftomo.cu0000644000175000017500000000270400000000000017021 0ustar00pierrepierre/* Perform a "half tomography" sinogram conversion. A 360 degrees sinogram is converted to a 180 degrees sinogram with a field of view extended (at most) twice". * Parameters: * sinogram: the 360 degrees sinogram, shape (n_angles, n_x) * output: the 160 degrees sinogram, shape (n_angles/2, rotation_axis_position * 2) * weights: an array of weight, size n_x - rotation_axis_position */ __global__ void halftomo_kernel( float* sinogram, float* output, float* weights, int n_angles, int n_x, int rotation_axis_position ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; int n_a2 = (n_angles + 1) / 2; int d = n_x - rotation_axis_position; int n_x2 = 2 * rotation_axis_position; int r = rotation_axis_position; if ((x >= n_x2) || (y >= n_a2)) return; // output[:, :r - d] = sino[:n_a2, :r - d] if (x < r - d) { output[y * n_x2 + x] = sinogram[y * n_x + x]; } // output[:, r-d:r+d] = (1 - weights) * sino[:n_a2, r-d:] else if (x < r+d) { float w = weights[x - (r - d)]; output[y * n_x2 + x] = (1.0f - w) * sinogram[y*n_x + x] \ + w * sinogram[(n_a2 + y)*n_x + (n_x2 - 1 - x)]; } // output[:, nx:] = sino[n_a2:, ::-1][:, 2 * d :] = sino[n_a2:, -2*d-1:-n_x-1:-1] else { output[y * n_x2 + x] = sinogram[(n_a2 + y)*n_x + (n_x2 - 1 - x)]; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/src/helical_padding.cu0000644000175000017500000000373600000000000020305 0ustar00pierrepierre // see nabu/pipeline/helical/filtering.py for details __global__ void padding( float* data, int* mirror_indexes, #if defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) || defined(MIRROR_EDGES_VARIABLE_ROT_POS) int *rot_axis_pos, #else int rot_axis_pos, #endif int Nx, int Ny, int Nx_padded, int pad_left_len, int pad_right_len #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) ,float pad_left_val, float pad_right_val #endif ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny) || x < Nx) return; int idx = y*Nx_padded + x; int y_mirror = mirror_indexes[y]; int x_mirror =0 ; #if defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) || defined(MIRROR_EDGES_VARIABLE_ROT_POS) int two_rots = rot_axis_pos[y] + rot_axis_pos[y_mirror]; #else int two_rots = 2*rot_axis_pos ; #endif if( two_rots > Nx) { x_mirror = two_rots - x ; if (x_mirror < 0 ) { #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) if( x < Nx_padded - pad_left_len) { data[idx] = pad_left_val; } else { data[idx] = pad_right_val; } #else if( x < Nx_padded - pad_left_len) { data[idx] = data[y_mirror*Nx_padded + 0]; } else { data[idx] = data[y*Nx_padded + 0]; } #endif } else { data[idx] = data[y_mirror*Nx_padded + x_mirror]; } } else { x_mirror = two_rots - (x - Nx_padded) ; if (x_mirror > Nx-1 ) { #if defined(MIRROR_CONSTANT) || defined(MIRROR_CONSTANT_VARIABLE_ROT_POS) if( x < Nx_padded - pad_left_len) { data[idx] = pad_left_val ; } else { data[idx] = pad_right_val; } #else if( x < Nx_padded - pad_left_len) { data[idx] = data[y*Nx_padded + Nx - 1 ]; } else { data[idx] = data[y_mirror*Nx_padded + Nx-1]; } #endif } else { data[idx] = data[y_mirror*Nx_padded + x_mirror]; } } return; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1603795396.0 nabu-2023.1.1/nabu/cuda/src/histogram.cu0000644000175000017500000000145000000000000017202 0ustar00pierrepierretypedef unsigned int uint; __global__ void histogram( float * array, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz, // input/output depth float arr_min, // array minimum value float arr_max, // array maximum value uint* hist, // histogram int nbins // histogram size (number of bins) ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; float val = array[(gidz*Ny + gidy)*Nx + gidx]; float bin_pos = nbins * ((val - arr_min) / (arr_max - arr_min)); uint bin_left = min((uint) bin_pos, nbins-1); atomicAdd(hist + bin_left, 1); } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/cuda/src/hst_backproj.cu0000644000175000017500000000454500000000000017666 0ustar00pierrepierretexture texProj; //~ cudaChannelFormatDesc floatTex = cudaCreateChannelDesc(); __global__ void backproj( int num_proj, int num_bins, float axis_position, float *d_SLICE, float gpu_offset_x, float gpu_offset_y, float * d_cos_s, float * d_sin_s, float * d_axis_s) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; const int tidy = threadIdx.y; const int bidy = blockIdx.y; __shared__ float shared[768]; float * sh_sin = shared; float * sh_cos = shared+256; float * sh_axis = sh_cos+256; float pcos, psin; float h0, h1, h2, h3; const float apos_off_x = gpu_offset_x - axis_position ; const float apos_off_y = gpu_offset_y - axis_position ; float acorr05; float res0 = 0, res1 = 0, res2 = 0, res3 = 0; const float bx00 = (32 * bidx + 2 * tidx + apos_off_x); const float by00 = (32 * bidy + 2 * tidy + apos_off_y); int read=0; for (int proj=0; proj=read) { __syncthreads(); int ip = tidy*16+tidx; if( read+ip < num_proj) { sh_cos [ip] = d_cos_s[read+ip]; sh_sin [ip] = d_sin_s[read+ip]; sh_axis[ip] = d_axis_s[read+ip]; } read = read + 256; // 256=16*16 block size __syncthreads(); } pcos = sh_cos[256 - read + proj] ; psin = sh_sin[256 - read + proj] ; acorr05 = sh_axis[256 - read + proj]; h0 = acorr05 + bx00*pcos - by00*psin; h1 = acorr05 + bx00*pcos - (by00+1)*psin; h2 = acorr05 + (bx00+1)*pcos - by00*psin; h3 = acorr05 + (bx00+1)*pcos - (by00+1)*psin; if(h0 >= 0 && h0 < num_bins) res0 += tex2D(texProj, h0 + 0.5f, proj + 0.5f); if(h1>=0 && h1=0 && h2=0 && h3= Nx) || (i >= Ny)) return; int extrapolate_side = x_new[0] > x[0] ? 1 : 0; // 0: left, 1: right float dx, dy; if (i == 0) extrapolate_side = 1; else if (i == Ny - 1) extrapolate_side = 0; int extrapolate_side_compl = 1 - extrapolate_side; dx = x[i+extrapolate_side] - x[i - extrapolate_side_compl]; dy = arr2D[(i + extrapolate_side) * Nx + c] - arr2D[(i - extrapolate_side_compl) * Nx + c]; out[i * Nx + c] = (dy / dx) * (x_new[i] - x[i]) + arr2D[i * Nx + c]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1628752049.0 nabu-2023.1.1/nabu/cuda/src/medfilt.cu0000644000175000017500000000462100000000000016634 0ustar00pierrepierre#include "boundary.h" typedef unsigned int uint; #ifndef MEDFILT_X #define MEDFILT_X 3 #endif #ifndef MEDFILT_Y #define MEDFILT_Y 3 #endif #ifndef DO_THRESHOLD #define DO_THRESHOLD 0 #endif // General-purpose 2D (or batched 2D) median filter with a square footprint. // Boundary handling is customized via the USED_CONV_MODE macro (see boundary.h) // Most of the time is spent computing the median, so this kernel can be sped up by // - creating dedicated kernels for 3x3, 5x5 (see http://ndevilla.free.fr/median/median/src/optmed.c) // - Using a quickselect algorithm instead of sorting (see http://ndevilla.free.fr/median/median/src/quickselect.c) __global__ void medfilt2d( float * input, float * output, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz, // input/output depth float threshold // threshold for thresholded median filter ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= Nx) || (gidy >= Ny) || (gidz >= Nz)) return; int c, hL, hR; GET_CENTER_HL(MEDFILT_X); // Get elements in a 3x3 neighborhood float elements[MEDFILT_X*MEDFILT_Y] = {0}; for (int jy = 0; jy <= hR+hL; jy++) { CONV_IDX_Y; // Get index "y" for (int jx = 0; jx <= hR+hL; jx++) { CONV_IDX_X; // Get index "x" elements[jy*MEDFILT_Y+jx] = READ_IMAGE_2D_XY; } } // Sort the elements with insertion sort // TODO quickselect ? int i = 1, j; while (i < MEDFILT_X*MEDFILT_Y) { j = i; while (j > 0 && elements[j-1] > elements[j]) { float tmp = elements[j]; elements[j] = elements[j-1]; elements[j-1] = tmp; j--; } i++; } float median = elements[MEDFILT_X*MEDFILT_Y/2]; #if DO_THRESHOLD == 1 float out_val = 0.0f; uint idx = (gidz*Ny + gidy)*Nx + gidx; if (input[idx] >= median + threshold) out_val = median; else out_val = input[idx]; output[idx] = out_val; #elif DO_THRESHOLD == 2 float out_val = 0.0f; uint idx = (gidz*Ny + gidy)*Nx + gidx; if (fabsf(input[idx] - median) > threshold) out_val = median; else out_val = input[idx]; output[idx] = out_val; #else output[(gidz*Ny + gidy)*Nx + gidx] = median; #endif } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1600668647.0 nabu-2023.1.1/nabu/cuda/src/normalization.cu0000644000175000017500000000257700000000000020106 0ustar00pierrepierretypedef unsigned int uint; /** * Chebyshev background removal. * This kernel does a degree 2 polynomial estimation of each line of an array, * and then subtracts the estimation from each line. * This process is done in-place. */ __global__ void normalize_chebyshev( float * array, int Nx, // input/output number of columns int Ny, // input/output number of rows int Nz // input/output depth ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; uint gidz = blockDim.z * blockIdx.z + threadIdx.z; if ((gidx >= 1) || (gidy >= Ny) || (gidz >= Nz)) return; float ff0=0.0f, ff1=0.0f, ff2=0.0f; float sum0=0.0f, sum1=0.0f, sum2=0.0f; float f0, f1, f2, x; for (int j=0; j < Nx; j++) { uint pos = (gidz*Ny + gidy)*Nx + j; float arr_val = array[pos]; x = 2.0f*(j + 0.5f - Nx/2.0f)/Nx; f0 = 1.0f; f1 = x; f2 = (3.0f*x*x-1.0f); ff0 = ff0 + f0 * arr_val; ff1 = ff1 + f1 * arr_val; ff2 = ff2 + f2 * arr_val; sum0 += f0 * f0; sum1 += f1 * f1; sum2 += f2 * f2; } for (int j=0; j< Nx; j++) { uint pos = (gidz*Ny + gidy)*Nx + j; x = 2.0f*(j+0.5f-Nx/2.0f)/Nx; f0 = 1.0f; f1 = x; f2 = (3.0f*x*x-1.0f); array[pos] -= ff0*f0/sum0 + ff1*f1/sum1 + ff2*f2/sum2; } } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1671864494.0 nabu-2023.1.1/nabu/cuda/src/padding.cu0000644000175000017500000001051000000000000016610 0ustar00pierrepierre#include typedef pycuda::complex complex; typedef unsigned int uint; /** This function padds in-place a 2D array with constant values. It is designed to leave the data in the "FFT layout", i.e the data is *not* in the center of the extended/padded data. In one dimension: <--------------- N0 ----------------> | original data | padded values | <----- N -------- ><---- Pl+Pr -----> N0: width of data Pl, Pr: left/right padding lengths ASSUMPTIONS: - data is already extended before padding (its size is Nx_padded * Ny_padded) - the original data lies in the top-left quadrant. **/ __global__ void padding_constant( float* data, int Nx, int Ny, int Nx_padded, int Ny_padded, int pad_left_len, int pad_right_len, int pad_top_len, int pad_bottom_len, float pad_left_val, float pad_right_val, float pad_top_val, float pad_bottom_val ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; int idx = y*Nx_padded + x; // data[s0:s0+Pd, :s1] = pad_bottom_val if ((Ny <= y) && (y < Ny+pad_bottom_len) && (x < Nx)) data[idx] = pad_bottom_val; // data[s0+Pd:s0+Pd+Pu, :s1] = pad_top_val else if ((Ny + pad_bottom_len <= y) && (y < Ny+pad_bottom_len+pad_top_len) && (x < Nx)) data[idx] = pad_top_val; // data[:, s1:s1+Pr] = pad_right_val else if ((Nx <= x) && (x < Nx+pad_right_len)) data[idx] = pad_right_val; // data[:, s1+Pr:s1+Pr+Pl] = pad_left_val else if ((Nx+pad_right_len <= x) && (x < Nx+pad_right_len+pad_left_len)) data[idx] = pad_left_val; // top-left quadrant else return; } __global__ void padding_edge( float* data, int Nx, int Ny, int Nx_padded, int Ny_padded, int pad_left_len, int pad_right_len, int pad_top_len, int pad_bottom_len ) { int x = blockDim.x * blockIdx.x + threadIdx.x; int y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; int idx = y*Nx_padded + x; // // This kernel can be optimized: // - Optimize the logic to use less comparisons // - Store the values data[0], data[s0-1, 0], data[0, s1-1], data[s0-1, s1-1] // into shared memory to read only once from global mem. // // data[s0:s0+Pd, :s1] = data[s0, :s1] if ((Ny <= y) && (y < Ny+pad_bottom_len) && (x < Nx)) data[idx] = data[(Ny-1)*Nx_padded+x]; // data[s0+Pd:s0+Pd+Pu, :s1] = data[0, :s1] else if ((Ny + pad_bottom_len <= y) && (y < Ny+pad_bottom_len+pad_top_len) && (x < Nx)) data[idx] = data[x]; // data[:s0, s1:s1+Pr] = data[:s0, s1] else if ((y < Ny) && (Nx <= x) && (x < Nx+pad_right_len)) data[idx] = data[y*Nx_padded + Nx-1]; // data[:s0, s1+Pr:s1+Pr+Pl] = data[:s0, 0] else if ((y < Ny) && (Nx+pad_right_len <= x) && (x < Nx+pad_right_len+pad_left_len)) data[idx] = data[y*Nx_padded]; // data[s0:s0+Pb, s1:s1+Pr] = data[s0-1, s1-1] else if ((Ny <= y && y < Ny + pad_bottom_len) && (Nx <= x && x < Nx + pad_right_len)) data[idx] = data[(Ny-1)*Nx_padded + Nx-1]; // data[s0:s0+Pb, s1+Pr:s1+Pr+Pl] = data[s0-1, 0] else if ((Ny <= y && y < Ny + pad_bottom_len) && (Nx+pad_right_len <= x && x < Nx + pad_right_len+pad_left_len)) data[idx] = data[(Ny-1)*Nx_padded]; // data[s0+Pb:s0+Pb+Pu, s1:s1+Pr] = data[0, s1-1] else if ((Ny+pad_bottom_len <= y && y < Ny + pad_bottom_len+pad_top_len) && (Nx <= x && x < Nx + pad_right_len)) data[idx] = data[Nx-1]; // data[s0+Pb:s0+Pb+Pu, s1+Pr:s1+Pr+Pl] = data[0, 0] else if ((Ny+pad_bottom_len <= y && y < Ny + pad_bottom_len+pad_top_len) && (Nx+pad_right_len <= x && x < Nx + pad_right_len+pad_left_len)) data[idx] = data[0]; // top-left quadrant else return; } __global__ void coordinate_transform( float* array_in, float* array_out, int* cols_inds, int* rows_inds, int Nx, int Nx_padded, int Ny_padded ) { uint x = blockDim.x * blockIdx.x + threadIdx.x; uint y = blockDim.y * blockIdx.y + threadIdx.y; if ((x >= Nx_padded) || (y >= Ny_padded)) return; uint idx = y*Nx_padded + x; int x2 = cols_inds[idx]; int y2 = rows_inds[idx]; array_out[idx] = array_in[y2*Nx + x2]; } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1614847836.0 nabu-2023.1.1/nabu/cuda/src/rotation.cu0000644000175000017500000000122100000000000017040 0ustar00pierrepierretypedef unsigned int uint; texture tex_image; __global__ void rotate( float* output, int Nx, int Ny, float cos_angle, float sin_angle, float rotc_x, float rotc_y ) { uint gidx = blockDim.x * blockIdx.x + threadIdx.x; uint gidy = blockDim.y * blockIdx.y + threadIdx.y; if (gidx >= Nx || gidy >= Ny) return; float x = (gidx - rotc_x)*cos_angle - (gidy - rotc_y)*sin_angle; float y = (gidx - rotc_x)*sin_angle + (gidy - rotc_y)*cos_angle; x += rotc_x; y += rotc_y; float out_val = tex2D(tex_image, x + 0.5f, y + 0.5f); output[gidy * Nx + gidx] = out_val; } ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4567332 nabu-2023.1.1/nabu/cuda/tests/0000755000175000017500000000000000000000000015227 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/cuda/tests/__init__.py0000644000175000017500000000000100000000000017327 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/cuda/tests/test_medfilt.py0000644000175000017500000000524100000000000020266 0ustar00pierrepierreimport pytest import numpy as np from scipy.misc import ascent from scipy.ndimage import median_filter from nabu.testutils import generate_tests_scenarios from nabu.cuda.utils import get_cuda_context, __has_pycuda__ if __has_pycuda__: from nabu.cuda.medfilt import MedianFilter import pycuda.gpuarray as garray scenarios = generate_tests_scenarios( { "input_on_gpu": [False, True], "output_on_gpu": [False, True], "footprint": [(3, 3), (5, 5)], "mode": ["reflect", "nearest", "wrap"], "batched_2d": [False, True], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = np.ascontiguousarray(ascent()[::2, ::2][:-1, :], dtype="f") cls.tol = 1e-7 cls.ctx = get_cuda_context(cleanup_at_exit=False) cls.allocate_numpy_arrays() cls.allocate_cuda_arrays() yield cls.ctx.pop() @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda/pycuda for this test") @pytest.mark.usefixtures("bootstrap") class TestMedianFilter(object): @classmethod def allocate_numpy_arrays(cls): shape = cls.data.shape cls.input = cls.data cls.input3d = np.tile(cls.input, (2, 1, 1)) @classmethod def allocate_cuda_arrays(cls): shape = cls.data.shape cls.d_input = garray.to_gpu(cls.input) cls.d_output = garray.zeros_like(cls.d_input) cls.d_input3d = garray.to_gpu(cls.input3d) cls.d_output3d = garray.zeros_like(cls.d_input3d) # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.parametrize("config", scenarios) def testMedfilt(self, config): if config["input_on_gpu"]: input_data = self.d_input if not (config["batched_2d"]) else self.d_input3d else: input_data = self.input if not (config["batched_2d"]) else self.input3d if config["output_on_gpu"]: output_data = self.d_output if not (config["batched_2d"]) else self.d_output3d else: output_data = None # Cuda median filter medfilt = MedianFilter( input_data.shape, footprint=config["footprint"], mode=config["mode"], cuda_options={"ctx": self.ctx}, ) res = medfilt.medfilt2(input_data, output=output_data) if config["output_on_gpu"]: res = res.get() # Reference (scipy) ref = median_filter(self.input, config["footprint"][0], mode=config["mode"]) max_absolute_error = np.max(np.abs(res - ref)) assert max_absolute_error < self.tol, "Something wrong with configuration %s" % str(config) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/cuda/tests/test_padding.py0000644000175000017500000001567500000000000020264 0ustar00pierrepierreimport numpy as np import pytest from scipy.misc import ascent from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.utils import calc_padding_lengths, get_cuda_srcfile from nabu.testutils import get_data, generate_tests_scenarios if __has_pycuda__: import pycuda.gpuarray as garray from nabu.cuda.kernel import CudaKernel from nabu.cuda.padding import CudaPadding scenarios_legacy = [ { "shape": (512, 501), "shape_padded": (1023, 1022), "constant_values": ((1.0, 2.0), (3.0, 4.0)), }, ] # parametrize with fixture and "params=" will launch a new class for each scenario. # the attributes set to "cls" will remain for all the tests done in this class # with the current scenario. @pytest.fixture(scope="class", params=scenarios_legacy) def bootstrap_legacy(request): cls = request.cls cls.data = get_data("mri_proj_astra.npz")["data"] cls.tol = 1e-7 cls.params = request.param cls.ctx = get_cuda_context(cleanup_at_exit=False) cls._calc_pad() cls._init_kernels() yield cls.ctx.pop() @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda and pycuda for this test") @pytest.mark.usefixtures("bootstrap_legacy") class TestPaddingLegacy: @classmethod def _calc_pad(cls): cls.shape = cls.params["shape"] cls.data = np.ascontiguousarray(cls.data[: cls.shape[0], : cls.shape[1]]) cls.shape_padded = cls.params["shape_padded"] ((pt, pb), (pl, pr)) = calc_padding_lengths(cls.shape, cls.shape_padded) cls.pad_top_len = pt cls.pad_bottom_len = pb cls.pad_left_len = pl cls.pad_right_len = pr @classmethod def _init_kernels(cls): cls.pad_kern = CudaKernel( "padding_constant", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiiiffff", ) cls.pad_edge_kern = CudaKernel( "padding_edge", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiii", ) cls.d_data_padded = garray.zeros(cls.shape_padded, "f") def _init_padding(self, arr=None): arr = arr or self.data self.d_data_padded.fill(0) Ny, Nx = self.shape self.d_data_padded[:Ny, :Nx] = self.data def _pad_numpy(self, arr=None, **np_pad_kwargs): arr = arr or self.data data_padded_ref = np.pad( arr, ((self.pad_top_len, self.pad_bottom_len), (self.pad_left_len, self.pad_right_len)), **np_pad_kwargs ) # Put in the FFT layout data_padded_ref = np.roll(data_padded_ref, (-self.pad_top_len, -self.pad_left_len), axis=(0, 1)) return data_padded_ref def test_constant_padding(self): self._init_padding() # Pad using the cuda kernel ((val_top, val_bottom), (val_left, val_right)) = self.params["constant_values"] Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.pad_kern( self.d_data_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, val_left, val_right, val_top, val_bottom, ) # Pad using numpy data_padded_ref = self._pad_numpy(mode="constant", constant_values=self.params["constant_values"]) # Compare errmax = np.max(np.abs(self.d_data_padded.get() - data_padded_ref)) assert errmax < self.tol, "Max error is too high" def test_edge_padding(self): self._init_padding() # Pad using the cuda kernel ((val_top, val_bottom), (val_left, val_right)) = self.params["constant_values"] Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.pad_edge_kern( self.d_data_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, ) # Pad using numpy data_padded_ref = self._pad_numpy(mode="edge") # Compare errmax = np.max(np.abs(self.d_data_padded.get() - data_padded_ref)) assert errmax < self.tol, "Max error is too high" scenarios = generate_tests_scenarios( { "shape": [(511, 512), (512, 511)], "pad_width": [((256, 255), (128, 127)), ((0, 0), (6, 7))], "mode": CudaPadding.supported_modes if __has_pycuda__ else [], "constant_values": [0, ((1.0, 2.0), (3.0, 4.0))], "output_is_none": [True, False], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = ascent().astype("f") cls.tol = 1e-7 cls.ctx = get_cuda_context(cleanup_at_exit=False) yield cls.ctx.pop() @pytest.mark.skipif(not (__has_pycuda__), reason="Need Cuda and pycuda for this test") @pytest.mark.usefixtures("bootstrap") class TestCudaPadding: @pytest.mark.parametrize("config", scenarios) def test_padding(self, config): shape = config["shape"] data = self.data[: shape[0], : shape[1]] kwargs = {} if config["mode"] == "constant": kwargs["constant_values"] = config["constant_values"] ref = np.pad(data, config["pad_width"], mode=config["mode"], **kwargs) if config["output_is_none"]: output = None else: output = garray.zeros(ref.shape, "f") cuda_padding = CudaPadding( config["shape"], config["pad_width"], mode=config["mode"], constant_values=config["constant_values"], cuda_options={"ctx": self.ctx}, ) d_img = garray.to_gpu(np.ascontiguousarray(data, dtype="f")) res = cuda_padding.pad(d_img, output=output) err_max = np.max(np.abs(res.get() - ref)) assert err_max < self.tol, str("Something wrong with padding for configuration %s" % (str(config))) def test_custom_coordinate_transform(self): data = self.data R, C = np.indices(data.shape, dtype=np.int32) pad_width = ((256, 255), (254, 251)) mode = "reflect" coords_R = np.pad(R, pad_width, mode=mode) coords_C = np.pad(C, pad_width, mode=mode) # Further transform of coordinates - here FFT layout coords_R = np.roll(coords_R, (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)) coords_C = np.roll(coords_C, (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)) cuda_padding = CudaPadding(data.shape, (coords_R, coords_C), mode=mode, cuda_options={"ctx": self.ctx}) d_img = garray.to_gpu(data) d_out = garray.zeros(cuda_padding.padded_shape, "f") res = cuda_padding.pad(d_img, output=d_out) ref = np.roll(np.pad(data, pad_width, mode=mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)) err_max = np.max(np.abs(d_out.get() - ref)) assert err_max < self.tol, "Something wrong with custom padding" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/cuda/utils.py0000644000175000017500000002030000000000000015572 0ustar00pierrepierreimport atexit from math import ceil import numpy as np from ..resources.gpu import GPUDescription try: import pycuda import pycuda.driver as cuda from pycuda import gpuarray as garray from pycuda.tools import make_default_context from pycuda.tools import clear_context_caches __has_pycuda__ = True __pycuda_error_msg__ = None if pycuda.VERSION[0] < 2020: print("Error: need pycuda >= 2020.1") __has_pycuda__ = False except ImportError as err: __has_pycuda__ = False __pycuda_error_msg__ = str(err) try: import skcuda __has_cufft__ = True except ImportError: __has_cufft__ = False def get_cuda_context(device_id=None, cleanup_at_exit=True): """ Create or get a CUDA context. """ current_ctx = cuda.Context.get_current() # If a context already exists, use this one # TODO what if the device used is different from device_id ? if current_ctx is not None: return current_ctx # Otherwise create a new context cuda.init() if device_id is None: device_id = 0 # Use the Context obtained by retaining the device's primary context, # which is the one used by the CUDA runtime API (ex. scikit-cuda). # Unlike Context.make_context(), the newly-created context is not made current. context = cuda.Device(device_id).retain_primary_context() context.push() # Register a clean-up function at exit def _finish_up(context): if context is not None: context.pop() context = None clear_context_caches() if cleanup_at_exit: atexit.register(_finish_up, context) return context def count_cuda_devices(): if cuda.Context.get_current() is None: cuda.init() return cuda.Device.count() def get_gpu_memory(device_id): """ Return the total memory (in GigaBytes) of a device. """ cuda.init() return cuda.Device(device_id).total_memory() / 1e9 def is_gpu_usable(): """ Test whether at least one Nvidia GPU is available. """ try: n_gpus = count_cuda_devices() except Exception as exc: # Fragile if exc.__str__() != "cuInit failed: no CUDA-capable device is detected": raise n_gpus = 0 res = n_gpus > 0 return res def detect_cuda_gpus(): """ Detect the available Nvidia CUDA GPUs on the current host. Returns -------- gpus: dict Dictionary where the key is the GPU ID, and the value is a `pycuda.driver.Device` object. error_msg: str In the case where there is an error, the message is returned in this item. Otherwise, it is a None object. """ gpus = {} error_msg = None if not (__has_pycuda__): return {}, __pycuda_error_msg__ try: cuda.init() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg try: n_gpus = cuda.Device.count() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg for i in range(n_gpus): gpus[i] = cuda.Device(i) return gpus, None def collect_cuda_gpus(): """ Return a dictionary of GPU ids and brief description of each CUDA-compatible GPU with a few fields. """ gpus, error_msg = detect_cuda_gpus() if error_msg is not None: return None cuda_gpus = {} for gpu_id, gpu in gpus.items(): cuda_gpus[gpu_id] = GPUDescription(gpu).get_dict() return cuda_gpus """ pycuda/driver.py np.complex64: SIGNED_INT32, num_channels = 2 np.float64: SIGNED_INT32, num_channels = 2 np.complex128: array_format.SIGNED_INT32, num_channels = 4 double precision: pycuda-helpers.hpp: typedef float fp_tex_float; // --> float32 typedef int2 fp_tex_double; // --> float64 typedef uint2 fp_tex_cfloat; // --> complex64 typedef int4 fp_tex_cdouble; // --> complex128 """ def cuarray_format_to_dtype(cuarr_fmt): # reverse of cuda.dtype_to_array_format fmt = cuda.array_format mapping = { fmt.UNSIGNED_INT8: np.uint8, fmt.UNSIGNED_INT16: np.uint16, fmt.UNSIGNED_INT32: np.uint32, fmt.SIGNED_INT8: np.int8, fmt.SIGNED_INT16: np.int16, fmt.SIGNED_INT32: np.int32, fmt.FLOAT: np.float32, } if cuarr_fmt not in mapping: raise TypeError("Unknown format %s" % cuarr_fmt) return mapping[cuarr_fmt] def cuarray_shape_dtype(cuarray): desc = cuarray.get_descriptor_3d() shape = (desc.height, desc.width) if desc.depth > 0: shape = (desc.depth,) + shape dtype = cuarray_format_to_dtype(desc.format) return shape, dtype def get_shape_dtype(arr): if isinstance(arr, garray.GPUArray) or isinstance(arr, np.ndarray): return arr.shape, arr.dtype elif isinstance(arr, cuda.Array): return cuarray_shape_dtype(arr) else: raise ValueError("Unknown array type %s" % str(type(arr))) def copy_array(dst, src, check=False, src_dtype=None): """ Copy a source array to a destination array. Source and destination can be either numpy.ndarray, pycuda.Driver.Array, or pycuda.gpuarray.GPUArray. Parameters ----------- dst: pycuda.driver.Array or pycuda.gpuarray.GPUArray or numpy.ndarray Destination array. Its content will be overwritten by copy. src: pycuda.driver.Array or pycuda.gpuarray.GPUArray or numpy.ndarray Source array. check: bool, optional Whether to check src and dst shape and data type. """ shape_src, dtype_src = get_shape_dtype(src) shape_dst, dtype_dst = get_shape_dtype(dst) dtype_src = src_dtype or dtype_src if check: if shape_src != shape_dst: raise ValueError("shape_src != shape_dst : have %s and %s" % (str(shape_src), str(shape_dst))) if dtype_src != dtype_dst: raise ValueError("dtype_src != dtype_dst : have %s and %s" % (str(dtype_src), str(dtype_dst))) if len(shape_src) == 2: copy = cuda.Memcpy2D() h, w = shape_src elif len(shape_src) == 3: copy = cuda.Memcpy3D() d, h, w = shape_src copy.depth = d else: raise ValueError("Expected arrays with 2 or 3 dimensions") if isinstance(src, cuda.Array): copy.set_src_array(src) elif isinstance(src, garray.GPUArray): copy.set_src_device(src.gpudata) else: # numpy copy.set_src_host(src) if isinstance(dst, cuda.Array): copy.set_dst_array(dst) elif isinstance(dst, garray.GPUArray): copy.set_dst_device(dst.gpudata) else: # numpy copy.set_dst_host(dst) copy.width_in_bytes = copy.dst_pitch = w * np.dtype(dtype_src).itemsize copy.dst_height = copy.height = h # ?? if len(shape_src) == 2: copy(True) else: copy() ### def copy_big_gpuarray(dst, src, itemsize=4, checks=False): """ Copy a big `pycuda.gpuarray.GPUArray` into another. Transactions of more than 2**32 -1 octets fail, so are doing several partial copies of smaller arrays. """ d2h = isinstance(dst, np.ndarray) if checks: assert dst.dtype == src.dtype assert dst.shape == src.shape limit = 2**32 - 1 if np.prod(dst.shape) * itemsize < limit: if d2h: src.get(ary=dst) else: dst[:] = src[:] return def get_shape2(shape): shape2 = list(shape) while np.prod(shape2) * 4 > limit: shape2[0] //= 2 return tuple(shape2) shape2 = get_shape2(dst.shape) nz0 = dst.shape[0] nz = shape2[0] n_transfers = ceil(nz0 / nz) for i in range(n_transfers): zmax = min((i + 1) * nz, nz0) if d2h: src[i * nz : zmax].get(ary=dst[i * nz : zmax]) else: dst[i * nz : zmax] = src[i * nz : zmax] def replace_array_memory(arr, new_shape): """ Replace the underlying buffer data of a `pycuda.gpuarray.GPUArray`. This function is dangerous ! It should merely be used to clear memory, the array should not be used afterwise. """ arr.gpudata.free() arr.gpudata = arr.allocator(int(np.prod(new_shape) * arr.dtype.itemsize)) arr.shape = new_shape # TODO re-compute strides return arr ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4567332 nabu-2023.1.1/nabu/estimation/0000755000175000017500000000000000000000000015325 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/__init__.py0000644000175000017500000000057300000000000017443 0ustar00pierrepierrefrom .alignment import AlignmentBase from .cor import ( CenterOfRotation, CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow, CenterOfRotationAdaptiveSearch, ) from .cor_sino import SinoCor from .distortion import estimate_flat_distortion from .focus import CameraFocus from .tilt import CameraTilt from .translation import DetectorTranslationAlongBeam ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682590397.0 nabu-2023.1.1/nabu/estimation/alignment.py0000644000175000017500000005140500000000000017662 0ustar00pierrepierre# import math import logging import numpy as np from numpy.polynomial.polynomial import Polynomial from scipy.ndimage import median_filter import scipy.fft # pylint: disable=E0611 from ..utils import previouspow2 from ..misc import fourier_filters from ..resources.logger import LoggerOrPrint try: import matplotlib.pyplot as plt __have_matplotlib__ = True except ImportError: logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled") plt = None __have_matplotlib__ = False try: from tqdm import tqdm def progress_bar(x, verbose=True): if verbose: return tqdm(x) else: return x except ImportError: def progress_bar(x, verbose=True): return x local_fftn = scipy.fft.rfftn local_ifftn = scipy.fft.irfftn class AlignmentBase: default_extra_options = {"blocking_plots": False} def __init__( self, vert_fft_width=False, horz_fft_width=False, verbose=False, logger=None, data_type=np.float32, extra_options=None, ): """ Alignment basic functions. Parameters ---------- vert_fft_width: boolean, optional If True, restrict the vertical size to a power of 2: >>> new_v_dim = 2 ** math.floor(math.log2(v_dim)) horz_fft_width: boolean, optional If True, restrict the horizontal size to a power of 2: >>> new_h_dim = 2 ** math.floor(math.log2(h_dim)) verbose: boolean, optional When True it will produce verbose output, including plots. data_type: `numpy.float32` Computation data type. """ self._init_parameters(vert_fft_width, horz_fft_width, verbose, logger, data_type, extra_options=extra_options) self._plot_windows = {} def _init_parameters(self, vert_fft_width, horz_fft_width, verbose, logger, data_type, extra_options=None): self.logger = LoggerOrPrint(logger) self.truncate_vert_pow2 = vert_fft_width self.truncate_horz_pow2 = horz_fft_width if verbose and not __have_matplotlib__: self.logger.warning("Matplotlib not available. Plotting disabled, despite being activated by user") verbose = False self.verbose = verbose self.data_type = data_type self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) @staticmethod def _check_img_stack_size(img_stack: np.ndarray, img_pos: np.ndarray): shape_stack = np.squeeze(img_stack).shape shape_pos = np.squeeze(img_pos).shape if not len(shape_stack) == 3: raise ValueError( "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack))) ) if not len(shape_pos) == 1: raise ValueError( "Positions need to be a 1-dimensional array. Shape of the positions variable: %s" % (" ".join(("%d" % x for x in shape_pos))) ) if not shape_stack[0] == shape_pos[0]: raise ValueError( "The same number of images and positions is required." + " Shape of stack: %s, shape of positions variable: %s" % ( " ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)), ) ) @staticmethod def _check_img_pair_sizes(img_1: np.ndarray, img_2: np.ndarray): shape_1 = np.squeeze(img_1).shape shape_2 = np.squeeze(img_2).shape if not len(shape_1) == 2: raise ValueError( "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1))) ) if not len(shape_2) == 2: raise ValueError( "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2))) ) if not np.all(shape_1 == shape_2): raise ValueError( "Images need to be of the same shape. Shape of image #1: %s, image #2: %s" % ( " ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)), ) ) @staticmethod def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None): """Computes the sub-pixel max position of the given function sampling. Parameters ---------- f_vals: numpy.ndarray Function values of the sampled points fy: numpy.ndarray, optional Vertical coordinates of the sampled points fx: numpy.ndarray, optional Horizontal coordinates of the sampled points Raises ------ ValueError In case position and values do not have the same size, or in case the fitted maximum is outside the fitting region. Returns ------- tuple(float, float) Estimated (vertical, horizontal) function max, according to the coordinates in fy and fx. """ if not (len(f_vals.shape) == 2): raise ValueError( "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in f_vals.shape))) ) if fy is None: fy_half_size = (f_vals.shape[0] - 1) / 2 fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0]) elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])): raise ValueError( "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]" % (fy.size, " ".join(("%d" % s for s in f_vals.shape))) ) if fx is None: fx_half_size = (f_vals.shape[1] - 1) / 2 fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1]) elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])): raise ValueError( "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]" % (fx.size, " ".join(("%d" % s for s in f_vals.shape))) ) fy, fx = np.meshgrid(fy, fx, indexing="ij") fy = fy.flatten() fx = fx.flatten() coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy**2, fx**2]) coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0] # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is: # x_v = -b / 2a. For a 2D parabola, the vertex position is: # (y, x)_v = - b / A, where: A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]] b = coeffs[1:3] vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0] vertex_min_yx = [np.min(fy), np.min(fx)] vertex_max_yx = [np.max(fy), np.max(fx)] if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx): raise ValueError( "Fitted (y: {}, x: {}) positions are outside the input margins y: [{}, {}], and x: [{}, {}]".format( vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1], ) ) return vertex_yx @staticmethod def refine_max_position_1d(f_vals, fx=None, return_vertex_val=False, return_all_coeffs=False): """Computes the sub-pixel max position of the given function sampling. Parameters ---------- f_vals: numpy.ndarray Function values of the sampled points fx: numpy.ndarray, optional Coordinates of the sampled points return_vertex_val: boolean, option Enables returning the vertex values. Defaults to False. Raises ------ ValueError In case position and values do not have the same size, or in case the fitted maximum is outside the fitting region. Returns ------- float Estimated function max, according to the coordinates in fx. """ if not len(f_vals.shape) in (1, 2): raise ValueError( "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in f_vals.shape))) ) num_vals = f_vals.shape[0] if fx is None: fx_half_size = (num_vals - 1) / 2 fx = np.linspace(-fx_half_size, fx_half_size, num_vals) else: fx = np.squeeze(fx) if not (len(fx.shape) == 1 and np.all(fx.size == num_vals)): raise ValueError( "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d" % (fx.size, num_vals) ) if len(f_vals.shape) == 1: # using Polynomial.fit, because supposed to be more numerically # stable than previous solutions (according to numpy). poly = Polynomial.fit(fx, f_vals, deg=2) coeffs = poly.convert().coef else: coords = np.array([np.ones(num_vals), fx, fx**2]) coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0] # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is: # x_v = -b / 2a. vertex_x = -coeffs[1, :] / (2 * coeffs[2, :]) if not return_all_coeffs: vertex_x = vertex_x[0] vertex_min_x = np.min(fx) vertex_max_x = np.max(fx) lower_bound_ok = vertex_min_x < vertex_x upper_bound_ok = vertex_x < vertex_max_x if not np.all(lower_bound_ok * upper_bound_ok): if len(f_vals.shape) == 1: message = "Fitted position {} is outide the input margins [{}, {}]".format( vertex_x, vertex_min_x, vertex_max_x ) else: message = "Fitted positions outide the input margins [{}, {}]: {} below and {} above".format( vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok), ) raise ValueError(message) if return_vertex_val: vertex_val = coeffs[0, :] + vertex_x * coeffs[1, :] / 2 return vertex_x, vertex_val else: return vertex_x @staticmethod def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None): """ Extracts a region around the maximum value. Parameters ---------- cc: numpy.ndarray Correlation image. peak_radius: int, optional The l_inf radius of the area to extract around the peak. The default is 1. cc_vs: numpy.ndarray, optional The vertical coordinates of `cc`. The default is None. cc_hs: numpy.ndarray, optional The horizontal coordinates of `cc`. The default is None. Returns ------- f_vals: numpy.ndarray The extracted function values. fv: numpy.ndarray The vertical coordinates of the extracted values. fh: numpy.ndarray The horizontal coordinates of the extracted values. """ img_shape = np.array(cc.shape) # get pixel having the maximum value of the correlation array pix_max_corr = np.argmax(cc) pv, ph = np.unravel_index(pix_max_corr, img_shape) # select a n x n neighborhood for the sub-pixel fitting (with wrapping) pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2] ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1] # extract the (v, h) pixel coordinates fv = None if cc_vs is None else cc_vs[pv] fh = None if cc_hs is None else cc_hs[ph] # extract the correlation values pv, ph = np.meshgrid(pv, ph, indexing="ij") f_vals = cc[pv, ph] return (f_vals, fv, fh) @staticmethod def extract_peak_regions_1d(cc, axis=-1, peak_radius=1, cc_coords=None): """ Extracts a region around the maximum value. Parameters ---------- cc: numpy.ndarray Correlation image. axis: int, optional Find the max values along the specified direction. The default is -1. peak_radius: int, optional The l_inf radius of the area to extract around the peak. The default is 1. cc_coords: numpy.ndarray, optional The coordinates of `cc` along the selected axis. The default is None. Returns ------- f_vals: numpy.ndarray The extracted function values. fc_ax: numpy.ndarray The coordinates of the extracted values, along the selected axis. """ if len(cc.shape) == 1: cc = cc[None, ...] img_shape = np.array(cc.shape) if not (len(img_shape) == 2): raise ValueError( "The input image should be either a 1 or 2-dimensional array. Array of shape: [%s] was given." % (" ".join(("%d" % s for s in cc.shape))) ) other_axis = (axis + 1) % 2 # get pixel having the maximum value of the correlation array pix_max = np.argmax(cc, axis=axis) # select a n neighborhood for the many 1D sub-pixel fittings (with wrapping) p_ax_range = np.arange(-peak_radius, +peak_radius + 1) p_ax = (pix_max[None, :] + p_ax_range[:, None]) % img_shape[axis] p_ln = np.tile(np.arange(0, img_shape[other_axis])[None, :], [2 * peak_radius + 1, 1]) # extract the pixel coordinates along the axis fc_ax = None if cc_coords is None else cc_coords[p_ax.flatten()].reshape(p_ax.shape) # extract the correlation values if other_axis == 0: f_vals = cc[p_ln, p_ax] else: f_vals = cc[p_ax, p_ln] return (f_vals, fc_ax) def _determine_roi(self, img_shape, roi_yxhw): if roi_yxhw is None: # vertical and horizontal window sizes are reduced to a power of 2 # to accelerate fft if requested. Default is not. roi_yxhw = previouspow2(img_shape) if not self.truncate_vert_pow2: roi_yxhw[0] = img_shape[0] if not self.truncate_horz_pow2: roi_yxhw[1] = img_shape[1] roi_yxhw = np.array(roi_yxhw, dtype=np.intp) if len(roi_yxhw) == 2: # Convert centered 2-element roi into 4-element roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw)) return roi_yxhw def _prepare_image( self, img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None, ): """ Prepare and returns a cropped and filtered image, or array of filtered images if the input is an array of images. Parameters ---------- img: numpy.ndarray image or stack of images invalid_val: float value to be used in replacement of nan and inf values median_filt_shape: int or sequence of int the width or the widths of the median window low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` Returns ------- numpy.array_like The computed filter """ img = np.squeeze(img) # Removes singleton dimensions, but does a shallow copy img = np.ascontiguousarray(img, dtype=self.data_type) if roi_yxhw is not None: img = img[ ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3], ] img = img.copy() img[np.isnan(img)] = invalid_val img[np.isinf(img)] = invalid_val if high_pass is not None or low_pass is not None: img_filter = fourier_filters.get_bandpass_filter( img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) # fft2 and iff2 use axes=(-2, -1) by default img = local_ifftn(local_fftn(img, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real if median_filt_shape is not None: img_shape = img.shape # expanding filter shape with ones, to cover the stack of images # but disabling inter-image filtering median_filt_shape = np.concatenate( ( np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.intp), median_filt_shape, ) ) img = median_filter(img, size=median_filt_shape) return img def _transform_to_fft( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None ): do_circular_conv = padding_mode is None or padding_mode == "wrap" img_shape = img_2.shape if not do_circular_conv: pad_size = np.ceil(np.array(img_shape) / 2).astype(np.intp) pad_array = [(0,)] * len(img_shape) for a in axes: pad_array[a] = (pad_size[a],) img_1 = np.pad(img_1, pad_array, mode=padding_mode) img_2 = np.pad(img_2, pad_array, mode=padding_mode) else: pad_size = None img_shape = img_2.shape # compute fft's of the 2 images img_fft_1 = local_fftn(img_1, axes=axes) img_fft_2 = local_fftn(img_2, axes=axes) if low_pass is not None or high_pass is not None: filt = fourier_filters.get_bandpass_filter( img_shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) else: filt = None return img_fft_1, img_fft_2, filt, pad_size def _compute_correlation_fft( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None ): img_fft_1, img_fft_2, filt, pad_size = self._transform_to_fft( img_1, img_2, padding_mode=padding_mode, axes=axes, low_pass=low_pass, high_pass=high_pass ) img_prod = img_fft_1 * np.conjugate(img_fft_2) if filt is not None: img_prod *= filt # inverse fft of the product to get cross_correlation of the 2 images cc = np.real(local_ifftn(img_prod, axes=axes)) if pad_size is not None: cc_shape = cc.shape cc = np.fft.fftshift(cc, axes=axes) slicing = [slice(None)] * len(cc_shape) for a in axes: slicing[a] = slice(pad_size[a], cc_shape[a] - pad_size[a]) cc = cc[tuple(slicing)] cc = np.fft.ifftshift(cc, axes=axes) return cc def _add_plot_window(self, fig, ax=None): self._plot_windows[fig.number] = {"figure": fig, "axes": ax} def close_plot_window(self, n, errors="raise"): """ Close a plot window. Applicable only if the class was instantiated with verbose=True. Parameters ---------- n: int Figure number to close errors: str, optional What to do with errors. It can be either "raise", "log" or "ignore". """ if not self.verbose: return if n not in self._plot_windows: msg = "Cannot close plot window number %d: no such window" % n if errors == "raise": raise ValueError(msg) elif errors == "log": self.logger.error(msg) fig_ax = self._plot_windows.pop(n) plt.close(fig_ax["figure"].number) # would also work with the object itself def close_last_plot_windows(self, n=1): """ Close the last "n" plot windows. Applicable only if the class was instanciated with verbose=True. Parameters ----------- n: int, optional Integer indicating how many plot windows should be closed. """ figs_nums = sorted(self._plot_windows.keys(), reverse=True) n = min(n, len(figs_nums)) for i in range(n): self.close_plot_window(figs_nums[i], errors="ignore") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/estimation/cor.py0000644000175000017500000010770200000000000016471 0ustar00pierrepierreimport math import numpy as np from ..misc import fourier_filters from .alignment import AlignmentBase, plt, progress_bar, local_fftn, local_ifftn # three possible values for the validity check, which can optionally be returned by the find_shifts methods cor_result_validity = { "unknown": "unknown", "sound": "sound", "correct": "sound", "questionable": "questionable", } class CenterOfRotation(AlignmentBase): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, shift_axis: int = -1, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, ): """Find the Center of Rotation (CoR), given two images. This method finds the half-shift between two opposite images, by means of correlation computed in Fourier space. The output of this function, allows to compute motor movements for aligning the sample rotation axis. Given the following values: - L1: distance from source to motor - L2: distance from source to detector - ps: physical pixel size - v: output of this function displacement of motor = (L1 / L2 * ps) * v Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). shift_axis: int Axis along which we want the shift to be computed. Default is -1 (horizontal). roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` return_validity: a boolean, defaults to false if set to True adds a second return value which may have three string values. These values are "unknown", "sound", "questionable". It will be "uknown" if the validation method is not implemented and it will be "sound" or "questionable" if it is implemented. Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- float Estimated center of rotation position from the center of the RoI in pixels. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... CoR_calc = CenterOfRotation() ... cor_position = CoR_calc.find_shift(radio1, radio2) Or for noisy images: >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3)) """ self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass) img_shape = img_2.shape cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2]) cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs) fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh) if return_validity: return fitted_shifts_vh[shift_axis] / 2.0, cor_result_validity["unknown"] else: return fitted_shifts_vh[shift_axis] / 2.0 class CenterOfRotationSlidingWindow(CenterOfRotation): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, side, window_width=None, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, ): """Semi-automatically find the Center of Rotation (CoR), given two images or sinograms. Suitable for half-aquisition scan. This method finds the half-shift between two opposite images, by minimizing difference over a moving window. The output of this function, allows to compute motor movements for aligning the sample rotation axis. Given the following values: - L1: distance from source to motor - L2: distance from source to detector - ps: physical pixel size - v: output of this function displacement of motor = (L1 / L2 * ps) * v Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). side: string Expected region of the CoR. Allowed values: 'left', 'center' or 'right'. window_width: int, optional Width of window that will slide on the other image / part of the sinogram. Default is None. roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` return_validity: a boolean, defaults to false if set to True adds a second return value which may have three string values. These values are "unknown", "sound", "questionable". It will be "uknown" if the validation method is not implemented and it will be "sound" or "questionable" if it is implemented. Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- float Estimated center of rotation position from the center of the RoI in pixels. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... CoR_calc = CenterOfRotationSlidingWindow() ... cor_position = CoR_calc.find_shift(radio1, radio2) Or for noisy images: >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3)) """ validity_check_result = cor_result_validity["unknown"] if side is None: raise ValueError("Side should be one of 'left', 'right', or 'center'. 'None' was given instead") self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image( img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_2 = self._prepare_image( img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape if window_width is None: if side.lower() == "center": window_width = round(img_shape[-1] / 4.0 * 3.0) else: window_width = round(img_shape[-1] / 10) window_shift = window_width // 2 window_width = window_shift * 2 + 1 if side.lower() == "right": win_2_start = 0 elif side.lower() == "left": win_2_start = img_shape[-1] - window_width elif side.lower() == "center": win_2_start = img_shape[-1] // 2 - window_shift else: raise ValueError( "Side should be one of 'left', 'right', or 'center'. '%s' was given instead" % side.lower() ) win_2_end = win_2_start + window_width # number of pixels where the window will "slide". n = img_shape[-1] - window_width diffs_mean = np.zeros((n,), dtype=img_1.dtype) diffs_std = np.zeros((n,), dtype=img_1.dtype) for ii in progress_bar(range(n), verbose=self.verbose): win_1_start, win_1_end = ii, ii + window_width img_diff = img_1[:, win_1_start:win_1_end] - img_2[:, win_2_start:win_2_end] diffs_abs = np.abs(img_diff) diffs_mean[ii] = diffs_abs.mean() diffs_std[ii] = diffs_abs.std() diffs_mean = diffs_mean.min() - diffs_mean win_ind_max = np.argmax(diffs_mean) diffs_std = diffs_std.min() - diffs_std if not win_ind_max == np.argmax(diffs_std): self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] else: validity_check_result = cor_result_validity["sound"] (f_vals, f_pos) = self.extract_peak_regions_1d(diffs_mean, peak_radius=peak_fit_radius) win_pos_max, win_val_max = self.refine_max_position_1d(f_vals, return_vertex_val=True) cor_h = -(win_2_start - (win_ind_max + win_pos_max)) / 2.0 if (side.lower() == "right" and win_ind_max == 0) or (side.lower() == "left" and win_ind_max == n): self.logger.warning("Sliding window width %d might be too large!" % window_width) if self.verbose: cor_pos = -(win_2_start - np.arange(n)) / 2.0 print("Lowest difference window: index=%d, range=[0, %d]" % (win_ind_max, n)) print("CoR tested for='%s', found at voxel=%g (from center)" % (side, cor_h)) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(cor_pos, diffs_mean, label="Mean difference") ax.stem(cor_h, win_val_max, linefmt="C1-", markerfmt="C1o", label="Best mean difference") ax.stem(cor_pos, -diffs_std, linefmt="C2-", markerfmt="C2o", label="Std-dev difference") ax.set_title("Window dispersions") plt.show(block=False) if return_validity: return cor_h, validity_check_result else: return cor_h class CenterOfRotationGrowingWindow(CenterOfRotation): def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, side="all", min_window_width=11, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_validity=False, ): """Automatically find the Center of Rotation (CoR), given two images or sinograms. Suitable for half-aquisition scan. This method finds the half-shift between two opposite images, by minimizing difference over a moving window. The output of this function, allows to compute motor movements for aligning the sample rotation axis. Given the following values: - L1: distance from source to motor - L2: distance from source to detector - ps: physical pixel size - v: output of this function displacement of motor = (L1 / L2 * ps) * v Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). side: string, optional Expected region of the CoR. Allowed values: 'left', 'center', 'right', or 'all'. Default is 'all'. min_window_width: int, optional Minimum window width that covers the common region of the two images / sinograms. Default is 11. roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` return_validity: a boolean, defaults to false if set to True adds a second return value which may have three string values. These values are "unknown", "sound", "questionable". It will be "uknown" if the validation method is not implemented and it will be "sound" or "questionable" if it is implemented. Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- float Estimated center of rotation position from the center of the RoI in pixels. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... CoR_calc = CenterOfRotationGrowingWindow() ... cor_position = CoR_calc.find_shift(radio1, radio2) Or for noisy images: >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3)) """ validity_check_result = cor_result_validity["unknown"] self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image( img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_2 = self._prepare_image( img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape def window_bounds(mid_point, window_max_width=img_shape[-1]): return ( np.fmax(np.ceil(mid_point - window_max_width / 2), 0).astype(np.intp), np.fmin(np.ceil(mid_point + window_max_width / 2), img_shape[-1]).astype(np.intp), ) img_lower_half_size = np.floor(img_shape[-1] / 2).astype(np.intp) img_upper_half_size = np.ceil(img_shape[-1] / 2).astype(np.intp) if side.lower() == "right": win_1_mid_start = img_lower_half_size win_1_mid_end = np.floor(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width win_2_mid_start = -img_upper_half_size + min_window_width win_2_mid_end = img_upper_half_size elif side.lower() == "left": win_1_mid_start = -img_lower_half_size + min_window_width win_1_mid_end = img_lower_half_size win_2_mid_start = img_upper_half_size win_2_mid_end = np.ceil(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width elif side.lower() == "center": win_1_mid_start = 0 win_1_mid_end = img_shape[-1] win_2_mid_start = 0 win_2_mid_end = img_shape[-1] elif side.lower() == "all": win_1_mid_start = -img_lower_half_size + min_window_width win_1_mid_end = np.floor(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width win_2_mid_start = -img_upper_half_size + min_window_width win_2_mid_end = np.ceil(img_shape[-1] * 3 / 2).astype(np.intp) - min_window_width else: raise ValueError( "Side should be one of 'left', 'right', or 'center'. '%s' was given instead" % side.lower() ) n1 = win_1_mid_end - win_1_mid_start n2 = win_2_mid_end - win_2_mid_start if not n1 == n2: raise ValueError( "Internal error: the number of window steps for the two images should be the same." + "Found the following configuration instead => Side: %s, #1: %d, #2: %d" % (side, n1, n2) ) diffs_mean = np.zeros((n1,), dtype=img_1.dtype) diffs_std = np.zeros((n1,), dtype=img_1.dtype) for ii in progress_bar(range(n1), verbose=self.verbose): win_1 = window_bounds(win_1_mid_start + ii) win_2 = window_bounds(win_2_mid_end - ii) img_diff = img_1[:, win_1[0] : win_1[1]] - img_2[:, win_2[0] : win_2[1]] diffs_abs = np.abs(img_diff) diffs_mean[ii] = diffs_abs.mean() diffs_std[ii] = diffs_abs.std() diffs_mean = diffs_mean.min() - diffs_mean win_ind_max = np.argmax(diffs_mean) diffs_std = diffs_std.min() - diffs_std if not win_ind_max == np.argmax(diffs_std): self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] else: validity_check_result = cor_result_validity["sound"] (f_vals, f_pos) = self.extract_peak_regions_1d(diffs_mean, peak_radius=peak_fit_radius) win_pos_max, win_val_max = self.refine_max_position_1d(f_vals, return_vertex_val=True) cor_h = (win_1_mid_start + (win_ind_max + win_pos_max) - img_upper_half_size) / 2.0 if (side.lower() == "right" and win_ind_max == 0) or (side.lower() == "left" and win_ind_max == n1): self.logger.warning("Minimum growing window width %d might be too large!" % min_window_width) if self.verbose: cor_pos = (win_1_mid_start + np.arange(n1) - img_upper_half_size) / 2.0 self.logger.info("Lowest difference window: index=%d, range=[0, %d]" % (win_ind_max, n1)) self.logger.info("CoR tested for='%s', found at voxel=%g (from center)" % (side, cor_h)) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(cor_pos, diffs_mean, label="Mean difference") ax.stem(cor_h, win_val_max, linefmt="C1-", markerfmt="C1o", label="Best mean difference") ax.stem(cor_pos, -diffs_std, linefmt="C2-", markerfmt="C2o", label="Std-dev difference") ax.set_title("Window dispersions") plt.show(block=False) if return_validity: return cor_h, validity_check_result else: return cor_h class CenterOfRotationAdaptiveSearch(CenterOfRotation): """This adaptive method works by applying a gaussian which highlights, by apodisation, a region which can possibly contain the good center of rotation. The whole image is spanned during several applications of the apodisation. At each application the apodisation function, which is a gaussian, is moved to a new guess position. The lenght of the step, by which the gaussian is moved, and its sigma are obtained by multiplying the shortest distance from the left or right border with a self.step_fraction and self.sigma_fraction factors which ensure global overlapping. for each step a region around the CoR of each image is selected, and the regions of the two images are compared to calculate a cost function. The value of the cost function, at its minimum is used to select the best step at which the CoR is taken as final result. The option filtered_cost= True (default) triggers the filtering (according to low_pass and high_pass) of the two images which are used for he cost function. ( Note: the low_pass and high_pass options are used, if given, also without the filtered_cost option, by being passed to the base class CenterOfRotation ) """ sigma_fraction = 1.0 / 4.0 step_fraction = 1.0 / 6.0 def find_shift( self, img_1: np.ndarray, img_2: np.ndarray, roi_yxhw=None, median_filt_shape=None, padding_mode=None, high_pass=None, low_pass=None, margins=None, filtered_cost=True, return_validity=False, ): """Find the Center of Rotation (CoR), given two images. This method finds the half-shift between two opposite images, by means of correlation computed in Fourier space. A global search is done on on the detector span (minus a margin) without assuming centered scan conditions. The output of this function, allows to compute motor movements for aligning the sample rotation axis. Given the following values: - L1: distance from source to motor - L2: distance from source to detector - ps: physical pixel size - v: output of this function displacement of motor = (L1 / L2 * ps) * v Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' low_pass: float or sequence of two floats. Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. margins: None or a couple of floats or ints if margins is None or in the form of (margin1,margin2) the search is done between margin1 and dim_x-1-margin2. If left to None then by default (margin1,margin2) = ( 10, 10 ). filtered_cost: boolean. True by default. It triggers the use of filtered images in the calculation of the cost function. return_validity: a boolean, defaults to false if set to True adds a second return value which may have three string values. These values are "unknown", "sound", "questionable". It will be "uknown" if the validation method is not implemented and it will be "sound" or "questionable" if it is implemented. Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- float Estimated center of rotation position from the center of the RoI in pixels. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... CoR_calc = CenterOfRotationAdaptiveSearch() ... cor_position = CoR_calc.find_shift(radio1, radio2) Or for noisy images: >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3), high_pass=20, low_pass=1 ) """ validity_check_result = cor_result_validity["unknown"] self._check_img_pair_sizes(img_1, img_2) used_type = img_1.dtype roi_yxhw = self._determine_roi(img_1.shape, roi_yxhw) if filtered_cost and (low_pass is not None or high_pass is not None): img_filter = fourier_filters.get_bandpass_filter( img_1.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass, use_rfft=True, data_type=self.data_type, ) # fft2 and iff2 use axes=(-2, -1) by default img_filtered_1 = local_ifftn(local_fftn(img_1, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real img_filtered_2 = local_ifftn(local_fftn(img_2, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real else: img_filtered_1 = img_1 img_filtered_2 = img_2 img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_filtered_1 = self._prepare_image(img_filtered_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_filtered_2 = self._prepare_image(img_filtered_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) dim_radio = img_1.shape[1] if margins is None: lim_1, lim_2 = 10, dim_radio - 1 - 10 else: lim_1, lim_2 = margins lim_2 = dim_radio - 1 - lim_2 if lim_1 < 1: lim_1 = 1 if lim_2 > dim_radio - 2: lim_2 = dim_radio - 2 if lim_2 <= lim_1: message = ( "Image shape or cropped selection too small for global search." + " After removal of the margins the search limits collide." + " The cropped size is %d\n" % (dim_radio) ) raise ValueError(message) found_centers = [] x_cor = lim_1 while x_cor < lim_2: tmp_sigma = ( min( (img_1.shape[1] - x_cor), (x_cor), ) * self.sigma_fraction ) tmp_x = (np.arange(img_1.shape[1]) - x_cor) / tmp_sigma apodis = np.exp(-tmp_x * tmp_x / 2.0) x_cor_rel = x_cor - (img_1.shape[1] // 2) img_1_apodised = img_1 * apodis try: cor_position = CenterOfRotation.find_shift( self, img_1_apodised.astype(used_type), img_2.astype(used_type), low_pass=low_pass, high_pass=high_pass, roi_yxhw=roi_yxhw, ) except ValueError as err: if "positions are outside the input margins" in str(err): x_cor = min(x_cor + x_cor * self.step_fraction, x_cor + (dim_radio - x_cor) * self.step_fraction) continue except: message = "Unexpected error from base class CenterOfRotation.find_shift in CenterOfRotationAdaptiveSearch.find_shift : {err}".format( err=err ) self.logger.error(message) raise p_1 = cor_position * 2 if cor_position < 0: p_2 = img_2.shape[1] + cor_position * 2 else: p_2 = -img_2.shape[1] + cor_position * 2 if abs(x_cor_rel - p_1 / 2) < abs(x_cor_rel - p_2 / 2): cor_position = p_1 / 2 else: cor_position = p_2 / 2 cor_in_img = img_1.shape[1] // 2 + cor_position tmp_sigma = ( min( (img_1.shape[1] - cor_in_img), (cor_in_img), ) * self.sigma_fraction ) M1 = int(round(cor_position + img_1.shape[1] // 2)) - int(round(tmp_sigma)) M2 = int(round(cor_position + img_1.shape[1] // 2)) + int(round(tmp_sigma)) piece_1 = img_filtered_1[:, M1:M2] piece_2 = img_filtered_2[:, img_1.shape[1] - M2 : img_1.shape[1] - M1] if piece_1.size and piece_2.size: piece_1 = piece_1 - piece_1.mean() piece_2 = piece_2 - piece_2.mean() energy = np.array(piece_1 * piece_1 + piece_2 * piece_2, "d").sum() diff_energy = np.array((piece_1 - piece_2) * (piece_1 - piece_2), "d").sum() cost = diff_energy / energy if not np.isnan(cost): if tmp_sigma * 2 > abs(x_cor_rel - cor_position): found_centers.append([cost, abs(x_cor_rel - cor_position), cor_position, energy]) x_cor = min(x_cor + x_cor * self.step_fraction, x_cor + (dim_radio - x_cor) * self.step_fraction) if len(found_centers) == 0: message = "Unable to find any valid CoR candidate in {my_class}.find_shift ".format( my_class=self.__class__.__name__ ) raise ValueError(message) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Now build the neigborhood of the minimum as a list of five elements: # the minimum in the middle of the two before, and the two after filtered_found_centers = [] for i in range(len(found_centers)): if i > 0: if abs(found_centers[i][2] - found_centers[i - 1][2]) < 0.5: filtered_found_centers.append(found_centers[i]) continue if i + 1 < len(found_centers): if abs(found_centers[i][2] - found_centers[i + 1][2]) < 0.5: filtered_found_centers.append(found_centers[i]) continue if len(filtered_found_centers): found_centers = filtered_found_centers min_choice = min(found_centers) index_min_choice = found_centers.index(min_choice) min_neighborood = [ found_centers[i][2] if (i >= 0 and i < len(found_centers)) else math.nan for i in range(index_min_choice - 2, index_min_choice + 2 + 1) ] score_right = 0 for i_pos in [3, 4]: if abs(min_neighborood[i_pos] - min_neighborood[2]) < 0.5: score_right += 1 else: break score_left = 0 for i_pos in [1, 0]: if abs(min_neighborood[i_pos] - min_neighborood[2]) < 0.5: score_left += 1 else: break if score_left + score_right >= 2: validity_check_result = cor_result_validity["sound"] else: self.logger.warning( "Minimum mean difference and minimum std-dev of differences do not coincide. " + "This means that the validity of the found solution might be questionable." ) validity_check_result = cor_result_validity["questionable"] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # An informative message in case one wish to look at how it has gone informative_message = " ".join( ["CenterOfRotationAdaptiveSearch found this neighborood of the optimal position:"] + [str(t) if not math.isnan(t) else "N.A." for t in min_neighborood] ) self.logger.debug(informative_message) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # The return value is the optimum which had been placed in the middle of the neighborood cor_position = min_neighborood[2] if return_validity: return cor_position, validity_check_result else: return cor_position __call__ = find_shift ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/estimation/cor_sino.py0000644000175000017500000001453600000000000017523 0ustar00pierrepierre""" This module provides global definitions and methods to compute COR in extrem Half Acquisition mode """ __authors__ = ["C. Nemoz", "H.Payno"] __license__ = "MIT" __date__ = "13/04/2021" import numpy as np from scipy.signal import convolve2d from ..resources.logger import LoggerOrPrint def schift(mat, val): ker = np.zeros((3, 3)) s = 1.0 if val < 0: s = -1.0 val = s * val ker[1, 1] = 1 - val if s > 0: ker[1, 2] = val else: ker[1, 0] = val mat = convolve2d(mat, ker, mode="same") return mat class SinoCor: """ This class has 2 methods: - overlap. Find a rough estimate of COR - accurate. Try to refine COR to 1/10 pixel """ def __init__(self, sinogram, logger=None): """ """ self.logger = LoggerOrPrint(logger) self.sx = sinogram.shape[1] self.sy = sinogram.shape[0] # algorithm cannot accept odd number of projs nproj2 = int((self.sy - (self.sy % 2)) / 2) # extract upper and lower part of sinogram, flipping H the upper part self.data1 = sinogram[0:nproj2, :] self.data2 = np.fliplr(sinogram[nproj2:, :].copy()) self.rcor_abs = round(self.sx / 2.0) self.cor_acc = round(self.sx / 2.0) # parameters for overlap sino - rough estimation # default sliding ROI is 20% of the width of the detector # the maximum size of ROI in the "right" case is 2*(self.sx - COR) # ex: 2048 pixels, COR= 2000, window_width should not exceed 96! self.window_width = round(self.sx / 5) def overlap(self, side="right", window_width=None): """ Compute COR by minimizing difference of circulating ROI - side: preliminary knowledge if the COR is on right or left - window_width: width of ROI that will slide on the other part of the sinogram by default, 20% of the width of the detector. """ if window_width is None: window_width = self.window_width if not (window_width & 1): window_width -= 1 # number of pixels where the window will "slide". n = self.sx - int(window_width) nr = range(n) dmax = 1000000000.0 imax = 0 # Should we do both right and left and take the minimum "diff" of the 2 ? # windows self.data2 moves over self.data1, measure the width of the histogram and retains the smaller one. if side == "right": for i in nr: imout = self.data1[:, n - i : n - i + window_width] - self.data2[:, 0:window_width] diff = imout.max() - imout.min() if diff < dmax: dmax = diff imax = i self.cor_abs = self.sx - (imax + window_width + 1.0) / 2.0 self.cor_rel = self.sx / 2 - (imax + window_width + 1.0) / 2.0 else: for i in nr: imout = self.data1[:, i : i + window_width] - self.data2[:, self.sx - window_width : self.sx] diff = imout.max() - imout.min() if diff < dmax: dmax = diff imax = i self.cor_abs = (imax + window_width - 1.0) / 2 self.cor_rel = self.cor_abs - self.sx / 2.0 - 1 if imax < 1: self.logger.warning("sliding width %d seems too large!" % window_width) self.rcor_abs = round(self.cor_abs) return self.rcor_abs def accurate(self, neighborhood=7, shift_value=0.1): """ refine the calculation around COR integer pre-calculated value The search will be executed in the defined neighborhood Parameters ----------- neighborhood: int Parameter for accurate calculation in the vicinity of the rough estimate. It must be an odd number. 0.1 pixels float shifts will be performed over this number of pixel """ # define the H-size (odd) of the window one can use to find the best overlap moving finely over ng pixels if not (neighborhood & 1): neighborhood += 1 ng2 = int(neighborhood / 2) # pleft and pright are the number of pixels available on the left and the right of the cor position # to slide a window pleft = self.rcor_abs - ng2 pright = self.sx - self.rcor_abs - ng2 - 1 # the maximum window to slide is restricted by the smaller side if pleft > pright: p_sign = 1 xwin = 2 * (self.sx - self.rcor_abs - ng2) - 1 else: p_sign = -1 xwin = 2 * (self.rcor_abs - ng2) + 1 # Note that xwin is odd xc1 = self.rcor_abs - int(xwin / 2) xc2 = self.sx - self.rcor_abs - int(xwin / 2) - 1 im1 = self.data1[:, xc1 : xc1 + xwin] im2 = self.data2[:, xc2 : xc2 + xwin] pixs = p_sign * (np.arange(neighborhood) - ng2) diff0 = 1000000000.0 isfr = shift_value * np.arange(10) self.cor_acc = self.rcor_abs for pix in pixs: x0 = xc1 + pix for isf in isfr: if isf != 0: ims = schift(self.data1[:, x0 : x0 + xwin].copy(), -p_sign * isf) else: ims = self.data1[:, x0 : x0 + xwin] imout = ims - self.data2[:, xc2 : xc2 + xwin] diff = imout.max() - imout.min() if diff < diff0: self.cor_acc = self.rcor_abs + (pix + p_sign * isf) / 2.0 diff0 = diff return self.cor_acc # Aliases estimate_cor_coarse = overlap estimate_cor_fine = accurate class SinoCorInterface: """ A class that mimics the interface of CenterOfRotation, while calling SinoCor """ def __init__(self, logger=None, **kwargs): self._logger = logger def find_shift(self, img_1, img_2, side="right", window_width=None, neighborhood=7, shift_value=0.1, **kwargs): sinogram = np.vstack([img_1, img_2]) cor_finder = SinoCor(sinogram, logger=self._logger) cor_finder.estimate_cor_coarse(side=side, window_width=window_width) cor = cor_finder.estimate_cor_fine(neighborhood=neighborhood, shift_value=shift_value) # offset will be added later - keep compatibility with result from AlignmentBase.find_shift() return cor - img_1.shape[1] / 2 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/estimation/distortion.py0000644000175000017500000001113500000000000020076 0ustar00pierrepierreimport numpy as np import scipy.interpolate from .translation import DetectorTranslationAlongBeam from ..misc.filters import correct_spikes from ..resources.logger import LoggerOrPrint def estimate_flat_distortion( flat, image, tile_size=100, interpolation_kind="linear", padding_mode="edge", correction_spike_threshold=None, logger=None, ): """ Estimate the wavefront distortion on a flat image, from another image. Parameters ---------- flat: np.array The flat-field image to be corrected image: np.ndarray The image to correlate the flat against. tile_size: int The wavefront corrections are calculated by correlating the image to the flat, region by region. The regions are tiles of size tile_size interpolation_kind: "linear" or "cubic" The interpolation method used for interpolation padding_mode: string Padding mode. Must be valid for np.pad when wavefront correction is applied, the corrections are first found for the tiles, which gives the shift at the center of each tiled. Then, to interpolate the corrections, at the positions f every pixel, on must add also the border of the extremal tiles. This is done by padding with a width of 1, and using the mode given 'padding_mode'. correction_spike_threshold: float, optional By default it is None and no spike correction is performed on the shifts grid which is found by correlation. If set to a float, a spike removal will be applied using such threshold Returns -------- coordinates: np.ndarray An array having dimensions (flat.shape[0], flat.shape[1], 2) where each coordinates[i,j] contains the coordinates of the position in the image "flat" which correlates to the pixel (i,j) in the image "im2". """ logger = LoggerOrPrint(logger) starts_r = np.array(range(0, image.shape[0] - tile_size, tile_size)) starts_c = np.array(range(0, image.shape[1] - tile_size, tile_size)) cor1 = np.zeros([len(starts_r), len(starts_c)], np.float32) cor2 = np.zeros([len(starts_r), len(starts_c)], np.float32) shift_finder = DetectorTranslationAlongBeam() for ir, r in enumerate(starts_r): for ic, c in enumerate(starts_c): try: coeff_v, coeff_h, shifts_vh_per_img = shift_finder.find_shift( np.array([image[r : r + tile_size, c : c + tile_size], flat[r : r + tile_size, c : c + tile_size]]), np.array([0, 1]), return_shifts=True, low_pass=(1.0, 0.3), high_pass=(tile_size, tile_size * 0.3), ) cor1[ir, ic], cor2[ir, ic] = shifts_vh_per_img[1] except ValueError as e: if "positions are outside" in str(e): logger.debug(str(e)) cor1[ir, ic], cor2[ir, ic] = (0, 0) else: raise cor1[np.isnan(cor1)] = 0 cor2[np.isnan(cor2)] = 0 if correction_spike_threshold is not None: cor1 = correct_spikes(cor1, correction_spike_threshold) cor2 = correct_spikes(cor2, correction_spike_threshold) # TODO implement the previous spikes correction in CCDCorrection - median_clip # spikes_corrector = CCDCorrection(cor1.shape, median_clip_thresh=3, abs_diff=True, preserve_borders=True) # cor1 = spikes_corrector.median_clip_correction(cor1) # cor2 = spikes_corrector.median_clip_correction(cor2) cor1 = np.pad(cor1, ((1, 1), (1, 1)), mode=padding_mode) cor2 = np.pad(cor2, ((1, 1), (1, 1)), mode=padding_mode) hp = np.concatenate([[0.0], starts_c + tile_size * 0.5, [image.shape[1]]]) vp = np.concatenate([[0.0], starts_r + tile_size * 0.5, [image.shape[0]]]) h_ticks = np.arange(image.shape[1]).astype(np.float32) v_ticks = np.arange(image.shape[0]).astype(np.float32) spline_degree = {"linear": 1, "cubic": 3}[interpolation_kind] interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor1, kx=spline_degree, ky=spline_degree) cor1 = interpolator(h_ticks, v_ticks) interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor2, kx=spline_degree, ky=spline_degree) cor2 = interpolator(h_ticks, v_ticks) hh = np.arange(image.shape[1]).astype(np.float32) vv = np.arange(image.shape[0]).astype(np.float32) unshifted_v, unshifted_h = np.meshgrid(vv, hh, indexing="ij") shifted_v = unshifted_v - cor1 shifted_h = unshifted_h - cor2 coordinates = np.transpose(np.array([shifted_v, shifted_h]), axes=[1, 2, 0]) return coordinates ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/focus.py0000644000175000017500000003662500000000000017032 0ustar00pierrepierreimport numpy as np from .alignment import plt from .cor import CenterOfRotation class CameraFocus(CenterOfRotation): def find_distance( self, img_stack: np.ndarray, img_pos: np.array, metric="std", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Find the focal distance of the camera system. This routine computes the motor position that corresponds to having the scintillator on the focal plain of the camera system. Parameters ---------- img_stack: numpy.ndarray A stack of images at different distances. img_pos: numpy.ndarray Position of the images along the translation axis metric: string, optional The property, whose maximize occurs at the focal position. Defaults to 'std' (standard deviation). All options are: 'std' | 'grad' | 'psd' roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. Returns ------- focus_pos: float Estimated position of the focal plane of the camera system. focus_ind: float Image index of the estimated position of the focal plane of the camera system (starting from 1!). Examples -------- Given the focal stack associated to multiple positions of the camera focus motor called `img_stack`, and the associated positions `img_pos`, the following code computes the highest focus position: >>> focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind = focus_calc.find_distance(img_stack, img_pos) where `focus_pos` is the corresponding motor position, and `focus_ind` is the associated image position (starting from 1). """ self._check_img_stack_size(img_stack, img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_stack = self._prepare_image( img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, low_pass=low_pass, high_pass=high_pass, ) img_stds = np.std(img_stack, axis=(-2, -1)) / np.mean(img_stack, axis=(-2, -1)) # assuming images are equispaced focus_step = (img_pos[-1] - img_pos[0]) / (num_imgs - 1) img_inds = np.arange(num_imgs) (f_vals, f_pos) = self.extract_peak_regions_1d(img_stds, peak_radius=peak_fit_radius, cc_coords=img_inds) focus_ind, img_std_max = self.refine_max_position_1d(f_vals, return_vertex_val=True) focus_ind += f_pos[1, :] focus_pos = img_pos[0] + focus_step * focus_ind focus_ind += 1 if self.verbose: self.logger.info( "Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind, ) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.stem(img_pos, img_stds) ax.stem(focus_pos, img_std_max, linefmt="C1-", markerfmt="C1o") ax.set_title("Images std") plt.show(block=False) return focus_pos, focus_ind def _check_img_block_size(self, img_shape, regions_number, suggest_new_shape=True): img_shape = np.array(img_shape) new_shape = img_shape if not len(img_shape) == 2: raise ValueError( "Images need to be square 2-dimensional and with shape multiple of the number of assigned regions.\n" " Image shape: %s, regions number: %d" % (img_shape, regions_number) ) if not (img_shape[0] == img_shape[1] and np.all((np.array(img_shape) % regions_number) == 0)): new_shape = (img_shape // regions_number) * regions_number new_shape = np.fmin(new_shape, new_shape.min()) message = ( "Images need to be square 2-dimensional and with shape multiple of the number of assigned regions.\n" " Image shape: %s, regions number: %d. Cropping to image shape: %s" % (img_shape, regions_number, new_shape) ) if suggest_new_shape: self.logger.info(message) else: raise ValueError(message) return new_shape @staticmethod def _fit_plane(f_vals): f_vals_half_shape = (np.array(f_vals.shape) - 1) / 2 fy = np.linspace(-f_vals_half_shape[-2], f_vals_half_shape[-2], f_vals.shape[-2]) fx = np.linspace(-f_vals_half_shape[-1], f_vals_half_shape[-1], f_vals.shape[-1]) fy, fx = np.meshgrid(fy, fx, indexing="ij") coords = np.array([np.ones(f_vals.size), fy.flatten(), fx.flatten()]) return np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0], fy, fx def find_scintillator_tilt( self, img_stack: np.ndarray, img_pos: np.array, regions_number=4, metric="std", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Finds the scintillator tilt and focal distance of the camera system. This routine computes the mounting tilt of the scintillator and the motor position that corresponds to having the scintillator on the focal plain of the camera system. The input is supposed to be a stack of square images, whose sizes are multiples of the `regions_number` parameter. If images with a different size are passed, this function will crop the images. This also generates a warning. To suppress the warning, it is suggested to specify a ROI that satisfies those criteria (see examples). The computed tilts `tilt_vh` are in unit-length per pixel-size. To obtain the tilts it is necessary to divide by the pixel-size: >>> tilt_vh_deg = np.rad2deg(np.arctan(tilt_vh / pixel_size)) The correction to be applied is: >>> tilt_corr_vh_deg = - np.rad2deg(np.arctan(tilt_vh / pixel_size)) The legacy octave macros computed the approximation of these values in radians: >>> tilt_corr_vh_rad = - tilt_vh / pixel_size Note that `pixel_size` should be in the same unit scale as `img_pos`. Parameters ---------- img_stack: numpy.ndarray A stack of images at different distances. img_pos: numpy.ndarray Position of the images along the translation axis regions_number: int, optional The number of regions to subdivide the image into, along each direction. Defaults to 4. metric: string, optional The property, whose maximize occurs at the focal position. Defaults to 'std' (standard deviation). All options are: 'std' | 'grad' | 'psd' roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> auto-suggest correct size. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. Returns ------- focus_pos: float Estimated position of the focal plane of the camera system. focus_ind: float Image index of the estimated position of the focal plane of the camera system (starting from 1!). tilts_vh: tuple(float, float) Estimated scintillator tilts in the vertical and horizontal direction respectively per unit-length per pixel-size. Examples -------- Given the focal stack associated to multiple positions of the camera focus motor called `img_stack`, and the associated positions `img_pos`, the following code computes the highest focus position: >>> focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt(img_stack, img_pos) ... tilt_corr_vh_deg = - np.rad2deg(np.arctan(tilt_vh / pixel_size)) or to keep compatibility with the old octave macros: >>> tilt_corr_vh_rad = - tilt_vh / pixel_size For non square images, or images with sizes that are not multiples of the `regions_number` parameter, and no ROI is being passed, this function will try to crop the image stack to the correct size. If you want to remove the warning message, it is suggested to set a ROI like the following: >>> regions_number = 4 ... img_roi = (np.array(img_stack.shape[1:]) // regions_number) * regions_number ... img_roi = np.fmin(img_roi, img_roi.min()) ... focus_calc = alignment.CameraFocus() ... focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt( ... img_stack, img_pos, roi_yxhw=img_roi, regions_number=regions_number) """ self._check_img_stack_size(img_stack, img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] if roi_yxhw is None: # If no roi is being passed, we try to crop the images to the # correct size, if needed roi_yxhw = self._check_img_block_size(img_shape, regions_number, suggest_new_shape=True) roi_yxhw = self._determine_roi(img_shape, roi_yxhw) else: # If a roi is being passed, and the images don't have the correct # shape, we raise an error roi_yxhw = self._determine_roi(img_shape, roi_yxhw) self._check_img_block_size(roi_yxhw[2:], regions_number, suggest_new_shape=False) img_stack = self._prepare_image( img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, low_pass=low_pass, high_pass=high_pass, ) img_shape = img_stack.shape[-2:] block_size = np.array(img_shape) / regions_number block_stack_size = np.array( [num_imgs, regions_number, block_size[-2], regions_number, block_size[-1]], dtype=np.intp, ) img_stack = np.reshape(img_stack, block_stack_size) img_stds = np.std(img_stack, axis=(-3, -1)) / np.mean(img_stack, axis=(-3, -1)) img_stds = np.reshape(img_stds, [num_imgs, -1]).transpose() # assuming images are equispaced focus_step = (img_pos[-1] - img_pos[0]) / (num_imgs - 1) img_inds = np.arange(num_imgs) (f_vals, f_pos) = self.extract_peak_regions_1d(img_stds, peak_radius=peak_fit_radius, cc_coords=img_inds) focus_inds = self.refine_max_position_1d(f_vals) focus_inds += f_pos[1, :] focus_poss = img_pos[0] + focus_step * focus_inds # Fitting the plane focus_poss = np.reshape(focus_poss, [regions_number, regions_number]) coeffs, fy, fx = self._fit_plane(focus_poss) focus_pos, tg_v, tg_h = coeffs # The angular coefficient along x is the tilt around the y axis and vice-versa tilts_vh = np.array([tg_h, tg_v]) / block_size focus_ind = np.mean(focus_inds) + 1 if self.verbose: self.logger.info( "Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind, ) self.logger.info("Fitted tilts (to be divided by pixel size, and converted to deg): (v, h) %s" % tilts_vh) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") self._add_plot_window(fig, ax=ax) ax.plot_wireframe(fx, fy, focus_poss) regions_half_shape = (regions_number - 1) / 2 base_points = np.linspace(-regions_half_shape, regions_half_shape, regions_number) ax.plot( np.zeros((regions_number,)), base_points, np.polyval([tg_v, focus_pos], base_points), "C2", ) ax.plot( base_points, np.zeros((regions_number,)), np.polyval([tg_h, focus_pos], base_points), "C2", ) ax.scatter(0, 0, focus_pos, marker="o", c="C1") ax.set_title("Images std") plt.show(block=False) return focus_pos, focus_ind, tilts_vh ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4567332 nabu-2023.1.1/nabu/estimation/tests/0000755000175000017500000000000000000000000016467 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1621502133.0 nabu-2023.1.1/nabu/estimation/tests/__init__.py0000644000175000017500000000000000000000000020566 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/tests/test_alignment.py0000644000175000017500000000467300000000000022070 0ustar00pierrepierreimport numpy as np import pytest from nabu.estimation.alignment import AlignmentBase @pytest.fixture(scope="class") def bootstrap_base(request): cls = request.cls cls.abs_tol = 2.5e-2 @pytest.mark.usefixtures("bootstrap_base") class TestAlignmentBase: def test_peak_fitting_2d_3x3(self): # Fit a 3 x 3 grid fy = np.linspace(-1, 1, 3) fx = np.linspace(-1, 1, 3) yy, xx = np.meshgrid(fy, fx, indexing="ij") peak_pos_yx = np.random.rand(2) * 1.6 - 0.8 f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100) fitted_peak_pos_yx = AlignmentBase.refine_max_position_2d(f_vals, fy, fx) message = ( "Computed peak position: (%f, %f) " % (*fitted_peak_pos_yx,) + " and real peak position (%f, %f) do not coincide." % (*peak_pos_yx,) + " Difference: (%f, %f)," % (*(fitted_peak_pos_yx - peak_pos_yx),) + " tolerance: %f" % self.abs_tol ) assert np.all(np.isclose(peak_pos_yx, fitted_peak_pos_yx, atol=self.abs_tol)), message def test_peak_fitting_2d_error_checking(self): # Fit a 3 x 3 grid fy = np.linspace(-1, 1, 3) fx = np.linspace(-1, 1, 3) yy, xx = np.meshgrid(fy, fx, indexing="ij") peak_pos_yx = np.random.rand(2) + 1.5 f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100) with pytest.raises(ValueError) as ex: AlignmentBase.refine_max_position_2d(f_vals, fy, fx) message = ( "Error should have been raised about the peak being fitted outside margins, " + "other error raised instead:\n%s" % str(ex.value) ) assert "positions are outside the input margins" in str(ex.value), message def test_extract_peak_regions_1d(self): img = np.random.randint(0, 10, size=(8, 8)) peaks_pos = np.argmax(img, axis=-1) peaks_val = np.max(img, axis=-1) cc_coords = np.arange(0, 8) ( found_peaks_val, found_peaks_pos, ) = AlignmentBase.extract_peak_regions_1d(img, axis=-1, cc_coords=cc_coords) message = ( "The found peak positions do not correspond to the expected peak positions:\n Expected: %s\n Found: %s" % ( peaks_pos, found_peaks_pos[1, :], ) ) assert np.all(peaks_val == found_peaks_val[1, :]), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/tests/test_cor.py0000644000175000017500000004133700000000000020673 0ustar00pierrepierreimport os import numpy as np import pytest import scipy.ndimage import h5py from nabu.testutils import utilstest, __do_long_tests__ from nabu.testutils import get_data as nabu_get_data from nabu.estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationGrowingWindow, CenterOfRotationSlidingWindow, ) from nabu.estimation.cor_sino import SinoCor @pytest.fixture(scope="class") def bootstrap_cor(request): cls = request.cls cls.abs_tol = 0.2 cls.data, calib_data = get_cor_data_h5("test_alignment_cor.h5") cls.cor_gl_pix, cls.cor_hl_pix, cls.tilt_deg = calib_data @pytest.fixture(scope="class") def bootstrap_cor_win(request): cls = request.cls cls.abs_tol = 0.2 cls.data_ha_proj, cls.cor_ha_pr_pix = get_cor_win_proj_data_h5("ha_autocor_radios.npz") cls.data_ha_sino, cls.cor_ha_sn_pix = get_cor_win_sino_data_h5("halftomo_1_sino.npz") def get_cor_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] cor_global_pix = hf["/calibration/alignment/global/x_rotation_axis_pixel_position"][()] cor_highlow_pix = hf["/calibration/alignment/highlow/x_rotation_axis_pixel_position"][()] tilt_deg = hf["/calibration/alignment/highlow/z_camera_tilt"][()] return data, (cor_global_pix, cor_highlow_pix, tilt_deg) def get_cor_win_proj_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) data = np.load(dataset_downloaded_path) radios = np.stack((data["radio1"], data["radio2"]), axis=0) return radios, data["cor_pos"] def get_cor_win_sino_data_h5(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) data = np.load(dataset_downloaded_path) sino_shape = data["sino"].shape sinos = np.stack((data["sino"][: sino_shape[0] // 2], data["sino"][sino_shape[0] // 2 :]), axis=0) return sinos, data["cor"] - sino_shape[1] / 2 @pytest.mark.usefixtures("bootstrap_cor") class TestCor: def test_cor_posx(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message # testing again with the validity return value cor_position, result_validity = CoR_calc.find_shift(radio1, radio2, return_validity=True) assert np.isscalar(cor_position) message = ( "returned result_validity is %s " % result_validity + " while it should be unknown because the validity check is not yet implemented" ) assert result_validity == "unknown", message def test_noisy_cor_posx(self): radio1 = np.fmax(self.data[0, :, :], 0) radio2 = np.fmax(self.data[1, :, :], 0) radio1 = np.random.poisson(radio1 * 400) radio2 = np.random.poisson(np.fliplr(radio2) * 400) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3)) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isscalar(cor_position) assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_noisyHF_cor_posx(self): """test with noise at high frequencies""" radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) noise_level = radio1.max() / 16.0 noise_ima1 = np.random.normal(0.0, size=radio1.shape) * noise_level noise_ima2 = np.random.normal(0.0, size=radio2.shape) * noise_level noise_ima1 = noise_ima1 - scipy.ndimage.gaussian_filter(noise_ima1, 2.0) noise_ima2 = noise_ima2 - scipy.ndimage.gaussian_filter(noise_ima2, 2.0) radio1 = radio1 + noise_ima1 radio2 = radio2 + noise_ima2 CoR_calc = CenterOfRotation() # cor_position = CoR_calc.find_shift(radio1, radio2) cor_position = CoR_calc.find_shift(radio1, radio2, low_pass=(6.0, 0.3)) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_half_tomo_cor_exp(self): """test the half_tomo algorithm on experimental data""" radios = nabu_get_data("ha_autocor_radios.npz") radio1 = radios["radio1"] radio2 = radios["radio2"] cor_pos = radios["cor_pos"] radio2 = np.fliplr(radio2) CoR_calc = CenterOfRotationAdaptiveSearch() cor_position = CoR_calc.find_shift(radio1, radio2, low_pass=1, high_pass=20, filtered_cost=True) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = ( "Computed CoR %f " % cor_position + " and real CoR %f should coincide when using the halftomo algorithm with hald tomo data" % cor_pos ) assert np.isclose(cor_pos, cor_position, atol=self.abs_tol), message @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_half_tomo_cor_exp_limited(self): """test the hal_tomo algorithm on experimental data and global search with limits""" radios = nabu_get_data("ha_autocor_radios.npz") radio1 = radios["radio1"] radio2 = radios["radio2"] cor_pos = radios["cor_pos"] radio2 = np.fliplr(radio2) CoR_calc = CenterOfRotationAdaptiveSearch() cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, low_pass=1, high_pass=20, margins=(100, 10), filtered_cost=False, return_validity=True ) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = ( "Computed CoR %f " % cor_position + " and real CoR %f should coincide when using the halftomo algorithm with hald tomo data" % cor_pos ) assert np.isclose(cor_pos, cor_position, atol=self.abs_tol), message message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_cor_posx_linear(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotation() cor_position = CoR_calc.find_shift(radio1, radio2, padding_mode="edge") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_error_checking_001(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :1:] radio2 = self.data[1, :, :] with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2) message = "Error should have been raised about img #1 shape, other error raised instead:\n%s" % str(ex.value) assert "Images need to be 2-dimensional. Shape of image #1" in str(ex.value), message def test_error_checking_002(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :] radio2 = self.data with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2) message = "Error should have been raised about img #2 shape, other error raised instead:\n%s" % str(ex.value) assert "Images need to be 2-dimensional. Shape of image #2" in str(ex.value), message def test_error_checking_003(self): CoR_calc = CenterOfRotation() radio1 = self.data[0, :, :] radio2 = self.data[1, :, 0:10] with pytest.raises(ValueError) as ex: CoR_calc.find_shift(radio1, radio2) message = ( "Error should have been raised about different image shapes, " + "other error raised instead:\n%s" % str(ex.value) ) assert "Images need to be of the same shape" in str(ex.value), message @pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win") class TestCorWindowSlide: def test_proj_center_axis_lft(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift( radio1, radio2, side="left", window_width=round(radio1.shape[-1] / 4.0 * 3.0) ) assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message cor_position, result_validity = CoR_calc.find_shift( radio1, radio2, side="left", window_width=round(radio1.shape[-1] / 4.0 * 3.0), return_validity=True ) message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_proj_center_axis_cen(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="center") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_proj_right_axis_rgt(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="right") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_proj_left_axis_lft(self): radio1 = np.fliplr(self.data_ha_proj[0, :, :]) radio2 = self.data_ha_proj[1, :, :] CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="left") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_sino_right_axis_rgt(self): sino1 = self.data_ha_sino[0, :, :] sino2 = np.fliplr(self.data_ha_sino[1, :, :]) CoR_calc = CenterOfRotationSlidingWindow() cor_position = CoR_calc.find_shift(sino1, sino2, side="right") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 5), message @pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win") class TestCorWindowGrow: def test_proj_center_axis_cen(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="center") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message def test_proj_right_axis_rgt(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="right") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_proj_left_axis_lft(self): radio1 = np.fliplr(self.data_ha_proj[0, :, :]) radio2 = self.data_ha_proj[1, :, :] CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="left") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message cor_position, result_validity = CoR_calc.find_shift(radio1, radio2, side="left", return_validity=True) message = "returned result_validity is %s " % result_validity + " while it should be sound" assert result_validity == "sound", message def test_proj_right_axis_all(self): radio1 = self.data_ha_proj[0, :, :] radio2 = np.fliplr(self.data_ha_proj[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(radio1, radio2, side="all") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message def test_sino_right_axis_rgt(self): sino1 = self.data_ha_sino[0, :, :] sino2 = np.fliplr(self.data_ha_sino[1, :, :]) CoR_calc = CenterOfRotationGrowingWindow() cor_position = CoR_calc.find_shift(sino1, sino2, side="right") assert np.isscalar(cor_position), f"cor_position expected to be a scale, {type(cor_position)} returned" message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 4), message @pytest.mark.usefixtures("bootstrap_cor_win") class TestCoarseToFineSinoCor: def test_coarse_to_fine(self): """ Test nabu.estimation.cor_sino.SinoCor """ sino_halftomo = np.vstack([self.data_ha_sino[0], self.data_ha_sino[1]]) sino_cor = SinoCor(sino_halftomo) cor_coarse = sino_cor.estimate_cor_coarse() assert np.isscalar(cor_coarse), f"cor_position expected to be a scale, {type(cor_position)} returned" cor_fine = sino_cor.estimate_cor_fine() assert np.isscalar(cor_fine), f"cor_position expected to be a scale, {type(cor_position)} returned" cor_ref = self.cor_ha_sn_pix + sino_halftomo.shape[-1] / 2.0 message = "Computed CoR %f " % cor_fine + " and expected CoR %f do not coincide" % cor_ref assert abs(cor_fine - cor_ref) < self.abs_tol * 2, message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/tests/test_focus.py0000644000175000017500000001011200000000000021212 0ustar00pierrepierreimport os import numpy as np import pytest import h5py from nabu.testutils import utilstest, __do_long_tests__ from nabu.estimation.focus import CameraFocus @pytest.fixture(scope="class") def bootstrap_fcs(request): cls = request.cls cls.abs_tol_dist = 1e-2 cls.abs_tol_tilt = 2.5e-4 ( cls.data, cls.img_pos, cls.pixel_size, (calib_data_std, calib_data_angle), ) = get_focus_data("test_alignment_focus.h5") ( cls.angle_best_ind, cls.angle_best_pos, cls.angle_tilt_v, cls.angle_tilt_h, ) = calib_data_angle cls.std_best_ind, cls.std_best_pos = calib_data_std def get_focus_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] img_pos = hf["/entry/instrument/detector/distance"][()] pixel_size = np.mean( [ hf["/entry/instrument/detector/x_pixel_size"][()], hf["/entry/instrument/detector/y_pixel_size"][()], ] ) angle_best_ind = hf["/calibration/focus/angle/best_img"][()] angle_best_pos = hf["/calibration/focus/angle/best_pos"][()] angle_tilt_v = hf["/calibration/focus/angle/tilt_v_rad"][()] angle_tilt_h = hf["/calibration/focus/angle/tilt_h_rad"][()] std_best_ind = hf["/calibration/focus/std/best_img"][()] std_best_pos = hf["/calibration/focus/std/best_pos"][()] calib_data_angle = (angle_best_ind, angle_best_pos, angle_tilt_v, angle_tilt_h) calib_data_std = (std_best_ind, std_best_pos) return data, img_pos, pixel_size, (calib_data_std, calib_data_angle) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") @pytest.mark.usefixtures("bootstrap_fcs") class TestFocus: def test_find_distance(self): focus_calc = CameraFocus() focus_pos, focus_ind = focus_calc.find_distance(self.data, self.img_pos) message = ( "Computed focus motor position %f " % focus_pos + " and expected %f do not coincide" % self.std_best_pos ) assert np.isclose(self.std_best_pos, focus_pos, atol=self.abs_tol_dist), message message = "Computed focus image index %f " % focus_ind + " and expected %f do not coincide" % self.std_best_ind assert np.isclose(self.std_best_ind, focus_ind, atol=self.abs_tol_dist), message def test_find_scintillator_tilt(self): focus_calc = CameraFocus() focus_pos, focus_ind, tilts_vh = focus_calc.find_scintillator_tilt(self.data, self.img_pos) message = ( "Computed focus motor position %f " % focus_pos + " and expected %f do not coincide" % self.angle_best_pos ) assert np.isclose(self.angle_best_pos, focus_pos, atol=self.abs_tol_dist), message message = ( "Computed focus image index %f " % focus_ind + " and expected %f do not coincide" % self.angle_best_ind ) assert np.isclose(self.angle_best_ind, focus_ind, atol=self.abs_tol_dist), message expected_tilts_vh = np.squeeze(np.array([self.angle_tilt_v, self.angle_tilt_h])) computed_tilts_vh = -tilts_vh / (self.pixel_size / 1000) message = "Computed tilts %s and expected %s do not coincide" % ( computed_tilts_vh, expected_tilts_vh, ) assert np.all(np.isclose(computed_tilts_vh, expected_tilts_vh, atol=self.abs_tol_tilt)), message def test_size_determination(self): inp_shape = [2162, 2560] exp_shape = np.array([2160, 2160]) new_shape = CameraFocus()._check_img_block_size(inp_shape, 4, suggest_new_shape=True) message = "New suggested shape: %s and expected: %s do not coincide" % (new_shape, exp_shape) assert np.all(new_shape == exp_shape), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/tests/test_tilt.py0000644000175000017500000000317200000000000021057 0ustar00pierrepierreimport pytest import numpy as np from nabu.estimation.tilt import CameraTilt from nabu.estimation.tests.test_cor import bootstrap_cor try: import skimage.transform as skt # noqa: F401 __have_skimage__ = True except ImportError: __have_skimage__ = False @pytest.mark.usefixtures("bootstrap_cor") class TestCameraTilt: def test_1dcorrelation(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) tilt_calc = CameraTilt() cor_position, camera_tilt = tilt_calc.compute_angle(radio1, radio2) message = "Computed tilt %f " % camera_tilt + " and real tilt %f do not coincide" % self.tilt_deg assert np.isclose(self.tilt_deg, camera_tilt, atol=self.abs_tol), message message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_hl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message @pytest.mark.skipif(not (__have_skimage__), reason="need scikit-image for this test") def test_fftpolar(self): radio1 = self.data[0, :, :] radio2 = np.fliplr(self.data[1, :, :]) tilt_calc = CameraTilt() cor_position, camera_tilt = tilt_calc.compute_angle(radio1, radio2, method="fft-polar") message = "Computed tilt %f " % camera_tilt + " and real tilt %f do not coincide" % self.tilt_deg assert np.isclose(self.tilt_deg, camera_tilt, atol=self.abs_tol), message message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.cor_hl_pix assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/tests/test_translation.py0000644000175000017500000000566700000000000022454 0ustar00pierrepierreimport os import numpy as np import pytest import h5py from nabu.testutils import utilstest from nabu.estimation.translation import DetectorTranslationAlongBeam import scipy.ndimage def get_alignxc_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_downloaded_path, "r") as hf: data = hf["/entry/instrument/detector/data"][()] img_pos = hf["/entry/instrument/detector/distance"][()] unit_length_shifts_vh = [ hf["/calibration/alignxc/y_pixel_shift_unit"][()], hf["/calibration/alignxc/x_pixel_shift_unit"][()], ] all_shifts_vh = hf["/calibration/alignxc/yx_pixel_offsets"][()] return data, img_pos, (unit_length_shifts_vh, all_shifts_vh) @pytest.fixture(scope="class") def bootstrap_dtr(request): cls = request.cls cls.abs_tol = 1e-1 cls.data, cls.img_pos, calib_data = get_alignxc_data("test_alignment_alignxc.h5") cls.expected_shifts_vh, cls.all_shifts_vh = calib_data @pytest.mark.usefixtures("bootstrap_dtr") class TestDetectorTranslation: def test_alignxc(self): T_calc = DetectorTranslationAlongBeam() shifts_v, shifts_h, found_shifts_list = T_calc.find_shift(self.data, self.img_pos, return_shifts=True) message = "Computed shifts coefficients %s and expected %s do not coincide" % ( (shifts_v, shifts_h), self.expected_shifts_vh, ) assert np.all(np.isclose(self.expected_shifts_vh, [shifts_v, shifts_h], atol=self.abs_tol)), message message = "Computed shifts %s and expected %s do not coincide" % ( found_shifts_list, self.all_shifts_vh, ) assert np.all(np.isclose(found_shifts_list, self.all_shifts_vh, atol=self.abs_tol)), message def test_alignxc_synth(self): T_calc = DetectorTranslationAlongBeam() stack = np.zeros([4, 512, 512], "d") for i in range(4): stack[i, 200 - i * 10, 200 - i * 10] = 1 stack = scipy.ndimage.gaussian_filter(stack, [0, 10, 10.0]) * 100 x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2])) for i in range(4): xc = x - (250 + i * 1.234) yc = y - (250 + i * 1.234 * 2) stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5) shifts_v, shifts_h, found_shifts_list = T_calc.find_shift( stack, np.array([0.0, 1, 2, 3]), high_pass=1.0, return_shifts=True ) message = "Found shifts per units %s and reference %s do not coincide" % ( (shifts_v, shifts_h), (-1.234 * 2, -1.234), ) assert np.all(np.isclose((shifts_v, shifts_h), (-1.234 * 2, -1.234), atol=self.abs_tol)), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682590397.0 nabu-2023.1.1/nabu/estimation/tilt.py0000644000175000017500000002071000000000000016653 0ustar00pierrepierreimport numpy as np from numpy.polynomial.polynomial import Polynomial, polyval from .alignment import median_filter, plt from .cor import CenterOfRotation try: import skimage.transform as skt __have_skimage__ = True except ImportError: __have_skimage__ = False class CameraTilt(CenterOfRotation): def compute_angle( self, img_1: np.ndarray, img_2: np.ndarray, method="1d-correlation", roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, ): """Find the camera tilt, given two opposite images. This method finds the tilt between the camera pixel columns and the rotation axis, by performing a 1-dimensional correlation between two opposite images. The output of this function, allows to compute motor movements for aligning the camera tilt. Parameters ---------- img_1: numpy.ndarray First image img_2: numpy.ndarray Second image, it needs to have been flipped already (e.g. using numpy.fliplr). method: str Tilt angle computation method. Default is "1d-correlation" (traditional). All options are: - "1d-correlation": fastest, but works best for small tilts - "fft-polar": slower, but works well on all ranges of tilts roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters` high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters` Raises ------ ValueError In case images are not 2-dimensional or have different sizes. Returns ------- cor_offset_pix: float Estimated center of rotation position from the center of the RoI in pixels. tilt_deg: float Estimated camera tilt angle in degrees. Examples -------- The following code computes the center of rotation position for two given images in a tomography scan, where the second image is taken at 180 degrees from the first. >>> radio1 = data[0, :, :] ... radio2 = np.fliplr(data[1, :, :]) ... tilt_calc = CameraTilt() ... cor_offset, camera_tilt = tilt_calc.compute_angle(radio1, radio2) Or for noisy images: >>> cor_offset, camera_tilt = tilt_calc.compute_angle(radio1, radio2, median_filt_shape=(3, 3)) """ self._check_img_pair_sizes(img_1, img_2) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 img_shape = img_2.shape roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) if method.lower() == "1d-correlation": return self._compute_angle_1dcorrelation( img_1, img_2, padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass ) elif method.lower() == "fft-polar": if not __have_skimage__: raise ValueError( 'Camera tilt calculation using "fft-polar" is only available with scikit-image.' " Please install the package to use this option." ) return self._compute_angle_fftpolar( img_1, img_2, padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass ) else: raise ValueError('Invalid method: %s. Valid options are: "1d-correlation" | "fft-polar"' % method) def _compute_angle_1dcorrelation( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None ): cc = self._compute_correlation_fft( img_1, img_2, padding_mode, axes=(-1,), high_pass=high_pass, low_pass=low_pass ) img_shape = img_2.shape cc_h_coords = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) (f_vals, fh) = self.extract_peak_regions_1d(cc, peak_radius=peak_fit_radius, cc_coords=cc_h_coords) fitted_shifts_h = self.refine_max_position_1d(f_vals, return_all_coeffs=True) fitted_shifts_h += fh[1, :] # Computing tilt fitted_shifts_h = median_filter(fitted_shifts_h, size=3) half_img_size = (img_shape[-2] - 1) / 2 cc_v_coords = np.linspace(-half_img_size, half_img_size, img_shape[-2]) coeffs_h = Polynomial.fit(cc_v_coords, fitted_shifts_h, deg=1).convert().coef tilt_deg = np.rad2deg(-coeffs_h[1] / 2) cor_offset_pix = coeffs_h[0] / 2 if self.verbose: self.logger.info( "Fitted center of rotation (pixels):", cor_offset_pix, "and camera tilt (degrees):", tilt_deg, ) f, ax = plt.subplots(1, 1) self._add_plot_window(f, ax=ax) ax.plot(cc_v_coords, fitted_shifts_h) ax.plot(cc_v_coords, polyval(cc_v_coords, coeffs_h), "-C1") ax.set_title("Correlation peaks") plt.show(block=self.extra_options["blocking_plots"]) return cor_offset_pix, tilt_deg def _compute_angle_fftpolar( self, img_1: np.ndarray, img_2: np.ndarray, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None ): img_shape = img_2.shape img_fft_1, img_fft_2, filt, _ = self._transform_to_fft( img_1, img_2, padding_mode=padding_mode, axes=(-2, -1), low_pass=low_pass, high_pass=high_pass ) if filt is not None: img_fft_1 *= filt img_fft_2 *= filt # abs removes the translation component img_fft_1 = np.abs(np.fft.fftshift(img_fft_1, axes=(-2, -1))) img_fft_2 = np.abs(np.fft.fftshift(img_fft_2, axes=(-2, -1))) # transform to polar coordinates img_fft_1 = skt.warp_polar(img_fft_1, scaling="linear", output_shape=img_shape) img_fft_2 = skt.warp_polar(img_fft_2, scaling="linear", output_shape=img_shape) # only use half of the fft domain img_fft_1 = img_fft_1[..., : img_fft_1.shape[-2] // 2, :] img_fft_2 = img_fft_2[..., : img_fft_2.shape[-2] // 2, :] tilt_pix = self.find_shift(img_fft_1, img_fft_2, shift_axis=-2) tilt_deg = -(360 / img_shape[0]) * tilt_pix img_1 = skt.rotate(img_1, tilt_deg) img_2 = skt.rotate(img_2, -tilt_deg) cor_offset_pix = self.find_shift( img_1, img_2, padding_mode=padding_mode, peak_fit_radius=peak_fit_radius, high_pass=high_pass, low_pass=low_pass, ) if self.verbose: print( "Fitted center of rotation (pixels):", cor_offset_pix, "and camera tilt (degrees):", tilt_deg, ) return cor_offset_pix, tilt_deg ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/translation.py0000644000175000017500000002015500000000000020240 0ustar00pierrepierreimport numpy as np from numpy.polynomial.polynomial import Polynomial, polyval from .alignment import AlignmentBase, plt class DetectorTranslationAlongBeam(AlignmentBase): def find_shift( self, img_stack: np.ndarray, img_pos: np.array, roi_yxhw=None, median_filt_shape=None, padding_mode=None, peak_fit_radius=1, high_pass=None, low_pass=None, return_shifts=False, use_adjacent_imgs=False, ): """Find the vertical and horizontal shifts for translations of the detector along the beam direction. These shifts are in pixels-per-unit-translation, and they are due to the misalignment of the translation stage, with respect to the beam propagation direction. To compute the vertical and horizontal tilt angles from the obtained `shift_pix`: >>> tilt_deg = np.rad2deg(np.arctan(shift_pix * pixel_size)) where `pixel_size` and and the input parameter `img_pos` have to be expressed in the same units. Parameters ---------- img_stack: numpy.ndarray A stack of images (usually 4) at different distances img_pos: numpy.ndarray Position of the images along the translation axis roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional 4 elements vector containing: vertical and horizontal coordinates of first pixel, plus height and width of the Region of Interest (RoI). Or a 2 elements vector containing: plus height and width of the centered Region of Interest (RoI). Default is None -> deactivated. median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional Shape of the median filter window. Default is None -> deactivated. padding_mode: str in numpy.pad's mode list, optional Padding mode, which determines the type of convolution. If None or 'wrap' are passed, this resorts to the traditional circular convolution. If 'edge' or 'constant' are passed, it results in a linear convolution. Default is the circular convolution. All options are: None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean' | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap' peak_fit_radius: int, optional Radius size around the max correlation pixel, for sub-pixel fitting. Minimum and default value is 1. low_pass: float or sequence of two floats Low-pass filter properties, as described in `nabu.misc.fourier_filters`. high_pass: float or sequence of two floats High-pass filter properties, as described in `nabu.misc.fourier_filters`. return_shifts: boolean, optional Adds a third returned argument, containing the pixel shifts of each image with respect to the first one in the stack. Defaults to False. use_adjacent_imgs: boolean, optional Compute correlation between adjacent images. It can be used when dealing with large shifts, to avoid overflowing the shift. This option allows to replicate the behavior of the reference function `alignxc.m` However, it is detrimental to shift fitting accuracy. Defaults to False. Returns ------- coeff_v: float Estimated vertical shift in pixel per unit-distance of the detector translation. coeff_h: float Estimated horizontal shift in pixel per unit-distance of the detector translation. shifts_vh: list, optional The pixel shifts of each image with respect to the first image in the stack. Activated if return_shifts is True. Examples -------- The following example creates a stack of shifted images, and retrieves the computed shift. Here we use a high-pass filter, due to the presence of some low-frequency noise component. >>> import numpy as np ... import scipy as sp ... import scipy.ndimage ... from nabu.preproc.alignment import DetectorTranslationAlongBeam ... ... tr_calc = DetectorTranslationAlongBeam() ... ... stack = np.zeros([4, 512, 512]) ... ... # Add low frequency spurious component ... for i in range(4): ... stack[i, 200 - i * 10, 200 - i * 10] = 1 ... stack = sp.ndimage.filters.gaussian_filter(stack, [0, 10, 10.0]) * 100 ... ... # Add the feature ... x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2])) ... for i in range(4): ... xc = x - (250 + i * 1.234) ... yc = y - (250 + i * 1.234 * 2) ... stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5) ... ... # Image translation along the beam ... img_pos = np.arange(4) ... ... # Find the shifts from the features ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, high_pass=1.0) ... print(shifts_v, shifts_h) >>> ( -2.47 , -1.236 ) and the following commands convert the shifts in angular tilts: >>> tilt_v_deg = np.rad2deg(np.arctan(shifts_v * pixel_size)) >>> tilt_h_deg = np.rad2deg(np.arctan(shifts_h * pixel_size)) To enable the legacy behavior of `alignxc.m` (correlation between adjacent images): >>> shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, use_adjacent_imgs=True) To plot the correlation shifts and the fitted straight lines for both directions: >>> tr_calc = DetectorTranslationAlongBeam(verbose=True) ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos) """ self._check_img_stack_size(img_stack, img_pos) if peak_fit_radius < 1: self.logger.warning("Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius) peak_fit_radius = 1 num_imgs = img_stack.shape[0] img_shape = img_stack.shape[-2:] roi_yxhw = self._determine_roi(img_shape, roi_yxhw) img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape) # do correlations ccs = [ self._compute_correlation_fft( img_stack[ii - 1 if use_adjacent_imgs else 0, ...], img_stack[ii, ...], padding_mode, high_pass=high_pass, low_pass=low_pass, ) for ii in range(1, num_imgs) ] img_shape = img_stack.shape[-2:] cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2]) cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1]) shifts_vh = np.zeros((num_imgs, 2)) for ii, cc in enumerate(ccs): (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs) shifts_vh[ii + 1, :] = self.refine_max_position_2d(f_vals, fv, fh) if use_adjacent_imgs: shifts_vh = np.cumsum(shifts_vh, axis=0) # Polynomial.fit is supposed to be more numerically stable than polyfit # (according to numpy) coeffs_v = Polynomial.fit(img_pos, shifts_vh[:, 0], deg=1).convert().coef coeffs_h = Polynomial.fit(img_pos, shifts_vh[:, 1], deg=1).convert().coef if self.verbose: self.logger.info( "Fitted pixel shifts per unit-length: vertical = %f, horizontal = %f" % (coeffs_v[1], coeffs_h[1]) ) f, axs = plt.subplots(1, 2) self._add_plot_window(f, ax=axs) axs[0].scatter(img_pos, shifts_vh[:, 0]) axs[0].plot(img_pos, polyval(img_pos, coeffs_v), "-C1") axs[0].set_title("Vertical shifts") axs[1].scatter(img_pos, shifts_vh[:, 1]) axs[1].plot(img_pos, polyval(img_pos, coeffs_h), "-C1") axs[1].set_title("Horizontal shifts") plt.show(block=False) if return_shifts: return coeffs_v[1], coeffs_h[1], shifts_vh else: return coeffs_v[1], coeffs_h[1] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/estimation/utils.py0000644000175000017500000000257100000000000017044 0ustar00pierrepierreimport numpy as np from scipy.signal import find_peaks def is_fullturn_scan(angles): """ Return True if the angles correspond to a full-turn (360 degrees) scan. """ angles = angles % (2 * np.pi) angles -= angles.min() sin_angles = np.sin(angles) min_dist = 5 # TODO find a more robust angles-based min distance, though this should cover most of the cases maxima = find_peaks(sin_angles, distance=min_dist)[0] minima = find_peaks(-sin_angles, distance=min_dist)[0] n_max = maxima.size n_min = minima.size # abs(n_max - n_min) actually means the following: # * 0: All turns are full (eg. 2pi, 4pi) # * 1: At least one half-turn remains (eg. pi, 3pi) if abs(n_max - n_min) == 0: return True else: return False def get_halfturn_indices(angles): angles = angles % (2 * np.pi) angles -= angles.min() sin_angles = np.sin(angles) min_dist = 5 # TODO find a more robust angles-based min distance, though this should cover most of the cases maxima = find_peaks(sin_angles, distance=min_dist)[0] minima = find_peaks(-sin_angles, distance=min_dist)[0] extrema = np.sort(np.hstack([maxima, minima])) extrema -= extrema.min() extrema = np.hstack([extrema, [angles.size - 1]]) res = [] for i in range(extrema.size - 1): res.append((extrema[i], extrema[i + 1])) return res ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/io/0000755000175000017500000000000000000000000013560 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/io/__init__.py0000644000175000017500000000026000000000000015667 0ustar00pierrepierrefrom .reader import NPReader, EDFReader, HDF5File, HDF5Loader, ChunkReader, Readers from .writer import NXProcessWriter, TIFFWriter, EDFWriter, JP2Writer, NPYWriter, NPZWriter ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682590505.0 nabu-2023.1.1/nabu/io/cast_volume.py0000644000175000017500000003474100000000000016464 0ustar00pierrepierreimport os from nabu.misc.utils import rescale_data from nabu.pipeline.params import files_formats from tomoscan.volumebase import VolumeBase from tomoscan.scanbase import TomoScanBase from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from tomoscan.io import HDF5File from silx.io.utils import get_data from silx.utils.enum import Enum as _Enum import numpy from silx.io.url import DataUrl from typing import Optional import logging _logger = logging.getLogger(__name__) __all__ = ["get_default_output_volume", "cast_volume"] _DEFAULT_OUTPUT_DIR = "vol_cast" RESCALE_MIN_PERCENTILE = 10 RESCALE_MAX_PERCENTILE = 90 def get_default_output_volume( input_volume: VolumeBase, output_type: str, output_dir: str = _DEFAULT_OUTPUT_DIR ) -> VolumeBase: """ For a given input volume and output type return output volume as an instance of VolumeBase :param VolumeBase intput_volume: volume for which we want to get the resulting output volume for a cast :param str output_type: output_type of the volume (edf, tiff, hdf5...) :param str output_dir: output dir to save the cast volume """ if not isinstance(input_volume, VolumeBase): raise TypeError(f"input_volume is expected to be an instance of {VolumeBase}") valid_file_formats = set(files_formats.values()) if not output_type in valid_file_formats: raise ValueError(f"output_type is not a valid value ({output_type}). Valid values are {valid_file_formats}") if isinstance(input_volume, (EDFVolume, TIFFVolume, JP2KVolume)): if output_type == "hdf5": file_path = os.path.join( input_volume.data_url.file_path(), output_dir, input_volume.get_volume_basename() + ".hdf5", ) volume = HDF5Volume( file_path=file_path, data_path="/volume", ) assert volume.get_identifier() is not None, "volume should be able to create an identifier" return volume elif output_type in ("tiff", "edf", "jp2"): if output_type == "tiff": Constructor = TIFFVolume elif output_type == "edf": Constructor = EDFVolume elif output_type == "jp2": Constructor = JP2KVolume return Constructor( folder=os.path.join( os.path.dirname(input_volume.data_url.file_path()), output_dir, ), volume_basename=input_volume.get_volume_basename(), ) else: raise NotImplementedError elif isinstance(input_volume, (HDF5Volume, MultiTIFFVolume)): if output_type == "hdf5": data_file_parent_path, data_file_name = os.path.split(input_volume.data_url.file_path()) # replace extension: data_file_name = ".".join( [ os.path.splitext(data_file_name)[0], "hdf5", ] ) if isinstance(input_volume, HDF5Volume): data_data_path = input_volume.data_url.data_path() metadata_data_path = input_volume.metadata_url.data_path() try: data_path = os.path.commonprefix([data_data_path, metadata_data_path]) except Exception: data_path = "volume" else: data_data_path = HDF5Volume.DATA_DATASET_NAME metadata_data_path = HDF5Volume.METADATA_GROUP_NAME file_path = data_file_name data_path = "volume" volume = HDF5Volume( file_path=os.path.join(data_file_parent_path, output_dir, data_file_name), data_path=data_path, ) assert volume.get_identifier() is not None, "volume should be able to create an identifier" return volume elif output_type in ("tiff", "edf", "jp2"): if output_type == "tiff": Constructor = TIFFVolume elif output_type == "edf": Constructor = EDFVolume elif output_type == "jp2": Constructor = JP2KVolume file_parent_path, file_name = os.path.split(input_volume.data_url.file_path()) file_name = os.path.splitext(file_name)[0] return Constructor( folder=os.path.join( file_parent_path, output_dir, os.path.basename(file_name), ) ) else: raise NotImplementedError else: raise NotImplementedError def cast_volume( input_volume: VolumeBase, output_volume: VolumeBase, output_data_type: numpy.dtype, data_min=None, data_max=None, scan: Optional[TomoScanBase] = None, rescale_min_percentile=RESCALE_MIN_PERCENTILE, rescale_max_percentile=RESCALE_MAX_PERCENTILE, save=True, store=False, ) -> VolumeBase: """ cast givent volume to output_volume of 'output_data_type' type :param VolumeBase input_volume: :param VolumeBase output_volume: :param numpy.dtype output_data_type: output data type :param number data_min: `data` min value to clamp to new_min. Any lower value will also be clamp to new_min. :param number data_max: `data` max value to clamp to new_max. Any hight value will also be clamp to new_max. :param TomoScanBase scan: source scan that produced input_volume. Can be used to find histogram for example. :param rescale_min_percentile: if `data_min` is None will set data_min to 'rescale_min_percentile' :param rescale_max_percentile: if `data_max` is None will set data_min to 'rescale_max_percentile' :param bool save: if True dump the slice on disk (one by one) :param bool store: if True once the volume is cast then set `output_volume.data` :return: output_volume with data and metadata set .. warning:: the created will volume will not be saved in this processing. If you want to save the cast volume you must do it yourself. """ if not isinstance(input_volume, VolumeBase): raise TypeError(f"input_volume is expected to be a {VolumeBase}. {type(input_volume)} provided") if not isinstance(output_volume, VolumeBase): raise TypeError(f"output_volume is expected to be a {VolumeBase}. {type(output_volume)} provided") if not isinstance(output_data_type, numpy.dtype): raise TypeError(f"output_data_type is expected to be a {numpy.dtype}. {type(output_data_type)} provided") # start processing # check for data_min and data_max if data_min is None or data_max is None: found_data_min, found_data_max = _try_to_find_min_max_from_histo( input_volume=input_volume, scan=scan, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ) if found_data_min is None or found_data_max is None: _logger.warning("couldn't find histogram, recompute volume min and max values") data_min, data_max = input_volume.get_min_max() _logger.info(f"min and max found ({data_min} ; {data_max})") data_min = data_min if data_min is not None else found_data_min data_max = data_max if data_max is not None else found_data_max data = [] for input_slice, frame_dumper in zip( input_volume.browse_slices(), output_volume.data_file_saver_generator( input_volume.get_volume_shape()[0], data_url=output_volume.data_url, overwrite=output_volume.overwrite, ), ): if numpy.issubdtype(output_data_type, numpy.integer): new_min = numpy.iinfo(output_data_type).min new_max = numpy.iinfo(output_data_type).max output_slice = clamp_and_rescale_data( data=input_slice, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max, rescale_min_percentile=rescale_min_percentile, rescale_max_percentile=rescale_max_percentile, ).astype(output_data_type) else: output_slice = input_slice.astype(output_data_type) if save: frame_dumper[:] = output_slice if store: # only keep data in cache if not dump to disk data.append(output_slice) if store: output_volume.data = numpy.asarray(data) # try also to append some metadata to it try: output_volume.metadata = input_volume.metadata or input_volume.load_metadata() except (OSError, KeyError): # if no metadata provided and or saved in disk or if some key are missing pass return output_volume def clamp_and_rescale_data( data: numpy.ndarray, new_min, new_max, data_min=None, data_max=None, rescale_min_percentile=RESCALE_MIN_PERCENTILE, rescale_max_percentile=RESCALE_MAX_PERCENTILE, ): """ rescale data to 'new_min', 'new_max' :param numpy.ndarray data: data to be rescaled :param dtype output_dtype: output dtype :param new_min: rescaled data new min (clamp min value) :param new_max: rescaled data new max (clamp max value) :param data_min: `data` min value to clamp to new_min. Any lower value will also be clamp to new_min. :param data_max: `data` max value to clamp to new_max. Any hight value will also be clamp to new_max. :param rescale_min_percentile: if `data_min` is None will set data_min to 'rescale_min_percentile' :param rescale_max_percentile: if `data_max` is None will set data_min to 'rescale_max_percentile' """ if data_min is None: data_min = numpy.percentile(data, rescale_min_percentile) if data_max is None: data_max = numpy.percentile(data, rescale_max_percentile) # rescale data rescaled_data = rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max) # clamp data rescaled_data[rescaled_data < new_min] = new_min rescaled_data[rescaled_data > new_max] = new_max return rescaled_data def find_histogram(volume: VolumeBase, scan: Optional[TomoScanBase] = None) -> Optional[DataUrl]: """ Look for histogram of the provided url. If found one return the DataUrl of the nabu histogram """ if not isinstance(volume, VolumeBase): raise TypeError(f"volume is expected to be an instance of {VolumeBase} not {type(volume)}") elif isinstance(volume, HDF5Volume): histogram_file = volume.data_url.file_path() if volume.url is not None: data_path = volume.url.data_path() if data_path.endswith("reconstruction"): data_path = "/".join( [ *data_path.split("/")[:-1], "histogram/results/data", ] ) else: data_path = "/".join((volume.url.data_path(), "histogram/results/data")) else: # TODO: FIXME: in some case (if the users provides the full data_url and if the 'DATA_DATASET_NAME' is not used we # will endup with an invalid data_path. Hope this case will not happen. Anyway this is a case that we can't handle.) # if trouble: check if data_path exists. If not raise an error saying this we can't find an histogram for this volume data_path = volume.data_url.data_path().replace(HDF5Volume.DATA_DATASET_NAME, "histogram/results/data") elif isinstance(volume, (EDFVolume, JP2KVolume, TIFFVolume, MultiTIFFVolume)): if isinstance(volume, (EDFVolume, JP2KVolume, TIFFVolume)): # TODO: check with pierre what is the policy of histogram files names histogram_file = os.path.join( volume.data_url.file_path(), volume.get_volume_basename() + "histogram.hdf5", ) else: # TODO: check with pierre what is the policy of histogram files names file_path, _ = os.path.splitext(volume.data_url.file_path()) histogram_file = os.path.join(file_path + "histogram.hdf5") if scan is not None: data_path = getattr(scan, "entry", "entry") else: # TODO: FIXME: how to get the entry name in every case ? # possible solutions are: # * look at the different entries and check for histogram: will work if only one histogram in the file # * Add a histogram request so the user can provide it (can be done at tomoscan level or nabu if we think this is specific to nabu) _logger.info("histogram file found but unable to find relevant histogram") return None else: raise NotImplementedError(f"volume {type(volume)} not handled") if not os.path.exists(histogram_file): _logger.info(f"{histogram_file} not found") return None with HDF5File(histogram_file, mode="r") as h5f: if not data_path in h5f: _logger.info(f"{data_path} in {histogram_file} not found") return None else: _logger.info(f"Found histogram {histogram_file}::/{data_path}") return DataUrl( file_path=histogram_file, data_path=data_path, scheme="silx", ) def _get_hst_saturations(hist, bins, rescale_min_percentile, rescale_max_percentile): hist_cum = numpy.cumsum(hist) bin_index_min = numpy.searchsorted(hist_cum, numpy.percentile(hist_cum, rescale_min_percentile)) bin_index_max = numpy.searchsorted(hist_cum, numpy.percentile(hist_cum, rescale_max_percentile)) return bins[bin_index_min], bins[bin_index_max] def _try_to_find_min_max_from_histo( input_volume: VolumeBase, rescale_min_percentile, rescale_max_percentile, scan=None ) -> tuple: """ util to interpret nabu histogram and deduce data_min and data_max to be used for rescaling a volume """ histogram_res_url = find_histogram(input_volume, scan=scan) data_min = data_max = None if histogram_res_url is not None: try: histogram = get_data(histogram_res_url) except Exception as e: _logger.error(f"Fail to load histogram from {histogram_res_url}. Reason is {e}") else: bins = histogram[1] hist = histogram[0] data_min, data_max = _get_hst_saturations(hist, bins, rescale_min_percentile, rescale_max_percentile) return data_min, data_max ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/detector_distortion.py0000644000175000017500000002674000000000000020232 0ustar00pierrepierreimport numpy as np from scipy import sparse class DetectorDistortionBase: """ """ def __init__(self, detector_full_shape_vh=(0, 0)): """This is the basis class. A simple identity transformation which has the only merit to show how it works.Reimplement this function to have more parameters for other transformations """ self._build_full_transformation(detector_full_shape_vh) def transform(self, source_data, do_full=False): """performs the transformation""" if do_full: result = self.transformation_matrix_full @ source_data.flat result.shape = source_data.shape else: result = self.transformation_matrix @ source_data.flat result.shape = self.target_shape return result def _build_full_transformation(self, detector_full_shape_vh): """A simple identity. Reimplement this function to have more parameters for other transformations """ self.detector_full_shape_vh = detector_full_shape_vh sz, sx = detector_full_shape_vh # A simple identity matrix in sparse coordinates format self.total_detector_npixs = detector_full_shape_vh[0] * detector_full_shape_vh[1] I_tmp = np.arange(self.total_detector_npixs) J_tmp = np.arange(self.total_detector_npixs) V_tmp = np.ones([self.total_detector_npixs], "f") coo_tmp = sparse.coo_matrix((V_tmp, (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() ## The following arrays are kept for future usage ## when, according to the "sub_region" parameter of the moment, ## they will be used to extract a "slice" of them ## which will map an appropriate data region corresponding to "sub_region_source" ## to the target "sub_region" of the moment self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def get_adapted_subregion(self, sub_region_xz): if sub_region_xz is not None: start_x, end_x, start_z, end_z = sub_region_xz else: start_x = 0 end_x = None start_z = 0 end_z = None (start_x, end_x, start_z, end_z) = self.set_sub_region_transformation((start_x, end_x, start_z, end_z)) return (start_x, end_x, start_z, end_z) def set_sub_region_transformation(self, target_sub_region=None): """must return a source sub_region. It sets internally an object (a practical implementation would be a sparse matrice) which can be reused in further applications of "transform" method for transforming the source sub_region data into the target sub_region """ if target_sub_region is None: target_sub_region = (None, None, 0, None) if self._status is not None and self._status["target_sub_region"] == target_sub_region: return self._status["source_sub_region"] else: self._status = None return self._set_sub_region_transformation(target_sub_region) def set_full_transformation(self): self._set_sub_region_transformation(do_full=True) def get_actual_shapes_source_target(self): if self._status is None: return None, None else: return self._status["source_sub_region"], self._status["target_sub_region"] def _set_sub_region_transformation(self, target_sub_region=None, do_full=False): """to be reimplemented in the derived classes""" if target_sub_region is None or do_full: target_sub_region = (None, None, 0, None) (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionBase only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] indices_offset = self.full_csr_indptr[row_ptr_start] source_offset = target_sub_region[2] * self.detector_full_shape_vh[1] data_tmp = self.full_csr_data[indices_start:indices_end] indices_tmp = self.full_csr_indices[indices_start:indices_end] - source_offset indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - indices_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end - z_start) * self.detector_full_shape_vh[1] tmp_transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) if do_full: self.transformation_matrix_full = tmp_transformation_matrix return None else: self.transformation_matrix = tmp_transformation_matrix self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start, z_end)), } return self._status["source_sub_region"] class DetectorDistortionMapsXZ(DetectorDistortionBase): def __init__(self, map_x, map_z): """ This class implements the distortion correction from the knowledge of two arrays, map_x and map_z. Pixel (i,j) of the corrected image is obtained by interpolating the raw data at position ( map_z(i,j), map_x(i,j) ). Parameters: map_x : float 2D array map_z : float 2D array """ self._build_full_transformation(map_x, map_z) def _build_full_transformation(self, map_x, map_z): """ """ detector_full_shape_vh = map_x.shape if detector_full_shape_vh != map_z.shape: message = f""" map_x and map_z must have the same shape but the dimensions were {map_x.shape} and {map_z.shape} """ raise ValueError(message) coordinates = np.array([map_z, map_x]) # padding sz, sx = detector_full_shape_vh total_detector_npixs = sz * sx xs = np.clip(np.array(coordinates[1].flat), [[0]], [[sx - 1]]) zs = np.clip(np.array(coordinates[0].flat), [[0]], [[sz - 1]]) ix0s = np.floor(xs) ix1s = np.ceil(xs) fx = xs - ix0s iz0s = np.floor(zs) iz1s = np.ceil(zs) fz = zs - iz0s I_tmp = np.empty([4 * sz * sx], np.int64) J_tmp = np.empty([4 * sz * sx], np.int64) V_tmp = np.ones([4 * sz * sx], "f") I_tmp[:] = np.arange(sz * sx * 4) // 4 J_tmp[0::4] = iz0s * sx + ix0s J_tmp[1::4] = iz0s * sx + ix1s J_tmp[2::4] = iz1s * sx + ix0s J_tmp[3::4] = iz1s * sx + ix1s V_tmp[0::4] = (1 - fz) * (1 - fx) V_tmp[1::4] = (1 - fz) * fx V_tmp[2::4] = fz * (1 - fx) V_tmp[3::4] = fz * fx self.detector_full_shape_vh = detector_full_shape_vh coo_tmp = sparse.coo_matrix((V_tmp.astype("f"), (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def _set_sub_region_transformation( self, target_sub_region=( ( None, None, 0, 0, ), ), do_full=False, ): if target_sub_region is None or do_full: target_sub_region = (None, None, 0, None) (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionRotation only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] data_tmp = self.full_csr_data[indices_start:indices_end] target_offset = self.full_csr_indptr[row_ptr_start] indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - target_offset indices_tmp = self.full_csr_indices[indices_start:indices_end] iz_source = (indices_tmp) // self.detector_full_shape_vh[1] z_start_source = iz_source.min() z_end_source = iz_source.max() + 1 source_offset = z_start_source * self.detector_full_shape_vh[1] indices_tmp = indices_tmp - source_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end_source - z_start_source) * self.detector_full_shape_vh[1] tmp_transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) if do_full: self.transformation_matrix_full = tmp_transformation_matrix return None else: self.transformation_matrix = tmp_transformation_matrix self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start_source, z_end_source)), } return self._status["source_sub_region"] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/reader.py0000644000175000017500000005403300000000000015401 0ustar00pierrepierreimport os from math import ceil from multiprocessing.pool import ThreadPool import numpy as np from silx.io.dictdump import h5todict from tomoscan.io import HDF5File from .utils import get_compacted_dataslices, convert_dict_values from ..misc.binning import binning as image_binning from ..utils import subsample_dict, get_3D_subregion, get_num_threads try: from silx.third_party.EdfFile import EdfFile except ImportError: EdfFile = None class Reader: """ Abstract class for various file readers. """ def __init__(self, sub_region=None): """ Parameters ---------- sub_region: tuple, optional Coordinates in the form (start_x, end_x, start_y, end_y), to read a subset of each frame. It can be used for Regions of Interest (ROI). Indices start at zero ! """ self._set_default_parameters(sub_region) def _set_default_parameters(self, sub_region): self._set_subregion(sub_region) def _set_subregion(self, sub_region): self.sub_region = sub_region if sub_region is not None: start_x, end_x, start_y, end_y = sub_region self.start_x = start_x self.end_x = end_x self.start_y = start_y self.end_y = end_y else: self.start_x = 0 self.end_x = None self.start_y = 0 self.end_y = None def get_data(self, data_url): """ Get data from a silx.io.url.DataUrl """ raise ValueError("Base class") def release(self): """ Release the file if needed. """ pass class NPReader(Reader): multi_load = True def __init__(self, sub_region=None, mmap=True): """ Reader for NPY/NPZ files. Mostly used for internal development. Please refer to the documentation of nabu.io.reader.Reader """ super().__init__(sub_region=sub_region) self._file_desc = {} self._set_mmap(mmap) def _set_mmap(self, mmap): self.mmap_mode = "r" if mmap else None def _open(self, data_url): file_path = data_url.file_path() file_ext = self._get_file_type(file_path) if file_ext == "npz": if file_path not in self._file_desc: self._file_desc[file_path] = np.load(file_path, mmap_mode=self.mmap_mode) data_ref = self._file_desc[file_path][data_url.data_path()] else: data_ref = np.load(file_path, mmap_mode=self.mmap_mode) return data_ref @staticmethod def _get_file_type(fname): if fname.endswith(".npy"): return "npy" elif fname.endswith(".npz"): return "npz" else: raise ValueError("Not a numpy file: %s" % fname) def get_data(self, data_url): data_ref = self._open(data_url) data_slice = data_url.data_slice() if data_slice is None: res = data_ref[self.start_y : self.end_y, self.start_x : self.end_x] else: res = data_ref[data_slice, self.start_y : self.end_y, self.start_x : self.end_x] return res def release(self): for fname, fdesc in self._file_desc.items(): if fdesc is not None: fdesc.close() self._file_desc[fname] = None def __del__(self): self.release() class EDFReader(Reader): multi_load = False # not implemented def __init__(self, sub_region=None): """ A class for reading series of EDF Files. Multi-frames EDF are not supported. """ if EdfFile is None: raise ImportError("Need EdfFile to use this reader") super().__init__(sub_region=sub_region) def read(self, fname): E = EdfFile(fname, "r") if self.sub_region is not None: data = E.GetData( 0, Pos=(self.start_x, self.start_y), Size=(self.end_x - self.start_x, self.end_y - self.start_y), ) else: data = E.GetData(0) E.File.close() return data def get_data(self, data_url): return self.read(data_url.file_path()) class HDF5Reader(Reader): multi_load = True def __init__(self, sub_region=None): """ A class for reading a HDF5 File. """ super().__init__(sub_region=sub_region) self._file_desc = {} def _open(self, file_path): if file_path not in self._file_desc: self._file_desc[file_path] = HDF5File(file_path, "r", swmr=True) def get_data(self, data_url): file_path = data_url.file_path() self._open(file_path) h5dataset = self._file_desc[file_path][data_url.data_path()] data_slice = data_url.data_slice() if data_slice is None: res = h5dataset[self.start_y : self.end_y, self.start_x : self.end_x] else: res = h5dataset[data_slice, self.start_y : self.end_y, self.start_x : self.end_x] return res def release(self): for fname, fdesc in self._file_desc.items(): if fdesc is not None: try: fdesc.close() self._file_desc[fname] = None except Exception as exc: print("Error while closing %s: %s" % (fname, str(exc))) def __del__(self): self.release() class HDF5Loader: """ An alternative class to HDF5Reader where information is first passed at class instantiation """ def __init__(self, fname, data_path, sub_region=None, data_buffer=None, pre_allocate=True, dtype="f"): self.fname = fname self.data_path = data_path self._set_subregion(sub_region) if not ((data_buffer is not None) ^ (pre_allocate is True)): raise ValueError("Please provide either 'data_buffer' or 'pre_allocate'") self.data = data_buffer self._loaded = False if pre_allocate: expected_shape = get_hdf5_dataset_shape(fname, data_path, sub_region=sub_region) self.data = np.zeros(expected_shape, dtype=dtype) def _set_subregion(self, sub_region): self.sub_region = sub_region if sub_region is not None: start_z, end_z, start_y, end_y, start_x, end_x = sub_region self.start_x, self.end_x = start_x, end_x self.start_y, self.end_y = start_y, end_y self.start_z, self.end_z = start_z, end_z else: self.start_x, self.end_x = None, None self.start_y, self.end_y = None, None self.start_z, self.end_z = None, None def load_data(self, force_load=False): if self._loaded and not force_load: return self.data with HDF5File(self.fname, "r") as fdesc: if self.data is None: self.data = fdesc[self.data_path][ self.start_z : self.end_z, self.start_y : self.end_y, self.start_x : self.end_x ] else: self.data[:] = fdesc[self.data_path][ self.start_z : self.end_z, self.start_y : self.end_y, self.start_x : self.end_x ] self._loaded = True return self.data class ChunkReader: """ A reader of chunk of images. """ def __init__( self, files, sub_region=None, detector_corrector=None, pre_allocate=True, data_buffer=None, convert_float=False, shape=None, dtype=None, binning=None, dataset_subsampling=None, num_threads=None, ): """ Initialize a "ChunkReader". A chunk is a stack of images. Parameters ---------- files: dict Dictionary where the key is the file/data index, and the value is a silx.io.url.DataUrl pointing to the data. The dict must contain only the files which shall be used ! Note: the shape and data type is infered from the first data file. sub_region: tuple, optional If provided, this must be a tuple in the form (start_x, end_x, start_y, end_y). Each image will be cropped to this region. This is used to specify a chunk of files. Each of the parameters can be None, in this case the default start and end are taken in each dimension. pre_allocate: bool Whether to pre-allocate data before reading. data_buffer: array-like, optional If `pre_allocate` is set to False, this parameter has to be provided. It is an array-like object which will hold the data. convert_float: bool Whether to convert data to float32, regardless of the input data type. shape: tuple, optional Shape of each image. If not provided, it is inferred from the first image in the collection. dtype: `numpy.dtype`, optional Data type of each image. If not provided, it is inferred from the first image in the collection. binning: int or tuple of int, optional Whether to bin the data. If multi-dimensional binning is done, the parameter must be in the form (binning_x, binning_y). Each image will be binned by these factors. dataset_subsampling: int, optional Whether to subsample the dataset. If an integer `n` is provided, then one image out of `n` will be read. num_threads: int, optional Number of threads to use for binning the data. Default is to use all available threads. This parameter has no effect when binning is disabled. Notes ------ The files are provided as a collection of `silx.io.DataURL`. The file type is inferred from the extension. Binning is different from subsampling. Using binning will not speed up the data retrieval (quite the opposite), since the whole (subregion of) data is read and then binning is performed. """ self.detector_corrector = detector_corrector self._get_reader_class(files) self.dataset_subsampling = dataset_subsampling self.num_threads = get_num_threads(num_threads) self._set_files(files) self._get_shape_and_dtype(shape, dtype, binning) self._set_subregion(sub_region) self._init_reader() self._loaded = False self.convert_float = convert_float if convert_float: self.out_dtype = np.float32 else: self.out_dtype = self.dtype if not ((data_buffer is not None) ^ (pre_allocate is True)): raise ValueError("Please provide either 'data_buffer' or 'pre_allocate'") self.files_data = data_buffer if data_buffer is not None: # overwrite out_dtype self.out_dtype = data_buffer.dtype if data_buffer.shape != self.chunk_shape: raise ValueError("Expected shape %s but got %s" % (self.shape, data_buffer.shape)) if pre_allocate: self.files_data = np.zeros(self.chunk_shape, dtype=self.out_dtype) if (self.binning is not None) and (np.dtype(self.out_dtype).kind in ["u", "i"]): raise ValueError( "Output datatype cannot be integer when using binning. Please set the 'convert_float' parameter to True or specify a 'data_buffer'." ) def _set_files(self, files): if len(files) == 0: raise ValueError("Expected at least one data file") self.n_files = len(files) self.files = files self._sorted_files_indices = sorted(files.keys()) self._fileindex_to_idx = dict.fromkeys(self._sorted_files_indices) self._configure_subsampling() def _infer_file_type(self, files): fname = files[sorted(files.keys())[0]].file_path() ext = os.path.splitext(fname)[-1].replace(".", "") if ext not in Readers: raise ValueError("Unknown file format %s. Supported formats are: %s" % (ext, str(Readers.keys()))) return ext def _get_reader_class(self, files): ext = self._infer_file_type(files) reader_class = Readers[ext] self._reader_class = reader_class def _get_shape_and_dtype(self, shape, dtype, binning): if shape is None or dtype is None: shape, dtype = self._infer_shape_and_dtype() assert len(shape) == 2, "Expected the shape of an image (2-tuple)" self.shape_total = shape self.dtype = dtype self._set_binning(binning) def _configure_subsampling(self): dataset_subsampling = self.dataset_subsampling self.files_subsampled = self.files if dataset_subsampling is not None and dataset_subsampling > 1: self.files_subsampled = subsample_dict(self.files, dataset_subsampling) self.n_files = len(self.files_subsampled) if not (self._reader_class.multi_load): # 3D loading not supported for this reader. # Data is loaded frames by frame, so subsample directly self.files self.files = self.files_subsampled self._sorted_files_indices = sorted(self.files.keys()) self._fileindex_to_idx = dict.fromkeys(self._sorted_files_indices) def _infer_shape_and_dtype(self): self._reader_entire_image = self._reader_class() first_file_dataurl = self.files[self._sorted_files_indices[0]] first_file_data = self._reader_entire_image.get_data(first_file_dataurl) return first_file_data.shape, first_file_data.dtype def _set_subregion(self, sub_region): sub_region = sub_region or (None, None, None, None) start_x, end_x, start_y, end_y = sub_region if start_x is None: start_x = 0 if start_y is None: start_y = 0 if end_x is None: end_x = self.shape_total[1] if end_y is None: end_y = self.shape_total[0] self.sub_region = (start_x, end_x, start_y, end_y) self.shape = (end_y - start_y, end_x - start_x) if self.binning is not None: self.shape = (self.shape[0] // self.binning[1], self.shape[1] // self.binning[0]) self.chunk_shape = (self.n_files,) + self.shape if self.detector_corrector is not None: self.detector_corrector.set_sub_region_transformation(target_sub_region=self.sub_region) def _init_reader(self): # instantiate reader with user params if self.detector_corrector is not None: adapted_subregion = self.detector_corrector.get_adapted_subregion(self.sub_region) else: adapted_subregion = self.sub_region self.file_reader = self._reader_class(sub_region=adapted_subregion) def _set_binning(self, binning): if binning is None: self.binning = None return if np.isscalar(binning): binning = (binning, binning) else: assert len(binning) == 2, "Expected binning in the form (binning_x, binning_y)" if binning[0] == 1 and binning[1] == 1: self.binning = None return for b in binning: if int(b) != b: raise ValueError("Expected an integer number for binning values, but got %s" % binning) self.binning = binning def get_data(self, file_url): """ Get the data associated to a file url. """ arr = self.file_reader.get_data(file_url) if arr.ndim == 2: if self.detector_corrector is not None: arr = self.detector_corrector.transform(arr) if self.binning is not None: arr = image_binning(arr, self.binning[::-1]) else: if self.detector_corrector is not None: if self.detector_corrector is not None: _, ( src_x_start, src_x_end, src_z_start, src_z_end, ) = self.detector_corrector.get_actual_shapes_source_target() arr_target = np.empty([len(arr), src_z_end - src_z_start, src_x_end - src_x_start], "f") def apply_corrector(i_img_tuple): i, img = i_img_tuple arr_target[i] = self.detector_corrector.transform(img) with ThreadPool(self.num_threads) as tp: tp.map(apply_corrector, enumerate(arr)) arr = arr_target if self.binning is not None: nz = arr.shape[0] res = np.zeros((nz,) + image_binning(arr[0], self.binning[::-1]).shape, dtype="f") def apply_binning(img_res_tuple): img, res = img_res_tuple res[:] = image_binning(img, self.binning[::-1]) with ThreadPool(self.num_threads) as tp: tp.map(apply_binning, zip(arr, res)) arr = res return arr def _load_single(self): for i, fileidx in enumerate(self._sorted_files_indices): file_url = self.files[fileidx] self.files_data[i] = self.get_data(file_url) self._fileindex_to_idx[fileidx] = i def _load_multi(self): urls_compacted = get_compacted_dataslices(self.files, subsampling=self.dataset_subsampling) loaded = {} start_idx = 0 for idx in self._sorted_files_indices: url = urls_compacted[idx] url_str = str(url) is_loaded = loaded.get(url_str, False) if is_loaded: continue ds = url.data_slice() delta_z = ds.stop - ds.start if ds.step is not None and ds.step > 1: delta_z = ceil(delta_z / ds.step) end_idx = start_idx + delta_z self.files_data[start_idx:end_idx] = self.get_data(url) start_idx += delta_z loaded[url_str] = True def load_files(self, overwrite: bool = False): """ Load the files whose links was provided at class instantiation. Parameters ----------- overwrite: bool, optional Whether to force reloading the files if already loaded. """ if self._loaded and not (overwrite): raise ValueError("Radios were already loaded. Call load_files(overwrite=True) to force reloading") if self.file_reader.multi_load: self._load_multi() else: self._load_single() self._loaded = True load_data = load_files @property def data(self): return self.files_data Readers = { "edf": EDFReader, "hdf5": HDF5Reader, "h5": HDF5Reader, "nx": HDF5Reader, "npz": NPReader, "npy": NPReader, } def load_images_from_dataurl_dict(data_url_dict, **chunk_reader_kwargs): """ Load a dictionary of dataurl into numpy arrays. Parameters ---------- data_url_dict: dict A dictionary where the keys are integers (the index of each image in the dataset), and the values are numpy.ndarray (data_url_dict). Other parameters ----------------- chunk_reader_kwargs: params Named parameters passed to `nabu.io.reader.ChunkReader`. Returns -------- res: dict A dictionary where the keys are the same as `data_url_dict`, and the values are numpy arrays. """ chunk_reader = ChunkReader(data_url_dict, **chunk_reader_kwargs) img_dict = {} for img_idx, img_url in data_url_dict.items(): img_dict[img_idx] = chunk_reader.get_data(img_url) return img_dict def load_images_stack_from_hdf5(fname, h5_data_path, sub_region=None): """ Load a 3D dataset from a HDF5 file. Parameters ----------- fname: str File path h5_data_path: str Data path within the HDF5 file sub_region: tuple, optional Tuple indicating which sub-volume to load, in the form (xmin, xmax, ymin, ymax, zmin, zmax) where the 3D dataset has the python shape (N_Z, N_Y, N_X). This means that the data will be loaded as `data[zmin:zmax, ymin:ymax, xmin:xmax]`. """ xmin, xmax, ymin, ymax, zmin, zmax = get_3D_subregion(sub_region) with HDF5File(fname, "r") as f: d_ptr = f[h5_data_path] data = d_ptr[zmin:zmax, ymin:ymax, xmin:xmax] return data def get_hdf5_dataset_shape(fname, h5_data_path, sub_region=None): zmin, zmax, ymin, ymax, xmin, xmax = get_3D_subregion(sub_region) with HDF5File(fname, "r") as f: d_ptr = f[h5_data_path] shape = d_ptr.shape n_z, n_y, n_x = shape # perhaps there is more elegant res_shape = [] for n, bounds in zip([n_z, n_y, n_x], ((zmin, zmax), (ymin, ymax), (xmin, xmax))): res_shape.append(np.arange(n)[bounds[0] : bounds[1]].size) return tuple(res_shape) def check_virtual_sources_exist(fname, data_path): with HDF5File(fname, "r") as f: if data_path not in f: print("No dataset %s in file %s" % (data_path, fname)) return False dptr = f[data_path] if not dptr.is_virtual: return True for vsource in dptr.virtual_sources(): vsource_fname = os.path.join(os.path.dirname(dptr.file.filename), vsource.file_name) if not os.path.isfile(vsource_fname): print("No such file: %s" % vsource_fname) return False elif not check_virtual_sources_exist(vsource_fname, vsource.dset_name): print("Error with virtual source %s" % vsource_fname) return False return True def import_h5_to_dict(h5file, h5path, asarray=False): """ Wrapper on top of silx.io.dictdump.dicttoh5 replacing "None" with None Parameters ----------- h5file: str File name h5path: str Path in the HDF5 file asarray: bool, optional Whether to convert each numeric value to an 0D array. Default is False. """ dic = h5todict(h5file, path=h5path, asarray=asarray) modified_dic = convert_dict_values(dic, {"None": None}, bytes_tostring=True) return modified_dic ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/reader_helical.py0000644000175000017500000001057500000000000017065 0ustar00pierrepierrefrom .reader import * class ChunkReaderHelical(ChunkReader): """implements reading by projections subsets""" def _set_subregion(self, sub_region): super()._set_subregion(sub_region) ########### # undo the chun_size setting of the base class # to avoid allocation of Tera bytes in the helical case self.chunk_shape = (1,) + self.shape def set_data_buffer(self, data_buffer, pre_allocate=False): if data_buffer is not None: # overwrite out_dtype self.files_data = data_buffer self.out_dtype = data_buffer.dtype if data_buffer.shape[1:] != self.shape: raise ValueError("Expected shape %s but got %s" % (self.shape, data_buffer.shape)) if pre_allocate: self.files_data = np.zeros(self.chunk_shape, dtype=self.out_dtype) if (self.binning is not None) and (np.dtype(self.out_dtype).kind in ["u", "i"]): raise ValueError( "Output datatype cannot be integer when using binning. Please set the 'convert_float' parameter to True or specify a 'data_buffer'." ) def get_binning(self): if self.binning is None: return 1, 1 else: return self.binning def _load_single(self, sub_total_prange_slice=slice(None, None)): if sub_total_prange_slice == slice(None, None): sorted_files_indices = self._sorted_files_indices else: sorted_files_indices = self._sorted_files_indices[sub_total_prange_slice] for i, fileidx in enumerate(sorted_files_indices): file_url = self.files[fileidx] self.files_data[i] = self.get_data(file_url) self._fileindex_to_idx[fileidx] = i def _apply_subsample_fact(self, t): if t is not None: t = t * self.dataset_subsampling return t def _load_multi(self, sub_total_prange_slice=slice(None, None)): if sub_total_prange_slice == slice(None, None): files_to_be_compacted_dict = self.files sorted_files_indices = self._sorted_files_indices else: if self.dataset_subsampling > 1: start, stop, step = list( map( self._apply_subsample_fact, [sub_total_prange_slice.start, sub_total_prange_slice.stop, sub_total_prange_slice.step], ) ) sub_total_prange_slice = slice(start, stop, step) sorted_files_indices = self._sorted_files_indices[sub_total_prange_slice] files_to_be_compacted_dict = dict( zip(sorted_files_indices, [self.files[idx] for idx in sorted_files_indices]) ) urls_compacted = get_compacted_dataslices(files_to_be_compacted_dict, subsampling=self.dataset_subsampling) loaded = {} start_idx = 0 for idx in sorted_files_indices: url = urls_compacted[idx] url_str = str(url) is_loaded = loaded.get(url_str, False) if is_loaded: continue ds = url.data_slice() delta_z = ds.stop - ds.start if ds.step is not None and ds.step > 1: delta_z //= ds.step end_idx = start_idx + delta_z self.files_data[start_idx:end_idx] = self.get_data(url) start_idx += delta_z loaded[url_str] = True def load_files(self, overwrite: bool = False, sub_total_prange_slice=slice(None, None)): """ Load the files whose links was provided at class instantiation. Parameters ----------- overwrite: bool, optional Whether to force reloading the files if already loaded. """ if self._loaded and not (overwrite): raise ValueError("Radios were already loaded. Call load_files(overwrite=True) to force reloading") if self.file_reader.multi_load: self._load_multi(sub_total_prange_slice) else: if self.dataset_subsampling > 1: assert ( False ), " in helica pipeline, load file _load_single has not yet been adapted to angular subsampling " self._load_single(sub_total_prange_slice) self._loaded = True load_data = load_files @property def data(self): return self.files_data ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/io/tests/0000755000175000017500000000000000000000000014722 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1596607942.0 nabu-2023.1.1/nabu/io/tests/__init__.py0000644000175000017500000000000000000000000017021 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/tests/test_cast_volume.py0000644000175000017500000002126300000000000020660 0ustar00pierrepierreimport numpy from nabu.io.cast_volume import ( cast_volume, clamp_and_rescale_data, find_histogram, get_default_output_volume, ) from tomoscan.esrf.volume import ( EDFVolume, HDF5Volume, JP2KVolume, MultiTIFFVolume, TIFFVolume, ) from nabu.io.writer import __have_jp2k__ from tomoscan.esrf.scan.edfscan import EDFTomoScan from tomoscan.esrf.scan.hdf5scan import HDF5TomoScan import pytest import h5py import os from silx.io.url import DataUrl @pytest.mark.skipif(not __have_jp2k__, reason="need jp2k (glymur) for this test") def test_get_default_output_volume(): """ insure nabu.io.cast_volume is working properly """ with pytest.raises(TypeError): # test input_volume type get_default_output_volume(input_volume="dsad/dsad/", output_type="jp2") with pytest.raises(ValueError): # test output value get_default_output_volume(input_volume=EDFVolume(folder="test"), output_type="toto") # test edf to jp2 input_volume = EDFVolume( folder="/path/to/my_folder", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="jp2", ) assert isinstance(output_volume, JP2KVolume) assert output_volume.data_url.file_path() == "/path/to/vol_cast" assert output_volume.get_volume_basename() == "my_folder" # test hdf5 to tiff input_volume = HDF5Volume( file_path="/path/to/my_file.hdf5", data_path="entry0012", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="tiff", ) assert isinstance(output_volume, TIFFVolume) assert output_volume.data_url.file_path() == "/path/to/vol_cast/my_file" assert output_volume.get_volume_basename() == "my_file" # test Multitiff to hdf5 input_volume = MultiTIFFVolume( file_path="my_file.tiff", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="hdf5", ) assert isinstance(output_volume, HDF5Volume) assert output_volume.data_url.file_path() == "vol_cast/my_file.hdf5" assert output_volume.data_url.data_path() == HDF5Volume.DATA_DATASET_NAME assert output_volume.metadata_url.file_path() == "vol_cast/my_file.hdf5" assert output_volume.metadata_url.data_path() == HDF5Volume.METADATA_GROUP_NAME # test jp2 to hdf5 input_volume = JP2KVolume( folder="folder", volume_basename="basename", ) output_volume = get_default_output_volume( input_volume=input_volume, output_type="hdf5", ) assert isinstance(output_volume, HDF5Volume) assert output_volume.data_url.file_path() == "folder/vol_cast/basename.hdf5" assert output_volume.data_url.data_path() == f"/volume/{HDF5Volume.DATA_DATASET_NAME}" assert output_volume.metadata_url.file_path() == "folder/vol_cast/basename.hdf5" assert output_volume.metadata_url.data_path() == f"/volume/{HDF5Volume.METADATA_GROUP_NAME}" def test_find_histogram_hdf5_volume(tmp_path): """ test find_histogram function with hdf5 volume """ h5_file = os.path.join(tmp_path, "test_file") with h5py.File(h5_file, mode="w") as h5f: h5f.require_group("myentry/histogram/results/data") # if volume url provided then can find it assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="myentry")) == DataUrl( file_path=h5_file, data_path="myentry/histogram/results/data", scheme="silx", ) assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="entry")) == None def test_find_histogram_single_frame_volume(tmp_path): """ test find_histogram function with single frame volume TODO: improve: for now histogram file are created manually. If this can be more coupled with the "real" histogram generation it would be way better """ # create volume and histogram volume = EDFVolume( folder=tmp_path, volume_basename="volume", ) histogram_file = os.path.join(tmp_path, "volumehistogram.hdf5") with h5py.File(histogram_file, mode="w") as h5f: h5f.require_group("entry/histogram/results/data") # check behavior assert find_histogram(volume=volume) == None assert find_histogram( volume=volume, scan=EDFTomoScan(scan=str(tmp_path)), ) == DataUrl( file_path=histogram_file, data_path="entry", scheme="silx", ) assert find_histogram( volume=volume, scan=HDF5TomoScan(scan=str(tmp_path), entry="entry"), ) == DataUrl( file_path=histogram_file, data_path="entry", scheme="silx", ) def test_find_histogram_multi_tiff_volume(tmp_path): """ test find_histogram function with multi tiff frame volume TODO: improve: for now histogram file are created manually. If this can be more coupled with the "real" histogram generation it would be way better """ # create volume and histogram tiff_file = os.path.join(tmp_path, "my_tiff.tif") volume = MultiTIFFVolume( file_path=tiff_file, ) histogram_file = os.path.join(tmp_path, "my_tiffhistogram.hdf5") with h5py.File(histogram_file, mode="w") as h5f: h5f.require_group("entry/histogram/results/data") # check behavior assert find_histogram(volume=volume) == None assert find_histogram( volume=volume, scan=EDFTomoScan(scan=str(tmp_path)), ) == DataUrl( file_path=histogram_file, data_path="entry", scheme="silx", ) assert find_histogram( volume=volume, scan=HDF5TomoScan(scan=str(tmp_path), entry="entry"), ) == DataUrl( file_path=histogram_file, data_path="entry", scheme="silx", ) @pytest.mark.parametrize("input_dtype", (numpy.float32, numpy.float64, numpy.uint8, numpy.uint16)) def test_clamp_and_rescale_data(input_dtype): """ test 'rescale_data' function """ array = numpy.linspace( start=1, stop=100, num=100, endpoint=True, dtype=input_dtype, ).reshape(10, 10) rescaled_array = clamp_and_rescale_data( data=array, new_min=10, new_max=90, rescale_min_percentile=20, # provided to insure they will be ignored rescale_max_percentile=80, # provided to insure they will be ignored ) assert rescaled_array.min() == 10 assert rescaled_array.max() == 90 numpy.testing.assert_equal(rescaled_array.flatten()[0:10], numpy.array([10] * 10)) numpy.testing.assert_equal(rescaled_array.flatten()[90:100], numpy.array([90] * 10)) def test_cast_volume(tmp_path): """ test cast_volume """ raw_data = numpy.linspace( start=1, stop=100, num=100, endpoint=True, dtype=numpy.float64, ).reshape(1, 10, 10) volume_hdf5_file_path = os.path.join(tmp_path, "myvolume.hdf5") volume_hdf5 = HDF5Volume( file_path=volume_hdf5_file_path, data_path="myentry", data=raw_data, ) volume_edf = EDFVolume( folder=os.path.join(tmp_path, "volume_folder"), ) # test when no histogram existing cast_volume( input_volume=volume_hdf5, output_volume=volume_edf, output_data_type=numpy.dtype(numpy.uint16), rescale_min_percentile=10, rescale_max_percentile=90, save=True, store=True, ) # if percentiles 10 and 90 provided, no data_min and data_max then they will be computed from data min / max # append histogram with h5py.File(volume_hdf5_file_path, mode="a") as h5s: hist = numpy.array([20, 20, 20, 20, 20, 20]) bins = numpy.array([0, 20, 40, 60, 80, 100]) h5s["myentry/histogram/results/data"] = numpy.vstack((hist, bins)) # and test it again volume_edf.overwrite = True cast_volume( input_volume=volume_hdf5, output_volume=volume_edf, output_data_type=numpy.dtype(numpy.uint16), rescale_min_percentile=20, rescale_max_percentile=60, save=True, store=True, ) # test to cast the already cast volumes volume_tif = EDFVolume( folder=os.path.join(tmp_path, "second_volume_folder"), ) volume_tif.overwrite = True cast_volume( input_volume=volume_edf, output_volume=volume_tif, output_data_type=numpy.dtype(numpy.uint8), save=True, store=True, ) assert volume_tif.data.dtype == numpy.uint8 volume_tif.overwrite = False with pytest.raises(OSError): cast_volume( input_volume=volume_edf, output_volume=volume_tif, output_data_type=numpy.dtype(numpy.uint8), save=True, store=True, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/tests/test_detector_distortion.py0000644000175000017500000001427700000000000022435 0ustar00pierrepierreimport numpy as np import scipy.ndimage from nabu.io.detector_distortion import DetectorDistortionBase from scipy import sparse from nabu.misc.rotation import Rotation, __have__skimage__ import pytest if __have__skimage__: import skimage @pytest.mark.skipif(not (__have__skimage__), reason="Need scikit-image for rotation") def test_detector_distortion(): image = scipy.ndimage.gaussian_filter(np.random.random([379, 1357]), 3.0) center_xz = ((image.shape[1] - 1) / 2, (image.shape[0] - 1) / 2) part_to_be_retrieved = image[100:279] rotated_image = skimage.transform.rotate(image, angle=5.0, center=center_xz[::]) corrector = DetectorDistortionRotation(detector_full_shape_vh=image.shape, center_xz=center_xz, angle_deg=-5.0) start_x, end_x, start_z, end_z = corrector.set_sub_region_transformation( target_sub_region=( None, None, 100, 279, ) ) source = rotated_image[start_z:end_z, start_x:end_x] retrieved = corrector.transform(source) diff = (retrieved - part_to_be_retrieved)[:, 20:-20] assert abs(diff).std() < 1e-3 class DetectorDistortionRotation(DetectorDistortionBase): """ """ def __init__(self, detector_full_shape_vh=(0, 0), center_xz=(0, 0), angle_deg=0.0): """This is the basis class. A simple identity transformation which has the only merit to show how it works.Reimplement this function to have more parameters for other transformations """ self._build_full_transformation(detector_full_shape_vh, center_xz, angle_deg) def _build_full_transformation(self, detector_full_shape_vh, center_xz, angle_deg): """A simple identity. Reimplement this function to have more parameters for other transformations """ indices = np.indices(detector_full_shape_vh) center_x, center_z = center_xz coordinates = (indices.T - np.array([center_z, center_x])).T c = np.cos(np.deg2rad(angle_deg)) s = np.sin(np.deg2rad(angle_deg)) rot_mat = np.array([[c, s], [-s, c]]) coordinates = np.tensordot(rot_mat, coordinates, axes=[1, 0]) # padding sz, sx = detector_full_shape_vh total_detector_npixs = sz * sx xs = np.clip(np.array(coordinates[1].flat) + center_x, [[0]], [[sx - 1]]) zs = np.clip(np.array(coordinates[0].flat) + center_z, [[0]], [[sz - 1]]) ix0s = np.floor(xs) ix1s = np.ceil(xs) fx = xs - ix0s iz0s = np.floor(zs) iz1s = np.ceil(zs) fz = zs - iz0s I_tmp = np.empty([4 * sz * sx], np.int64) J_tmp = np.empty([4 * sz * sx], np.int64) V_tmp = np.ones([4 * sz * sx], "f") I_tmp[:] = np.arange(sz * sx * 4) // 4 J_tmp[0::4] = iz0s * sx + ix0s J_tmp[1::4] = iz0s * sx + ix1s J_tmp[2::4] = iz1s * sx + ix0s J_tmp[3::4] = iz1s * sx + ix1s V_tmp[0::4] = (1 - fz) * (1 - fx) V_tmp[1::4] = (1 - fz) * fx V_tmp[2::4] = fz * (1 - fx) V_tmp[3::4] = fz * fx self.detector_full_shape_vh = detector_full_shape_vh coo_tmp = sparse.coo_matrix((V_tmp.astype("f"), (I_tmp, J_tmp)), shape=(sz * sx, sz * sx)) csr_tmp = coo_tmp.tocsr() self.full_csr_data = csr_tmp.data self.full_csr_indices = csr_tmp.indices self.full_csr_indptr = csr_tmp.indptr ## This will be used to save time if the same sub_region argument is requested several time in a row self._status = None def _set_sub_region_transformation( self, target_sub_region=( ( None, None, 0, 0, ), ), ): (x_start, x_end, z_start, z_end) = target_sub_region if z_start is None: z_start = 0 if z_end is None: z_end = self.detector_full_shape_vh[0] if (x_start, x_end) not in [(None, None), (0, None), (0, self.detector_full_shape_vh[1])]: message = f""" In the base class DetectorDistortionRotation only vertical slicing is accepted. The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the full horizontal size which is {self.detector_full_shape_vh[1]} """ raise ValueError() x_start, x_end = 0, self.detector_full_shape_vh[1] row_ptr_start = z_start * self.detector_full_shape_vh[1] row_ptr_end = z_end * self.detector_full_shape_vh[1] indices_start = self.full_csr_indptr[row_ptr_start] indices_end = self.full_csr_indptr[row_ptr_end] data_tmp = self.full_csr_data[indices_start:indices_end] target_offset = self.full_csr_indptr[row_ptr_start] indptr_tmp = self.full_csr_indptr[row_ptr_start : row_ptr_end + 1] - target_offset indices_tmp = self.full_csr_indices[indices_start:indices_end] iz_source = (indices_tmp) // self.detector_full_shape_vh[1] z_start_source = iz_source.min() z_end_source = iz_source.max() + 1 source_offset = z_start_source * self.detector_full_shape_vh[1] indices_tmp = indices_tmp - source_offset target_size = (z_end - z_start) * self.detector_full_shape_vh[1] source_size = (z_end_source - z_start_source) * self.detector_full_shape_vh[1] self.transformation_matrix = sparse.csr_matrix( (data_tmp, indices_tmp, indptr_tmp), shape=(target_size, source_size) ) self.target_shape = ((z_end - z_start), self.detector_full_shape_vh[1]) ## For the identity matrix the source and the target have the same size. ## The two following lines are trivial. ## For this identity transformation only the slicing of the appropriate part ## of the identity sparse matrix is slightly laborious. ## Practical case will be more complicated and source_sub_region ## will be in general larger than the target_sub_region self._status = { "target_sub_region": ((x_start, x_end, z_start, z_end)), "source_sub_region": ((x_start, x_end, z_start_source, z_end_source)), } return self._status["source_sub_region"] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/tests/test_writers.py0000644000175000017500000001602400000000000020035 0ustar00pierrepierrefrom os import path from tempfile import TemporaryDirectory import pytest import numpy as np from tifffile import TiffReader from silx.io.dictdump import h5todict, dicttoh5 from nabu.misc.utils import psnr from nabu.io.writer import NXProcessWriter, TIFFWriter, JP2Writer, __have_jp2k__ from nabu.io.reader import import_h5_to_dict from nabu.testutils import get_data if __have_jp2k__: from glymur import Jp2k @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls._tmpdir = TemporaryDirectory(prefix="nabu_") cls.tempdir = cls._tmpdir.name cls.sino_data = get_data("mri_sino500.npz")["data"].astype(np.uint16) cls.data = cls.sino_data yield # teardown cls._tmpdir.cleanup() @pytest.mark.usefixtures("bootstrap") class TestTiff: def _check_tif_file(self, fname, expected_data, n_expected_images): with TiffReader(fname) as tif: assert len(tif.pages) == n_expected_images for i in range(n_expected_images): data_read = tif.pages[i].asarray() if expected_data.ndim == 3: expected_data_ = expected_data[i] else: expected_data_ = expected_data assert np.allclose(data_read, expected_data_) def test_2D(self): # TODO use start_index=None (by default) in TIFFWriter # if start_index is None: # url = DataUrl(file_path=dirname(self.fname), data_path="basename(self.fname).tiff", scheme="tifffile") # volume = TIFFVolume(volume_basename=basename(self.fname), data_url=url) # volume.data = img # volume.save() pytest.skip("Writing a single 2D tiff is disabled for now") fname = path.join(self.tempdir, "test_tiff2D.tif") data = np.arange(100 * 101, dtype="f").reshape((100, 101)) nabu_tif = TIFFWriter(fname) nabu_tif.write(data) self._check_tif_file(fname, data, 1) def test_3D_data_split_in_multiple_files(self): fname = path.join(self.tempdir, "test_tiff3D_single.tif") data = np.arange(11 * 100 * 101, dtype="f").reshape((11, 100, 101)) nabu_tif = TIFFWriter(fname, multiframe=False, start_index=500) nabu_tif.write(data) assert not (path.isfile(fname)), "found %s" % fname prefix, ext = path.splitext(fname) for i in range(data.shape[0]): curr_rel_fname = prefix + str("_%06d" % (i + nabu_tif.start_index)) + ext curr_fname = path.join(self.tempdir, curr_rel_fname) self._check_tif_file(curr_fname, data[i], 1) def test_3D_data_in_one_file(self): fname = path.join(self.tempdir, "test_tiff3D_multi.tif") data = np.arange(11 * 100 * 101, dtype="f").reshape((11, 100, 101)) nabu_tif = TIFFWriter(fname, multiframe=True) nabu_tif.write(data) assert path.isfile(fname) self._check_tif_file(fname, data, data.shape[0]) @pytest.mark.skipif(not (__have_jp2k__), reason="Need openjpeg2000/glymur for this test") @pytest.mark.usefixtures("bootstrap") class TestJP2: def _check_jp2_file(self, fname, expected_data, expected_psnr=None): data = Jp2k(fname)[:] if expected_psnr is None: assert np.allclose(data, expected_data) else: computed_psnr = psnr(data, expected_data) assert np.abs(computed_psnr - expected_psnr) < 1 def test_2D_lossless(self): data = get_data("mri_sino500.npz")["data"].astype(np.uint16) fname = path.join(self.tempdir, "sino500.jp2") nabu_jp2 = JP2Writer(fname, psnr=[0]) nabu_jp2.write(data) self._check_jp2_file(fname, data) def test_2D_lossy(self): fname = path.join(self.tempdir, "sino500_lossy.jp2") nabu_jp2 = JP2Writer(fname, psnr=[80]) nabu_jp2.write(self.sino_data) self._check_jp2_file(fname, self.sino_data, expected_psnr=80) def test_3D(self): fname = path.join(self.tempdir, "sino500_multi.jp2") n_images = 5 data = np.tile(self.sino_data, (n_images, 1, 1)) for i in range(n_images): data[i] += i nabu_jp2 = JP2Writer(fname, start_index=10) nabu_jp2.write(data) assert not (path.isfile(fname)) prefix, ext = path.splitext(fname) for i in range(data.shape[0]): curr_rel_fname = prefix + str("_%06d" % (i + nabu_jp2.start_index)) + ext curr_fname = path.join(self.tempdir, curr_rel_fname) self._check_jp2_file(curr_fname, data[i]) @pytest.fixture(scope="class") def bootstrap_h5(request): cls = request.cls cls._tmpdir = TemporaryDirectory(prefix="nabu_") cls.tempdir = cls._tmpdir.name cls.data = get_data("mri_sino500.npz")["data"].astype(np.uint16) cls.h5_config = { "key1": "value1", "some_int": 1, "some_float": 1.0, "some_dict": { "numpy_array": np.ones((5, 6), dtype="f"), "key2": "value2", }, } yield # teardown cls._tmpdir.cleanup() @pytest.mark.usefixtures("bootstrap_h5") class TestNXWriter: def test_write_simple(self): fname = path.join(self.tempdir, "sino500.h5") writer = NXProcessWriter(fname, entry="entry0000") writer.write(self.data, "test_write_simple") def test_write_with_config(self): fname = path.join(self.tempdir, "sino500_cfg.h5") writer = NXProcessWriter(fname, entry="entry0000") writer.write(self.data, "test_write_with_config", config=self.h5_config) def test_overwrite(self): fname = path.join(self.tempdir, "sino500_overwrite.h5") writer = NXProcessWriter(fname, entry="entry0000", overwrite=True) writer.write(self.data, "test_overwrite", config=self.h5_config) writer2 = NXProcessWriter(fname, entry="entry0001", overwrite=True) writer2.write(self.data, "test_overwrite", config=self.h5_config) # overwrite entry0000 writer3 = NXProcessWriter(fname, entry="entry0000", overwrite=True) new_data = self.data.copy() new_data += 1 new_config = self.h5_config.copy() new_config["key1"] = "modified value" writer3.write(new_data, "test_overwrite", config=new_config) res = import_h5_to_dict(fname, "/") assert "entry0000" in res assert "entry0001" in res assert np.allclose(res["entry0000"]["test_overwrite"]["results"]["data"], self.data + 1) rec_cfg = res["entry0000"]["test_overwrite"]["configuration"] assert rec_cfg["key1"] == "modified value" def test_no_overwrite(self): fname = path.join(self.tempdir, "sino500_no_overwrite.h5") writer = NXProcessWriter(fname, entry="entry0000", overwrite=False) writer.write(self.data, "test_no_overwrite") writer2 = NXProcessWriter(fname, entry="entry0000", overwrite=False) with pytest.raises((RuntimeError, OSError)) as ex: writer2.write(self.data, "test_no_overwrite") message = "Error should have been raised for trying to overwrite, but got the following: %s" % str(ex.value) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/tiffwriter_zmm.py0000644000175000017500000000740100000000000017204 0ustar00pierrepierrefrom . import * from .writer import TIFFWriter as StandardTIFFWriter from os import path from tifffile import TiffWriter import numpy as np class TIFFWriter(StandardTIFFWriter): # pylint: disable=E0102 def __init__( self, fname, multiframe=False, start_index=0, heights_above_stage_mm=None, filemode=None, append=False, big_tiff=None, ): """ Tiff writer. Parameters ----------- fname: str Path to the output file name multiframe: bool, optional Whether to write all data in one single file. Default is False. start_index: int, optional When writing a stack of images, each image is written in a dedicated file (unless multiframe is set to True). In this case, the output is a series of files `filename_0000.tif`, `filename_0001.tif`, etc. This parameter is the starting index for file names. This option is ignored when multiframe is True. heights_above_stage_mm: None or a list of heights if this parameters is given, the file names will be indexed with the height filemode: str, optional DEPRECATED. Will be ignored. Please refer to 'append' append: bool, optional Whether to append data to the file rather than overwriting. Default is False. big_tiff: bool, optional Whether to write in "big tiff" format: https://www.awaresystems.be/imaging/tiff/bigtiff.html Default is True when multiframe is True. Note that default "standard" tiff cannot exceed 4 GB. Notes ------ If multiframe is False (default), then each image will be written in a dedicated tiff file. """ super().__init__( fname, multiframe=multiframe, start_index=start_index, filemode=filemode, append=append, big_tiff=big_tiff ) self.heights_above_stage_mm = heights_above_stage_mm def _write_tiff(self, data, config=None, filename=None): # TODO metadata filename = filename or self.fname with TiffWriter(filename, bigtiff=self.big_tiff, append=self.append) as tif: tif.write(data) def write(self, data, *args, config=None, **kwargs): # Single image, or multiple image in the same file if self.multiframe: self._write_tiff(data, config=config) # Multiple image, one file per image else: if len(data.shape) == 2: data = np.array([data]) dirname, rel_filename = path.split(self.fname) prefix, ext = path.splitext(rel_filename) for i in range(data.shape[0]): if self.heights_above_stage_mm is None: curr_rel_filename = prefix + str("_%06d" % (self.start_index + i)) + ext else: value_mm = self.heights_above_stage_mm[i] if value_mm < 0: sign = "-" value_mm = -value_mm else: sign = "" part_mm = int(value_mm) rest_um = (value_mm - part_mm) * 1000 part_um = int(rest_um) rest_nm = (rest_um - part_um) * 1000 part_nm = int(rest_nm) curr_rel_filename = prefix + "_{}{:06d}p{:03d}{:03d}".format(sign, part_mm, part_um, part_nm) + ext fname = path.join(dirname, curr_rel_filename) self._write_tiff(data[i], filename=fname, config=None) def get_filename(self): if self.multiframe: return self.fname else: return path.dirname(self.fname) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/io/utils.py0000644000175000017500000002124300000000000015274 0ustar00pierrepierreimport os from typing import Optional import warnings import contextlib import h5py import numpy as np from silx.io.url import DataUrl from tomoscan.volumebase import VolumeBase from tomoscan.esrf import EDFVolume, HDF5Volume, TIFFVolume, JP2KVolume, MultiTIFFVolume from tomoscan.io import HDF5File # This function might be moved elsewhere def get_compacted_dataslices(urls, subsampling=None): """ Regroup urls to get the data more efficiently. Build a structure mapping files indices to information on how to load the data: `{indices_set: data_location}` where `data_location` contains contiguous indices. Parameters ----------- urls: dict Dictionary where the key is an integer and the value is a silx `DataUrl`. subsampling: int, optional Subsampling factor when reading the frames. If an integer `n` is provided, then one frame out of `n` will be read. Returns -------- merged_urls: dict Dictionary with the same keys as the `urls` parameter, and where the values are the corresponding `silx.io.url.DataUrl` with merged data_slice. """ subsampling = subsampling or 1 def _convert_to_slice(idx): if np.isscalar(idx): return slice(idx, idx + 1) # otherwise, assume already slice object return idx def is_contiguous_slice(slice1, slice2, step=1): if np.isscalar(slice1): slice1 = slice(slice1, slice1 + step) if np.isscalar(slice2): slice2 = slice(slice2, slice2 + step) return slice2.start == slice1.stop def merge_slices(slice1, slice2, step=1): return slice(slice1.start, slice2.stop, step) if len(urls) == 0: return urls sorted_files_indices = sorted(urls.keys()) idx0 = sorted_files_indices[0] first_url = urls[idx0] merged_indices = [[idx0]] # location = (file_path, data_path, slice) data_location = [[first_url.file_path(), first_url.data_path(), _convert_to_slice(first_url.data_slice())]] pos = 0 curr_fp, curr_dp, curr_slice = data_location[pos] for idx in sorted_files_indices[1:]: url = urls[idx] next_slice = _convert_to_slice(url.data_slice()) if ( (url.file_path() == curr_fp) and (url.data_path() == curr_dp) and is_contiguous_slice(curr_slice, next_slice, step=subsampling) ): merged_indices[pos].append(idx) merged_slices = merge_slices(curr_slice, next_slice, step=subsampling) data_location[pos][-1] = merged_slices curr_slice = merged_slices else: # "jump" pos += 1 merged_indices.append([idx]) data_location.append([url.file_path(), url.data_path(), _convert_to_slice(url.data_slice())]) curr_fp, curr_dp, curr_slice = data_location[pos] # Format result res = {} for ind, dl in zip(merged_indices, data_location): res.update(dict.fromkeys(ind, DataUrl(file_path=dl[0], data_path=dl[1], data_slice=dl[2]))) return res def get_first_hdf5_entry(fname): with HDF5File(fname, "r") as fid: entry = list(fid.keys())[0] return entry def hdf5_entry_exists(fname, entry): with HDF5File(fname, "r") as fid: res = fid.get(entry, None) is not None return res def get_h5_value(fname, h5_path, default_ret=None): with HDF5File(fname, "r") as fid: try: val_ptr = fid[h5_path][()] except KeyError: val_ptr = default_ret return val_ptr def get_h5_str_value(dataset_ptr): """ Get a HDF5 field which can be bytes or str (depending on h5py version !). """ data = dataset_ptr[()] if isinstance(data, str): return data else: return bytes.decode(data) def create_dict_of_indices(images_stack, images_indices): """ From an image stack with the images indices, create a dictionary where each index is the image index, and the value is the corresponding image. Parameters ---------- images_stack: numpy.ndarray A 3D numpy array in the layout (n_images, n_y, n_x) images_indices: array or list of int Array containing the indices of images in the stack Examples -------- Given a simple array stack: >>> images_stack = np.arange(3*4*5).reshape((3,4,5)) ... images_indices = [2, 7, 1] ... create_dict_of_indices(images_stack, images_indices) ... # returns {2: array1, 7: array2, 1: array3} """ if images_stack.ndim != 3: raise ValueError("Expected a 3D array") if len(images_indices) != images_stack.shape[0]: raise ValueError("images_stack must have as many images as the length of images_indices") res = {} for i in range(len(images_indices)): res[images_indices[i]] = images_stack[i] return res def convert_dict_values(dic, val_replacements, bytes_tostring=False): """ Modify a dictionary to be able to export it with silx.io.dicttoh5 """ modified_dic = {} for key, value in dic.items(): if isinstance(key, int): # np.isscalar ? key = str(key) if isinstance(value, bytes) and bytes_tostring: value = bytes.decode(value.tostring()) if isinstance(value, dict): value = convert_dict_values(value, val_replacements, bytes_tostring=bytes_tostring) else: if isinstance(value, DataUrl): value = value.path() elif value.__hash__ is not None and value in val_replacements: value = val_replacements[value] modified_dic[key] = value return modified_dic class _BaseReader(contextlib.AbstractContextManager): def __init__(self, url: DataUrl): if not isinstance(url, DataUrl): raise TypeError("url should be an instance of DataUrl") if url.scheme() not in ("silx", "h5py"): raise ValueError("Valid scheme are silx and h5py") if url.data_slice() is not None: raise ValueError("Data slices are not managed. Data path should " "point to a bliss node (h5py.Group)") self._url = url self._file_handler = None def __exit__(self, *exc): return self._file_handler.close() class EntryReader(_BaseReader): """Context manager used to read a bliss node""" def __enter__(self): self._file_handler = HDF5File(filename=self._url.file_path(), mode="r") if self._url.data_path() == "": entry = self._file_handler else: entry = self._file_handler[self._url.data_path()] if not isinstance(entry, h5py.Group): raise ValueError("Data path should point to a bliss node (h5py.Group)") return entry class DatasetReader(_BaseReader): """Context manager used to read a bliss node""" def __enter__(self): self._file_handler = HDF5File(filename=self._url.file_path(), mode="r") entry = self._file_handler[self._url.data_path()] if not isinstance(entry, h5py.Dataset): raise ValueError("Data path ({}) should point to a dataset (h5py.Dataset)".format(self._url.path())) return entry # TODO: require some utils function to deduce type. And insure homogeneity. Might be moved in tomoscan ? def file_format_is_edf(file_format: str): return file_format.lower().lstrip(".") == "edf" def file_format_is_jp2k(file_format: str): return file_format.lower().lstrip(".") in ("jp2k", "jp2") def file_format_is_tiff(file_format: str): return file_format.lower().lstrip(".") in ("tiff", "tif") def file_format_is_hdf5(file_format: str): return file_format.lower().lstrip(".") in ("hdf5", "hdf", "nx", "nexus") def get_output_volume(location: str, file_prefix: Optional[str], file_format: str, multitiff=False) -> VolumeBase: # TODO: see strategy. what if user provide a .nx ... ? # this function should be more generic location, extension = os.path.splitext(location) if extension == "": extension = file_format if file_format_is_edf(extension): return EDFVolume(folder=location, volume_basename=file_prefix) elif file_format_is_jp2k(extension): return JP2KVolume(folder=location, volume_basename=file_prefix) elif file_format_is_hdf5(file_format=extension): if extension is None: if file_prefix is None: location = ".".join([location, extension]) else: location = os.path.join(location, ".".join([file_prefix, extension])) return HDF5Volume(file_path=location) elif file_format_is_tiff(extension): if multitiff: return MultiTIFFVolume(file_path=location) else: return TIFFVolume(folder=location, volume_basename=file_prefix) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/io/writer.py0000644000175000017500000006752700000000000015467 0ustar00pierrepierrefrom glob import glob from os import path, getcwd, chdir from posixpath import join as posix_join from datetime import datetime import numpy as np from h5py import VirtualSource, VirtualLayout from silx.io.dictdump import dicttoh5, dicttonx from silx.io.url import DataUrl try: from tomoscan.io import HDF5File except: from h5py import File as HDF5File from tomoscan.esrf import EDFVolume, HDF5Volume, TIFFVolume, MultiTIFFVolume, JP2KVolume, RawVolume from tomoscan.esrf.volume.jp2kvolume import has_glymur as __have_jp2k__ from .. import version as nabu_version from ..utils import merged_shape, deprecation_warning from ..misc.utils import rescale_data from .utils import convert_dict_values def get_datetime(): """ Function used by some writers to indicate the current date. """ return datetime.now().replace(microsecond=0).isoformat() class Writer: """ Base class for all writers. """ def __init__(self, fname): self.fname = fname def get_filename(self): return self.fname class TomoscanNXProcessWriter(Writer): """ A class to write Nexus file with a processing result - using tomoscan.volumes as a backend """ def __init__(self, fname, entry=None, filemode="a", overwrite=False): """ Initialize a NXProcessWriter. Parameters ----------- fname: str Path to the HDF5 file. entry: str, optional Entry in the HDF5 file. Default is "entry" """ super().__init__(fname) self._set_entry(entry) self._filemode = filemode # TODO: notify file mode is deprecated self.overwrite = overwrite def _set_entry(self, entry): self.entry = entry or "entry" data_path = posix_join("/", self.entry) self.data_path = data_path def _write_npadday(self, result, volume, nx_info): if result.ndim == 2: result = result.reshape(1, result.shape[0], result.shape[1]) volume.data = result self._update_volume_metadata(volume) volume.save() results_path = posix_join(nx_info["nx_process_path"], "results", nx_info["data_name"]) process_name = nx_info["process_name"] process_info = nx_info["process_info"] process_info.update( { f"{process_name}/results@NX_class": "NXdata", f"{process_name}/results@signal": nx_info["data_name"], } ) if nx_info.get("is_frames_stack", True): process_info.update({f"{process_name}/results@interpretation": "image"}) if nx_info.get("direct_access", False): # prepare the direct access plots process_info.update( { f"{process_name}@default": "results", "@default": f"{process_name}/results", } ) return results_path def _write_dict(self, result, volume, nx_info): self._update_volume_metadata(volume) volume.save_metadata() # if result is a dictionary then we only have some metadata to be saved results_path = posix_join(nx_info["nx_process_path"], "results") proc_result_key = posix_join(nx_info["process_name"], "results") proc_result = convert_dict_values(result, {None: "None"}) process_info = nx_info["process_info"] process_info.update({proc_result_key: proc_result}) return results_path def _write_virtual_layout(self, result, volume, nx_info): # TODO: add test on tomoscan to ensure this use case is handled volume.data = result self._update_volume_metadata(volume) volume.save() results_path = posix_join(nx_info["nx_process_path"], "results", nx_info["data_name"]) return results_path @staticmethod def _update_volume_metadata(volume): if volume.metadata is not None: volume.metadata = convert_dict_values( volume.metadata, {None: "None"}, ) def write( self, result, process_name, processing_index=0, config=None, data_name="data", is_frames_stack=True, direct_access=True, ) -> str: """ Write the result in the current NXProcess group. Parameters ---------- result: numpy.ndarray Array containing the processing result process_name: str Name of the processing processing_index: int Index of the processing (in a pipeline) config: dict, optional Dictionary containing the configuration. """ entry_path = self.data_path nx_process_path = "/".join([entry_path, process_name]) if config is not None: config.update({"@NX_class": "NXcollection"}) nabu_process_info = { "@NX_class": "NXentry", f"{process_name}@NX_class": "NXprocess", f"{process_name}/program": "nabu", f"{process_name}/version": nabu_version, f"{process_name}/date": get_datetime(), f"{process_name}/sequence_index": np.int32(processing_index), } # Create HDF5Volume object with initial information volume = HDF5Volume( data_url=DataUrl( file_path=self.fname, data_path=f"{nx_process_path}/results/{data_name}", scheme="silx", ), metadata_url=DataUrl( file_path=self.fname, data_path=f"{nx_process_path}/configuration", ), metadata=config, overwrite=self.overwrite, ) if isinstance(result, dict): write_method = self._write_dict elif isinstance(result, np.ndarray): write_method = self._write_npadday elif isinstance(result, VirtualLayout): write_method = self._write_virtual_layout else: raise TypeError(f"'result' must be a dict, numpy array or h5py.VirtualLayout, not {type(result)}") nx_info = { "process_name": process_name, "nx_process_path": nx_process_path, "process_info": nabu_process_info, "data_name": data_name, "is_frames_stack": is_frames_stack, "direct_access": direct_access, } results_path = write_method(result, volume, nx_info) dicttonx( nabu_process_info, h5file=self.fname, h5path=entry_path, update_mode="replace", mode="a", ) return results_path ################################################################################################### ## Nabu original code for NXProcessWriter - also works for non-3D data, does not depend on tomoscan ################################################################################################### def h5_write_object(h5group, key, value, overwrite=False, default_val=None): existing_val = h5group.get(key, default_val) if existing_val is not default_val: if not overwrite: raise OSError("Unable to create link (name already exists): %s" % h5group.name) else: h5group.pop(key) h5group[key] = value class NXProcessWriter(Writer): """ A class to write Nexus file with a processing result. """ def __init__(self, fname, entry=None, filemode="a", overwrite=False): """ Initialize a NXProcessWriter. Parameters ----------- fname: str Path to the HDF5 file. entry: str, optional Entry in the HDF5 file. Default is "entry" """ super().__init__(fname) self._set_entry(entry) self._filemode = filemode self.overwrite = overwrite def _set_entry(self, entry): self.entry = entry or "entry" data_path = posix_join("/", self.entry) self.data_path = data_path def write( self, result, process_name, processing_index=0, config=None, data_name="data", is_frames_stack=True, direct_access=True, ): """ Write the result in the current NXProcess group. Parameters ---------- result: numpy.ndarray Array containing the processing result process_name: str Name of the processing processing_index: int Index of the processing (in a pipeline) config: dict, optional Dictionary containing the configuration. """ swmr = self._filemode == "r" with HDF5File(self.fname, self._filemode, swmr=swmr) as fid: nx_entry = fid.require_group(self.data_path) if "NX_class" not in nx_entry.attrs: nx_entry.attrs["NX_class"] = "NXentry" nx_process = nx_entry.require_group(process_name) nx_process.attrs["NX_class"] = "NXprocess" metadata = { "program": "nabu", "version": nabu_version, "date": get_datetime(), "sequence_index": np.int32(processing_index), } for key, val in metadata.items(): h5_write_object(nx_process, key, val, overwrite=self.overwrite) if config is not None: export_dict_to_h5( config, self.fname, posix_join(nx_process.name, "configuration"), overwrite_data=True, mode="a" ) nx_process["configuration"].attrs["NX_class"] = "NXcollection" if isinstance(result, dict): results_path = posix_join(nx_process.name, "results") export_dict_to_h5(result, self.fname, results_path, overwrite_data=self.overwrite, mode="a") else: nx_data = nx_process.require_group("results") results_path = nx_data.name nx_data.attrs["NX_class"] = "NXdata" nx_data.attrs["signal"] = data_name results_data_path = posix_join(results_path, data_name) if self.overwrite and results_data_path in fid: del fid[results_data_path] if isinstance(result, VirtualLayout): nx_data.create_virtual_dataset(data_name, result) else: # assuming array-like nx_data[data_name] = result if is_frames_stack: nx_data[data_name].attrs["interpretation"] = "image" nx_data.attrs["signal"] = data_name # prepare the direct access plots if direct_access: nx_process.attrs["default"] = "results" if "default" not in nx_entry.attrs: nx_entry.attrs["default"] = posix_join(nx_process.name, "results") # Return the internal path to "results" return results_path # COMPAT. LegacyNXProcessWriter = NXProcessWriter # ######################################################################################## ######################################################################################## ######################################################################################## def export_dict_to_h5(dic, h5file, h5path, overwrite_data=True, mode="a"): """ Wrapper on top of silx.io.dictdump.dicttoh5 replacing None with "None" Parameters ----------- dic: dict Dictionary containing the options h5file: str File name h5path: str Path in the HDF5 file overwrite_data: bool, optional Whether to overwrite data when writing HDF5. Default is True mode: str, optional File mode. Default is "a" (append). """ modified_dic = convert_dict_values( dic, {None: "None"}, ) update_mode = {True: "modify", False: "add"}[bool(overwrite_data)] return dicttoh5(modified_dic, h5file=h5file, h5path=h5path, update_mode=update_mode, mode=mode) def create_virtual_layout(files_or_pattern, h5_path, base_dir=None, axis=0): """ Create a HDF5 virtual layout. Parameters ---------- files_or_pattern: str or list A list of file names, or a wildcard pattern. If a list is provided, it will not be sorted! This will have to be done before calling this function. h5_path: str Path inside the HDF5 input file(s) base_dir: str, optional Base directory when using relative file names. axis: int, optional Data axis to merge. Default is 0. """ prev_cwd = None if base_dir is not None: prev_cwd = getcwd() chdir(base_dir) if isinstance(files_or_pattern, str): files_list = glob(files_or_pattern) files_list.sort() else: # list files_list = files_or_pattern if files_list == []: raise ValueError("Nothing found as pattern %s" % files_or_pattern) virtual_sources = [] shapes = [] for fname in files_list: with HDF5File(fname, "r", swmr=True) as fid: shape = fid[h5_path].shape vsource = VirtualSource(fname, name=h5_path, shape=shape) virtual_sources.append(vsource) shapes.append(shape) total_shape = merged_shape(shapes, axis=axis) virtual_layout = VirtualLayout(shape=total_shape, dtype="f") start_idx = 0 for vsource, shape in zip(virtual_sources, shapes): n_imgs = shape[axis] # Perhaps there is more elegant if axis == 0: virtual_layout[start_idx : start_idx + n_imgs] = vsource elif axis == 1: virtual_layout[:, start_idx : start_idx + n_imgs, :] = vsource elif axis == 2: virtual_layout[:, :, start_idx : start_idx + n_imgs] = vsource else: raise ValueError("Only axis 0,1,2 are supported") # start_idx += n_imgs if base_dir is not None: chdir(prev_cwd) return virtual_layout def merge_hdf5_files( files_or_pattern, h5_path, output_file, process_name, output_entry=None, output_filemode="a", data_name="data", processing_index=0, config=None, base_dir=None, axis=0, overwrite=False, ): """ Parameters ----------- files_or_pattern: str or list A list of file names, or a wildcard pattern. If a list is provided, it will not be sorted! This will have to be done before calling this function. h5_path: str Path inside the HDF5 input file(s) output_file: str Path of the output file process_name: str Name of the process output_entry: str, optional Output HDF5 root entry (default is "/entry") output_filemode: str, optional File mode for output file. Default is "a" (append) processing_index: int, optional Processing index for the output file. Default is 0. config: dict, optional Dictionary describing the configuration needed to get the results. base_dir: str, optional Base directory when using relative file names. axis: int, optional Data axis to merge. Default is 0. overwrite: bool, optional Whether to overwrite already existing data in the final file. Default is False. """ if base_dir is not None: prev_cwd = getcwd() virtual_layout = create_virtual_layout(files_or_pattern, h5_path, base_dir=base_dir, axis=axis) nx_file = NXProcessWriter(output_file, entry=output_entry, filemode=output_filemode, overwrite=overwrite) nx_file.write( virtual_layout, process_name, processing_index=processing_index, config=config, data_name=data_name, is_frames_stack=True, ) if base_dir is not None and prev_cwd != getcwd(): chdir(prev_cwd) class TIFFWriter(Writer): def __init__(self, fname, multiframe=False, start_index=0, filemode=None, append=False, big_tiff=None): """ Tiff writer. Parameters ----------- fname: str Path to the output file name multiframe: bool, optional Whether to write all data in one single file. Default is False. start_index: int, optional When writing a stack of images, each image is written in a dedicated file (unless multiframe is set to True). In this case, the output is a series of files `filename_0000.tif`, `filename_0001.tif`, etc. This parameter is the starting index for file names. This option is ignored when multiframe is True. filemode: str, optional DEPRECATED. Will be ignored. Please refer to 'append' append: bool, optional Whether to append data to the file rather than overwriting. Default is False. big_tiff: bool, optional Whether to write in "big tiff" format: https://www.awaresystems.be/imaging/tiff/bigtiff.html Default is True when multiframe is True. Note that default "standard" tiff cannot exceed 4 GB. Notes ------ If multiframe is False (default), then each image will be written in a dedicated tiff file. """ super().__init__(fname) self.multiframe = multiframe self.start_index = start_index self.append = append if big_tiff is None: big_tiff = multiframe if multiframe and not big_tiff: # raise error ? print("big_tiff was set to False while multiframe was set to True. This will probably be problematic.") self.big_tiff = big_tiff # Compat. self.filemode = filemode if filemode is not None: deprecation_warning("Ignored parameter 'filemode'. Please use the 'append' parameter") def write(self, data, *args, config=None, **kwargs): ext = None if not isinstance(data, np.ndarray): raise TypeError(f"data is expected to be a numpy array and not {type(data)}") # Single image, or multiple image in the same file if self.multiframe: volume = MultiTIFFVolume( self.fname, data=data, metadata={ "config": config, }, append=self.append, ) file_path = self.fname # Multiple image, one file per image else: if data.ndim == 2: data = data.reshape(1, data.shape[0], data.shape[1]) file_path, ext = path.splitext(self.fname) volume = TIFFVolume( path.dirname(file_path), volume_basename=path.basename(file_path), data=data, metadata={ "config": config, }, start_index=self.start_index, data_extension=ext.lstrip("."), overwrite=True, ) volume.save() class EDFWriter(Writer): def __init__(self, fname, start_index=0, filemode="w"): """ EDF (ESRF Data Format) writer. Parameters ----------- fname: str Path to the output file name start_index: int, optional When writing a stack of images, each image is written in a dedicated file In this case, the output is a series of files `filename_0000.tif`, `filename_0001.edf`, etc. This parameter is the starting index for file names. """ super().__init__(fname) self.filemode = filemode self.start_index = start_index def write(self, data, *args, config=None, **kwargs): if not isinstance(data, np.ndarray): raise TypeError(f"data is expected to be a numpy array and not {type(data)}") header = { "software": "nabu", "data": get_datetime(), } if data.ndim == 2: data = data.reshape(1, data.shape[0], data.shape[1]) volume = EDFVolume(path.dirname(self.fname), data=data, start_index=self.start_index, header=header) volume.save() class JP2Writer(Writer): def __init__( self, fname, start_index=0, filemode="wb", psnr=None, cratios=None, auto_convert=True, float_clip_values=None, n_threads=None, overwrite=False, single_file=True, ): """ JPEG2000 writer. This class requires the python package `glymur` and the library `libopenjp2`. Parameters ----------- fname: str Path to the output file name start_index: int, optional When writing a stack of images, each image is written in a dedicated file The output is a series of files `filename_0000.tif`, `filename_0001.tif`, etc. This parameter is the starting index for file names. psnr: list of int, optional The PSNR (Peak Signal-to-Noise ratio) for each jpeg2000 layer. This defines a quality metric for lossy compression. The number "0" stands for lossless compression. cratios: list of int, optional Compression ratio for each jpeg2000 layer auto_convert: bool, optional Whether to automatically cast floating point data to uint16. Default is True. float_clip_values: tuple of floats, optional If set to a tuple of two values (min, max), then each image values will be clipped to these minimum and maximum values. n_threads: int, optional Number of threads to use for encoding. Default is the number of available threads. Needs libopenjpeg >= 2.4.0. """ super().__init__(fname) if not (__have_jp2k__): raise ValueError("Need glymur python package and libopenjp2 library") self.n_threads = n_threads # self.setup_multithread_encoding(n_threads=n_threads, what_if_not_available="ignore") self.filemode = filemode self.start_index = start_index self.single_file = single_file self.auto_convert = auto_convert if psnr is not None and np.isscalar(psnr): psnr = [psnr] self.psnr = psnr self.cratios = cratios self._vmin = None self._vmax = None self.overwrite = overwrite self.clip_float = False if float_clip_values is not None: self._float_clip_min, self._float_clip_max = float_clip_values self.clip_float = True def write(self, data, *args, **kwargs): if not isinstance(data, np.ndarray): raise TypeError(f"data is expected to be a numpy array and not {type(data)}") if data.ndim == 2: data = data.reshape(1, data.shape[0], data.shape[1]) if self.single_file and data.ndim == 3 and data.shape[0] == 1: # case we will have a single file as output data_url = DataUrl( file_path=path.dirname(self.fname), data_path=self.fname, scheme=JP2KVolume.DEFAULT_DATA_SCHEME, ) metadata_url = DataUrl( file_path=path.dirname(self.fname), data_path=f"{path.dirname(self.fname)}/{path.basename(self.fname)}_info.txt", scheme=JP2KVolume.DEFAULT_METADATA_SCHEME, ) volume_basename = None folder = None extension = None else: # case we need to save it as set of file file_path, ext = path.splitext(self.fname) data_url = None metadata_url = None volume_basename = path.basename(file_path) folder = path.dirname(self.fname) extension = ext.lstrip(".") volume = JP2KVolume( folder=folder, start_index=self.start_index, cratios=self.cratios, psnr=self.psnr, n_threads=self.n_threads, volume_basename=volume_basename, data_url=data_url, metadata_url=metadata_url, data_extension=extension, overwrite=self.overwrite, ) if data.dtype != np.uint16 and self.auto_convert: if self.clip_float: data = np.clip(data, self._float_clip_min, self._float_clip_max) data = rescale_data(data, 0, 65535, data_min=self._vmin, data_max=self._vmax) data = data.astype(np.uint16) volume.data = data config = kwargs.get("config", None) if config is not None: volume.metadata = {"config": config} volume.save() class NPYWriter(Writer): def write(self, result, *args, **kwargs): np.save(self.fname, result) class NPZWriter(Writer): def write(self, result, *args, **kwargs): save_args = {"result": result} config = kwargs.get("config", None) if config is not None: save_args["configuration"] = config np.savez(self.fname, **save_args) class HSTVolWriter(Writer): """ A writer to mimic PyHST2 ".vol" files """ def __init__(self, fname, append=False, **kwargs): super().__init__(fname) self.append = append self._vol_writer = RawVolume(fname, overwrite=True, append=append) @staticmethod def generate_metadata(data, **kwargs): n_z, n_y, n_x = data.shape metadata = { "NUM_X": n_x, "NUM_Y": n_y, "NUM_Z": n_z, "voxelSize": kwargs.get("voxelSize", 40.0), "BYTEORDER": "LOWBYTEFIRST", "ValMin": kwargs.get("ValMin", 0.0), "ValMax": kwargs.get("ValMin", 1.0), "s1": 0.0, "s2": 0.0, "S1": 0.0, "S2": 0.0, } return metadata @staticmethod def sanitize_metadata(metadata): # To be fixed in RawVolume for what in ["NUM_X", "NUM_Y", "NUM_Z"]: metadata[what] = int(metadata[what]) for what in ["voxelSize", "ValMin", "ValMax", "s1", "s2", "S1", "S2"]: metadata[what] = float(metadata[what]) def write(self, data, *args, config=None, **kwargs): existing_metadata = self._vol_writer.load_metadata() new_metadata = self.generate_metadata(data) if len(existing_metadata) == 0 or not (self.append): # first write or append==False metadata = new_metadata else: # append write ; update metadata metadata = existing_metadata.copy() self.sanitize_metadata(metadata) metadata["NUM_Z"] += new_metadata["NUM_Z"] self._vol_writer.data = data self._vol_writer.metadata = metadata self._vol_writer.save() # Also save .xml self._vol_writer.save_metadata( url=DataUrl( scheme="lxml", file_path=self._vol_writer.metadata_url.file_path().replace(".info", ".xml"), ) ) class HSTVolVolume(HSTVolWriter): """ An interface to HSTVolWriter with the same API than tomoscan.esrf.volume. This is really not ideal, see nabu:#381 """ def __init__(self, **kwargs): file_path = kwargs.pop("file_path", None) if file_path is None: raise ValueError("Missing mandatory 'file_path' parameter") super().__init__(file_path, append=kwargs.pop("append", False), **kwargs) self.data = None self.metadata = None def save(self): if self.data is None: raise ValueError("Must set data first") self.write(self.data) def save_metadata(self): pass # already done for HST part - proper metadata is not supported # Unused - kept for compat. Writers = { "h5": NXProcessWriter, "hdf5": NXProcessWriter, "nx": NXProcessWriter, "nexus": NXProcessWriter, "npy": NPYWriter, "npz": NPZWriter, "tif": TIFFWriter, "tiff": TIFFWriter, "j2k": JP2Writer, "jp2": JP2Writer, "jp2k": JP2Writer, "edf": EDFWriter, "vol": HSTVolWriter, } ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/misc/0000755000175000017500000000000000000000000014104 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/misc/__init__.py0000644000175000017500000000000000000000000016203 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/binning.py0000644000175000017500000000555000000000000016107 0ustar00pierrepierreimport numpy as np from ..utils import deprecated def binning(img, bin_factor, out_dtype=np.float32): """ Bin an image by a factor of "bin_factor". Parameters ---------- bin_factor: tuple of int Binning factor in each axis. out_dtype: dtype, optional Output data type. Default is float32. Notes ----- If the image original size is not a multiple of the binning factor, the last items (in the considered axis) will be dropped. The resulting shape is (img.shape[0] // bin_factor[0], img.shape[1] // bin_factor[1]) """ s = img.shape n0, n1 = bin_factor shp = (s[0] - (s[0] % n0), s[1] - (s[1] % n1)) sub_img = img[: shp[0], : shp[1]] out_shp = (shp[0] // n0, shp[1] // n1) res = np.zeros(out_shp, dtype=out_dtype) for i in range(n0): for j in range(n1): res[:] += sub_img[i::n0, j::n1] res /= n0 * n1 return res def binning_n_alt(img, bin_factor, out_dtype=np.float32): """ Alternate, "clever" but slower implementation """ n0, n1 = bin_factor new_shape = tuple(s - (s % n) for s, n in zip(img.shape, bin_factor)) sub_img = img[: new_shape[0], : new_shape[1]] img_view_4d = sub_img.reshape((new_shape[0] // n0, n0, new_shape[1] // n1, n1)) return img_view_4d.astype(out_dtype).mean(axis=1).mean(axis=-1) # # COMPAT. # @deprecated("Please use binning()", do_print=True) def binning2(img, out_dtype=np.float32): return binning(img, (2, 2), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning2_horiz(img, out_dtype=np.float32): return binning(img, (1, 2), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning2_vertic(img, out_dtype=np.float32): return binning(img, (2, 1), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3(img, out_dtype=np.float32): return binning(img, (3, 3), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3_horiz(img, out_dtype=np.float32): return binning(img, (1, 3), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def binning3_vertic(img, out_dtype=np.float32): return binning(img, (3, 1), out_dtype=out_dtype) @deprecated("Please use binning()", do_print=True) def get_binning_function(binning_factor): """ Determine the binning function to use. """ binning_functions = { (2, 2): binning2, (2, 1): binning2_vertic, (1, 2): binning2_horiz, (3, 3): binning3, (3, 1): binning3_vertic, (1, 3): binning3_horiz, (2, 3): None, # was a limitation (3, 2): None, # was a limitation } if binning_factor not in binning_functions: raise ValueError("Could not get a function for binning factor %s" % binning_factor) return binning_functions[binning_factor] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/filters.py0000644000175000017500000000113400000000000016125 0ustar00pierrepierreimport numpy as np import scipy.signal def correct_spikes(image, threshold): """ Perform a conditional median filtering The filtering is done in-place, meaning that the array content is modified. Parameters ---------- image: numpy.ndarray Image to filter threshold: float Median filter threshold """ m_im = scipy.signal.medfilt2d(image) fixed_part = np.array(image[[0, 0, -1, -1], [0, -1, 0, -1]]) where = abs(image - m_im) > threshold image[where] = m_im[where] image[[0, 0, -1, -1], [0, -1, 0, -1]] = fixed_part return image ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/misc/fourier_filters.py0000644000175000017500000001307300000000000017665 0ustar00pierrepierre# -*- coding: utf-8 -*- """ Fourier filters. """ from functools import lru_cache import numpy as np import scipy.special as spspe @lru_cache(maxsize=10) def get_lowpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64): """Computes a low pass filter using the erfc function. Parameters ---------- img_shape: tuple Shape of the image cutoff_par: float or sequence of two floats Position of the cut off in pixels, if a sequence is given the second float expresses the width of the transition region which is given as a fraction of the cutoff frequency. When only one float is given for this argument a gaussian is applied whose sigma is the parameter. When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff frequency while a smooth erfc transition to zero is done use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ if cutoff_par is None: return 1 elif isinstance(cutoff_par, (int, float)): cutoff_pix = cutoff_par cutoff_trans_fact = None else: try: cutoff_pix, cutoff_trans_fact = cutoff_par except ValueError: raise ValueError( "Argument cutoff_par (which specifies the pass filter shape) must be either a scalar or a" " sequence of two scalars" ) if (not isinstance(cutoff_pix, (int, float))) or (not isinstance(cutoff_trans_fact, (int, float))): raise ValueError( "Argument cutoff_par (which specifies the pass filter shape) must be one number or a sequence" "of two numbers" ) coords = [np.fft.fftfreq(s, 1) for s in img_shape] coords = np.meshgrid(*coords, indexing="ij") r = np.sqrt(np.sum(np.array(coords, dtype=data_type) ** 2, axis=0)) if cutoff_trans_fact is not None: k_cut = 0.5 / cutoff_pix k_cut_width = k_cut * cutoff_trans_fact k_pos_rescaled = (r - k_cut) / k_cut_width res = spspe.erfc(k_pos_rescaled) / 2 else: res = np.exp(-(np.pi**2) * (r**2) * (cutoff_pix**2) * 2) # Making sure to force result to chosen data type res = res.astype(data_type) if use_rfft: slicelist = [slice(None)] * (len(res.shape) - 1) + [slice(0, res.shape[-1] // 2 + 1)] return res[tuple(slicelist)] else: return res def get_highpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64): """Computes a high pass filter using the erfc function. Parameters ---------- img_shape: tuple Shape of the image cutoff_par: float or sequence of two floats Position of the cut off in pixels, if a sequence is given the second float expresses the width of the transition region which is given as a fraction of the cutoff frequency. When only one float is given for this argument a gaussian is applied whose sigma is the parameter, and the result is subtracted from 1 to obtain the high pass filter When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff frequency and then a smooth transition to zero is done for smaller frequency use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ if cutoff_par is None: return 1 else: return 1 - get_lowpass_filter(img_shape, cutoff_par, use_rfft=use_rfft, data_type=data_type) def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None, use_rfft=False, data_type=np.float64): """Computes a band pass filter using the erfc function. The cutoff structures should be formed as follows: - tuple of two floats: the first indicates the cutoff frequency, the second \ determines the width of the transition region, as fraction of the cutoff frequency. - one float -> it represents the sigma of a gaussian which acts as a filter or anti-filter (1 - filter). Parameters ---------- img_shape: tuple Shape of the image cutoff_lowpass: float or sequence of two floats Cutoff parameters for the low-pass filter cutoff_highpass: float or sequence of two floats Cutoff parameters for the high-pass filter use_rfft: boolean, optional Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False. data_type: `numpy.dtype`, optional Specifies the data type of the computed filter. It defaults to `numpy.float64` Raises ------ ValueError In case of malformed cutoff_par Returns ------- numpy.array_like The computed filter """ return get_lowpass_filter( img_shape, cutoff_par=cutoff_lowpass, use_rfft=use_rfft, data_type=data_type ) * get_highpass_filter(img_shape, cutoff_par=cutoff_highpass, use_rfft=use_rfft, data_type=data_type) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/histogram.py0000644000175000017500000002606400000000000016463 0ustar00pierrepierrefrom math import log2, ceil import numpy as np from silx.math import Histogramnd from tomoscan.io import HDF5File from ..utils import check_supported from ..resources.logger import LoggerOrPrint class PartialHistogram: """ A class for computing histogram progressively. In certain cases, it is cumbersome to compute a histogram directly on a big chunk of data (ex. data not fitting in memory, disk access too slow) while some parts of the data are readily available in-memory. """ histogram_methods = ["fixed_bins_width", "fixed_bins_number"] bin_width_policies = ["uint16"] backends = ["numpy", "silx"] def __init__(self, method="fixed_bins_width", bin_width="uint16", num_bins=None, min_bins=None, backend="silx"): """ Initialize a PartialHistogram class. Parameters ---------- method: str, optional Partial histogram computing method. Available are: - `fixed_bins_width`: all the histograms are computed with the same bin width. The class adapts to the data range and computes the number of bins accordingly. - `fixed_bins_number`: all the histograms are computed with the same number of bins. The class adapts to the data range and computes the bin width accordingly. Default is "fixed_bins_width" bin_width: str or float, optional Policy for histogram bins when method="fixed_bins_width". Available are: - "uint16": The bin width is computed so that floating-point elements `f1` and `f2` satisfying `|f1 - f2| < bin_width` implies `f1_converted - f2_converted < 1` once cast to uint16. - A number: all the bins have this fixed width. Default is "uint16" num_bins: int, optional Number of bins when method = 'fixed_bins_number'. min_bins: int, optional Minimum number of bins when method = 'fixed_bins_width'. backend: str, optional Which histogram backend to use for computations. Available are "silx", "numpy". Fastest is "silx". """ check_supported(method, self.histogram_methods, "histogram computing method") self.method = method check_supported(backend, self.backends, "histogram backend") self.backend = backend self._set_bin_width(bin_width) self._set_num_bins(num_bins) self.min_bins = min_bins self._set_histogram_methods() def _set_bin_width(self, bin_width): if self.method == "fixed_bins_number": self.bin_width = None return if isinstance(bin_width, str): check_supported(bin_width, self.bin_width_policies, "bin width policy") self._fixed_bw = False else: bin_width = float(bin_width) self._fixed_bw = True self.bin_width = bin_width def _set_num_bins(self, num_bins): if self.method == "fixed_bins_width": self.num_bins = None return if self.method == "fixed_bins_number" and num_bins is None: raise ValueError("Need to specify num_bins for method='fixed_bins_number'") self.num_bins = int(num_bins) def _set_histogram_methods(self): self._histogram_methods = { "fixed_bins_number": { "compute": self._compute_histogram_fixed_nbins, "merge": self._merge_histograms_fixed_nbins, }, "fixed_bins_width": { "compute": self._compute_histogram_fixed_bw, "merge": self._merge_histograms_fixed_bw, }, } assert set(self._histogram_methods.keys()) == set(self.histogram_methods) @staticmethod def _get_histograms_and_bins(histograms, center=False, dont_truncate_bins=False): histos = [h[0] for h in histograms] if dont_truncate_bins: bins = [h[1] for h in histograms] else: if center: bins = [0.5 * (h[1][1:] + h[1][:-1]) for h in histograms] else: bins = [h[1][:-1] for h in histograms] return histos, bins # # Histogram with fixed number of bins # def _compute_histogram_fixed_nbins(self, data, data_range=None): if data.ndim > 1: data = data.ravel() dmin, dmax = data.min(), data.max() if data_range is None else data_range if self.backend == "numpy": res = np.histogram(data, bins=self.num_bins) elif self.backend == "silx": histogrammer = Histogramnd(data, n_bins=self.num_bins, histo_range=(dmin, dmax), last_bin_closed=True) res = histogrammer.histo, histogrammer.edges[0] # pylint: disable=E1136 return res def _merge_histograms_fixed_nbins(self, histograms, dont_truncate_bins=False): histos, bins = self._get_histograms_and_bins(histograms, dont_truncate_bins=dont_truncate_bins) res = np.histogram( np.hstack(bins), weights=np.hstack(histos), bins=self.num_bins, ) return res # # Histogram with fixed bin width # def _bin_width_u16(self, dmin, dmax): return (dmax - dmin) / 65535.0 def _bin_width_fixed(self, dmin, dmax): return self.bin_width def get_bin_width(self, dmin, dmax): if self._fixed_bw: return self._bin_width_fixed(dmin, dmax) elif self.bin_width == "uint16": return self._bin_width_u16(dmin, dmax) else: raise ValueError() def _compute_histogram_fixed_bw(self, data, data_range=None): dmin, dmax = data.min(), data.max() if data_range is None else data_range min_bins = self.min_bins or 1 bw_max = self.get_bin_width(dmin, dmax) nbins = 0 bw_factor = 1 while nbins < min_bins: bw = 2 ** round(log2(bw_max)) / bw_factor nbins = int((dmax - dmin) / bw) bw_factor *= 2 res = np.histogram(data, bins=nbins) return res def _merge_histograms_fixed_bw(self, histograms, **kwargs): histos, bins = self._get_histograms_and_bins(histograms, center=False) dmax = max([b[-1] for b in bins]) dmin = min([b[0] for b in bins]) bw_max = max([b[1] - b[0] for b in bins]) res = np.histogram(np.hstack(bins), weights=np.hstack(histos), bins=int((dmax - dmin) / bw_max)) return res # # Dispatch methods # def compute_histogram(self, data, data_range=None): compute_hist_func = self._histogram_methods[self.method]["compute"] return compute_hist_func(data, data_range=data_range) def merge_histograms(self, histograms, **kwargs): merge_hist_func = self._histogram_methods[self.method]["merge"] return merge_hist_func(histograms, **kwargs) class VolumeHistogram: """ A class for computing the histogram of an entire volume. Unless explicitly specified, histogram is computed in several passes so that not all the volume is loaded in memory. """ def __init__(self, data_url, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None): """ Initialize a VolumeHistogram object. Parameters ---------- fname: DataUrl DataUrl to the HDF5 file. chunk_size_slices: int, optional Compute partial histograms of groups of slices. This is the default behavior, where the groups size is 100 slices. This parameter is mutually exclusive with 'chunk_size_GB'. chunk_size_GB: float, optional Maximum memory (in GB) to use when computing the histogram by group of slices. This parameter is mutually exclusive with 'chunk_size_slices'. nbins: int, optional Histogram number of bins. Default is 1e6. """ self.data_url = data_url self.logger = LoggerOrPrint(logger) self._get_data_info() self._set_chunk_size(chunk_size_slices, chunk_size_GB) self.nbins = int(nbins) self._init_histogrammer() def _get_data_info(self): self.fname = self.data_url.file_path() self.data_path = self.data_url.data_path() with HDF5File(self.fname, "r") as fid: try: data_ptr = fid[self.data_path] except KeyError: msg = str( "Could not access HDF5 path %s in file %s. Please check that this file \ actually contains a reconstruction and that the HDF5 path is correct" % (self.data_path, self.fname) ) self.logger.fatal(msg) raise ValueError(msg) if data_ptr.ndim != 3: msg = "Expected data to have 3 dimensions, got %d" % data_ptr.ndim raise ValueError(msg) self.data_shape = data_ptr.shape self.data_dtype = data_ptr.dtype self.data_nbytes_GB = np.prod(data_ptr.shape) * data_ptr.dtype.itemsize / 1e9 def _set_chunk_size(self, chunk_size_slices, chunk_size_GB): if not ((chunk_size_slices is not None) ^ (chunk_size_GB is not None)): raise ValueError("Please specify either chunk_size_slices or chunk_size_GB") if chunk_size_slices is None: chunk_size_slices = int(chunk_size_GB / (np.prod(self.data_shape[1:]) * self.data_dtype.itemsize / 1e9)) self.chunk_size = chunk_size_slices self.logger.debug("Computing histograms by groups of %d slices" % self.chunk_size) def _init_histogrammer(self): self.histogrammer = PartialHistogram(method="fixed_bins_number", num_bins=self.nbins) def _compute_histogram(self, data): return self.histogrammer.compute_histogram(data.ravel()) # 1D def compute_volume_histogram(self): n_z = self.data_shape[0] histograms = [] n_steps = ceil(n_z / self.chunk_size) with HDF5File(self.fname, "r") as fid: for chunk_id in range(n_steps): self.logger.debug("Computing histogram %d/%d" % (chunk_id + 1, n_steps)) z_slice = slice(chunk_id * self.chunk_size, (chunk_id + 1) * self.chunk_size) images_stack = fid[self.data_path][z_slice, :, :] hist = self._compute_histogram(images_stack) histograms.append(hist) res = self.histogrammer.merge_histograms(histograms) return res def hist_as_2Darray(hist, center=True, dtype="f"): hist, bins = hist if bins.size != hist.size: # assert bins.size == hist.size +1 if center: bins = 0.5 * (bins[1:] + bins[:-1]) else: bins = bins[:-1] res = np.zeros((2, hist.size), dtype=dtype) res[0] = hist res[1] = bins.astype(dtype) return res def add_last_bin(histo_bins): """ Add the last bin (max value) to a list of bin edges. """ res = np.zeros(histo_bins.size + 1, dtype=histo_bins.dtype) res[:-1] = histo_bins[:] res[-1] = res[-2] + (res[1] - res[0]) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/histogram_cuda.py0000644000175000017500000000625300000000000017455 0ustar00pierrepierreimport numpy as np from ..utils import get_cuda_srcfile, updiv from ..cuda.utils import __has_pycuda__ from .histogram import PartialHistogram, VolumeHistogram if __has_pycuda__: import pycuda.gpuarray as garray from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing class CudaPartialHistogram(PartialHistogram): def __init__( self, method="fixed_bins_number", bin_width="uint16", num_bins=None, min_bins=None, cuda_options=None, ): if method == "fixed_bins_width": raise NotImplementedError("Histogram with fixed bins width is not implemented with the Cuda backend") super().__init__( method=method, bin_width=bin_width, num_bins=num_bins, min_bins=min_bins, ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_histogram() def _init_cuda_histogram(self): self.cuda_hist = CudaKernel( "histogram", filename=get_cuda_srcfile("histogram.cu"), signature="PiiiffPi", ) self.d_hist = garray.zeros(self.num_bins, dtype=np.uint32) def _compute_histogram_fixed_nbins(self, data, data_range=None): if isinstance(data, np.ndarray): data = garray.to_gpu(data) if data_range is None: # Should be possible to do both in one single pass with ReductionKernel # and garray.vec.float2, but the last step in volatile shared memory # still gives errors. To be investigated... data_min = garray.min(data).get()[()] data_max = garray.max(data).get()[()] else: data_min, data_max = data_range Nz, Ny, Nx = data.shape block = (16, 16, 4) grid = ( updiv(Nx, block[0]), updiv(Ny, block[1]), updiv(Nz, block[2]), ) self.d_hist.fill(0) self.cuda_hist( data, Nx, Ny, Nz, data_min, data_max, self.d_hist, self.num_bins, grid=grid, block=block, ) # Return a result in the same format as numpy.histogram res_hist = self.d_hist.get() res_bins = np.linspace(data_min, data_max, num=self.num_bins + 1, endpoint=True) return res_hist, res_bins class CudaVolumeHistogram(VolumeHistogram): def __init__( self, data_url, chunk_size_slices=100, chunk_size_GB=None, nbins=1e6, logger=None, cuda_options=None, ): self.cuda_options = cuda_options super().__init__( data_url, chunk_size_slices=chunk_size_slices, chunk_size_GB=chunk_size_GB, nbins=nbins, logger=logger, ) def _init_histogrammer(self): self.histogrammer = CudaPartialHistogram( method="fixed_bins_number", num_bins=self.nbins, cuda_options=self.cuda_options, ) def _compute_histogram(self, data): return self.histogrammer.compute_histogram(data) # 3D ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/padding.py0000644000175000017500000000523100000000000016065 0ustar00pierrepierreimport math import numpy as np def pad_interpolate(im, padded_img_shape_vh, translation_vh=None, padding_mode="reflect"): """ This function produces a centered padded image and , optionally if translation_vh is set, performs a Fourier shift of the whole image. In case of translation, the image is first padded to a larger extent, that encompasses the final padded with plus a translation margin, and then translated, and final recut to the required padded width. The values are translated: if a feature appear at x in the original image it will appear at pad+translation+x in the final image. Parameters ------------ im: np.ndaray the input image translation_vh: a sequence of two float the vertical and horizontal shifts """ if translation_vh is not None: pad_extra = 2 ** (1 + np.ceil(np.log2(np.maximum(1, np.ceil(abs(np.array(translation_vh)))))).astype(np.int32)) else: pad_extra = [0, 0] origy, origx = im.shape rety = padded_img_shape_vh[0] + pad_extra[0] retx = padded_img_shape_vh[1] + pad_extra[1] xpad = [0, 0] xpad[0] = math.ceil((retx - origx) / 2) xpad[1] = retx - origx - xpad[0] ypad = [0, 0] ypad[0] = math.ceil((rety - origy) / 2) ypad[1] = rety - origy - ypad[0] y2 = origy - ypad[1] x2 = origx - xpad[1] if ypad[0] + 1 > origy or xpad[0] + 1 > origx or y2 < 1 or x2 < 1: raise ValueError("Too large padding for this reflect padding type") padded_im = np.pad(im, pad_width=((ypad[0], ypad[1]), (xpad[0], xpad[1])), mode=padding_mode) if translation_vh is not None: freqs_list = list(map(np.fft.fftfreq, padded_im.shape)) shifts_list = [np.exp(-2.0j * np.pi * freqs * trans) for (freqs, trans) in zip(freqs_list, translation_vh)] shifts_2D = shifts_list[0][:, None] * shifts_list[1][None, :] padded_im = np.fft.ifft2(np.fft.fft2(padded_im) * shifts_2D).real padded_im = recut(padded_im, padded_img_shape_vh) return padded_im def recut(im, new_shape_vh): """ This method implements a centered cut which reverts the centered padding applied in the present class. Parameters ----------- im: np.ndarray A 2D image. new_shape_vh: tuple The shape of the cutted image. Returns -------- The image cutted to new_shape_vh. """ new_shape_vh = np.array(new_shape_vh) old_shape_vh = np.array(im.shape) center_vh = (old_shape_vh - 1) / 2 start_vh = np.round(0.5 + center_vh - new_shape_vh / 2).astype(np.int32) end_vh = start_vh + new_shape_vh return im[start_vh[0] : end_vh[0], start_vh[1] : end_vh[1]] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/misc/rotation.py0000644000175000017500000000350000000000000016313 0ustar00pierrepierreimport numpy as np try: from skimage.transform import rotate __have__skimage__ = True except ImportError: __have__skimage__ = False class Rotation: supported_modes = { "constant": "constant", "zeros": "constant", "edge": "edge", "edges": "edge", "symmetric": "symmetric", "sym": "symmetric", "reflect": "reflect", "wrap": "wrap", "periodic": "wrap", } def __init__(self, shape, angle, center=None, mode="edge", reshape=False, **sk_kwargs): """ Initiate a Rotation object. Parameters ---------- shape: tuple of int Shape of the images to process angle: float Rotation angle in DEGREES center: tuple of float, optional Coordinates of the center of rotation, in the format (X, Y) (mind the non-python convention !). Default is ((Nx - 1)/2.0, (Ny - 1)/2.0) mode: str, optional Padding mode. Default is "edge". reshape: bool, optional Other Parameters ----------------- All the other parameters are passed directly to scikit image 'rotate' function: order, cval, clip, preserve_range. """ self.shape = shape self.angle = angle self.center = center self.mode = mode self.reshape = reshape self.sk_kwargs = sk_kwargs def rotate(self, img, output=None): if not __have__skimage__: raise ValueError("scikit-image is needed for using rotate()") res = rotate(img, self.angle, resize=self.reshape, center=self.center, mode=self.mode, **self.sk_kwargs) if output is not None: output[:] = res[:] return output else: return res __call__ = rotate ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/rotation_cuda.py0000644000175000017500000000533700000000000017321 0ustar00pierrepierreimport numpy as np from .rotation import Rotation from ..utils import get_cuda_srcfile, updiv from ..cuda.utils import __has_pycuda__, copy_array from ..cuda.processing import CudaProcessing if __has_pycuda__: from ..cuda.kernel import CudaKernel import pycuda.gpuarray as garray import pycuda.driver as cuda class CudaRotation(Rotation): def __init__(self, shape, angle, center=None, mode="edge", reshape=False, cuda_options=None, **sk_kwargs): if center is None: center = ((shape[1] - 1) / 2.0, (shape[0] - 1) / 2.0) super().__init__(shape, angle, center=center, mode=mode, reshape=reshape, **sk_kwargs) self._init_cuda_rotation(cuda_options) def _init_cuda_rotation(self, cuda_options): cuda_options = cuda_options or {} self.cuda_processing = CudaProcessing(**cuda_options) self._allocate_arrays() self._init_rotation_kernel() def _allocate_arrays(self): self._d_image_cua = cuda.np_to_array(np.zeros(self.shape, "f"), "C") self.cuda_processing.init_arrays_to_none(["d_output"]) def _init_rotation_kernel(self): self.cuda_rotation_kernel = CudaKernel("rotate", get_cuda_srcfile("rotation.cu")) self.texref_image = self.cuda_rotation_kernel.module.get_texref("tex_image") self.texref_image.set_filter_mode(cuda.filter_mode.LINEAR) # bilinear self.texref_image.set_address_mode(0, cuda.address_mode.CLAMP) # TODO tune self.texref_image.set_address_mode(1, cuda.address_mode.CLAMP) # TODO tune self.cuda_rotation_kernel.prepare("Piiffff", [self.texref_image]) self.texref_image.set_array(self._d_image_cua) self._cos_theta = np.cos(np.deg2rad(self.angle)) self._sin_theta = np.sin(np.deg2rad(self.angle)) self._Nx = np.int32(self.shape[1]) self._Ny = np.int32(self.shape[0]) self._center_x = np.float32(self.center[0]) self._center_y = np.float32(self.center[1]) self._block = (32, 32, 1) # tune ? self._grid = (updiv(self.shape[1], self._block[1]), updiv(self.shape[0], self._block[0]), 1) def rotate(self, img, output=None, do_checks=True): copy_array(self._d_image_cua, img, check=do_checks) if output is not None: d_out = output else: self.cuda_processing.allocate_array("d_output", self.shape, np.float32) d_out = self.cuda_processing.d_output self.cuda_rotation_kernel( d_out, self._Nx, self._Ny, self._cos_theta, self._sin_theta, self._center_x, self._center_y, grid=self._grid, block=self._block, ) return d_out __call__ = rotate ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/misc/tests/0000755000175000017500000000000000000000000015246 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/misc/tests/__init__.py0000644000175000017500000000000100000000000017346 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/tests/test_binning.py0000644000175000017500000000353400000000000020310 0ustar00pierrepierrefrom itertools import product import numpy as np import pytest from nabu.misc.binning import * @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = np.arange(100 * 99, dtype=np.uint16).reshape(100, 99) cls.tol = 1e-5 @pytest.mark.usefixtures("bootstrap") class TestBinning: def testBinning(self): """ Test the general-purpose binning function with an image defined by its indices. The test "image" is an array where entry [i, j] equals i * Nx + j (where Nx = array.shape[1]). Let (b_i, b_j) be the binning factor along each dimension, then the binned array at position [p_i, p_j] is equal to 1/(b_i*b_j) * Sum(Sum(Nx*i + j, (j, p_j, p_j+b_j-1)), (i, p_i, p_i+b_i-1)) which happens to be equal to (Nx * b_i + 2*Nx * p_i - Nx + b_j + 2*p_j - 1) /2 """ def get_reference_binned_image(img_shape, bin_factor): # reference[p_i, p_j] = 0.5 * (Nx * b_i + 2*Nx * p_i - Nx + b_j + 2*p_j - 1) Ny, Nx = img_shape img_shape_reduced = tuple(s - (s % b) for s, b in zip(img_shape, bin_factor)) b_i, b_j = bin_factor inds_i, inds_j = np.indices(img_shape_reduced) p_i = inds_i[::b0, ::b1] p_j = inds_j[::b0, ::b1] return 0.5 * (Nx * b_i + 2 * Nx * p_i - Nx + b_j + 2 * p_j - 1) # Various test settings binning_factors = [2, 3, 4, 5, 6, 8, 10] n_items = [63, 64, 65, 66, 125, 128, 130] # Yep, that's 2401 tests... params = product(n_items, n_items, binning_factors, binning_factors) for s0, s1, b0, b1 in params: img = np.arange(s0 * s1).reshape((s0, s1)) ref = get_reference_binned_image(img.shape, (b0, b1)) res = binning(img, (b0, b1)) assert np.allclose(res, ref) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/tests/test_histogram.py0000644000175000017500000000442000000000000020654 0ustar00pierrepierrefrom os import path from tempfile import mkdtemp import pytest import numpy as np from nabu.testutils import get_data from nabu.misc.histogram import PartialHistogram from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: from nabu.misc.histogram_cuda import CudaPartialHistogram import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("mri_rec_astra.npz")["data"] cls.data /= 10 cls.data[:100] *= 10 cls.data_part_1 = cls.data[:100] cls.data_part_2 = cls.data[100:] cls.data0 = cls.data.ravel() cls.bin_tol = 1e-5 * (cls.data0.max() - cls.data0.min()) cls.hist_rtol = 1.5e-3 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestPartialHistogram: def compare_histograms(self, hist1, hist2): errmax_bins = np.max(np.abs(hist1[1] - hist2[1])) assert errmax_bins < self.bin_tol errmax_hist = np.max(np.abs(hist1[0] - hist2[0]) / hist2[0].max()) assert errmax_hist / hist2[0].max() < self.hist_rtol def test_fixed_nbins(self): partial_hist = PartialHistogram(method="fixed_bins_number", num_bins=1e6) hist1 = partial_hist.compute_histogram(self.data_part_1.ravel()) hist2 = partial_hist.compute_histogram(self.data_part_2.ravel()) hist = partial_hist.merge_histograms([hist1, hist2]) ref = np.histogram(self.data0, bins=partial_hist.num_bins) self.compare_histograms(hist, ref) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_fixed_nbins_cuda(self): partial_hist = CudaPartialHistogram(method="fixed_bins_number", num_bins=1e6, cuda_options={"ctx": self.ctx}) data_part_1 = garray.to_gpu(np.tile(self.data_part_1, (1, 1, 1))) data_part_2 = garray.to_gpu(np.tile(self.data_part_2, (1, 1, 1))) hist1 = partial_hist.compute_histogram(data_part_1) hist2 = partial_hist.compute_histogram(data_part_2) hist = partial_hist.merge_histograms([hist1, hist2]) ref = np.histogram(self.data0, bins=partial_hist.num_bins) self.compare_histograms(hist, ref) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/tests/test_interpolation.py0000644000175000017500000000453000000000000021550 0ustar00pierrepierreimport numpy as np import pytest from scipy.misc import ascent from scipy.interpolate import interp1d from nabu.testutils import generate_tests_scenarios from nabu.utils import get_cuda_srcfile, updiv from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: import pycuda.gpuarray as garray from nabu.cuda.kernel import CudaKernel img0 = ascent().astype("f") scenarios = generate_tests_scenarios( { "image": [img0, img0[:, :511], img0[:511, :]], "x_bounds": [(180, 360), (0, 180), (50, 50 + 180)], "x_to_x_new": [0.1, -0.2, 0.3], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tol = 1e-4 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestInterpolation: def _get_reference_interpolation(self, img, x, x_new): interpolator = interp1d(x, img, kind="linear", axis=0, fill_value="extrapolate", copy=True) ref = interpolator(x_new) return ref def _compare(self, res, img, x, x_new): ref = self._get_reference_interpolation(img, x, x_new) mae = np.max(np.abs(res - ref)) return mae # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.skipif(not (__has_pycuda__), reason="need pycuda for this test") @pytest.mark.parametrize("config", scenarios) def test_cuda_interpolation(self, config): img = config["image"] Ny, Nx = img.shape xmin, xmax = config["x_bounds"] x = np.linspace(xmin, xmax, num=img.shape[0], endpoint=False, dtype="f") x_new = x + config["x_to_x_new"] d_img = garray.to_gpu(img) d_out = garray.zeros_like(d_img) d_x = garray.to_gpu(x) d_x_new = garray.to_gpu(x_new) cuda_interpolator = CudaKernel( "linear_interp_vertical", get_cuda_srcfile("interpolation.cu"), signature="PPiiPP" ) cuda_interpolator(d_img, d_out, Nx, Ny, d_x, d_x_new, grid=(updiv(Nx, 16), updiv(Ny, 16), 1), block=(16, 16, 1)) err = self._compare(d_out.get(), img, x, x_new) err_msg = str("Max error is too high for this configuration: %s" % str(config)) assert err < self.tol, err_msg ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/misc/tests/test_rotation.py0000644000175000017500000000517400000000000020525 0ustar00pierrepierreimport numpy as np import pytest from nabu.testutils import generate_tests_scenarios from nabu.misc.rotation import Rotation, __have__skimage__ from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __have__skimage__: from skimage.transform import rotate from skimage.data import chelsea ny, nx = chelsea().shape[:2] if __has_pycuda__: from nabu.misc.rotation_cuda import CudaRotation import pycuda.gpuarray as garray if __have__skimage__: scenarios = generate_tests_scenarios( { # ~ "output_is_none": [False, True], "mode": ["edge"], "angle": [5.0, 10.0, 45.0, 57.0, 90.0], "center": [None, ((nx - 1) / 2.0, (ny - 1) / 2.0), ((nx - 1) / 2.0, ny - 1)], } ) else: scenarios = {} @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.image = chelsea().mean(axis=-1, dtype=np.float32) if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) cls.d_image = garray.to_gpu(cls.image) yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.skipif(not (__have__skimage__), reason="Need scikit-image for rotation") @pytest.mark.usefixtures("bootstrap") class TestRotation: def _get_reference_rotation(self, config): return rotate( self.image, config["angle"], resize=False, center=config["center"], order=1, mode=config["mode"], clip=False, # preserve_range=False, ) def _check_result(self, res, config, tol): ref = self._get_reference_rotation(config) mae = np.max(np.abs(res - ref)) err_msg = str("Max error is too high for this configuration: %s" % str(config)) assert mae < tol, err_msg # parametrize on a class method will use the same class, and launch this # method with different scenarios. @pytest.mark.parametrize("config", scenarios) def test_rotation(self, config): R = Rotation(self.image.shape, config["angle"], center=config["center"], mode=config["mode"]) res = R(self.image) self._check_result(res, config, 1e-6) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda rotation") @pytest.mark.parametrize("config", scenarios) def test_cuda_rotation(self, config): R = CudaRotation( self.image.shape, config["angle"], center=config["center"], mode=config["mode"], cuda_options={"ctx": self.ctx}, ) d_res = R(self.d_image) res = d_res.get() self._check_result(res, config, 0.5) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/tests/test_unsharp.py0000644000175000017500000000552100000000000020342 0ustar00pierrepierreimport numpy as np import pytest from scipy.misc import ascent from nabu.misc.unsharp import UnsharpMask from nabu.misc.unsharp_opencl import OpenclUnsharpMask, __have_opencl__ as __has_pyopencl__ from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pyopencl__: from pyopencl import CommandQueue import pyopencl.array as parray from silx.opencl.common import ocl if __has_pycuda__: import pycuda.gpuarray as garray from nabu.misc.unsharp_cuda import CudaUnsharpMask @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = np.ascontiguousarray(ascent()[:, :511], dtype=np.float32) cls.tol = 1e-4 cls.sigma = 1.6 cls.coeff = 0.5 if __has_pycuda__: cls.ctx = get_cuda_context(cleanup_at_exit=False) if __has_pyopencl__: cls.cl_ctx = ocl.create_context() yield if __has_pycuda__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestUnsharp: def get_reference_result(self, method, data=None): if data is None: data = self.data unsharp_mask = UnsharpMask(data.shape, self.sigma, self.coeff, method=method) return unsharp_mask.unsharp(data) def check_result(self, result, method, data=None, error_msg_prefix=None): reference = self.get_reference_result(method, data=data) mae = np.max(np.abs(result - reference)) err_msg = str( "%s: max error is too high with method=%s: %.2e > %.2e" % (error_msg_prefix or "", method, mae, self.tol) ) assert mae < self.tol, err_msg @pytest.mark.skipif(not (__has_pyopencl__), reason="Need pyopencl for this test") def testOpenclUnsharp(self): cl_queue = CommandQueue(self.cl_ctx) d_image = parray.to_device(cl_queue, self.data) d_out = parray.zeros_like(d_image) for method in OpenclUnsharpMask.avail_methods: d_image = parray.to_device(cl_queue, self.data) d_out = parray.zeros_like(d_image) opencl_unsharp = OpenclUnsharpMask(self.data.shape, self.sigma, self.coeff, method=method, ctx=self.cl_ctx) opencl_unsharp.unsharp(d_image, output=d_out) res = d_out.get() self.check_result(res, method, error_msg_prefix="OpenclUnsharpMask") @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def testCudaUnsharp(self): d_image = garray.to_gpu(self.data) d_out = garray.zeros_like(d_image) for method in CudaUnsharpMask.avail_methods: cuda_unsharp = CudaUnsharpMask( self.data.shape, self.sigma, self.coeff, method=method, cuda_options={"ctx": self.ctx} ) cuda_unsharp.unsharp(d_image, output=d_out) res = d_out.get() self.check_result(res, method, error_msg_prefix="CudaUnsharpMask") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/unsharp.py0000644000175000017500000000517200000000000016143 0ustar00pierrepierreimport numpy as np from scipy.ndimage import convolve1d from silx.image.utils import gaussian_kernel class UnsharpMask: """ A helper class for unsharp masking. """ avail_methods = ["gaussian", "log", "imagej"] def __init__(self, shape, sigma, coeff, mode="reflect", method="gaussian"): """ Initialize a Unsharp mask. `UnsharpedImage = (1 + coeff)*Image - coeff * ConvolutedImage` If method == "log": `UnsharpedImage = Image + coeff*ConvolutedImage` Parameters ----------- shape: tuple Shape of the image. sigma: float Standard deviation of the Gaussian kernel coeff: float Coefficient in the linear combination of unsharp mask mode: str, optional Convolution mode. Default is "reflect" method: str, optional Method of unsharp mask. Can be "gaussian" (default) or "log" (Laplacian of Gaussian), or "imagej". Notes ----- The computation is the following depending on the method: - For method="gaussian": output = (1 + coeff) * image - coeff * image_blurred - For method="log": output = image + coeff * image_blurred - For method="imagej": output = (image - coeff*image_blurred)/(1-coeff) """ self.shape = shape self.ndim = len(self.shape) self.sigma = sigma self.coeff = coeff self._set_method(method) self.mode = mode self._compute_gaussian_kernel() def _set_method(self, method): if method not in self.avail_methods: raise ValueError("Unknown unsharp method '%s'. Available are %s" % (method, str(self.avail_methods))) self.method = method def _compute_gaussian_kernel(self): self._gaussian_kernel = np.ascontiguousarray(gaussian_kernel(self.sigma), dtype=np.float32) def _blur2d(self, image): res1 = convolve1d(image, self._gaussian_kernel, axis=1, mode=self.mode) res = convolve1d(res1, self._gaussian_kernel, axis=0, mode=self.mode) return res def unsharp(self, image, output=None): """ Reference unsharp mask implementation. """ image_b = self._blur2d(image) if self.method == "gaussian": res = (1 + self.coeff) * image - self.coeff * image_b elif self.method == "log": res = image + self.coeff * image_b else: # "imagej": res = (image - self.coeff * image_b) / (1 - self.coeff) if output is not None: output[:] = res[:] return output return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/unsharp_cuda.py0000644000175000017500000000410700000000000017134 0ustar00pierrepierrefrom ..cuda.utils import __has_pycuda__ from ..cuda.convolution import Convolution from ..cuda.processing import CudaProcessing from .unsharp import UnsharpMask if __has_pycuda__: from pycuda.elementwise import ElementwiseKernel class CudaUnsharpMask(UnsharpMask): def __init__(self, shape, sigma, coeff, mode="reflect", method="gaussian", cuda_options=None): """ Unsharp Mask, cuda backend. """ super().__init__(shape, sigma, coeff, mode=mode, method=method) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_convolution() self._init_mad_kernel() self.cuda_processing.init_arrays_to_none(["_d_out"]) def _init_convolution(self): self.convolution = Convolution( self.shape, self._gaussian_kernel, mode=self.mode, extra_options={ # Use the lowest amount of memory "allocate_input_array": False, "allocate_output_array": False, "allocate_tmp_array": True, }, ) def _init_mad_kernel(self): # garray.GPUArray.mul_add is out of place... self.mad_kernel = ElementwiseKernel( "float* array, float fac, float* other, float otherfac", "array[i] = fac * array[i] + otherfac * other[i]", name="mul_add", ) def unsharp(self, image, output=None): if output is None: output = self.cuda_processing.allocate_array("_d_out", self.shape, "f") self.convolution(image, output=output) if self.method == "gaussian": self.mad_kernel(output, -self.coeff, image, 1.0 + self.coeff) elif self.method == "log": # output = output * coeff + image where output was image_blurred self.mad_kernel(output, self.coeff, image, 1.0) else: # "imagej": # output = (image - coeff*image_blurred)/(1-coeff) where output was image_blurred self.mad_kernel(output, -self.coeff / (1 - self.coeff), image, 1.0 / (1 - self.coeff)) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/unsharp_opencl.py0000644000175000017500000000550400000000000017502 0ustar00pierrepierreimport numpy as np try: from silx.opencl.processing import EventDescription, OpenclProcessing from silx.opencl.convolution import Convolution as CLConvolution import pyopencl.array as parray from pyopencl.elementwise import ElementwiseKernel __have_opencl__ = True except ImportError: __have_opencl__ = False from .unsharp import UnsharpMask class OpenclUnsharpMask(UnsharpMask, OpenclProcessing): def __init__( self, shape, sigma, coeff, mode="reflect", method="gaussian", ctx=None, devicetype="all", platformid=None, deviceid=None, block_size=None, memory=None, profile=False, ): """ NB: For now, this class is designed to use the lowest amount of GPU memory as possible. Therefore, the input and output image/volumes are assumed to be already on device. """ if not (__have_opencl__): raise ImportError("Need silx and pyopencl") OpenclProcessing.__init__( self, ctx=ctx, devicetype=devicetype, platformid=platformid, deviceid=deviceid, block_size=block_size, memory=memory, profile=profile, ) UnsharpMask.__init__(self, shape, sigma, coeff, mode=mode, method=method) self._init_convolution() self._init_mad_kernel() def _init_convolution(self): self.convolution = CLConvolution( self.shape, self._gaussian_kernel, mode=self.mode, ctx=self.ctx, profile=self.profile, extra_options={ # Use the lowest amount of memory "allocate_input_array": False, "allocate_output_array": False, "allocate_tmp_array": True, "dont_use_textures": True, }, ) def _init_mad_kernel(self): # parray.Array.mul_add is out of place... self.mad_kernel = ElementwiseKernel( self.ctx, "float* array, float fac, float* other, float otherfac", "array[i] = fac * array[i] + otherfac * other[i]", name="mul_add", ) def unsharp(self, image, output): # For now image and output are assumed to be already allocated on device assert isinstance(image, parray.Array) assert isinstance(output, parray.Array) self.convolution(image, output=output) if self.method == "gaussian": self.mad_kernel(output, -self.coeff, image, 1.0 + self.coeff) elif self.method == "log": self.mad_kernel(output, self.coeff, image, 1.0) else: # "imagej": self.mad_kernel(output, -self.coeff / (1 - self.coeff), image, 1.0 / (1 - self.coeff)) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/misc/utils.py0000644000175000017500000000750200000000000015622 0ustar00pierrepierreimport numpy as np def rescale_data(data, new_min, new_max, data_min=None, data_max=None): if data_min is None: data_min = np.min(data) if data_max is None: data_max = np.max(data) return (new_max - new_min) / (data_max - data_min) * (data - data_min) + new_min def get_dtype_range(dtype, normalize_floats=False): if np.dtype(dtype).kind in ["u", "i"]: dtype_range = (np.iinfo(dtype).min, np.iinfo(dtype).max) else: if normalize_floats: dtype_range = (-1.0, 1.0) else: dtype_range = (np.finfo(dtype).min, np.finfo(dtype).max) return dtype_range def psnr(img1, img2): if img1.dtype != img2.dtype: raise ValueError("both images should have the same data type") dtype_range = get_dtype_range(img1.dtype, normalize_floats=True) dtype_range = dtype_range[-1] - dtype_range[0] if np.dtype(img1.dtype).kind in ["f", "c"]: img1 = rescale_data(img1, -1.0, 1.0) img2 = rescale_data(img2, -1.0, 1.0) mse = np.mean((img1.astype(np.float64) - img2) ** 2) return 10 * np.log10((dtype_range**2) / mse) # # silx.opencl.utils.ConvolutionInfos cannot be used as long as # silx.opencl instantiates the "ocl" singleton in __init__, # leaving opencl contexts all over the place in some cases. # # so for now: copypasta # class ConvolutionInfos(object): allowed_axes = { "1D": [None], "separable_2D_1D_2D": [None, (0, 1), (1, 0)], "batched_1D_2D": [(0,), (1,)], "separable_3D_1D_3D": [None, (0, 1, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0), (1, 0, 2), (0, 2, 1)], "batched_1D_3D": [(0,), (1,), (2,)], "batched_separable_2D_1D_3D": [(0,), (1,), (2,)], # unsupported (?) "2D": [None], "batched_2D_3D": [(0,), (1,), (2,)], "separable_3D_2D_3D": [ (1, 0), (0, 1), (2, 0), (0, 2), (1, 2), (2, 1), ], "3D": [None], } use_cases = { (1, 1): { "1D": { "name": "1D convolution on 1D data", "kernels": ["convol_1D_X"], }, }, (2, 2): { "2D": { "name": "2D convolution on 2D data", "kernels": ["convol_2D_XY"], }, }, (3, 3): { "3D": { "name": "3D convolution on 3D data", "kernels": ["convol_3D_XYZ"], }, }, (2, 1): { "separable_2D_1D_2D": { "name": "Separable (2D->1D) convolution on 2D data", "kernels": ["convol_1D_X", "convol_1D_Y"], }, "batched_1D_2D": { "name": "Batched 1D convolution on 2D data", "kernels": ["convol_1D_X", "convol_1D_Y"], }, }, (3, 1): { "separable_3D_1D_3D": { "name": "Separable (3D->1D) convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, "batched_1D_3D": { "name": "Batched 1D convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, "batched_separable_2D_1D_3D": { "name": "Batched separable (2D->1D) convolution on 3D data", "kernels": ["convol_1D_X", "convol_1D_Y", "convol_1D_Z"], }, }, (3, 2): { "separable_3D_2D_3D": { "name": "Separable (3D->2D) convolution on 3D data", "kernels": ["convol_2D_XY", "convol_2D_XZ", "convol_2D_YZ"], }, "batched_2D_3D": { "name": "Batched 2D convolution on 3D data", "kernels": ["convol_2D_XY", "convol_2D_XZ", "convol_2D_YZ"], }, }, } ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/opencl/0000755000175000017500000000000000000000000014431 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/opencl/__init__.py0000644000175000017500000000000000000000000016530 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/opencl/utils.py0000644000175000017500000001206000000000000016142 0ustar00pierrepierrefrom numpy import prod try: import pyopencl as cl __has_pyopencl__ = True __pyopencl_error_msg__ = None except ImportError as err: __has_pyopencl__ = False __pyopencl_error_msg__ = str(err) from ..resources.gpu import GPUDescription def usable_opencl_devices(): """ Test the available OpenCL platforms/devices. Returns -------- platforms: dict Dictionary where the key is the platform name, and the value is a list of `silx.opencl.common.Device` object. """ platforms = {} for platform in cl.get_platforms(): platforms[platform.name] = platform.get_devices() return platforms def detect_opencl_gpus(): """ Get the available OpenCL-compatible GPUs. Returns -------- gpus: dict Nested dictionary where the keys are OpenCL platform names, values are dictionary of GPU IDs and `silx.opencl.common.Device` object. error_msg: str In the case where there is an error, the message is returned in this item. Otherwise, it is a None object. """ gpus = {} error_msg = None if not (__has_pyopencl__): return {}, __pyopencl_error_msg__ try: platforms = usable_opencl_devices() except Exception as exc: error_msg = str(exc) if error_msg is not None: return {}, error_msg for platform_name, devices in platforms.items(): for d_id, device in enumerate(devices): if device.type == cl.device_type.GPU: # and bool(device.available): if platform_name not in gpus: gpus[platform_name] = {} gpus[platform_name][d_id] = device return gpus, None def collect_opencl_gpus(): """ Return a dictionary of platforms and brief description of each OpenCL-compatible GPU with a few fields """ gpus, error_msg = detect_opencl_gpus() if error_msg is not None: return None opencl_gpus = {} for platform, gpus in gpus.items(): for gpu_id, gpu in gpus.items(): if platform not in opencl_gpus: opencl_gpus[platform] = {} opencl_gpus[platform][gpu_id] = GPUDescription(gpu, device_id=gpu_id).get_dict() opencl_gpus[platform][gpu_id]["platform"] = platform return opencl_gpus def collect_opencl_cpus(): """ Return a dictionary of platforms and brief description of each OpenCL-compatible CPU with a few fields """ opencl_cpus = {} platforms = usable_opencl_devices() for platform, devices in platforms.items(): if "cuda" in platform.lower(): continue opencl_cpus[platform] = {} for device_id, device in enumerate(devices): # device_id might be inaccurate if device.type != cl.device_type.CPU: continue opencl_cpus[platform][device_id] = GPUDescription(device).get_dict() opencl_cpus[platform][device_id]["platform"] = platform return opencl_cpus def create_opencl_context(platform_id, device_id, cleanup_at_exit=True): """ Create an OpenCL context. """ platforms = cl.get_platforms() platform = platforms[platform_id] devices = platform.get_devices() ctx = cl.Context(devices=[devices[device_id]]) return ctx def replace_array_memory(arr, new_shape): """ Replace the underlying buffer data of a `pyopencl.array.Array`. This function is dangerous ! It should merely be used to clear memory, the array should not be used afterwise. """ arr.data.release() arr.base_data = cl.Buffer(arr.context, cl.mem_flags.READ_WRITE, prod(new_shape) * arr.dtype.itemsize) arr.shape = new_shape # strides seems to be updated by pyopencl return arr def pick_opencl_cpu_platform(opencl_cpus): """ Pick the best OpenCL implementation for the opencl cpu. This function assume that there is only one opencl-enabled CPU on the current machine, but there might be several OpenCL implementations/vendors. Parameters ---------- opencl_cpus: dict Dictionary with the available opencl-enabled CPUs. Usually obtained with collect_opencl_cpus(). Returns ------- cpu: dict A dictionary describing the CPU. """ if len(opencl_cpus) == 0: raise ValueError("No CPU to pick") name2device = {} for platform, devices in opencl_cpus.items(): for device_id, device_desc in devices.items(): name2device.setdefault(device_desc["name"], []) name2device[device_desc["name"]].append(platform) if len(name2device) > 1: raise ValueError("Expected at most one CPU but got %d: %s" % (len(name2device), list(name2device.keys()))) cpu_name = list(name2device.keys())[0] platforms = name2device[cpu_name] # Several platforms for the same CPU res = opencl_cpus[platforms[0]] if len(platforms) > 1: if "intel" in cpu_name.lower(): for platform in platforms: if "intel" in platform.lower(): res = opencl_cpus[platform] # return res[list(res.keys())[0]] ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4607332 nabu-2023.1.1/nabu/pipeline/0000755000175000017500000000000000000000000014756 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1628752049.0 nabu-2023.1.1/nabu/pipeline/__init__.py0000644000175000017500000000000000000000000017055 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/config.py0000644000175000017500000002212500000000000016577 0ustar00pierrepierrefrom os import linesep from configparser import ConfigParser from silx.io.dictdump import dicttoh5, h5todict from silx.io.url import DataUrl from ..utils import check_supported, deprecated # # option "type": # - required: always visible, user must provide a valid value # - optional: visible, but might be left blank # - advanced: optional and not visible by default # - unsupported: hidden (not implemented yet) _options_levels = { "required": 0, "optional": 1, "advanced": 2, "unsupported": 10, } class NabuConfigParser: @deprecated( "The class 'NabuConfigParser' is deprecated and will be removed in a future version. Please use parse_nabu_config_file instead", do_print=True, ) def __init__(self, fname): """ Nabu configuration file parser. Parameters ---------- fname: str File name of the configuration file """ parser = ConfigParser(inline_comment_prefixes=("#",)) # allow in-line comments with open(fname) as fid: file_content = fid.read() parser.read_string(file_content) self.parser = parser self.get_dict() self.file_content = file_content.split(linesep) def get_dict(self): # Is there an officially supported way to do this ? self.conf_dict = self.parser._sections return self.conf_dict def __str__(self): return self.conf_dict.__str__() def __repr__(self): return self.conf_dict.__repr__() def __getitem__(self, key): return self.conf_dict[key] def parse_nabu_config_file(fname, allow_no_value=False): """ Parse a configuration file and returns a dictionary. Parameters ---------- fname: str File name of the configuration file Returns ------- conf_dict: dict Dictionary with the configuration """ parser = ConfigParser( inline_comment_prefixes=("#",), # allow in-line comments allow_no_value=allow_no_value, ) with open(fname) as fid: file_content = fid.read() parser.read_string(file_content) conf_dict = parser._sections # Is there an officially supported way to do this ? return conf_dict def generate_nabu_configfile( fname, default_config, config=None, sections=None, sections_comments=None, comments=True, options_level=None, prefilled_values=None, ): """ Generate a nabu configuration file. Parameters ----------- fname: str Output file path. config: dict Configuration to save. If section and / or key missing will store the default value sections: list of str, optional Sections which should be included in the configuration file comments: bool, optional Whether to include comments in the configuration file options_level: str, optional Which "level" of options to embed in the file. Can be "required", "optional", "advanced". Default is "optional". """ if options_level is None: options_level = "optional" if prefilled_values is None: prefilled_values = {} check_supported(options_level, list(_options_levels.keys()), "options_level") options_level = _options_levels[options_level] if config is None: config = {} if sections is None: sections = default_config.keys() def dump_help(fid, help_sequence): for help_line in help_sequence.split(linesep): content = "# %s" % (help_line) if help_line.strip() != "" else "" content = content + linesep fid.write(content) with open(fname, "w") as fid: for section, section_content in default_config.items(): if section not in sections: continue if section != "dataset": fid.write("%s%s" % (linesep, linesep)) fid.write("[%s]%s" % (section, linesep)) if sections_comments is not None and section in sections_comments: dump_help(fid, sections_comments[section]) for key, values in section_content.items(): if options_level < _options_levels[values["type"]]: continue if comments and values["help"].strip() != "": dump_help(fid, values["help"]) value = values["default"] if section in prefilled_values and key in prefilled_values[section]: value = prefilled_values[section][key] if section in config and key in config[section]: value = config[section][key] fid.write("%s = %s%s" % (key, value, linesep)) def _extract_nabuconfig_section(section, default_config): res = {} for key, val in default_config[section].items(): res[key] = val["default"] return res def _extract_nabuconfig_keyvals(default_config): res = {} for section in default_config.keys(): res[section] = _extract_nabuconfig_section(section, default_config) return res def get_default_nabu_config(default_config): """ Return a dictionary with the default nabu configuration. """ return _extract_nabuconfig_keyvals(default_config) def _handle_modified_key(key, val, section, default_config, renamed_keys): if val is not None: return key, val, section if key in renamed_keys and renamed_keys[key]["section"] == section: info = renamed_keys[key] print(info["message"]) print("This is deprecated since version %s and will result in an error in futures versions" % (info["since"])) section = info.get("new_section", section) if info["new_name"] == "": return None, None, section # deleted key val = default_config[section].get(info["new_name"], None) return info["new_name"], val, section else: return key, None, section # unhandled renamed/deleted key def validate_config(config, default_config, renamed_keys, errors="warn"): """ Validate a configuration dictionary against a "default" configuration dict. Parameters ---------- config: dict configuration dict to be validated default_config: dict Reference configuration. Missing keys/sections from 'config' will be updated with keys from this dictionary. errors: str, optional What to do when an unknonw key/section is encountered. Possible actions are: - "warn": throw a warning, continue the validation - "raise": raise an error and exit """ def error(msg): if errors == "raise": raise ValueError(msg) else: print("Error: %s" % msg) res_config = {} for section, section_content in config.items(): # Ignore the "other" section if section.lower() == "other": continue if section not in default_config: error("Unknown section [%s]" % section) continue res_config[section] = _extract_nabuconfig_section(section, default_config) res_config[section].update(section_content) for key, value in res_config[section].items(): opt = default_config[section].get(key, None) key, opt, section_updated = _handle_modified_key(key, opt, section, default_config, renamed_keys) if key is None: continue # deleted key if opt is None: error("Unknown option '%s' in section [%s]" % (key, section_updated)) continue validator = default_config[section_updated][key]["validator"] if section_updated not in res_config: # missing section - handled later continue res_config[section_updated][key] = validator(section_updated, key, value) # Handle sections missing in config for section in set(default_config.keys()) - set(res_config.keys()): res_config[section] = _extract_nabuconfig_section(section, default_config) for key, value in res_config[section].items(): validator = default_config[section][key]["validator"] res_config[section][key] = validator(section, key, value) return res_config validate_nabu_config = deprecated("validate_nabu_config is renamed validate_config", do_print=True)(validate_config) def overwrite_config(conf, overwritten_params): """ Overwrite a (validated) configuration with a new parameters dict. Parameters ---------- conf: dict Configuration dictionary, usually output from validate_config() overwritten_params: dict Configuration dictionary with the same layout, containing parameters to overwrite """ overwritten_params = overwritten_params or {} for section, params in overwritten_params.items(): if section not in conf: raise ValueError("Unknown section %s" % section) current_section = conf[section] for key in params.keys(): if key not in current_section: raise ValueError("Unknown parameter '%s' in section '%s'" % (key, section)) conf[section][key] = overwritten_params[section][key] # --- return conf ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/config_validators.py0000644000175000017500000003331100000000000021026 0ustar00pierrepierreimport os path = os.path from ..utils import is_writeable from .params import * """ A validator is a function with - input: a value - output: the input value, or a modified input value - possibly raising exceptions in case of invalid value. """ # ------------------------------------------------------------------------------ # ---------------------------- Utils ------------------------------------------- # ------------------------------------------------------------------------------ def raise_error(section, key, msg=""): raise ValueError("Invalid value for %s/%s: %s" % (section, key, msg)) def validator(func): """ Common decorator for all validator functions. It modifies the signature of the decorated functions ! """ def wrapper(section, key, value): try: res = func(value) except AssertionError as e: raise_error(section, key, e) return res return wrapper def convert_to_int(val): val_int = 0 try: val_int = int(val) conversion_error = None except ValueError as exc: conversion_error = exc return val_int, conversion_error def convert_to_float(val): val_float = 0.0 try: val_float = float(val) conversion_error = None except ValueError as exc: conversion_error = exc return val_float, conversion_error def convert_to_bool(val): val_int, error = convert_to_int(val) res = None if not error: res = val_int > 0 else: if val.lower() in ["yes", "true"]: res = True error = None if val.lower() in ["no", "false"]: res = False error = None return res, error def convert_to_bool_noerr(val): res, err = convert_to_bool(val) if err is not None: raise ValueError("Could not convert to boolean: %s" % str(val)) return res def name_range_checker(name, valid_names, descr, replacements=None): name = name.strip().lower() if replacements is not None and name in replacements: name = replacements[name] valid = name in valid_names assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, str(valid_names)) return name # ------------------------------------------------------------------------------ # ---------------------------- Validators -------------------------------------- # ------------------------------------------------------------------------------ @validator def optional_string_validator(val): if len(val.strip()) == 0: return None return val @validator def file_name_validator(name): assert len(name) >= 1, "Name should be non-empty" return name @validator def file_location_validator(location): assert path.isfile(location), "location must be a file" return os.path.abspath(location) @validator def optional_file_location_validator(location): if len(location.strip()) > 0: assert path.isfile(location), "location must be a file" return os.path.abspath(location) return None @validator def optional_values_file_validator(location): if len(location.strip()) == 0: return None if path.splitext(location)[-1].strip() == "": # Assume path to h5 dataset. Validation is done later. if "://" not in location: location = "silx://" + os.path.abspath(location) else: # Assume plaintext file assert path.isfile(location), "Invalid file path" location = os.path.abspath(location) return location @validator def directory_location_validator(location): assert path.isdir(location), "location must be a directory" return os.path.abspath(location) @validator def optional_directory_location_validator(location): if len(location.strip()) > 0: assert is_writeable(location), "Directory must be writeable" return os.path.abspath(location) return None @validator def dataset_location_validator(location): if not (path.isdir(location)): assert ( path.isfile(location) and path.splitext(location)[-1].split(".")[-1].lower() in files_formats ), "Dataset location must be a directory or a HDF5 file" return os.path.abspath(location) @validator def directory_writeable_validator(location): assert is_writeable(location), "Directory must be writeable" return os.path.abspath(location) @validator def optional_output_directory_validator(location): if len(location.strip()) > 0: return directory_writeable_validator(location) return None @validator def optional_output_file_path_validator(location): if len(location.strip()) > 0: dirname, fname = path.split(location) assert os.access(dirname, os.W_OK), "Directory must be writeable" return os.path.abspath(location) return None @validator def integer_validator(val): val_int, error = convert_to_int(val) assert error is None, "number must be an integer" return val_int @validator def nonnegative_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int >= 0, "number must be a non-negative integer" return val_int @validator def positive_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int > 0, "number must be a positive integer" return val_int @validator def optional_positive_integer_validator(val): if len(val.strip()) == 0: return None val_int, error = convert_to_int(val) assert error is None and val_int > 0, "number must be a positive integer" return val_int @validator def nonzero_integer_validator(val): val_int, error = convert_to_int(val) assert error is None and val_int != 0, "number must be a non-zero integer" return val_int @validator def binning_validator(val): if val == "": val = "1" val_int, error = convert_to_int(val) assert error is None and val_int >= 0, "number must be a non-negative integer" return max(1, val_int) @validator def optional_file_name_validator(val): if len(val) > 0: assert len(val) >= 1, "Name should be non-empty" assert path.basename(val) == val, "File name should not be a path (no '/')" return val return None @validator def boolean_validator(val): res, error = convert_to_bool(val) assert error is None, "Invalid boolean value" return res @validator def boolean_or_auto_validator(val): res, error = convert_to_bool(val) if error is not None: assert val.lower() == "auto", "Valid values are 0, 1 and auto" return val return res @validator def float_validator(val): val_float, error = convert_to_float(val) assert error is None, "Invalid number" return val_float @validator def optional_float_validator(val): if isinstance(val, float): return val elif len(val.strip()) >= 1: val_float, error = convert_to_float(val) assert error is None, "Invalid number" else: val_float = None return val_float @validator def optional_nonzero_float_validator(val): if isinstance(val, float): val_float = val elif len(val.strip()) >= 1: val_float, error = convert_to_float(val) assert error is None, "Invalid number" else: val_float = None if val_float is not None: if abs(val_float) < 1e-6: val_float = None return val_float @validator def optional_tuple_of_floats_validator(val): if len(val.strip()) == 0: return None err_msg = "Expected a tuple of two numbers, but got %s" % val try: res = tuple(float(x) for x in val.strip("()").split(",")) except Exception as exc: raise ValueError(err_msg) if len(res) != 2: raise ValueError(err_msg) return res @validator def cor_validator(val): val_float, error = convert_to_float(val) if error is None: return val_float if len(val.strip()) == 0: return None val = name_range_checker( val.lower(), set(cor_methods.values()), "center of rotation estimation method", replacements=cor_methods ) return val @validator def tilt_validator(val): val_float, error = convert_to_float(val) if error is None: return val_float if len(val.strip()) == 0: return None val = name_range_checker( val.lower(), set(tilt_methods.values()), "automatic detector tilt estimation method", replacements=tilt_methods ) return val @validator def slice_num_validator(val): val_int, error = convert_to_int(val) if error is None: return val_int else: assert val in [ "first", "middle", "last", ], "Expected start_z and end_z to be either a number or first, middle or last" return val @validator def generic_options_validator(val): if len(val.strip()) == 0: return None return val cor_options_validator = generic_options_validator @validator def cor_slice_validator(val): if len(val) == 0: return None val_int, error = convert_to_int(val) if error: supported = ["top", "first", "bottom", "last", "middle"] assert val in supported, "Invalid value, must be a number or one of %s" % supported return val else: return val_int @validator def flatfield_enabled_validator(val): return name_range_checker(val, set(flatfield_modes.values()), "flatfield mode", replacements=flatfield_modes) @validator def phase_method_validator(val): return name_range_checker( val, set(phase_retrieval_methods.values()), "phase retrieval method", replacements=phase_retrieval_methods ) @validator def detector_distortion_correction_validator(val): return name_range_checker( val, set(detector_distortion_correction_methods.values()), "detector_distortion_correction_methods", replacements=detector_distortion_correction_methods, ) @validator def unsharp_method_validator(val): return name_range_checker( val, set(unsharp_methods.values()), "unsharp mask method", replacements=phase_retrieval_methods ) @validator def padding_mode_validator(val): return name_range_checker(val, set(padding_modes.values()), "padding mode", replacements=padding_modes) @validator def reconstruction_method_validator(val): return name_range_checker( val, set(reconstruction_methods.values()), "reconstruction method", replacements=reconstruction_methods ) @validator def fbp_filter_name_validator(val): return name_range_checker( val, set(fbp_filters.values()), "FBP filter", replacements=fbp_filters, ) @validator def iterative_method_name_validator(val): return name_range_checker( val, set(iterative_methods.values()), "iterative methods name", replacements=iterative_methods ) @validator def optimization_algorithm_name_validator(val): return name_range_checker( val, set(optim_algorithms.values()), "optimization algorithm name", replacements=iterative_methods ) @validator def output_file_format_validator(val): return name_range_checker(val, set(files_formats.values()), "output file format", replacements=files_formats) @validator def distribution_method_validator(val): val = name_range_checker( val, set(distribution_methods.values()), "workload distribution method", replacements=distribution_methods ) # TEMP. if val != "local": raise NotImplementedError("Computation method '%s' is not implemented yet" % val) # -- return val @validator def sino_normalization_validator(val): val = name_range_checker( val, set(sino_normalizations.values()), "sinogram normalization method", replacements=sino_normalizations ) return val @validator def sino_deringer_methods(val): val = name_range_checker( val, set(rings_methods.values()), "sinogram rings artefacts correction method", replacements=rings_methods, ) return val @validator def list_of_int_validator(val): ids = val.replace(",", " ").split() res = list(map(convert_to_int, ids)) err = list(filter(lambda x: x[1] is not None or x[0] < 0, res)) if err != []: raise ValueError("Could not convert to a list of GPU IDs: %s" % val) return list(set(map(lambda x: x[0], res))) @validator def list_of_shift_validator(values): ids = values.replace(" ", "").split(",") return [int(val) if val not in ("auto", "'auto'", '"auto"') else "auto" for val in ids] @validator def list_of_tomoscan_identifier(val): # TODO: insure those are valid tomoscan identifier return val @validator def resources_validator(val): val = val.strip() is_percentage = False if "%" in val: is_percentage = True val = val.replace("%", "") val_float, conversion_error = convert_to_float(val) assert conversion_error is None, str("Error while converting %s to float" % val) return (val_float, is_percentage) @validator def walltime_validator(val): # HH:mm:ss vals = val.strip().split(":") error_msg = "Invalid walltime format, expected HH:mm:ss" assert len(vals) == 3, error_msg hours, mins, secs = vals hours, err1 = convert_to_int(hours) mins, err2 = convert_to_int(mins) secs, err3 = convert_to_int(secs) assert err1 is None and err2 is None and err3 is None, error_msg err = hours < 0 or mins < 0 or mins > 59 or secs < 0 or secs > 59 assert err is False, error_msg return hours, mins, secs @validator def nonempty_string_validator(val): assert val != "", "Value cannot be empty" return val @validator def logging_validator(val): return name_range_checker(val, set(log_levels.values()), "logging level", replacements=log_levels) @validator def no_validator(val): return val ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/datadump.py0000644000175000017500000001457000000000000017136 0ustar00pierrepierrefrom os import path from ..resources.logger import LoggerOrPrint from .utils import get_subregion from .writer import WriterManager from ..io.reader import get_hdf5_dataset_shape try: import pycuda.gpuarray as garray __has_pycuda__ = True except: __has_pycuda__ = False class DataDumpManager: """ A helper class for managing data dumps, with the aim of saving/resuming the processing from a given step. """ def __init__(self, process_config, sub_region, margin=None, logger=None): """ Initialize a DataDump object. Parameters ----------- process_config: ProcessConfig ProcessConfig object sub_region: tuple of int Series of integers defining the sub-region being processed. The form is ((start_angle, end_angle), (start_z, end_z), (start_x, end_x)) margin: tuple of int, optional Margin, used when processing data, in the form ((up, down), (left, right)). Each item can be None. Using a margin means that a given chunk of data will eventually be cropped as `data[:, up:-down, left:-right]` logger: Logger, optional Logging object """ self.process_config = process_config self.processing_steps = process_config.processing_steps self.processing_options = process_config.processing_options self.dataset_info = process_config.dataset_info self._set_subregion_and_margin(sub_region, margin) self.logger = LoggerOrPrint(logger) self._configure_data_dumps() def _set_subregion_and_margin(self, sub_region, margin): self.sub_region = get_subregion(sub_region) self._z_sub_region = self.sub_region[1] self.z_min = self._z_sub_region[0] self.margin = get_subregion(margin, ndim=2) # ((U, D), (L, R)) self.margin_up = self.margin[0][0] self.start_index = self.z_min + self.margin_up self.delta_z = self._z_sub_region[-1] - self._z_sub_region[-2] self._grouped_processing = False iangle1, iangle2 = self.sub_region[0] if iangle1 != 0 or iangle2 < len(self.process_config.rotation_angles(subsampling=False)): self._grouped_processing = True self.start_index = self.sub_region[0][0] def _configure_dump(self, step_name, force_dump_to_fname=None): if force_dump_to_fname is not None: # Shortcut fname_full = force_dump_to_fname elif step_name in self.processing_steps: # Standard case if not self.processing_options[step_name].get("save", False): return fname_full = self.processing_options[step_name]["save_steps_file"] elif step_name == "sinogram" and self.process_config.dump_sinogram: # "sinogram" is a special keyword fname_full = self.process_config.dump_sinogram_file else: return # "fname_full" is the path to the final master file. # We also need to create partial files (in a sub-directory) fname, ext = path.splitext(fname_full) dirname, file_prefix = path.split(fname) self.data_dump[step_name] = WriterManager( dirname, file_prefix, file_format="hdf5", overwrite=True, start_index=self.start_index, logger=self.logger, metadata={ "process_name": step_name, "processing_index": 0, "config": { "processing_options": self.processing_options, # slow! "nabu_config": self.process_config.nabu_config, }, "entry": getattr(self.dataset_info.dataset_scanner, "entry", "entry"), }, ) def _configure_data_dumps(self): self.data_dump = {} for step_name in self.processing_steps: self._configure_dump(step_name) # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation if self.process_config.dump_sinogram: self._configure_dump("sinogram") def get_data_dump(self, step_name): """ Get information on where to write a given processing step. Parameters ---------- step_name: str Name of the processing step Returns ------- writer_configurator: WriterConfigurator An object with information on where to write the data for the given processing step. """ return self.data_dump.get(step_name, None) def get_read_dump_subregion(self): read_opts = self.processing_options["read_chunk"] if read_opts.get("process_file", None) is None: return None dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"] relative_start_z = self.z_min - dump_start_z relative_end_z = relative_start_z + self.delta_z # When using binning, every step after "read" results in smaller-sized data. # Therefore dumped data has shape (ceil(n_angles/subsampling), n_z//binning_z, n_x//binning_x) relative_start_z //= self.process_config.binning_x relative_end_z //= self.process_config.binning_x # (n_angles, n_z, n_x) subregion = (None, None, relative_start_z, relative_end_z, None, None) return subregion def _check_resume_from_step(self): read_opts = self.processing_options["read_chunk"] expected_radios_shape = get_hdf5_dataset_shape( read_opts["process_file"], read_opts["process_h5_path"], sub_region=self.get_read_dump_subregion(), ) # TODO check def dump_data_to_file(self, step_name, data): if step_name not in self.data_dump: return writer = self.data_dump[step_name] self.logger.info("Dumping data to %s" % writer.fname) if __has_pycuda__: if isinstance(data, garray.GPUArray): data = data.get() writer.write_data(data) def __repr__(self): res = "%s(%s, margin=%s)" % (self.__class__.__name__, str(self.sub_region), str(self.margin)) if len(self.data_dump) > 0: for step_name, writer_configurator in self.data_dump.items(): res += "\n- Dump %s to %s" % (step_name, writer_configurator.fname) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/dataset_validator.py0000644000175000017500000002176700000000000021037 0ustar00pierrepierreimport os from math import ceil from ..resources.logger import LoggerOrPrint from ..utils import copy_dict_items from ..reconstruction.sinogram import get_extended_sinogram_width class DatasetValidatorBase: # this in the helical derived class will be False _check_also_z = True def __init__(self, nabu_config, dataset_info, logger=None): """ Perform a coupled validation of nabu configuration against dataset information. Check the consistency of these two structures, and modify them in-place. Parameters ---------- nabu_config: dict Dictionary containing the nabu configuration, usually got from `nabu.pipeline.config.validate_config()` It will be modified ! dataset_info: `DatasetAnalyzer` instance Structure containing information on the dataset to process. It will be modified ! """ self.nabu_config = nabu_config self.dataset_info = dataset_info self.logger = LoggerOrPrint(logger) self.rec_params = copy_dict_items(self.nabu_config["reconstruction"], self.nabu_config["reconstruction"].keys()) self._validate() def _validate(self): raise ValueError("Base class") @property def is_halftomo(self): do_halftomo = self.nabu_config["reconstruction"].get("enable_halftomo", False) if do_halftomo == "auto": do_halftomo = self.dataset_info.is_halftomo if do_halftomo is None: raise ValueError( "'enable_halftomo' was set to 'auto' but unable to get the information on field of view" ) return do_halftomo def _check_not_empty(self): if len(self.dataset_info.projections) == 0: msg = "Dataset seems to be empty (no projections)" self.logger.fatal(msg) raise ValueError(msg) if self.dataset_info.n_angles is None: msg = "Could not determine the number of projections. Please check the .info or HDF5 file" self.logger.fatal(msg) raise ValueError(msg) for dim_name, n in zip(["dim_1", "dim_2"], self.dataset_info.radio_dims): if n is None: msg = "Could not determine %s. Please check the .info file or HDF5 file" % dim_name self.logger.fatal(msg) raise ValueError(msg) @staticmethod def _convert_negative_idx(idx, last_idx): res = idx if idx < 0: res = last_idx + idx return res def _get_nx_ny(self, binning_factor=1): nx = self.dataset_info.radio_dims[0] // binning_factor if self.is_halftomo: cor = self._get_cor(binning_factor=binning_factor) nx = get_extended_sinogram_width(nx, cor) ny = nx return nx, ny def _get_cor(self, binning_factor=1): cor = self.dataset_info.axis_position if binning_factor >= 1: # Backprojector uses middle of pixel for coordinate indices. # This means that the leftmost edge of the leftmost pixel has coordinate -0.5. # When using binning with a factor 'b', the CoR has to adapted as # cor_binned = (cor + 0.5)/b - 0.5 cor = (cor + 0.5) / binning_factor - 0.5 return cor def _convert_negative_indices(self): """ Convert any negative index to the corresponding positive index. """ nx, nz = self.dataset_info.radio_dims ny = nx if self.is_halftomo: if self.dataset_info.axis_position is None: raise ValueError( "Cannot use rotation axis position in the middle of the detector when half tomo is enabled" ) nx, ny = self._get_nx_ny() what = ( ("start_x", nx), ("end_x", nx), ("start_y", ny), ("end_y", ny), ) if self._check_also_z: what = what + ( ("start_z", nz), ("end_z", nz), ) for key, upper_bound in what: val = self.rec_params[key] if isinstance(val, str): idx_mapping = { "first": 0, "middle": upper_bound // 2, # works on both start_ and end_ since the end_ index is included "last": upper_bound - 1, # upper bound is included in the user interface (contrarily to python) } res = idx_mapping[val] else: res = self._convert_negative_idx(self.rec_params[key], upper_bound) self.rec_params[key] = res self.rec_region = copy_dict_items(self.rec_params, [w[0] for w in what]) def _get_output_filename(self): # This function modifies nabu_config ! opts = self.nabu_config["output"] dataset_path = self.nabu_config["dataset"]["location"] if opts["location"] == "" or opts["location"] is None: opts["location"] = os.path.dirname(dataset_path) if opts["file_prefix"] == "" or opts["file_prefix"] is None: if os.path.isfile(dataset_path): # hdf5 file_prefix = os.path.basename(dataset_path).split(".")[0] elif os.path.isdir(dataset_path): file_prefix = os.path.basename(dataset_path) else: raise ValueError("dataset location %s is neither a file or directory" % dataset_path) file_prefix += "_rec" # avoid overwriting dataset opts["file_prefix"] = file_prefix @staticmethod def _check_start_end_idx(start, end, n_elements, start_name="start_x", end_name="end_x"): assert start >= 0 and start < n_elements, "Invalid value %d for %s, must be >= 0 and < %d" % ( start, start_name, n_elements, ) assert end >= 0 and end < n_elements, "Invalid value for %d %s, must be >= 0 and < %d" % ( end, end_name, n_elements, ) assert start <= end, "Must have %s <= %s" % (start_name, end_name) def _handle_binning(self): """ Modify the dataset description/process config to handle binning and projections subsampling """ dataset_cfg = self.nabu_config["dataset"] self.binning = (dataset_cfg["binning"], dataset_cfg["binning_z"]) subsampling_factor = dataset_cfg["projections_subsampling"] self.subsampling_factor = subsampling_factor or 1 if self.binning != (1, 1): bin_x, bin_z = self.binning rec_cfg = self.rec_params # Update "start_xyz" rec_cfg["start_x"] //= bin_x rec_cfg["start_y"] //= bin_x rec_cfg["start_z"] //= bin_z # Update "end_xyz". Things are a little bit more complicated for several reasons: # - In the user interface (configuration file), end_xyz index is INCLUDED (contrarily to python). So there are +1, -1 all over the place. # - When using half tomography, n_x and n_y are less straightforward : 2*CoR(binning) instead of 2*CoR//binning # - delta = end - start [+1] should be a multiple of binning factor. This makes things much easier for processing pipeline. def ensure_multiple_of_binning(end, start, binning_factor): """ Update "end" so that end-start is a multiple of "binning_factor" Note that "end" is INCLUDED here (comes from user configuration) """ return end - ((end - start + 1) % binning_factor) end_z = ensure_multiple_of_binning(rec_cfg["end_z"], rec_cfg["start_z"], bin_z) rec_cfg["end_z"] = (end_z + 1) // bin_z - 1 nx_binned, ny_binned = self._get_nx_ny(binning_factor=bin_x) end_y = ensure_multiple_of_binning(rec_cfg["end_y"], rec_cfg["start_y"], bin_x) rec_cfg["end_y"] = min((end_y + 1) // bin_x - 1, ny_binned - 1) end_x = ensure_multiple_of_binning(rec_cfg["end_x"], rec_cfg["start_x"], bin_x) rec_cfg["end_x"] = min((end_x + 1) // bin_x - 1, nx_binned - 1) def _check_output_file(self): out_cfg = self.nabu_config["output"] out_fname = os.path.join(out_cfg["location"], out_cfg["file_prefix"] + out_cfg["file_format"]) if os.path.exists(out_fname): raise ValueError("File %s already exists" % out_fname) def _handle_processing_mode(self): mode = self.nabu_config["resources"]["method"] if mode == "preview": print( "Warning: the method 'preview' was selected. This means that the data volume will be binned so that everything fits in memory." ) # TODO automatically compute binning/subsampling factors as a function of lowest memory (GPU) self.nabu_config["dataset"]["binning"] = 2 self.nabu_config["dataset"]["binning_z"] = 2 self.nabu_config["dataset"]["projections_subsampling"] = 2 # TODO handle other modes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/detector_distortion_provider.py0000644000175000017500000000163000000000000023331 0ustar00pierrepierrefrom ..resources.utils import extract_parameters from ..io.detector_distortion import DetectorDistortionBase, DetectorDistortionMapsXZ import silx.io def DetectorDistortionProvider(detector_full_shape_vh=(0, 0), correction_type="", options=""): if correction_type == "identity": return DetectorDistortionBase(detector_full_shape_vh=detector_full_shape_vh) elif correction_type == "map_xz": options = options.replace("path=", "path_eq") user_params = extract_parameters(options) print(user_params, options) map_x = silx.io.get_data(user_params["map_x"].replace("path_eq", "path=")) map_z = silx.io.get_data(user_params["map_z"].replace("path_eq", "path=")) return DetectorDistortionMapsXZ(map_x=map_x, map_z=map_z) else: message = f""" Unknown correction type: {correction_type} requested """ raise ValueError(message) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/estimators.py0000644000175000017500000007644100000000000017536 0ustar00pierrepierre""" nabu.pipeline.estimators: helper classes/functions to estimate parameters of a dataset (center of rotation, detector tilt, etc). """ import inspect import numpy as np import scipy.fft # pylint: disable=E0611 from silx.io import get_data from ..preproc.flatfield import FlatFieldDataUrls from ..estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow, ) from ..estimation.cor_sino import SinoCorInterface from ..estimation.tilt import CameraTilt from ..estimation.utils import is_fullturn_scan from ..resources.logger import LoggerOrPrint from ..resources.utils import extract_parameters from ..utils import check_supported, is_int from .params import tilt_methods from ..resources.dataset_analyzer import get_0_180_radios from ..misc.rotation import Rotation from ..io.reader import ChunkReader from ..preproc.ccd import Log, CCDFilter from ..misc import fourier_filters from .params import cor_methods from ..io.reader import load_images_from_dataurl_dict def estimate_cor(method, dataset_info, do_flatfield=True, cor_options_str=None, logger=None): logger = LoggerOrPrint(logger) cor_options_str = cor_options_str or "" check_supported(method, list(cor_methods.keys()), "COR estimation method") method = cor_methods[method] # Extract CoR parameters from configuration file try: cor_options = extract_parameters(cor_options_str, sep=";") except Exception as exc: msg = "Could not extract parameters from cor_options: %s" % (str(exc)) logger.fatal(msg) raise ValueError(msg) # Dispatch if method in CORFinder.search_methods: cor_finder = CORFinder(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger) estimated_cor = cor_finder.find_cor() elif method in SinoCORFinder.search_methods: cor_finder = SinoCORFinder( method, dataset_info, slice_idx=cor_options.get("slice_idx", "middle"), subsampling=cor_options.get("subsampling", 10), do_flatfield=do_flatfield, cor_options=cor_options, logger=logger, ) estimated_cor = cor_finder.find_cor() else: composite_options = update_func_kwargs(CompositeCORFinder, cor_options) for what in ["cor_options", "logger"]: composite_options.pop(what, None) cor_finder = CompositeCORFinder( dataset_info, cor_options=cor_options, logger=logger, **composite_options, ) estimated_cor = cor_finder.find_cor() return estimated_cor class CORFinderBase: """ A base class for CoR estimators. It does common tasks like data reading, flatfield, etc. """ search_methods = {} def __init__(self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None): """ Initialize a CORFinder object. Parameters ---------- dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer` Dataset information structure """ check_supported(method, self.search_methods, "CoR estimation method") self.logger = LoggerOrPrint(logger) self.dataset_info = dataset_info self.do_flatfield = do_flatfield self.shape = dataset_info.radio_dims[::-1] self._init_cor_finder(method, cor_options) def _init_cor_finder(self, method, cor_options): self.method = method if not isinstance(cor_options, (type(None), dict)): raise TypeError( f"cor_options is expected to be an optional instance of dict. Get {cor_options} ({type(cor_options)}) instead" ) self.cor_options = cor_options or {} cor_class = self.search_methods[method]["class"] self.cor_finder = cor_class(logger=self.logger) default_lookup_side = "right" if self.dataset_info.is_halftomo else "center" lookup_side = self.cor_options.get("side", default_lookup_side) self.cor_exec_args = [] self.cor_exec_args.extend(self.search_methods[method].get("default_args", [])) # CenterOfRotationSlidingWindow is the only class to have a mandatory argument ("side") # TODO - it would be more elegant to have it as a kwarg... if len(self.cor_exec_args) > 0 and cor_class is CenterOfRotationSlidingWindow: self.cor_exec_args[0] = lookup_side # self.cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options) class CORFinder(CORFinderBase): """ Find the Center of Rotation with methods based on two (180-degrees opposed) radios. """ search_methods = { "centered": { "class": CenterOfRotation, }, "global": { "class": CenterOfRotationAdaptiveSearch, "default_kwargs": {"low_pass": 1, "high_pass": 20}, }, "sliding-window": { "class": CenterOfRotationSlidingWindow, "default_args": ["center"], }, "growing-window": { "class": CenterOfRotationGrowingWindow, }, } def __init__(self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None): """ Initialize a CORFinder object. Parameters ---------- dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer` Dataset information structure """ super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger) self._init_radios() self._init_flatfield() self._apply_flatfield() self._apply_tilt() def _init_radios(self): self.radios, self._radios_indices = get_0_180_radios(self.dataset_info, return_indices=True) def _init_flatfield(self): if not (self.do_flatfield): return self.flatfield = FlatFieldDataUrls( self.radios.shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=self._radios_indices, interpolation="linear", convert_float=True, ) def _apply_flatfield(self): if not (self.do_flatfield): return self.flatfield.normalize_radios(self.radios) def _apply_tilt(self): tilt = self.dataset_info.detector_tilt if tilt is None: return self.logger.debug("COREstimator: applying detector tilt correction of %f degrees" % tilt) rot = Rotation(self.shape, tilt) for i in range(self.radios.shape[0]): self.radios[i] = rot.rotate(self.radios[i]) def find_cor(self): """ Find the center of rotation. Returns ------- cor: float The estimated center of rotation for the current dataset. """ self.logger.info("Estimating center of rotation") self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(self.cor_exec_kwargs))) shift = self.cor_finder.find_shift( self.radios[0], np.fliplr(self.radios[1]), *self.cor_exec_args, **self.cor_exec_kwargs ) return self.shape[1] / 2 + shift # alias COREstimator = CORFinder class SinoCORFinder(CORFinderBase): """ A class for finding Center of Rotation based on 360 degrees sinograms. This class handles the steps of building the sinogram from raw radios. """ search_methods = { "sino-coarse-to-fine": { "class": SinoCorInterface, }, "sliding-window": { "class": CenterOfRotationSlidingWindow, "default_args": ["right"], }, "growing-window": { "class": CenterOfRotationGrowingWindow, }, } def __init__( self, method, dataset_info, slice_idx="middle", subsampling=10, do_flatfield=True, cor_options=None, logger=None ): """ Initialize a SinoCORFinder object. Other parameters ---------------- The following keys can be set in cor_options. slice_idx: int or str Which slice index to take for building the sinogram. For example slice_idx=0 means that we extract the first line of each projection. Value can also be "first", "top", "middle", "last", "bottom". subsampling: int, float subsampling strategy when building sinograms. As building the complete sinogram from raw projections might be tedious, the reading is done with subsampling. A positive integer value means the subsampling step (i.e `projections[::subsampling]`). A negative integer value means we take -subsampling projections in total. A float value indicates the angular step in DEGREES. """ super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger) self._check_360() self._set_slice_idx(slice_idx) self._set_subsampling(subsampling) self._load_raw_sinogram() self._flatfield(do_flatfield) self._get_sinogram() def _check_360(self): if self.dataset_info.dataset_scanner.scan_range == 360: return if not is_fullturn_scan(self.dataset_info.rotation_angles): raise ValueError("Sinogram-based Center of Rotation estimation can only be used for 360 degrees scans") def _set_slice_idx(self, slice_idx): n_z = self.dataset_info.radio_dims[1] if isinstance(slice_idx, str): str_to_idx = {"top": 0, "first": 0, "middle": n_z // 2, "bottom": n_z - 1, "last": n_z - 1} check_supported(slice_idx, str_to_idx.keys(), "slice location") slice_idx = str_to_idx[slice_idx] self.slice_idx = slice_idx def _set_subsampling(self, subsampling): projs_idx = sorted(self.dataset_info.projections.keys()) if is_int(subsampling): if subsampling < 0: # Total number of angles n_angles = -subsampling indices_float = np.linspace(projs_idx[0], projs_idx[-1], n_angles, endpoint=True) self.projs_indices = np.round(indices_float).astype(np.int32).tolist() else: # Subsampling step self.projs_indices = projs_idx[::subsampling] else: # Angular step raise NotImplementedError() def _load_raw_sinogram(self): if self.slice_idx is None: raise ValueError("Unknow slice index") # Subsample projections files = {} for idx in self.projs_indices: files[idx] = self.dataset_info.projections[idx] self.files = files self.data_reader = ChunkReader( self.files, sub_region=(None, None, self.slice_idx, self.slice_idx + 1), convert_float=True, ) self.data_reader.load_files() self._radios = self.data_reader.files_data def _flatfield(self, do_flatfield): self.do_flatfield = bool(do_flatfield) if not self.do_flatfield: return flatfield = FlatFieldDataUrls( self._radios.shape, self.dataset_info.flats, self.dataset_info.darks, radios_indices=self.projs_indices, sub_region=(None, None, self.slice_idx, self.slice_idx + 1), ) flatfield.normalize_radios(self._radios) def _get_sinogram(self): log = Log(self._radios.shape, clip_min=1e-6, clip_max=10.0) sinogram = self._radios[:, 0, :].copy() log.take_logarithm(sinogram) self.sinogram = sinogram @staticmethod def _split_sinogram(sinogram): n_a_2 = sinogram.shape[0] // 2 img_1, img_2 = sinogram[:n_a_2], sinogram[n_a_2:] # "Handle" odd number of projections if img_2.shape[0] > img_1.shape[0]: img_2 = img_2[:-1, :] # return img_1, img_2 def find_cor(self): self.logger.info("Estimating center of rotation") self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(self.cor_exec_kwargs))) img_1, img_2 = self._split_sinogram(self.sinogram) shift = self.cor_finder.find_shift(img_1, img_2, *self.cor_exec_args, **self.cor_exec_kwargs) return self.shape[1] / 2 + shift # alias SinoCOREstimator = SinoCORFinder class CompositeCORFinder: """ Class and method to prepare sinogram and calculate COR The pseudo sinogram is built with shrinked radios taken every theta_interval degres Compared to first writing by Christian Nemoz: - gives the same result of the original octave script on the dataset sofar tested - The meaning of parameter n_subsampling_y (alias subsampling_y)is now the number of lines which are taken from every radio. This is more meaningful in terms of amout of collected information because it does not depend on the radio size. Moreover this is what was done in the octave script - The spike_threshold has been added with default to 0.04 - The angular sampling is every 5 degree by default, as it is now the case also in the octave script - The finding of the optimal overlap is doing by looping over the possible overlap, according to the overlap. After a first testing phase by Elodie we, this part, which is the time consuming part, can be accelerated by several order of magnitude without modifing the final result """ _default_cor_options = {"low_pass": 0.4, "high_pass": 10, "side": "center", "near_pos": 0, "near_width": 20} def __init__( self, dataset_info, oversampling=4, theta_interval=5, n_subsampling_y=10, take_log=True, cor_options=None, spike_threshold=0.04, logger=None, ): self.dataset_info = dataset_info self.logger = LoggerOrPrint(logger) # the algorithm can work for angular ranges larger than 1.2*pi # up to an arbitrarily number of turns as it is the case in helical scans self.spike_threshold = spike_threshold # the following line is necessary for multi-turns scan because the encoders is always # in the interval 0-360 self.unwrapped_rotation_angles = np.unwrap(self.dataset_info.rotation_angles) self.angle_min = self.unwrapped_rotation_angles.min() self.angle_max = self.unwrapped_rotation_angles.max() self.sx, self.sy = self.dataset_info.radio_dims if (self.angle_max - self.angle_min) < 1.2 * np.pi: useful_span = None raise ValueError( f"""Sinogram-based Center of Rotation estimation can only be used for scans over more than 180 degrees. Your angular span was barely above 180 degrees, it was in fact {((self.angle_max - self.angle_min)/np.pi):.2f} x 180 and it is not considered to be enough by the discriminating condition which requires at least 1.2 half-turns """ ) else: useful_span = min(np.pi, (self.angle_max - self.angle_min) - np.pi) # readapt theta_interval accordingly if the span is smaller than pi if useful_span < np.pi: theta_interval = theta_interval * useful_span / np.pi self._get_cor_options(cor_options) self.take_log = take_log self.ovs = oversampling self.theta_interval = theta_interval target_sampling_y = np.round(np.linspace(0, self.sy - 1, n_subsampling_y + 2)).astype(int)[1:-1] if self.spike_threshold is not None: # take also one line below and on above for each line # to provide appropriate margin self.sampling_y = np.zeros([3 * len(target_sampling_y)], "i") self.sampling_y[0::3] = np.maximum(0, target_sampling_y - 1) self.sampling_y[2::3] = np.minimum(self.sy - 1, target_sampling_y + 1) self.sampling_y[1::3] = target_sampling_y self.ccd_correction = CCDFilter((len(self.sampling_y), self.sx), median_clip_thresh=self.spike_threshold) else: self.sampling_y = target_sampling_y self.nproj = self.dataset_info.n_angles my_condition = np.less(self.unwrapped_rotation_angles + np.pi, self.angle_max) * np.less( self.unwrapped_rotation_angles, self.angle_min + useful_span ) possibly_probed_angles = self.unwrapped_rotation_angles[my_condition] possibly_probed_indices = np.arange(len(self.unwrapped_rotation_angles))[my_condition] self.dproj = round(len(possibly_probed_angles) / np.rad2deg(useful_span) * self.theta_interval) self.probed_angles = possibly_probed_angles[:: self.dproj] self.probed_indices = possibly_probed_indices[:: self.dproj] self.absolute_indices = sorted(self.dataset_info.projections.keys()) my_flats = load_images_from_dataurl_dict(self.dataset_info.flats) if my_flats is not None and len(list(my_flats.keys())): self.use_flat = True self.flatfield = FlatFieldDataUrls( (len(self.absolute_indices), self.sy, self.sx), self.dataset_info.flats, self.dataset_info.darks, radios_indices=self.absolute_indices, dtype=np.float64, ) else: self.use_flat = False self.sx, self.sy = self.dataset_info.radio_dims self.mlog = Log((1,) + (self.sy, self.sx), clip_min=1e-6, clip_max=10.0) self.rcor_abs = round(self.sx / 2.0) self.cor_acc = round(self.sx / 2.0) self.nprobed = len(self.probed_angles) # initialize sinograms and radios arrays self.sino = np.zeros([2 * self.nprobed * n_subsampling_y, (self.sx - 1) * self.ovs + 1], "f") self._loaded = False self.high_pass = self.cor_options["high_pass"] img_filter = fourier_filters.get_bandpass_filter( (self.sino.shape[0] // 2, self.sino.shape[1]), cutoff_lowpass=self.cor_options["low_pass"] * self.ovs, cutoff_highpass=self.high_pass * self.ovs, use_rfft=False, # rfft changes the image dimensions lenghts to even if odd data_type=np.float64, ) # we are interested in filtering only along the x dimension only img_filter[:] = img_filter[0] self.img_filter = img_filter def _oversample(self, radio): """oversampling in the horizontal direction""" if self.ovs == 1: return radio result = np.zeros([radio.shape[0], (radio.shape[1] - 1) * self.ovs + 1], "f") result[:, :: self.ovs] = radio for i in range(1, self.ovs): f = i / self.ovs result[:, i :: self.ovs] = (1 - f) * result[:, 0 : -self.ovs : self.ovs] + f * result[ :, self.ovs :: self.ovs ] return result def _get_cor_options(self, cor_options): default_dict = self._default_cor_options.copy() if self.dataset_info.is_halftomo: default_dict["side"] = "right" if cor_options is None or cor_options == "": cor_options = {} if isinstance(cor_options, str): try: cor_options = extract_parameters(cor_options, sep=";") except Exception as exc: msg = "Could not extract parameters from cor_options: %s" % (str(exc)) self.logger.fatal(msg) raise ValueError(msg) default_dict.update(cor_options) cor_options = default_dict self.cor_options = cor_options def get_radio(self, image_num): # radio_dataset_idx = self.absolute_indices[image_num] radio_dataset_idx = image_num data_url = self.dataset_info.projections[radio_dataset_idx] radio = get_data(data_url).astype(np.float64) if self.use_flat: self.flatfield.normalize_single_radio(radio, radio_dataset_idx, dtype=radio.dtype) if self.take_log: self.mlog.take_logarithm(radio) radio = radio[self.sampling_y] if self.spike_threshold is not None: self.ccd_correction.median_clip_correction(radio, output=radio) radio = radio[1::3] return radio def get_sino(self, reload=False): """ Build sinogram (composite image) from the radio files """ if self._loaded and not reload: return self.sino sorting_indexes = np.argsort(self.unwrapped_rotation_angles) sorted_all_angles = self.unwrapped_rotation_angles[sorting_indexes] sorted_angle_indexes = np.arange(len(self.unwrapped_rotation_angles))[sorting_indexes] irad = 0 for prob_a, prob_i in zip(self.probed_angles, self.probed_indices): radio1 = self.get_radio(self.absolute_indices[prob_i]) other_angle = prob_a + np.pi insertion_point = np.searchsorted(sorted_all_angles, other_angle) if insertion_point > 0 and insertion_point < len(sorted_all_angles): other_i_l = sorted_angle_indexes[insertion_point - 1] other_i_h = sorted_angle_indexes[insertion_point] radio_l = self.get_radio(self.absolute_indices[other_i_l]) radio_h = self.get_radio(self.absolute_indices[other_i_h]) f = (other_angle - sorted_all_angles[insertion_point - 1]) / ( sorted_all_angles[insertion_point] - sorted_all_angles[insertion_point - 1] ) radio2 = (1 - f) * radio_l + f * radio_h else: if insertion_point == 0: other_i = sorted_angle_indexes[0] elif insertion_point == len(sorted_all_angles): other_i = sorted_angle_indexes[insertion_point - 1] radio2 = self.get_radio(self.absolute_indices[other_i]) self.sino[irad : irad + radio1.shape[0], :] = self._oversample(radio1) self.sino[ irad + self.nprobed * radio1.shape[0] : irad + self.nprobed * radio1.shape[0] + radio1.shape[0], : ] = self._oversample(radio2) irad = irad + radio1.shape[0] self.sino[np.isnan(self.sino)] = 0.0001 # ? return self.sino def find_cor(self, reload=False): self.logger.info("Estimating center of rotation") self.logger.debug("%s.find_shift(%s)" % (self.__class__.__name__, self.cor_options)) self.sinogram = self.get_sino(reload=reload) dim_v, dim_h = self.sinogram.shape assert dim_v % 2 == 0, " this should not happen " dim_v = dim_v // 2 radio1 = self.sinogram[:dim_v] radio2 = self.sinogram[dim_v:] orig_sy, orig_ovsd_sx = radio1.shape radio1 = scipy.fft.ifftn( scipy.fft.fftn(radio1, axes=(-2, -1)) * self.img_filter, axes=(-2, -1) ).real # TODO: convolute only along x radio2 = scipy.fft.ifftn( scipy.fft.fftn(radio2, axes=(-2, -1)) * self.img_filter, axes=(-2, -1) ).real # TODO: convolute only along x tmp_sy, ovsd_sx = radio1.shape assert orig_sy == tmp_sy and orig_ovsd_sx == ovsd_sx, "this should not happen" if self.cor_options["side"] == "center": overlap_min = max(round(ovsd_sx - ovsd_sx / 3), 4) overlap_max = min(round(ovsd_sx + ovsd_sx / 3), 2 * ovsd_sx - 4) elif self.cor_options["side"] == "right": overlap_min = max(4, self.ovs * self.high_pass * 3) overlap_max = ovsd_sx elif self.cor_options["side"] == "left": overlap_min = ovsd_sx overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3) elif self.cor_options["side"] == "all": overlap_min = max(4, self.ovs * self.high_pass * 3) overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3) elif self.cor_options["side"] == "near": near_pos = self.cor_options["near_pos"] near_width = self.cor_options["near_width"] overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width)) overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width)) else: message = f""" The cor options "side" can only have one of the three possible values ["","",""]. But it has the value "{self.cor_options["side"]}" instead """ raise ValueError(message) if overlap_min > overlap_max: message = f""" There is no safe search range in find_cor once the margins corresponding to the high_pass filter are discarded. Try reducing the low_pass parameter in cor_options """ raise ValueError(message) self.logger.info( "looking for overlap from min %.2f and max %.2f\n" % (overlap_min / self.ovs, overlap_max / self.ovs) ) best_overlap = overlap_min best_value = np.inf for z in range(int(overlap_min), int(overlap_max) + 1): if z <= ovsd_sx: my_z = z my_radio1 = radio1 my_radio2 = radio2 else: my_z = ovsd_sx - (z - ovsd_sx) my_radio1 = np.fliplr(radio1) my_radio2 = np.fliplr(radio2) common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : -(self.ovs * self.high_pass * 2)] common_right = my_radio2[:, ovsd_sx - my_z : -(self.ovs * self.high_pass * 2)] common = common_right - common_left value = np.linalg.norm(common) norm_diff2 = value * value if common_right.size == 0: continue norm_right = np.linalg.norm(common_right) norm_left = np.linalg.norm(common_left) value = norm_diff2 / (norm_right * norm_left) min_value = min(best_value, value) if min_value == value: best_overlap = z best_value = min_value self.logger.debug( "testing an overlap of %.2f pixels, actual best overlap is %.2f pixels over %d\r" % (z / self.ovs, best_overlap / self.ovs, ovsd_sx / self.ovs), ) offset = (ovsd_sx - best_overlap) / self.ovs / 2 cor_abs = (self.sx - 1) / 2 + offset return cor_abs # alias CompositeCOREstimator = CompositeCORFinder # Some heavily inelegant things going on here def get_default_kwargs(func): params = inspect.signature(func).parameters res = {} for param_name, param in params.items(): if param.default != inspect._empty: res[param_name] = param.default return res def update_func_kwargs(func, options): res_options = get_default_kwargs(func) for option_name, option_val in options.items(): if option_name in res_options: res_options[option_name] = option_val return res_options def get_class_name(class_object): return str(class_object).split(".")[-1].strip(">").strip("'").strip('"') class DetectorTiltEstimator: """ Helper class for detector tilt estimation. It automatically chooses the right radios and performs flat-field. """ default_tilt_method = "1d-correlation" # Given a tilt angle "a", the maximum deviation caused by the tilt (in pixels) is # N/2 * |sin(a)| where N is the number of pixels # We ignore tilts causing less than 0.25 pixel deviation: N/2*|sin(a)| < tilt_threshold tilt_threshold = 0.25 def __init__(self, dataset_info, do_flatfield=True, logger=None, autotilt_options=None): """ Initialize a detector tilt estimator helper. Parameters ---------- dataset_info: `dataset_info` object Data structure with the dataset information. do_flatfield: bool, optional Whether to perform flat field on radios. logger: `Logger` object, optional Logger object autotilt_options: dict, optional named arguments to pass to the detector tilt estimator class. """ self._set_params(dataset_info, do_flatfield, logger, autotilt_options) self.radios, self.radios_indices = get_0_180_radios(dataset_info, return_indices=True) self._init_flatfield() self._apply_flatfield() def _set_params(self, dataset_info, do_flatfield, logger, autotilt_options): self.dataset_info = dataset_info self.do_flatfield = bool(do_flatfield) self.logger = LoggerOrPrint(logger) self._get_autotilt_options(autotilt_options) def _init_flatfield(self): if not (self.do_flatfield): return self.flatfield = FlatFieldDataUrls( self.radios.shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=self.radios_indices, interpolation="linear", convert_float=True, ) def _apply_flatfield(self): if not (self.do_flatfield): return self.flatfield.normalize_radios(self.radios) def _get_autotilt_options(self, autotilt_options): if autotilt_options is None: self.autotilt_options = None return try: autotilt_options = extract_parameters(autotilt_options) except Exception as exc: msg = "Could not extract parameters from autotilt_options: %s" % (str(exc)) self.logger.fatal(msg) raise ValueError(msg) self.autotilt_options = autotilt_options if "threshold" in autotilt_options: self.tilt_threshold = autotilt_options.pop("threshold") def find_tilt(self, tilt_method=None): """ Find the detector tilt. Parameters ---------- tilt_method: str, optional Which tilt estimation method to use. """ if tilt_method is None: tilt_method = self.default_tilt_method check_supported(tilt_method, set(tilt_methods.values()), "tilt estimation method") self.logger.info("Estimating detector tilt angle") autotilt_params = { "roi_yxhw": None, "median_filt_shape": None, "padding_mode": None, "peak_fit_radius": 1, "high_pass": None, "low_pass": None, } autotilt_params.update(self.autotilt_options or {}) self.logger.debug("%s(%s)" % ("CameraTilt", str(autotilt_params))) tilt_calc = CameraTilt() tilt_cor_position, camera_tilt = tilt_calc.compute_angle( self.radios[0], np.fliplr(self.radios[1]), method=tilt_method, **autotilt_params ) self.logger.info("Estimated detector tilt angle: %f degrees" % camera_tilt) # Ignore too small tilts max_deviation = np.max(self.dataset_info.radio_dims) * np.abs(np.sin(np.deg2rad(camera_tilt))) if self.dataset_info.is_halftomo: max_deviation *= 2 if max_deviation < self.tilt_threshold: self.logger.info( "Estimated tilt angle (%.3f degrees) results in %.2f maximum pixels shift, which is below threshold (%.2f pixel). Ignoring the tilt, no correction will be done." % (camera_tilt, max_deviation, self.tilt_threshold) ) camera_tilt = None return camera_tilt # alias TiltFinder = DetectorTiltEstimator ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/fallback_utils.py0000644000175000017500000001370600000000000020316 0ustar00pierrepierre""" This module is meant to contain classes which are in the process of being superseed by new classes depending on recent packages with fast development cycles in order to be able to fall-back in two cases : -- the new packages, or one of their parts, break from one version to another. -- For parts of Nabu which need some extra time to adapt. """ from ..resources.logger import LoggerOrPrint from ..utils import check_supported from ..io.writer import Writers, LegacyNXProcessWriter from ..resources.utils import is_hdf5_extension from os import path, mkdir from .params import files_formats class WriterConfigurator: """No dependency on tomoscan for this class. The new class would be WriterManager which depend on tomoscan.""" _overwrite_warned = False def __init__( self, output_dir, file_prefix, file_format="hdf5", overwrite=False, start_index=None, logger=None, nx_info=None, write_histogram=False, histogram_entry="entry", writer_options=None, extra_options=None, ): """ Create a Writer from a set of parameters. Parameters ---------- output_dir: str Directory where the file(s) will be written. file_prefix: str File prefix (without leading path) start_index: int, optional Index to start the files numbering (filename_0123.ext). Default is 0. Ignored for HDF5 extension. logger: nabu.resources.logger.Logger, optional Logger object nx_info: dict, optional Dictionary containing the nexus information. write_histogram: bool, optional Whether to also write a histogram of data. If set to True, it will configure an additional "writer". histogram_entry: str, optional Name of the HDF5 entry for the output histogram file, if write_histogram is True. Ignored if the output format is already HDF5 : in this case, nx_info["entry"] is taken. writer_options: dict, optional Other advanced options to pass to Writer class. """ self.logger = LoggerOrPrint(logger) self.start_index = start_index self.write_histogram = write_histogram self.overwrite = overwrite writer_options = writer_options or {} self.extra_options = extra_options or {} check_supported(file_format, list(Writers.keys()), "output file format") self._set_output_dir(output_dir) self._set_file_name(file_prefix, file_format) # Init Writer writer_cls = Writers[file_format] writer_args = [self.fname] writer_kwargs = self._get_initial_writer_kwarg() self._writer_exec_args = [] self._writer_exec_kwargs = {} self._is_hdf5_output = is_hdf5_extension(file_format) if self._is_hdf5_output: writer_kwargs["entry"] = nx_info["entry"] writer_kwargs["filemode"] = "a" writer_kwargs["overwrite"] = overwrite self._writer_exec_args.append(nx_info["process_name"]) self._writer_exec_kwargs["processing_index"] = nx_info["processing_index"] self._writer_exec_kwargs["config"] = nx_info["config"] else: writer_kwargs["start_index"] = self.start_index if writer_options.get("tiff_single_file", False) and "tif" in file_format: do_append = writer_options.get("single_tiff_initialized", False) writer_kwargs.update({"multiframe": True, "append": do_append}) if files_formats.get(file_format, None) == "jp2": cratios = self.extra_options.get("jpeg2000_compression_ratio", None) if cratios is not None: cratios = [cratios] writer_kwargs["cratios"] = cratios writer_kwargs["float_clip_values"] = self.extra_options.get("float_clip_values", None) self.writer = writer_cls(*writer_args, **writer_kwargs) if self.write_histogram and not (self._is_hdf5_output): self._init_separate_histogram_writer(histogram_entry) def _get_initial_writer_kwarg(self): return {} def _set_output_dir(self, output_dir): self.output_dir = output_dir if path.exists(self.output_dir): if not (path.isdir(self.output_dir)): raise ValueError( "Unable to create directory %s: already exists and is not a directory" % self.output_dir ) else: self.logger.debug("Creating directory %s" % self.output_dir) mkdir(self.output_dir) def _set_file_name(self, file_prefix, file_format): self.file_prefix = file_prefix self.file_format = file_format self.fname = path.join(self.output_dir, file_prefix + "." + file_format) if path.exists(self.fname): err = "File already exists: %s" % self.fname if self.overwrite: if not (WriterConfigurator._overwrite_warned): self.logger.warning(err + ". It will be overwritten as requested in configuration") WriterConfigurator._overwrite_warned = True else: self.logger.fatal(err) raise ValueError(err) def _init_separate_histogram_writer(self, hist_entry): hist_fname = path.join(self.output_dir, "histogram_%06d.hdf5" % self.start_index) self.histogram_writer = LegacyNXProcessWriter( hist_fname, entry=hist_entry, filemode="w", overwrite=True, ) def get_histogram_writer(self): if not (self.write_histogram): return None if self._is_hdf5_output: return self.writer else: return self.histogram_writer def write_data(self, data): self.writer.write(data, *self._writer_exec_args, **self._writer_exec_kwargs) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4647331 nabu-2023.1.1/nabu/pipeline/fullfield/0000755000175000017500000000000000000000000016724 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1628752049.0 nabu-2023.1.1/nabu/pipeline/fullfield/__init__.py0000644000175000017500000000000000000000000021023 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/pipeline/fullfield/chunked.py0000644000175000017500000010631700000000000020727 0ustar00pierrepierrefrom os import path from time import time from math import ceil import numpy as np from silx.io.url import DataUrl from ...utils import remove_items_from_list from ...resources.logger import LoggerOrPrint from ...resources.utils import extract_parameters from ...io.reader import ChunkReader, HDF5Loader from ...preproc.ccd import Log, CCDFilter from ...preproc.flatfield import FlatFieldDataUrls from ...preproc.distortion import DistortionCorrection from ...preproc.shift import VerticalShift from ...preproc.double_flatfield import DoubleFlatField from ...preproc.phase import PaganinPhaseRetrieval from ...preproc.ctf import CTFPhaseRetrieval, GeoPars from ...reconstruction.sinogram import SinoBuilder, SinoNormalization from ...misc.rotation import Rotation from ...reconstruction.rings import MunchDeringer from ...misc.unsharp import UnsharpMask from ...misc.histogram import PartialHistogram, hist_as_2Darray from ..utils import use_options, pipeline_step, get_subregion from ..datadump import DataDumpManager from ..writer import WriterManager from ..detector_distortion_provider import DetectorDistortionProvider # For now we don't have a plain python/numpy backend for reconstruction try: from ...reconstruction.fbp_opencl import Backprojector except: Backprojector = None class ChunkedPipeline: """ Pipeline for "regular" full-field tomography. Data is processed by chunks. A chunk consists in K contiguous lines of all the radios. In parallel geometry, a chunk of K radios lines gives K sinograms, and equivalently K reconstructed slices. """ backend = "numpy" FlatFieldClass = FlatFieldDataUrls DoubleFlatFieldClass = DoubleFlatField CCDCorrectionClass = CCDFilter PaganinPhaseRetrievalClass = PaganinPhaseRetrieval CTFPhaseRetrievalClass = CTFPhaseRetrieval UnsharpMaskClass = UnsharpMask ImageRotationClass = Rotation VerticalShiftClass = VerticalShift SinoDeringerClass = MunchDeringer MLogClass = Log SinoBuilderClass = SinoBuilder SinoNormalizationClass = SinoNormalization FBPClass = Backprojector ConebeamClass = None # unsupported on CPU HistogramClass = PartialHistogram _default_extra_options = {} # These steps are skipped if the reconstruction is done in two stages. # The first stage will skip these steps, and the second stage will do these stages after merging sinograms. _reconstruction_steps = ["sino_rings_correction", "reconstruction", "save", "histogram"] def __init__( self, process_config, chunk_shape, margin=None, logger=None, use_grouped_mode=False, extra_options=None ): """ Initialize a "Chunked" pipeline. Parameters ---------- processing_config: `ProcessConfig` Process configuration. chunk_shape: tuple Shape of the chunk of data to process, in the form (n_angles, n_z, n_x). It has to account for possible cropping of the data, eg. [:, start_z:end_z, start_x:end_x] where start_xz and/or end_xz can be other than None. margin: tuple, optional Margin to use, in the form ((up, down), (left, right)). It is used for example when performing phase retrieval or a convolution-like operation: some extra data is kept to avoid boundaries issues. These boundaries are then discarded: the data volume is eventually cropped as `data[U:D, L:R]` where `((U, D), (L, R)) = margin` If not provided, no margin is applied. logger: `nabu.app.logger.Logger`, optional Logger class extra_options: dict, optional Advanced extra options. Notes ------ Using `margin` results in a lesser number of reconstructed slices. More specifically, if `margin = (V, H)`, then there will be `delta_z - 2*V` reconstructed slices (if the sub-region is in the middle of the volume) or `delta_z - V` reconstructed slices (if the sub-region is on top or bottom of the volume). """ self.logger = LoggerOrPrint(logger) self._set_params(process_config, chunk_shape, extra_options, margin, use_grouped_mode) self._init_pipeline() def _set_params(self, process_config, chunk_shape, extra_options, margin, use_grouped_mode): self.process_config = process_config self.dataset_info = self.process_config.dataset_info self.processing_steps = self.process_config.processing_steps.copy() self.processing_options = self.process_config.processing_options self._set_chunk_shape(chunk_shape, use_grouped_mode) self.set_subregion(None) self._set_margin(margin) self._set_extra_options(extra_options) self._callbacks = {} self._steps_name2component = {} self._steps_component2name = {} def _set_chunk_shape(self, chunk_shape, use_grouped_mode): if len(chunk_shape) != 3: raise ValueError("Expected chunk_shape to be a tuple of length 3 in the form (n_z, n_y, n_x)") self.chunk_shape = tuple(int(c) for c in chunk_shape) # cast to int, as numpy.int64 can make pycuda crash # TODO: sanity check (eg. compare to size of radios in dataset_info) ? # (n_a, n_z, n_x) self.radios_shape = ( ceil(self.chunk_shape[0] / self.process_config.subsampling_factor), self.chunk_shape[1] // self.process_config.binning[1], self.chunk_shape[2] // self.process_config.binning[0], ) self.n_angles = self.radios_shape[0] self.n_slices = self.radios_shape[1] self._grouped_processing = False if use_grouped_mode or self.chunk_shape[0] < len(self.process_config.rotation_angles(subsampling=False)): # TODO allow a certain tolerance in this case ? # Reconstruction is still possible (albeit less accurate) if delta is small self._grouped_processing = True self.logger.debug("Only a subset of angles is processed - Reconstruction will be skipped") self.processing_steps, _ = remove_items_from_list(self.processing_steps, self._reconstruction_steps) def _set_margin(self, margin): if margin is None: U, D, L, R = None, None, None, None else: ((U, D), (L, R)) = get_subregion(margin, ndim=2) # Replace "None" with zeros U, D, L, R = U or 0, D or 0, L or 0, R or 0 self.margin = ((U, D), (L, R)) self._margin_up = U self._margin_down = D self._margin_left = L self._margin_right = R self.use_margin = (U + D + L + R) > 0 self.n_recs = self.chunk_shape[1] - sum(self.margin[0]) self.radios_cropped_shape = (self.radios_shape[0], self.radios_shape[1] - U - D, self.radios_shape[2] - L - R) if self.use_margin: self.n_slices -= sum(self.margin[0]) def set_subregion(self, sub_region): """ Set the data volume sub-region to process. Note that processing margin, if any, is contained within the sub-region. Parameters ----------- sub_region: tuple Data volume sub-region, in the form ((start_a, end_a), (start_z, end_z), (start_x, end_x)) where the data volume has a layout (angles, Z, X) """ n_angles = self.dataset_info.n_angles n_x, n_z = self.dataset_info.radio_dims c_a, c_z, c_x = self.chunk_shape if sub_region is None: # By default, take the sub-region around central slice sub_region = ( (0, c_a), (n_z // 2 - c_z // 2, n_z // 2 - c_z // 2 + c_z), (n_x // 2 - c_x // 2, n_x // 2 - c_x // 2 + c_x), ) else: sub_region = get_subregion(sub_region, ndim=3) # check sub-region for i, start_end in enumerate(sub_region): start, end = start_end if start is not None and end is not None: if end - start != self.chunk_shape[i]: raise ValueError( "Invalid (start, end)=(%d, %d) for sub-region (dimension %d): chunk shape is %s, but %d-%d=%d != %d" % (start, end, i, str(self.chunk_shape), end, start, end - start, self.chunk_shape[i]) ) # self.logger.debug("Set sub-region to %s" % (str(sub_region))) self.sub_region = sub_region self._sub_region_xz = sub_region[2] + sub_region[1] def _set_extra_options(self, extra_options): self.extra_options = self._default_extra_options.copy() self.extra_options.update(extra_options or {}) # # Callbacks # def register_callback(self, step_name, callback): """ Register a callback for a pipeline processing step. Parameters ---------- step_name: str processing step name callback: callable A function. It will be executed once the processing step `step_name` is finished. The function takes only one argument: the class instance. """ if step_name not in self.processing_steps: raise ValueError("'%s' is not in processing steps %s" % (step_name, self.processing_steps)) if step_name in self._callbacks: self._callbacks[step_name].append(callback) else: self._callbacks[step_name] = [callback] # # Memory management # def _allocate_array(self, shape, dtype, name=None): return np.zeros(shape, dtype=dtype) def _allocate_recs(self, ny, nx, n_slices=None): n_slices = n_slices or self.n_slices self.recs = self._allocate_array((n_slices, ny, nx), "f", name="recs") # # Runtime attributes # @property def sub_region_xz(self): """ Return the currently processed sub-region in the form (start_x, end_x, start_z, end_z) """ return self._sub_region_xz @property def z_min(self): return self._sub_region_xz[2] @property def sino_shape(self): return self.process_config.sino_shape(binning=True, subsampling=True, after_ha=True) @property def sinos_shape(self): return (self.n_slices,) + self.sino_shape def get_slice_start_index(self): return self.z_min + self._margin_up # # Pipeline initialization # def _init_pipeline(self): self._allocate_radios() self._init_data_dump() self._init_reader() self._init_flatfield() self._init_double_flatfield() self._init_ccd_corrections() self._init_radios_rotation() self._init_phase() self._init_unsharp() self._init_radios_movements() self._init_mlog() self._init_sino_normalization() self._init_sino_rings_correction() self._init_reconstruction() self._init_histogram() self._init_writer() def _allocate_radios(self): self.radios = np.zeros(self.radios_shape, dtype=np.float32) self.data = self.radios # alias def _init_data_dump(self): self._resume_from_step = self.processing_options["read_chunk"].get("step_name", None) self.datadump_manager = DataDumpManager( self.process_config, self.sub_region, margin=self.margin, logger=self.logger ) # When using "grouped processing", sinogram has to be dumped. # If it was not specified by user, force sinogram dump # Perhaps these lines should be moved directly to DataDumpManager. if self._grouped_processing and not self.process_config.dump_sinogram: sino_dump_fname = self.process_config.get_save_steps_file("sinogram") self.datadump_manager._configure_dump("sinogram", force_dump_to_fname=sino_dump_fname) self.logger.debug("Will dump sinogram to %s" % self.datadump_manager.data_dump["sinogram"].fname) @use_options("read_chunk", "chunk_reader") def _init_reader(self): options = self.processing_options["read_chunk"] self._update_reader_configuration() process_file = options.get("process_file", None) if process_file is None: # Standard case - start pipeline from raw data if self.process_config.nabu_config["preproc"]["detector_distortion_correction"] is None: self.detector_corrector = None else: self.detector_corrector = DetectorDistortionProvider( detector_full_shape_vh=self.process_config.dataset_info.radio_dims[::-1], correction_type=self.process_config.nabu_config["preproc"]["detector_distortion_correction"], options=self.process_config.nabu_config["preproc"]["detector_distortion_correction_options"], ) # ChunkReader always take a non-subsampled dictionary "files". self.chunk_reader = ChunkReader( self._read_options["files"], sub_region=self.sub_region_xz, data_buffer=self.radios, pre_allocate=False, detector_corrector=self.detector_corrector, convert_float=True, binning=options["binning"], dataset_subsampling=options["dataset_subsampling"], ) else: # Resume pipeline from dumped intermediate step self.chunk_reader = HDF5Loader( process_file, options["process_h5_path"], sub_region=self.datadump_manager.get_read_dump_subregion(), data_buffer=self.radios, pre_allocate=False, ) self._resume_from_step = options["step_name"] self.logger.debug( "Load subregion %s from file %s" % (str(self.chunk_reader.sub_region), self.chunk_reader.fname) ) def _update_reader_configuration(self): """ Modify self.processing_options["read_chunk"] to select a subset of the files, if needed (i.e when processing only a subset of the images stack) """ self._read_options = self.processing_options["read_chunk"].copy() if self.n_angles == self.process_config.n_angles(subsampling=True): # Nothing to do if the full angular range is processed in one shot return if self._resume_from_step is not None: if self._resume_from_step == "sinogram": msg = "It makes no sense to use 'grouped processing' when resuming from sinogram" self.logger.fatal(msg) raise ValueError(msg) # Nothing to do if we resume the processing from a given step return input_data_files = {} files_indices = sorted(self._read_options["files"].keys()) angle_idx_start, angle_idx_end = self.sub_region[0] for i in range(angle_idx_start, angle_idx_end): idx = files_indices[i] input_data_files[idx] = self._read_options["files"][idx] self._read_options["files"] = input_data_files @use_options("flatfield", "flatfield") def _init_flatfield(self): self._ff_options = self.processing_options["flatfield"].copy() # Use chunk_reader.files instead of process_config.projs_indices(subsampling=True), because # chunk_reader might read only a subset of the files (in "grouped mode") self._ff_options["projs_indices"] = list(self.chunk_reader.files_subsampled.keys()) if self._ff_options.get("normalize_srcurrent", False): a_start_idx, a_end_idx = self.sub_region[0] subs = self.process_config.subsampling_factor self._ff_options["radios_srcurrent"] = self._ff_options["radios_srcurrent"][a_start_idx:a_end_idx:subs] distortion_correction = None if self._ff_options["do_flat_distortion"]: self.logger.info("Flats distortion correction will be applied") self.FlatFieldClass = FlatFieldDataUrls # no GPU implementation available, force this backend estimation_kwargs = {} estimation_kwargs.update(self._ff_options["flat_distortion_params"]) estimation_kwargs["logger"] = self.logger distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn" ) # FlatField parameter "radios_indices" must account for subsampling self.flatfield = self.FlatFieldClass( self.radios_shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=self._ff_options["projs_indices"], interpolation="linear", distortion_correction=distortion_correction, sub_region=self.sub_region_xz, detector_corrector=self.detector_corrector, binning=self._ff_options["binning"], radios_srcurrent=self._ff_options["radios_srcurrent"], flats_srcurrent=self._ff_options["flats_srcurrent"], convert_float=True, ) @use_options("double_flatfield", "double_flatfield") def _init_double_flatfield(self): options = self.processing_options["double_flatfield"] avg_is_on_log = options["sigma"] is not None result_url = None if options["processes_file"] not in (None, ""): result_url = DataUrl( file_path=options["processes_file"], data_path=(self.dataset_info.hdf5_entry or "entry") + "/double_flatfield/results/data", ) self.logger.info("Loading double flatfield from %s" % result_url.file_path()) if (self.n_angles < self.process_config.n_angles(subsampling=True)) and result_url is None: raise ValueError( "Cannot use double-flatfield when processing subset of radios. Please use the 'nabu-double-flatfield' command" ) self.double_flatfield = self.DoubleFlatFieldClass( self.radios_shape, result_url=result_url, sub_region=self.sub_region_xz, detector_corrector=self.detector_corrector, input_is_mlog=False, output_is_mlog=False, average_is_on_log=avg_is_on_log, sigma_filter=options["sigma"], ) @use_options("ccd_correction", "ccd_correction") def _init_ccd_corrections(self): options = self.processing_options["ccd_correction"] self.ccd_correction = self.CCDCorrectionClass( self.radios_shape[1:], median_clip_thresh=options["median_clip_thresh"] ) @use_options("rotate_projections", "projs_rot") def _init_radios_rotation(self): options = self.processing_options["rotate_projections"] center = options["center"] if center is None: nx, ny = self.radios_shape[1:][::-1] # after binning center = (nx / 2 - 0.5, ny / 2 - 0.5) center = (center[0], center[1] - self.z_min) self.projs_rot = self.ImageRotationClass( self.radios_shape[1:], options["angle"], center=center, mode="edge", reshape=False ) self._tmp_rotated_radio = self._allocate_array(self.radios_shape[1:], "f", name="tmp_rotated_radio") @use_options("radios_movements", "radios_movements") def _init_radios_movements(self): options = self.processing_options["radios_movements"] self._vertical_shifts = options["translation_movements"][:, 1] self.radios_movements = self.VerticalShiftClass(self.radios.shape, self._vertical_shifts) @use_options("phase", "phase_retrieval") def _init_phase(self): options = self.processing_options["phase"] if options["method"] == "CTF": translations_vh = getattr(self.dataset_info, "ctf_translations", None) geo_pars_params = options["ctf_geo_pars"].copy() geo_pars_params["logger"] = self.logger geo_pars = GeoPars(**geo_pars_params) self.phase_retrieval = self.CTFPhaseRetrievalClass( self.radios_shape[1:], geo_pars, options["delta_beta"], lim1=options["ctf_lim1"], lim2=options["ctf_lim2"], logger=self.logger, fftw_num_threads=None, # TODO tune in advanced params of nabu config file use_rfft=True, normalize_by_mean=options["ctf_normalize_by_mean"], translation_vh=translations_vh, ) else: self.phase_retrieval = self.PaganinPhaseRetrievalClass( self.radios_shape[1:], distance=options["distance_m"], energy=options["energy_kev"], delta_beta=options["delta_beta"], pixel_size=options["pixel_size_m"], padding=options["padding_type"], # TODO tune in advanced params of nabu config file fftw_num_threads=None, ) if self.phase_retrieval.use_fftw: self.logger.debug( "%s using FFTW with %d threads" % (self.phase_retrieval.__class__.__name__, self.phase_retrieval.fftw.num_threads) ) @use_options("unsharp_mask", "unsharp_mask") def _init_unsharp(self): options = self.processing_options["unsharp_mask"] self.unsharp_mask = self.UnsharpMaskClass( self.radios_shape[1:], options["unsharp_sigma"], options["unsharp_coeff"], mode="reflect", method=options["unsharp_method"], ) @use_options("take_log", "mlog") def _init_mlog(self): options = self.processing_options["take_log"] self.mlog = self.MLogClass( self.radios_shape, clip_min=options["log_min_clip"], clip_max=options["log_max_clip"] ) @use_options("sino_normalization", "sino_normalization") def _init_sino_normalization(self): options = self.processing_options["sino_normalization"] self.sino_normalization = self.SinoNormalizationClass( kind=options["method"], radios_shape=self.radios_cropped_shape, normalization_array=options["normalization_array"], ) @use_options("sino_rings_correction", "sino_deringer") def _init_sino_rings_correction(self): options = self.processing_options["sino_rings_correction"] fw_params = extract_parameters(options["user_options"]) fw_sigma = fw_params.pop("sigma", 1.0) self.sino_deringer = self.SinoDeringerClass(fw_sigma, sinos_shape=self.sinos_shape, **fw_params) @use_options("reconstruction", "reconstruction") def _init_reconstruction(self): options = self.processing_options["reconstruction"] if options["method"] == "FBP" and self.FBPClass is None: raise ValueError("No usable FBP module was found") if options["method"] == "cone" and self.ConebeamClass is None: raise ValueError("No usable cone-beam module was found") if options["method"] == "FBP": n_slices = self.n_slices radios_shape_for_sino_builder = self.radios_cropped_shape self.reconstruction = self.FBPClass( self.sinos_shape[1:], angles=options["angles"], rot_center=options["fbp_rotation_axis_position"], filter_name=options["fbp_filter_type"], slice_roi=self.process_config.rec_roi, padding_mode=options["padding_type"], extra_options={ "scale_factor": 1.0 / options["pixel_size_cm"], "axis_correction": options["axis_correction"], "centered_axis": options["centered_axis"], "clip_outer_circle": options["clip_outer_circle"], "filter_cutoff": options["fbp_filter_cutoff"], }, ) if options["fbp_filter_type"] is None: self.reconstruction.fbp = self.reconstruction.backproj if options["method"] == "cone": radios_shape_for_sino_builder = self.radios_shape n_slices = self.n_slices + sum(self.margin[0]) # For numerical stability, normalize all lengths with respect to detector pixel size pixel_size_m = self.dataset_info.pixel_size * 1e-6 source_sample_dist = options["source_sample_dist"] / pixel_size_m sample_detector_dist = options["sample_detector_dist"] / pixel_size_m self.reconstruction = self.ConebeamClass( # pylint: disable=E1102 (self.radios_shape[1],) + self.sino_shape, source_sample_dist, sample_detector_dist, angles=options["angles"], # TODO one center for each angle to handle "x translations" rot_center=options["rotation_axis_position"], pixel_size=1, ) self.sino_builder = self.SinoBuilderClass( radios_shape=radios_shape_for_sino_builder, rot_center=options["rotation_axis_position"], halftomo=options["enable_halftomo"], angles=options["angles"], ) self._allocate_recs(*self.process_config.rec_shape, n_slices=n_slices) if options["method"] == "cone": self.sinos = self._allocate_array(self.sino_builder.output_shape, "f", name="sinos") @use_options("histogram", "histogram") def _init_histogram(self): options = self.processing_options["histogram"] self.histogram = self.HistogramClass(method="fixed_bins_number", num_bins=options["histogram_bins"]) @use_options("save", "writer") def _init_writer(self): options = self.processing_options["save"] metadata = { "process_name": "reconstruction", "processing_index": 0, # TODO this one takes too much time to write, not useful for partial files # "processing_options": self.processing_options, # "nabu_config": self.process_config.nabu_config, "entry": getattr(self.dataset_info.dataset_scanner, "entry", "entry"), } writer_extra_options = { "jpeg2000_compression_ratio": options["jpeg2000_compression_ratio"], "float_clip_values": options["float_clip_values"], "tiff_single_file": options.get("tiff_single_file", False), "single_output_file_initialized": getattr(self.process_config, "single_output_file_initialized", False), } self.writer = WriterManager( options["location"], options["file_prefix"], file_format=options["file_format"], overwrite=options["overwrite"], start_index=self.get_slice_start_index(), logger=self.logger, metadata=metadata, histogram=("histogram" in self.processing_steps), extra_options=writer_extra_options, ) # # Pipeline execution # @pipeline_step("chunk_reader", "Reading data") def _read_data(self): self.logger.debug("Region = %s" % str(self.sub_region)) t0 = time() self.chunk_reader.load_data() el = time() - t0 shp = self.chunk_reader.data.shape self.logger.info("Read subvolume %s in %.2f s" % (str(shp), el)) @pipeline_step("flatfield", "Applying flat-field") def _flatfield(self): self.flatfield.normalize_radios(self.radios) @pipeline_step("double_flatfield", "Applying double flat-field") def _double_flatfield(self, radios=None): if radios is None: radios = self.radios self.double_flatfield.apply_double_flatfield(radios) @pipeline_step("ccd_correction", "Applying CCD corrections") def _ccd_corrections(self, radios=None): if radios is None: radios = self.radios _tmp_radio = self._allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio") for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] @pipeline_step("projs_rot", "Rotating projections") def _rotate_projections(self, radios=None): if radios is None: radios = self.radios tmp_radio = self._tmp_rotated_radio for i in range(radios.shape[0]): self.projs_rot.rotate(radios[i], output=tmp_radio) radios[i][:] = tmp_radio[:] @pipeline_step("phase_retrieval", "Performing phase retrieval") def _retrieve_phase(self): for i in range(self.radios.shape[0]): self.phase_retrieval.retrieve_phase(self.radios[i], output=self.radios[i]) @pipeline_step("unsharp_mask", "Performing unsharp mask") def _apply_unsharp(self): for i in range(self.radios.shape[0]): self.radios[i] = self.unsharp_mask.unsharp(self.radios[i]) @pipeline_step("mlog", "Taking logarithm") def _take_log(self): self.mlog.take_logarithm(self.radios) @pipeline_step("radios_movements", "Applying radios movements") def _radios_movements(self, radios=None): if radios is None: radios = self.radios self.radios_movements.apply_vertical_shifts(radios, list(range(radios.shape[0]))) def _crop_radios(self): if self.use_margin: self._orig_radios = self.radios if self.processing_options.get("reconstruction", {}).get("method", None) == "cone": return ((U, D), (L, R)) = self.margin self.logger.debug( "Cropping radios from %s to %s" % (str(self.radios_shape), str(self.radios_cropped_shape)) ) U, D, L, R = U or None, -D or None, L or None, -R or None self.radios = self.radios[:, U:D, L:R] # view @pipeline_step("sino_normalization", "Normalizing sinograms") def _normalize_sinos(self, radios=None): if radios is None: radios = self.radios sinos = radios.transpose((1, 0, 2)) self.sino_normalization.normalize(sinos) def _dump_sinogram(self): if self.datadump_manager is not None: self.datadump_manager.dump_data_to_file("sinogram", self.radios) @pipeline_step("sino_deringer", "Removing rings on sinograms") def _destripe_sinos(self): sinos = np.rollaxis(self.radios, 1, 0) # view self.sino_deringer.remove_rings(sinos) # TODO check it works with non-contiguous view @pipeline_step("reconstruction", "Reconstruction") def _reconstruct(self): """ Reconstruction for parallel geometry. For each target slice: get the corresponding sinogram, apply some processing, then reconstruct """ if self.processing_options["reconstruction"]["method"] == "cone": self._reconstruct_cone() return for i in range(self.n_slices): sino = self.sino_builder.get_sino(self.radios, i) if self.sino_deringer is not None: self.sino_deringer.remove_rings(sino) self.reconstruction.fbp(sino, output=self.recs[i]) def _reconstruct_cone(self): """ This reconstructs the entire sinograms stack at once """ self.sino_builder.get_sinos(self.radios, output=self.sinos) z_min, z_max = self.sub_region_xz[2:] n_z = self.process_config.radio_shape(binning=True)[0] self.reconstruction.reconstruct( # pylint: disable=E1101 self.sinos, output=self.recs, relative_z_position=((z_min + z_max) / self.process_config.binning_z / 2) - n_z / 2, ) @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None): if data is None: data = self.recs self.recs_histogram = self.histogram.compute_histogram(data) @pipeline_step("writer", "Saving data") def _write_data(self, data=None): if data is None: data = self.recs self.writer.write_data(data) self.logger.info("Wrote %s" % self.writer.fname) self._write_histogram() self.process_config.single_output_file_initialized = True def _write_histogram(self): if "histogram" not in self.processing_steps: return self.logger.info("Saving histogram") self.writer.write_histogram( hist_as_2Darray(self.recs_histogram), processing_index=1, config={ "file": path.basename(self.writer.fname), "bins": self.processing_options["histogram"]["histogram_bins"], }, ) def _process_finalize(self): if self.use_margin: self.radios = self._orig_radios def __repr__(self): res = "%s(%s, margin=%s)" % (self.__class__.__name__, str(self.chunk_shape), str(self.margin)) binning = self.process_config.binning subsampling = self.process_config.subsampling_factor if binning != (1, 1) or subsampling > 1: if binning != (1, 1): res += "\nImages binning: %s" % (str(binning)) if subsampling: res += "\nAngles subsampling: %d" % subsampling res += "\nRadios chunk: %s ---> %s" % (self.chunk_shape, self.radios_shape) if self.use_margin: res += "\nMargin: %s" % (str(self.margin)) res += "\nRadios chunk: %s ---> %s" % (str(self.radios_shape), str(self.radios_cropped_shape)) res += "\nCurrent subregion: %s" % (str(self.sub_region)) for step_name in self.processing_steps: res += "\n- %s" % (step_name) return res def _process_chunk(self): self._flatfield() self._double_flatfield() self._ccd_corrections() self._rotate_projections() self._retrieve_phase() self._apply_unsharp() self._take_log() self._radios_movements() self._crop_radios() self._normalize_sinos() self._dump_sinogram() # self._destripe_sinos() self._reconstruct() self._compute_histogram() self._write_data() self._process_finalize() def _reset_reader_subregion(self): if self._resume_from_step is None: # Normal mode - read data from raw radios self.chunk_reader._set_subregion(self.sub_region_xz) self.chunk_reader._init_reader() self.chunk_reader._loaded = False else: # Resume from a checkpoint. In this case, we have to re-initialize "datadump manager" # sooner to configure start_xyz, end_xyz self._init_data_dump() self.chunk_reader._set_subregion(self.datadump_manager.get_read_dump_subregion()) self.chunk_reader._loaded = False if self._grouped_processing: self._update_reader_configuration() self.chunk_reader._set_files(self._read_options["files"]) def _reset_sub_region(self, sub_region): self.set_subregion(sub_region) # When sub_region is changed, all components involving files reading have to be updated self._reset_reader_subregion() self._init_flatfield() # reset flatfield self._init_writer() self._init_double_flatfield() self._init_data_dump() def process_chunk(self, sub_region): self._reset_sub_region(sub_region) self._read_data() self._process_chunk() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/pipeline/fullfield/chunked_cuda.py0000644000175000017500000001243500000000000021720 0ustar00pierrepierrefrom ...preproc.ccd_cuda import CudaLog, CudaCCDFilter from ...preproc.flatfield_cuda import CudaFlatFieldDataUrls from ...preproc.shift_cuda import CudaVerticalShift from ...preproc.double_flatfield_cuda import CudaDoubleFlatField from ...preproc.phase_cuda import CudaPaganinPhaseRetrieval from ...preproc.ctf_cuda import CudaCTFPhaseRetrieval from ...reconstruction.sinogram_cuda import CudaSinoBuilder, CudaSinoNormalization from ...reconstruction.rings_cuda import CudaMunchDeringer from ...misc.unsharp_cuda import CudaUnsharpMask from ...misc.rotation_cuda import CudaRotation from ...misc.histogram_cuda import CudaPartialHistogram from ...reconstruction.fbp import Backprojector from ...reconstruction.cone import __have_astra__, ConebeamReconstructor from ...cuda.utils import get_cuda_context, __has_pycuda__, __pycuda_error_msg__ from ..utils import pipeline_step from .chunked import ChunkedPipeline if __has_pycuda__: import pycuda.gpuarray as garray if not (__have_astra__): ConebeamReconstructor = None class CudaChunkedPipeline(ChunkedPipeline): """ Cuda backend of ChunkedPipeline """ backend = "cuda" FlatFieldClass = CudaFlatFieldDataUrls DoubleFlatFieldClass = CudaDoubleFlatField CCDCorrectionClass = CudaCCDFilter PaganinPhaseRetrievalClass = CudaPaganinPhaseRetrieval CTFPhaseRetrievalClass = CudaCTFPhaseRetrieval UnsharpMaskClass = CudaUnsharpMask ImageRotationClass = CudaRotation VerticalShiftClass = CudaVerticalShift SinoDeringerClass = CudaMunchDeringer MLogClass = CudaLog SinoBuilderClass = CudaSinoBuilder FBPClass = Backprojector ConebeamClass = ConebeamReconstructor HistogramClass = CudaPartialHistogram SinoNormalizationClass = CudaSinoNormalization def __init__( self, process_config, chunk_shape, logger=None, extra_options=None, margin=None, use_grouped_mode=False, cuda_options=None, ): self._init_cuda(cuda_options) super().__init__( process_config, chunk_shape, logger=logger, extra_options=extra_options, use_grouped_mode=use_grouped_mode, margin=margin, ) self._allocate_array(self.radios.shape, "f", name="radios") self._determine_when_to_transfer_data_on_gpu() def _determine_when_to_transfer_data_on_gpu(self): # Decide when to transfer data to GPU. Normally it's right after reading the data, # But sometimes a part of the processing is done on CPU. self._when_to_transfer_radios_on_gpu = "read_data" if self.flatfield is not None and self.flatfield.distortion_correction is not None: self._when_to_transfer_radios_on_gpu = "flatfield" def _init_cuda(self, cuda_options): if not (__has_pycuda__): raise ImportError(__pycuda_error_msg__) cuda_options = cuda_options or {} self.ctx = get_cuda_context(**cuda_options) self._d_radios = None self._d_sinos = None self._d_recs = None def _allocate_array(self, shape, dtype, name=None): name = name or "tmp" # should be mandatory d_name = "_d_" + name d_arr = getattr(self, d_name, None) if d_arr is None: self.logger.debug("Allocating %s: %s" % (name, str(shape))) d_arr = garray.zeros(shape, dtype) setattr(self, d_name, d_arr) return d_arr def _transfer_radios_to_gpu(self): self.logger.debug("Transfering radios to GPU") self._d_radios.set(self.radios) self._h_radios = self.radios self.radios = self._d_radios def _process_finalize(self): self.radios = self._h_radios # # Pipeline execution (class specialization) # def _read_data(self): super()._read_data() if self._when_to_transfer_radios_on_gpu == "read_data": self._transfer_radios_to_gpu() def _flatfield(self): super()._flatfield() if self._when_to_transfer_radios_on_gpu == "flatfield": self._transfer_radios_to_gpu() def _reconstruct(self): super()._reconstruct() if "reconstruction" not in self.processing_steps: return if self.processing_options["reconstruction"]["method"] == "cone": ((U, D), (L, R)) = self.margin U, D = U or None, -D or None # not sure why slicing can't be done before get() self.recs = self.recs.get()[U:D, ...] else: self.recs = self.recs.get() def _write_data(self, data=None): super()._write_data(data=data) if "reconstruction" in self.processing_steps: self.recs = self._d_recs self.radios = self._h_radios @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None): if data is None: data = self._d_recs self.recs_histogram = self.histogram.compute_histogram(data) def _dump_data_to_file(self, step_name, data=None): if data is None: data = self.radios if step_name not in self.datadump_manager.data_dump: return if isinstance(data, garray.GPUArray): data = data.get() self.datadump_manager.dump_data_to_file(step_name, data=data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/fullfield/computations.py0000644000175000017500000002117500000000000022031 0ustar00pierrepierrefrom math import ceil from silx.image.tomography import get_next_power from ...utils import check_supported def estimate_required_memory(process_config, delta_z=None, delta_a=None, max_mem_allocation_GB=None, debug=False): """ Estimate the memory (RAM) needed for a reconstruction. Parameters ----------- process_config: `ProcessConfig` object Data structure with the processing configuration delta_z: int, optional How many lines are to be loaded in the each projection image. Default is to load lines [start_z:end_z] (where start_z and end_z come from the user configuration file) delta_a: int, optional How many (partial) projection images to load at the same time. Default is to load all the projection images. max_mem_allocation_GB: float, optional Maximum amount of memory in GB for one single array. Returns ------- required_memory: float Total required memory (in bytes). Raises ------ ValueError if one single-array allocation exceeds "max_mem_allocation_GB" Notes ----- pycuda <= 2022.1 cannot use arrays with more than 2**32 items (i.e 17.18 GB for float32). This was solved in more recent versions. """ def check_memory_limit(mem_GB, name): if max_mem_allocation_GB is None: return if mem_GB > max_mem_allocation_GB: raise ValueError( "Cannot allocate array '%s' %.3f GB > max_mem_allocation_GB = %.3f GB" % (name, mem_GB, max_mem_allocation_GB) ) dataset = process_config.dataset_info processing_steps = process_config.processing_steps # The "x" dimension (last axis) is always the image width, because # - Processing is never "cut" along this axis (we either split along frames, or images lines) # - Even if we want to reconstruct a vertical slice (i.e end_x - start_x == 1), # the tomography reconstruction will need the full information along this axis. # The only case where reading a x-subregion would be useful is to crop the initial data # (i.e knowing that a part of each image will be completely unused). This is not supported yet. Nx = process_config.radio_shape(binning=True)[-1] if delta_z is not None: Nz = delta_z // process_config.binning_z else: Nz = process_config.rec_delta_z # accounting for binning if delta_a is not None: Na = ceil(delta_a / process_config.subsampling_factor) else: Na = process_config.n_angles(subsampling=True) total_memory_needed = 0 # Read data # ---------- data_image_size = Nx * Nz * 4 data_volume_size = Na * data_image_size check_memory_limit(data_volume_size / 1e9, "projections") total_memory_needed += data_volume_size # CCD processing # --------------- if "flatfield" in processing_steps: # Flat-field is done in-place, but still need to load darks/flats n_darks = len(dataset.darks) n_flats = len(dataset.flats) total_memory_needed += (n_darks + n_flats) * data_image_size if "ccd_correction" in processing_steps: # CCD filter is "batched 2D" total_memory_needed += data_image_size # Phase retrieval # --------------- if "phase" in processing_steps: # Phase retrieval is done image-wise, so near in-place, but needs to # allocate some images, fft plans, and so on Nx_p = get_next_power(2 * Nx) Nz_p = get_next_power(2 * Nz) img_size_real = 2 * 4 * Nx_p * Nz_p img_size_cplx = 2 * 8 * ((Nx_p * Nz_p) // 2 + 1) total_memory_needed += 2 * img_size_real + 3 * img_size_cplx # Sinogram de-ringing # ------------------- if "sino_rings_correction" in processing_steps: # Process is done image-wise. # Needs one Discrete Wavelets transform and one FFT/IFFT plan for each scale total_memory_needed += (Nx * Na * 4) * 5.5 # approx. # Reconstruction # --------------- reconstructed_volume_size = 0 if "reconstruction" in processing_steps: rec_config = process_config.processing_options["reconstruction"] Nx_rec = rec_config["end_x"] - rec_config["start_x"] + 1 Ny_rec = rec_config["end_y"] - rec_config["start_y"] + 1 reconstructed_volume_size = Nz * Nx_rec * Ny_rec * 4 check_memory_limit(reconstructed_volume_size / 1e9, "reconstructions") total_memory_needed += reconstructed_volume_size if process_config.rec_params["method"] == "cone": # In cone-beam reconstruction, need both sinograms and reconstruction inside GPU. # That's big! total_memory_needed += 2 * data_volume_size if debug: print( "Mem for (delta_z=%s, delta_a=%s) ==> (Na=%d, Nz=%d, Nx=%d) : %.3f GB" % (delta_z, delta_a, Na, Nz, Nx, total_memory_needed / 1e9) ) return total_memory_needed def estimate_max_chunk_size( available_memory_GB, process_config, pipeline_part="all", n_rows=None, step=10, max_mem_allocation_GB=None ): """ Estimate the maximum size of the data chunk that can be loaded in memory. Parameters ---------- available_memory_GB: float available memory in Giga Bytes (GB - not GiB !). process_config: ProcessConfig ProcessConfig object pipeline_part: str Which pipeline part to consider. Possible options are: - "full": Account for all the pipeline steps (reading data all the way to reconstruction). - "radios": Consider only the processing steps on projection images (ignore sinogram-based steps and reconstruction) - "sinogram": Consider only the processing steps related to sinograms and reconstruction n_rows: int, optional How many lines to load in each projection. Only accounted for pipeline_part="radios". step: int, optional Step size when doing the iterative memory estimation max_mem_allocation_GB: float, optional Maximum size (in GB) for one single array. Returns ------- n_max: int If pipeline_par is "full" or "sinos": return the maximum number of lines that can be loaded in all the projections while fitting memory, i.e `data[:, 0:n_max, :]` If pipeline_part is "radios", return the maximum number of (partial) images that can be loaded while fitting memory, i.e `data[:, zmin:zmax, 0:n_max]` Notes ----- pycuda <= 2022.1 cannot use arrays with more than 2**32 items (i.e 17.18 GB for float32). This was solved in more recent versions. """ supported_pipeline_parts = ["all", "radios", "sinos"] check_supported(pipeline_part, supported_pipeline_parts, "pipeline_part") processing_steps_bak = process_config.processing_steps.copy() reconstruction_steps = ["sino_rings_correction", "reconstruction"] if pipeline_part == "all": # load lines from all the projections delta_a = None delta_z = 0 if pipeline_part == "radios": # order should not matter process_config.processing_steps = list(set(process_config.processing_steps) - set(reconstruction_steps)) # load lines from only a subset of projections delta_a = 0 delta_z = n_rows if pipeline_part == "sinos": process_config.processing_steps = [ step for step in process_config.processing_steps if step in reconstruction_steps ] # load lines from all the projections delta_a = None delta_z = 0 mem = 0 last_valid_delta_a = delta_a last_valid_delta_z = delta_z while True: try: mem = estimate_required_memory( process_config, delta_z=delta_z, delta_a=delta_a, max_mem_allocation_GB=max_mem_allocation_GB ) except ValueError: # For very big dataset this function might return "0". # Either start at 1, or use a smaller step... break if mem / 1e9 > available_memory_GB: break if delta_a is not None and delta_a > process_config.n_angles(): break if delta_z is not None and delta_z > process_config.radio_shape()[0]: break last_valid_delta_a, last_valid_delta_z = delta_a, delta_z if pipeline_part == "radios": delta_a += step else: delta_z += step process_config.processing_steps = processing_steps_bak res = last_valid_delta_a if pipeline_part == "radios" else last_valid_delta_z # Really not ideal. For very large dataset, "step" should be very small. # Otherwise we go from 0 -> OK to 10 -> not OK, and then retain 0... if res == 0: res = 1 # return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/fullfield/dataset_validator.py0000644000175000017500000000575700000000000023006 0ustar00pierrepierreimport os from ..dataset_validator import DatasetValidatorBase class FullFieldDatasetValidator(DatasetValidatorBase): def _validate(self): self._check_not_empty() self._convert_negative_indices() self._get_output_filename() self._check_can_do_flatfield() self._check_can_do_phase() self._check_can_do_reconstruction() self._check_slice_indices() self._handle_processing_mode() self._handle_binning() self._check_output_file() def _check_can_do_flatfield(self): if self.nabu_config["preproc"]["flatfield"]: darks = self.dataset_info.darks assert len(darks) > 0, "Need at least one dark to perform flat-field correction" for dark_id, dark_url in darks.items(): assert os.path.isfile(dark_url.file_path()), "Dark file %s not found" % dark_url.file_path() flats = self.dataset_info.flats assert len(flats) > 0, "Need at least one flat to perform flat-field correction" for flat_id, flat_url in flats.items(): assert os.path.isfile(flat_url.file_path()), "Flat file %s not found" % flat_url.file_path() def _check_slice_indices(self): nx, nz = self.dataset_info.radio_dims rec_params = self.rec_params if rec_params["enable_halftomo"]: ny, nx = self._get_nx_ny() what = (("start_x", "end_x", nx), ("start_y", "end_y", nx), ("start_z", "end_z", nz)) for start_name, end_name, numels in what: self._check_start_end_idx( rec_params[start_name], rec_params[end_name], numels, start_name=start_name, end_name=end_name ) def _check_can_do_phase(self): if self.nabu_config["phase"]["method"] is None: return self.dataset_info.check_defined_attribute("distance") self.dataset_info.check_defined_attribute("pixel_size") def _check_can_do_reconstruction(self): rec_options = self.nabu_config["reconstruction"] if rec_options["method"] is None: return self.dataset_info.check_defined_attribute("pixel_size") if rec_options["method"] == "cone": if rec_options["source_sample_dist"] is None: err_msg = "In cone-beam reconstruction, you have to provide 'source_sample_dist' in [reconstruction]" self.logger.fatal(err_msg) raise ValueError(err_msg) if rec_options["sample_detector_dist"] is None: if self.dataset_info.distance is None: err_msg = "Cone-beam reconstruction: 'sample_detector_dist' was not provided but could not be found in the dataset metadata either. Please provide 'sample_detector_dist'" self.logger.fatal(err_msg) raise ValueError(err_msg) self.logger.warning( "Cone-beam reconstruction: 'sample_detector_dist' not provided, will use the one in dataset metadata" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/pipeline/fullfield/nabu_config.py0000644000175000017500000006773200000000000021567 0ustar00pierrepierrefrom ..config_validators import * nabu_config = { "dataset": { "location": { "default": "", "help": "Dataset location, either a directory or a HDF5-Nexus file.", "validator": dataset_location_validator, "type": "required", }, "hdf5_entry": { "default": "", "help": "Entry in the HDF5 file, if applicable. Default is the first available entry.", "validator": optional_string_validator, "type": "advanced", }, "nexus_version": { "default": "1.0", "help": "Nexus version to use when browsing the HDF5 dataset. Default is 1.0.", "validator": float_validator, "type": "advanced", }, "darks_flats_dir": { "default": "", "help": "Path to a directory where XXX_flats.h5 and XXX_darks.h5 are to be found, where 'XXX' denotes the dataset basename. If these files are found, then reduced flats/darks will be loaded from them. Otherwise, reduced flats/darks will be saved to there once computed, either in the .nx directory, or in the output directory. Mind that the HDF5 entry corresponds to the one of the dataset.", "validator": optional_directory_location_validator, "type": "optional", }, "binning": { "default": "1", "help": "Binning factor in the horizontal dimension when reading the data.\nThe final slices dimensions will be divided by this factor.", "validator": binning_validator, "type": "advanced", }, "binning_z": { "default": "1", "help": "Binning factor in the vertical dimension when reading the data.\nThis results in a lesser number of reconstructed slices.", "validator": binning_validator, "type": "advanced", }, "projections_subsampling": { "default": "1", "help": "Projections subsampling factor: take one projection out of 'projection_subsampling'", "validator": binning_validator, "type": "advanced", }, "exclude_projections": { "default": "", "help": "Path to a file name containing projections to exclude (projection indices).", "validator": optional_file_location_validator, "type": "advanced", }, "overwrite_metadata": { "default": "", "help": "Which metadata to overwrite, separated by a semicolon, and with units. Example: 'energy = 19 kev; pixel_size = 1.6 um'", "validator": no_validator, "type": "advanced", }, }, "preproc": { "flatfield": { "default": "1", "help": "How to perform flat-field normalization. The parameter value can be:\n - 1 or True: enabled.\n - 0 or False: disabled\n - forced or force-load: perform flatfield regardless of the dataset by attempting to load darks/flats\n - force-compute: perform flatfield, ignore all .h5 files containing already computed darks/flats.", "validator": flatfield_enabled_validator, "type": "required", }, "flat_distortion_correction_enabled": { "default": "0", "help": "Whether to correct for flat distortion. If activated, each radio is correlated with its corresponding flat, in order to determine and correct the flat distortion.", "validator": boolean_validator, "type": "advanced", }, "flat_distortion_params": { "default": "tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None", "help": "Advanced parameters for flat distortion correction", "validator": optional_string_validator, "type": "advanced", }, "normalize_srcurrent": { "default": "0", "help": "Whether to normalize frames with Synchrotron Current. This can correct the effect of a beam refill not taken into account by flats.", "validator": boolean_validator, "type": "advanced", }, "ccd_filter_enabled": { "default": "0", "help": "Whether to enable the CCD hotspots correction.", "validator": boolean_validator, "type": "optional", }, "ccd_filter_threshold": { "default": "0.04", "help": "If ccd_filter_enabled = 1, a median filter is applied on the 3X3 neighborhood\nof every pixel. If a pixel value exceeds the median value more than this parameter,\nthen the pixel value is replaced with the median value.", "validator": float_validator, "type": "optional", }, "detector_distortion_correction": { "default": "", "help": "Apply coordinate transformation on the raw data, at the reading stage. Default (empty) is None. Available are: None, identity(for testing the pipeline), map_xz. This latter method requires two URLs being passet by detector_distortion_correction_options: map_x and map_z pointing to two 2D arrays containing the position where each pixel can be interpolated at in the raw data", "validator": detector_distortion_correction_validator, "type": "advanced", }, "detector_distortion_correction_options": { "default": "", "help": "Options for sinogram rings correction methods. The parameters are separated by commas and passed as 'name=value', for example: center_xz=(1000,100); angle_deg=5. Mind the semicolon separator (;).", "validator": generic_options_validator, "type": "advanced", }, "double_flatfield_enabled": { "default": "0", "help": "Whether to enable the 'double flat-field' filetering for correcting rings artefacts.", "validator": boolean_validator, "type": "optional", }, "dff_sigma": { "default": "", "help": "Enable high-pass filtering on double flatfield with this value of 'sigma'", "validator": optional_float_validator, "type": "advanced", }, "take_logarithm": { "default": "1", "help": "Whether to take logarithm after flat-field and phase retrieval.", "validator": boolean_validator, "type": "required", }, "log_min_clip": { "default": "1e-6", "help": "After division by the FF, and before the logarithm, the is clipped to this minimum. Enabled only if take_logarithm=1", "validator": float_validator, "type": "advanced", }, "log_max_clip": { "default": "10.0", "help": "After division by the FF, and before the logarithm, the is clipped to this maximum. Enabled only if take_logarithm=1", "validator": float_validator, "type": "advanced", }, "sino_normalization": { "default": "", "help": "Sinogram normalization method. Available methods are: chebyshev, subtraction, division, none. Default is none (no normalization)", "validator": sino_normalization_validator, "type": "advanced", }, "sino_normalization_file": { "default": "", "help": "Path to the file when sino_normalization is either 'subtraction' or 'division'. To specify the path within a HDF5 file, the syntax is /path/to/file?path=/entry/data", "validator": no_validator, "type": "advanced", }, "processes_file": { "default": "", "help": "Path to the file where some operations should be stored for later use. By default it is 'xxx_nabu_processes.h5'", "validator": optional_output_file_path_validator, "type": "advanced", }, "sino_rings_correction": { "default": "", "help": "Sinogram rings removal method. Default (empty) is None. Available are: None, munch. See also: sino_rings_options", "validator": sino_deringer_methods, "type": "optional", }, "sino_rings_options": { "default": "sigma=1.0 ; levels=10", "help": "Options for sinogram rings correction methods. The parameters are separated by commas and passed as 'name=value', for example: sigma=1.0;levels=10. Mind the semicolon separator (;).", "validator": generic_options_validator, "type": "advanced", }, "rotate_projections": { "default": "", "help": "Whether to rotate each projection image with a certain angle (in degree). By default (empty) no rotation is done.", "validator": optional_nonzero_float_validator, "type": "advanced", }, "rotate_projections_center": { "default": "", "help": "Center of rotation when 'rotate_projections' is non-empty. By default the center of rotation is the middle of each radio, i.e ((Nx-1)/2.0, (Ny-1)/2.0).", "validator": optional_tuple_of_floats_validator, "type": "advanced", }, "tilt_correction": { "default": "", "help": "Detector tilt correction. Default (empty) means no tilt correction.\nThe following values can be provided for automatic tilt estimation, in this case, the projection images are rotated by the found tilt value:\n - A scalar value: tilt correction angle in degrees\n - 1d-correlation: auto-detect tilt with the 1D correlation method (fastest, but works best for small tilts)\n - fft-polar: auto-detect tilt with polar FFT method (slower, but works well on all ranges of tilts)", "validator": tilt_validator, "type": "advanced", }, "autotilt_options": { "default": "", "help": "Options for methods computing automatically the detector tilt. The parameters are separated by commas and passed as 'name=value', for example: low_pass=1; high_pass=20. Mind the semicolon separator (;).", "validator": generic_options_validator, "type": "advanced", }, }, "phase": { "method": { "default": "none", "help": "Phase retrieval method. Available are: Paganin, CTF, None", "validator": phase_method_validator, "type": "required", }, "delta_beta": { "default": "100.0", "help": "Single-distance phase retrieval related parameters\n----------------------------\ndelta/beta ratio for the Paganin/CTF method", "validator": float_validator, "type": "required", }, "unsharp_coeff": { "default": "0", "help": "Unsharp mask strength. The unsharped image is equal to\n UnsharpedImage = (1 + coeff)*originalPaganinImage - coeff * ConvolvedImage. Setting this coefficient to zero means that no unsharp mask will be applied.", "validator": float_validator, "type": "optional", }, "unsharp_sigma": { "default": "0", "help": "Standard deviation of the Gaussian filter when applying an unsharp mask\nafter the phase filtering. Disabled if set to 0.", "validator": float_validator, "type": "optional", }, "unsharp_method": { "default": "gaussian", "help": "Which type of unsharp mask filter to use. Available values are gaussian, laplacian and imagej. Default is gaussian.", "validator": unsharp_method_validator, "type": "optional", }, "padding_type": { "default": "edge", "help": "Padding type for the filtering step in Paganin/CTF. Available are: mirror, edge, zeros", "validator": padding_mode_validator, "type": "advanced", }, "ctf_geometry": { "default": "z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True", "help": "Geometric parameters for CTF phase retrieval. Length units are in meters.", "validator": optional_string_validator, "type": "optional", }, "ctf_advanced_params": { "default": "length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True", "help": "Advanced parameters for CTF phase retrieval.", "validator": optional_string_validator, "type": "advanced", }, }, "reconstruction": { "method": { "default": "FBP", "help": "Reconstruction method. Possible values: FBP, cone, none. If value is 'none', no reconstruction will be done.", "validator": reconstruction_method_validator, "type": "required", }, "angles_file": { "default": "", "help": "In the case you want to override the angles found in the files metadata. The angles are in degree.", "validator": optional_file_location_validator, "type": "optional", }, "rotation_axis_position": { "default": "sliding-window", "help": "Rotation axis position. It can be a number or the name of an estimation method (empty value means the middle of the detector).\nThe following methods are available to find automatically the Center of Rotation (CoR):\n - centered : a fast and simple auto-CoR method. It only works when the CoR is not far from the middle of the detector. It does not work for half-tomography.\n - global : a slow but robust auto-CoR.\n - sliding-window : semi-automatically find the CoR with a sliding window. You have to specify on which side the CoR is (left, center, right). Please see the 'cor_options' parameter.\n - growing-window : automatically find the CoR with a sliding-and-growing window. You can tune the option with the parameter 'cor_options'.\n - sino-coarse-to-fine: Estimate CoR from sinogram. Only works for 360 degrees scans.\n - composite-coarse-to-fine: Estimate CoR from composite multi-angle images. Only works for 360 degrees scans.", "validator": cor_validator, "type": "required", }, "cor_options": { "default": "", "help": "Options for methods finding automatically the rotation axis position. The parameters are separated by commas and passed as 'name=value', for example: low_pass=1; high_pass=20. Mind the semicolon separator (;).", "validator": generic_options_validator, "type": "advanced", }, "cor_slice": { "default": "", "help": "Which slice to use for estimating the Center of Rotation (CoR). This parameter can be an integer or 'top', 'middle', 'bottom'.\nIf provided, the CoR will be estimated from the correspondig sinogram, and 'cor_options' can contain the parameter 'subsampling'.", "validator": cor_slice_validator, "type": "advanced", }, "axis_correction_file": { "default": "", "help": "In the case where the axis position is specified for each angle", "validator": optional_values_file_validator, "type": "advanced", }, "translation_movements_file": { "default": "", "help": "A file where each line describes the horizontal and vertical translations of the sample (or detector). The order is 'horizontal, vertical'.", "validator": optional_values_file_validator, "type": "advanced", }, "angle_offset": { "default": "0", "help": "Use this if you want to obtain a rotated reconstructed slice. The angle is in degrees.", "validator": float_validator, "type": "advanced", }, "fbp_filter_type": { "default": "ramlak", "help": "Filter type for FBP method. Available are: none, ramlak, shepp-logan, cosine, hamming, hann, tukey, lanczos, hilbert", "validator": fbp_filter_name_validator, "type": "advanced", }, "fbp_filter_cutoff": { "default": "1.", "help": "Cut-off frequency for Fourier filter used in FBP, in normalized units. Default is the Nyquist frequency 1.0", "validator": float_validator, "type": "advanced", }, "source_sample_dist": { "default": "", "help": "In cone-beam geometry, distance (in meters) between the X-ray source and the center of the sample. Default is infinity.", "validator": optional_float_validator, "type": "advanced", }, "sample_detector_dist": { "default": "", "help": "In cone-beam geometry, distance (in meters) between the center of the sample and the detector. Default is read from the input dataset.", "validator": optional_float_validator, "type": "advanced", }, "padding_type": { "default": "edges", "help": "Padding type for FBP. Available are: zeros, edges", "validator": padding_mode_validator, "type": "optional", # put "advanced" with default value "edges" ? }, "enable_halftomo": { "default": "auto", "help": "Whether to enable half-acquisition. Default is auto. You can enable/disable it manually by setting 1 or 0.", "validator": boolean_or_auto_validator, "type": "optional", }, "clip_outer_circle": { "default": "0", "help": "Whether to set to zero voxels falling outside of the reconstruction region", "validator": boolean_validator, "type": "optional", }, "centered_axis": { "default": "0", "help": "If set to true, the reconstructed region is centered on the rotation axis, i.e the center of the image will be the rotation axis position.", "validator": boolean_validator, "type": "optional", }, "start_x": { "default": "0", "help": "\nParameters for sub-volume reconstruction. Indices start at 0, and upper bounds are INCLUDED!\n----------------------------------------------------------------\n(x, y) are the dimension of a slice, and (z) is the 'vertical' axis\nBy default, all the volume is reconstructed slice by slice, along the axis 'z'.", "validator": nonnegative_integer_validator, "type": "optional", }, "end_x": { "default": "-1", "help": "", "validator": integer_validator, "type": "optional", }, "start_y": { "default": "0", "help": "", "validator": nonnegative_integer_validator, "type": "optional", }, "end_y": { "default": "-1", "help": "", "validator": integer_validator, "type": "optional", }, "start_z": { "default": "0", "help": "", "validator": slice_num_validator, "type": "optional", }, "end_z": { "default": "-1", "help": "", "validator": slice_num_validator, "type": "optional", }, "iterations": { "default": "200", "help": "\nParameters for iterative algorithms\n------------------------------------\nNumber of iterations", "validator": nonnegative_integer_validator, "type": "unsupported", }, "optim_algorithm": { "default": "chambolle-pock", "help": "Optimization algorithm for iterative methods", "validator": optimization_algorithm_name_validator, "type": "unsupported", }, "weight_tv": { "default": "1.0e-2", "help": "Total Variation regularization parameter for iterative methods", "validator": float_validator, "type": "unsupported", }, "preconditioning_filter": { "default": "1", "help": "Whether to enable 'filter preconditioning' for iterative methods", "validator": boolean_validator, "type": "unsupported", }, "positivity_constraint": { "default": "1", "help": "Whether to enforce a positivity constraint in the reconstruction.", "validator": boolean_validator, "type": "unsupported", }, }, "output": { "location": { "default": "", "help": "Directory where the output reconstruction is stored.", "validator": optional_directory_location_validator, "type": "required", }, "file_prefix": { "default": "", "help": "File prefix. Optional, by default it is inferred from the scanned dataset.", "validator": optional_file_name_validator, "type": "optional", }, "file_format": { "default": "hdf5", "help": "Output file format. Available are: hdf5, tiff, jp2, edf, vol", "validator": output_file_format_validator, "type": "optional", }, "overwrite_results": { "default": "1", "help": "What to do in the case where the output file exists.\nBy default, the output data is never overwritten and the process is interrupted if the file already exists.\nSet this option to 1 if you want to overwrite the output files.", "validator": boolean_validator, "type": "required", }, "tiff_single_file": { "default": "0", "help": "Whether to create a single large tiff file for the reconstructed volume.", "validator": boolean_validator, "type": "advanced", }, "jpeg2000_compression_ratio": { "default": "", "help": "Compression ratio for Jpeg2000 output.", "validator": optional_positive_integer_validator, "type": "advanced", }, "float_clip_values": { "default": "", "help": "Lower and upper bounds to use when converting from float32 to int. Floating point values are clipped to these (min, max) values before being cast to integer.", "validator": optional_tuple_of_floats_validator, "type": "advanced", }, }, "postproc": { "output_histogram": { "default": "0", "help": "Whether to compute a histogram of the volume.", "validator": boolean_validator, "type": "optional", }, "histogram_bins": { "default": "1000000", "help": "Number of bins for the output histogram. Default is one million. ", "validator": nonnegative_integer_validator, "type": "advanced", }, }, "resources": { "method": { "default": "local", "help": "Computations distribution method. It can be:\n - local: run the computations on the local machine\n - slurm: run the computations through SLURM\n - preview: reconstruct the slices/volume as quickly as possible, possibly doing some binning.", "validator": distribution_method_validator, "type": "required", }, "gpus": { "default": "1", "help": "Number of GPUs to use.", "validator": nonnegative_integer_validator, "type": "advanced", }, "gpu_id": { "default": "", "help": "For method = local only. List of GPU IDs to use. This parameter overwrites 'gpus'.\nIf left blank, exactly one GPU will be used, and the best one will be picked.", "validator": list_of_int_validator, "type": "advanced", }, "cpu_workers": { "default": "0", "help": "Number of 'CPU workers' for each GPU worker. It is discouraged to set this number to more than one. A value of -1 means exactly one CPU worker.", "validator": integer_validator, "type": "unsupported", }, "memory_per_node": { "default": "90%", "help": "RAM memory per computing node, either in GB or in percent of the AVAILABLE (!= total) node memory.\nIf several workers share the same node, their combined memory usage will not exceed this number.", "validator": resources_validator, "type": "unsupported", }, "threads_per_node": { "default": "100%", "help": "Number of threads to allocate on each node, either a number or a percentage of the available threads", "validator": resources_validator, "type": "unsupported", }, "queue": { "default": "gpu", "help": "\nParameters exclusive to the 'slurm' distribution method\n------------------------------------------------------\nName of the SLURM partition ('queue'). Full list is obtained with 'scontrol show partition'", "validator": nonempty_string_validator, "type": "unsupported", }, "walltime": { "default": "01:00:00", "help": "Time limit for the SLURM resource allocation, in the format Hours:Minutes:Seconds", "validator": walltime_validator, "type": "unsupported", }, }, "pipeline": { "save_steps": { "default": "", "help": "Save intermediate results. This is a list of comma-separated processing steps, for ex: flatfield, phase, sinogram.\nEach step generates a HDF5 file in the form name_file_prefix.hdf5 (ex. 'sinogram_file_prefix.hdf5')", "validator": optional_string_validator, "type": "optional", }, "resume_from_step": { "default": "", "help": "Resume the processing from a previously saved processing step. The corresponding file must exist in the output directory.", "validator": optional_string_validator, "type": "optional", }, "steps_file": { "default": "", "help": "File where the intermediate processing steps are written. By default it is empty, and intermediate processing steps are written in the same directory as the reconstructions, with a file prefix, ex. sinogram_mydataset.hdf5.", "validator": optional_output_file_path_validator, "type": "advanced", }, "verbosity": { "default": "2", "help": "Level of verbosity of the processing. 0 = terse, 3 = much information.", "validator": logging_validator, "type": "optional", }, }, # This section will be removed in the future (for now it is deprecated) "about": {}, } renamed_keys = { "marge": { "section": "phase", "new_name": "margin", "since": "2020.2.0", "message": "Option 'marge' has been renamed 'margin' in [phase]", }, "overwrite_results": { "section": "about", "new_name": "overwrite_results", "new_section": "output", "since": "2020.3.0", "message": "Option 'overwrite_results' was moved from section [about] to section [output]", }, "nabu_config_version": { "section": "about", "new_name": "", "new_section": "about", "since": "2020.3.1", "message": "Option 'nabu_config_version' was removed.", }, "nabu_version": { "section": "about", "new_name": "", "new_section": "about", "since": "2021.1.0", "message": "Option 'nabu_config' was removed.", }, "verbosity": { "section": "about", "new_name": "verbosity", "new_section": "pipeline", "since": "2021.1.0", "message": "Option 'verbosity' was moved from section [about] to section [pipeline]", }, "flatfield_enabled": { "section": "preproc", "new_name": "flatfield", "since": "2021.2.0", "message": "Option 'flatfield_enabled' has been renamed 'flatfield' in [preproc]", }, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/fullfield/processconfig.py0000644000175000017500000010514300000000000022146 0ustar00pierrepierreimport os import posixpath import numpy as np from silx.io import get_data from silx.io.url import DataUrl from ...utils import copy_dict_items, compare_dicts from ...io.utils import hdf5_entry_exists, get_h5_value from ...io.reader import import_h5_to_dict from ...resources.utils import extract_parameters, get_values_from_file from ...resources.nxflatfield import update_dataset_info_flats_darks from ...resources.utils import get_quantities_and_units from ...reconstruction.sinogram import get_extended_sinogram_width from ..estimators import estimate_cor from ..processconfig import ProcessConfigBase from .nabu_config import nabu_config, renamed_keys from .dataset_validator import FullFieldDatasetValidator class ProcessConfig(ProcessConfigBase): """ A ProcessConfig object has these main fields: - dataset_info: information about the current dataset - nabu_config: configuration from the user side - processing_steps/processing_options: configuration "ready-to use" for underlying classes It is built from the following steps. (1a) parse config: (conf_fname or conf_dict) --> "nabu_config" (1b) browse dataset: (nabu_config or existing dataset_info) --> dataset_info (2) update_dataset_info_with_user_config - Update flats/darks - CoR (value or estimation method) # no estimation yet - rotation angles - translations files - user sino normalization (eg. subtraction etc) (3) estimations - tilt - CoR (4) coupled validation (5) build processing steps (6) configure checkpoints (save/resume) """ default_nabu_config = nabu_config config_renamed_keys = renamed_keys _use_horizontal_translations = True _all_processing_steps = [ "read_chunk", "flatfield", "ccd_correction", "double_flatfield", "rotate_projections", "phase", "unsharp_mask", "take_log", "radios_movements", # radios are cropped after this step, if needed "sino_normalization", "sino_rings_correction", "reconstruction", "histogram", "save", ] def _update_dataset_info_with_user_config(self): """ Update the 'dataset_info' (DatasetAnalyzer class instance) data structure with options from user configuration. """ self.logger.debug("Updating dataset information with user configuration") if self.dataset_info.kind == "hdf5": update_dataset_info_flats_darks( self.dataset_info, self.nabu_config["preproc"]["flatfield"], output_dir=self.nabu_config["output"]["location"], darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"], ) self.rec_params = self.nabu_config["reconstruction"] self._update_dataset_with_user_overwrites() self._get_rotation_axis_position() self._update_rotation_angles() self._get_translation_file("reconstruction", "translation_movements_file", "translations") self._get_user_sino_normalization() def _update_dataset_with_user_overwrites(self): user_overwrites = self.nabu_config["dataset"]["overwrite_metadata"].strip() if user_overwrites in ("", None): return possible_overwrites = {"pixel_size": 1e6, "distance": 1.0, "energy": 1.0} try: overwrites = get_quantities_and_units(user_overwrites) except ValueError: msg = ( "Something wrong in config file in 'overwrite_metadata': could not get quantities/units from '%s'. Please check that separators are ';' and that units are provided (separated by a space)" % user_overwrites ) self.logger.fatal(msg) raise ValueError(msg) for quantity, conversion_factor in possible_overwrites.items(): user_value = overwrites.pop(quantity, None) if user_value is not None: self.logger.info("Overwriting %s = %s" % (quantity, user_value)) user_value *= conversion_factor setattr(self.dataset_info, quantity, user_value) def _get_translation_file(self, config_section, config_key, dataset_info_attr, last_dim=2): transl_file = self.nabu_config[config_section][config_key] if transl_file in (None, ""): return translations = None if transl_file is not None and "://" not in transl_file: try: translations = get_values_from_file( transl_file, shape=(self.dataset_info.n_angles, last_dim), any_size=True ).astype(np.float32) except ValueError: print("Something wrong with translation_movements_file %s" % transl_file) raise else: try: translations = get_data(transl_file) except: print("Something wrong with translation_movements_file %s" % transl_file) raise setattr(self.dataset_info, dataset_info_attr, translations) if self._use_horizontal_translations and translations is not None: # Horizontal translations are handled by "axis_correction" in backprojector horizontal_translations = translations[:, 0] if np.max(np.abs(horizontal_translations)) > 1e-3: self.dataset_info.axis_correction = horizontal_translations def _get_rotation_axis_position(self): super()._get_rotation_axis_position() rec_params = self.nabu_config["reconstruction"] axis_correction_file = rec_params["axis_correction_file"] axis_correction = None if axis_correction_file is not None: try: axis_correction = get_values_from_file( axis_correction_file, n_values=self.dataset_info.n_angles, any_size=True, ).astype(np.float32) except ValueError: print("Something wrong with axis correction file %s" % axis_correction_file) raise self.dataset_info.axis_correction = axis_correction def _update_rotation_angles(self): rec_params = self.nabu_config["reconstruction"] n_angles = self.dataset_info.n_angles angles_file = rec_params["angles_file"] if angles_file is not None: try: angles = get_values_from_file(angles_file, n_values=n_angles, any_size=True) angles = np.deg2rad(angles) except ValueError: self.logger.fatal("Something wrong with angle file %s" % angles_file) raise self.dataset_info.rotation_angles = angles elif self.dataset_info.rotation_angles is None: angles_range_txt = "[0, 180[ degrees" if rec_params["enable_halftomo"]: angles_range_txt = "[0, 360] degrees" angles = np.linspace(0, 2 * np.pi, n_angles, True) else: angles = np.linspace(0, np.pi, n_angles, False) self.logger.warning( "No information was found on rotation angles. Using default %s (half tomo is %s)" % (angles_range_txt, {0: "OFF", 1: "ON"}[int(self.do_halftomo)]) ) self.dataset_info.rotation_angles = angles def _get_cor(self): cor = self.nabu_config["reconstruction"]["rotation_axis_position"] if isinstance(cor, str): # auto-CoR cor = estimate_cor( cor, self.dataset_info, do_flatfield=(self.nabu_config["preproc"]["flatfield"] is not False), cor_options_str=self.nabu_config["reconstruction"]["cor_options"], logger=self.logger, ) self.logger.info("Estimated center of rotation: %.3f" % cor) self.dataset_info.axis_position = cor def _get_user_sino_normalization(self): self._sino_normalization_arr = None norm = self.nabu_config["preproc"]["sino_normalization"] if norm not in ["subtraction", "division"]: return norm_path = "silx://" + self.nabu_config["preproc"]["sino_normalization_file"].strip() url = DataUrl(norm_path) try: norm_array = get_data(url) self._sino_normalization_arr = norm_array.astype("f") except (ValueError, OSError) as exc: error_msg = "Could not load sino_normalization_file %s. The error was: %s" % (norm_path, str(exc)) self.logger.error(error_msg) raise ValueError(error_msg) @property def do_halftomo(self): """ Return True if the current dataset is to be reconstructed using 'half-acquisition' setting. """ enable_halftomo = self.nabu_config["reconstruction"]["enable_halftomo"] is_halftomo_dataset = self.dataset_info.is_halftomo if enable_halftomo == "auto": if is_halftomo_dataset is None: raise ValueError( "enable_halftomo was set to 'auto', but information on field of view was not found. Please set either 0 or 1 for enable_halftomo" ) return is_halftomo_dataset return enable_halftomo def _coupled_validation(self): self.logger.debug("Doing coupled validation") self._dataset_validator = FullFieldDatasetValidator(self.nabu_config, self.dataset_info) # Not so ideal to propagate fields like this for what in ["rec_params", "rec_region", "binning", "subsampling_factor"]: setattr(self, what, getattr(self._dataset_validator, what)) # # Attributes that combine dataset information and user overwrites (eg. binning). # Must be accessed after __init__() is done # @property def binning_x(self): return getattr(self, "binning", (1, 1))[0] @property def binning_z(self): return getattr(self, "binning", (1, 1))[1] @property def subsampling(self): return getattr(self, "subsampling_factor", None) def radio_shape(self, binning=False): n_x, n_z = self.dataset_info.radio_dims if binning: n_z //= self.binning_z n_x //= self.binning_x return (n_z, n_x) def n_angles(self, subsampling=False): rot_angles = self.dataset_info.rotation_angles if subsampling: rot_angles = rot_angles[:: (self.subsampling or 1)] return len(rot_angles) def radios_shape(self, binning=False, subsampling=False): n_z, n_x = self.radio_shape(binning=binning) n_a = self.n_angles(subsampling=subsampling) return (n_a, n_z, n_x) def rotation_axis_position(self, binning=False): cor = self.dataset_info.axis_position # might be None (default to the middle of detector) if cor is None and self.do_halftomo: raise ValueError("No information on center of rotation, cannot use half-tomography reconstruction") if cor is not None and binning: # Backprojector uses middle of pixel for coordinate indices. # This means that the leftmost edge of the leftmost pixel has coordinate -0.5. # When using binning with a factor 'b', the CoR has to adapted as # cor_binned = (cor + 0.5)/b - 0.5 cor = (cor + 0.5) / self.binning_x - 0.5 return cor def sino_shape(self, binning=False, subsampling=False, after_ha=False): """ Return the shape of a sinogram image. Parameter --------- binning: bool Whether to account for image binning subsampling: bool Whether to account for projections subsampling after_ha: bool Whether to return the sinogram shape before building the extended-FoV sinogram (default) or after. """ n_a, _, n_x = self.radios_shape(binning=binning, subsampling=subsampling) if self.do_halftomo and after_ha: # int(round(cor/binning)) is not equivalent to int(round(cor))//binning # the former should be used because HA sinograms are built after data is binned cor = self.rotation_axis_position(binning=binning) n_x = get_extended_sinogram_width(n_x, cor) n_a = (n_a + 1) // 2 return (n_a, n_x) def sinos_shape(self, binning=False, subsampling=False, after_ha=False): n_z, _ = self.radio_shape(binning=binning) return (n_z,) + self.sino_shape(binning=binning, subsampling=subsampling, after_ha=after_ha) def projs_indices(self, subsampling=False): step = 1 if subsampling: step = self.subsampling or 1 return sorted(self.dataset_info.projections.keys())[::step] def rotation_angles(self, subsampling=False): step = 1 if subsampling: step = self.subsampling or 1 return self.dataset_info.rotation_angles[::step] @property def rec_roi(self): """ Returns the reconstruction region of interest (ROI), accounting for binning in both dimensions. """ rec_params = self.rec_params # accounts for binning x_s, x_e = rec_params["start_x"], rec_params["end_x"] y_s, y_e = rec_params["start_y"], rec_params["end_y"] # Upper bound (end_xy) is INCLUDED in nabu config, hence the +1 here return (x_s, x_e + 1, y_s, y_e + 1) @property def rec_shape(self): # Accounts for binning! return tuple(np.diff(self.rec_roi)[::-2]) @property def rec_delta_z(self): # Accounts for binning! z_s, z_e = self.rec_params["start_z"], self.rec_params["end_z"] # Upper bound (end_xy) is INCLUDED in nabu config, hence the +1 here return z_e + 1 - z_s def is_before_radios_cropping(self, step): """ Return true if a given processing step happens before radios cropping """ if step == "sinogram": return False if step not in self._all_processing_steps: raise ValueError("Unknown step: '%s'. Available are: %s" % (step, self._all_processing_steps)) # sino_normalization return self._all_processing_steps.index(step) <= self._all_processing_steps.index("radios_movements") # # Build processing steps # # TODO update behavior and remove this function once GroupedPipeline cuda backend is implemented def get_radios_rotation_mode(self): """ Determine whether projections are to be rotated, and if so, when they are to be rotated. Returns ------- method: str or None Rotation method: one of the values of `nabu.resources.params.radios_rotation_mode` """ user_rotate_projections = self.nabu_config["preproc"]["rotate_projections"] tilt = self.dataset_info.detector_tilt phase_method = self.nabu_config["phase"]["method"] do_ctf = phase_method == "CTF" do_pag = phase_method == "paganin" do_unsharp = ( self.nabu_config["phase"]["unsharp_method"] is not None and self.nabu_config["phase"]["unsharp_coeff"] > 0 ) if user_rotate_projections is None and tilt is None: return None if do_ctf: return "full" # TODO "chunked" rotation is done only when using a "processing margin" # For now the processing margin is enabled only if phase or unsharp is enabled. # We can either # - Enable processing margin if rotating projections is needed (more complicated to implement) # - Always do "full" rotation (simpler to implement, at the expense of performances) if do_pag or do_unsharp: return "chunk" else: return "full" def _build_processing_steps(self): nabu_config = self.nabu_config dataset_info = self.dataset_info binning = (nabu_config["dataset"]["binning"], nabu_config["dataset"]["binning_z"]) tasks = [] options = {} # # Dataset / Get data # # First thing to do is to get the data (radios or sinograms) # For now data is assumed to be on disk (see issue #66). tasks.append("read_chunk") options["read_chunk"] = { "files": dataset_info.projections, "sub_region": None, "binning": binning, "dataset_subsampling": nabu_config["dataset"]["projections_subsampling"], } # # Flat-field # if nabu_config["preproc"]["flatfield"]: tasks.append("flatfield") options["flatfield"] = { # ChunkReader handles binning/subsampling by itself, # but FlatField needs "real" indices (after binning/subsampling) "projs_indices": self.projs_indices(subsampling=False), "binning": binning, "do_flat_distortion": nabu_config["preproc"]["flat_distortion_correction_enabled"], "flat_distortion_params": extract_parameters(nabu_config["preproc"]["flat_distortion_params"]), } normalize_srcurrent = nabu_config["preproc"]["normalize_srcurrent"] radios_srcurrent = None flats_srcurrent = None if normalize_srcurrent: if ( dataset_info.projections_srcurrent is None or dataset_info.flats_srcurrent is None or len(dataset_info.flats_srcurrent) == 0 ): self.logger.error("Cannot do SRCurrent normalization: missing flats and/or projections SRCurrent") normalize_srcurrent = False else: radios_srcurrent = dataset_info.projections_srcurrent flats_srcurrent = dataset_info.flats_srcurrent options["flatfield"].update( { "normalize_srcurrent": normalize_srcurrent, "radios_srcurrent": radios_srcurrent, "flats_srcurrent": flats_srcurrent, } ) # # Spikes filter # if nabu_config["preproc"]["ccd_filter_enabled"]: tasks.append("ccd_correction") options["ccd_correction"] = { "type": "median_clip", # only one available for now "median_clip_thresh": nabu_config["preproc"]["ccd_filter_threshold"], } # # Double flat field # if nabu_config["preproc"]["double_flatfield_enabled"]: tasks.append("double_flatfield") options["double_flatfield"] = { "sigma": nabu_config["preproc"]["dff_sigma"], "processes_file": nabu_config["preproc"]["processes_file"], } # # Radios rotation (do it here if possible) # if self.get_radios_rotation_mode() == "chunk": tasks.append("rotate_projections") options["rotate_projections"] = { "angle": nabu_config["preproc"]["rotate_projections"] or dataset_info.detector_tilt, "center": nabu_config["preproc"]["rotate_projections_center"], "mode": "chunk", } # # # Phase retrieval # if nabu_config["phase"]["method"] is not None: tasks.append("phase") options["phase"] = copy_dict_items(nabu_config["phase"], ["method", "delta_beta", "padding_type"]) options["phase"].update( { "energy_kev": dataset_info.energy, "distance_cm": dataset_info.distance * 1e2, "distance_m": dataset_info.distance, "pixel_size_microns": dataset_info.pixel_size, "pixel_size_m": dataset_info.pixel_size * 1e-6, } ) if binning != (1, 1): options["phase"]["delta_beta"] /= binning[0] * binning[1] if options["phase"]["method"] == "CTF": self._get_ctf_parameters(options["phase"]) # # Unsharp # if nabu_config["phase"]["unsharp_method"] is not None and nabu_config["phase"]["unsharp_coeff"] > 0: tasks.append("unsharp_mask") options["unsharp_mask"] = copy_dict_items( nabu_config["phase"], ["unsharp_coeff", "unsharp_sigma", "unsharp_method"] ) # # -logarithm # if nabu_config["preproc"]["take_logarithm"]: tasks.append("take_log") options["take_log"] = copy_dict_items(nabu_config["preproc"], ["log_min_clip", "log_max_clip"]) # # Radios rotation (do it here if mode=="full") # if self.get_radios_rotation_mode() == "full": tasks.append("rotate_projections") options["rotate_projections"] = { "angle": nabu_config["preproc"]["rotate_projections"] or dataset_info.detector_tilt, "center": nabu_config["preproc"]["rotate_projections_center"], "mode": "full", } # # Translation movements # translations = dataset_info.translations if translations is not None: tasks.append("radios_movements") options["radios_movements"] = {"translation_movements": dataset_info.translations[:: self.binning_z]} # # Sinogram normalization (before half-tomo) # if nabu_config["preproc"]["sino_normalization"] is not None: tasks.append("sino_normalization") options["sino_normalization"] = { "method": nabu_config["preproc"]["sino_normalization"], "normalization_array": self._sino_normalization_arr, } # # Sinogram-based rings artefacts removal # if nabu_config["preproc"]["sino_rings_correction"]: tasks.append("sino_rings_correction") options["sino_rings_correction"] = { "user_options": nabu_config["preproc"]["sino_rings_options"], } # # Reconstruction # if nabu_config["reconstruction"]["method"] is not None: tasks.append("reconstruction") # Iterative is not supported through configuration file for now. options["reconstruction"] = copy_dict_items( self.rec_params, [ "method", "fbp_filter_type", "fbp_filter_cutoff", "padding_type", "start_x", "end_x", "start_y", "end_y", "start_z", "end_z", "centered_axis", "clip_outer_circle", "source_sample_dist", "sample_detector_dist", ], ) rec_options = options["reconstruction"] rec_options["rotation_axis_position"] = self.rotation_axis_position(binning=True) rec_options["fbp_rotation_axis_position"] = rec_options["rotation_axis_position"] rec_options["enable_halftomo"] = self.do_halftomo rec_options["axis_correction"] = dataset_info.axis_correction if dataset_info.axis_correction is not None: rec_options["axis_correction"] = rec_options["axis_correction"][:: self.subsampling_factor] rec_options["angles"] = np.array(self.rotation_angles(subsampling=True)) rec_options["angles"] += np.deg2rad(nabu_config["reconstruction"]["angle_offset"]) rec_options["pixel_size_cm"] = dataset_info.pixel_size * 1e-4 # pix size is in microns in dataset_info if rec_options["sample_detector_dist"] is None: rec_options["sample_detector_dist"] = self.dataset_info.distance # was checked to be not None earlier # TODO improve halftomo handling if self.do_halftomo: rec_options["angles"] = rec_options["angles"][: (rec_options["angles"].size + 1) // 2] if rec_options["rotation_axis_position"] < (self.radio_shape(binning=True)[-1] - 1) / 2.0: rec_options["fbp_rotation_axis_position"] = ( self.radio_shape(binning=True)[-1] - rec_options["rotation_axis_position"] ) # --- # New key rec_options["cor_estimated_auto"] = isinstance(nabu_config["reconstruction"]["rotation_axis_position"], str) # # Histogram # if nabu_config["postproc"]["output_histogram"]: tasks.append("histogram") options["histogram"] = copy_dict_items(nabu_config["postproc"], ["histogram_bins"]) # # Save # if nabu_config["output"]["location"] is not None: tasks.append("save") options["save"] = copy_dict_items(nabu_config["output"], list(nabu_config["output"].keys())) options["save"]["overwrite"] = nabu_config["output"]["overwrite_results"] self.processing_steps = tasks self.processing_options = options if set(self.processing_steps) != set(self.processing_options.keys()): raise ValueError("Something wrong with process_config: options do not correspond to steps") # Add check if set(self.processing_steps) != set(self.processing_options.keys()): raise ValueError("Something wrong when building processing steps") def _get_ctf_parameters(self, phase_options): dataset_info = self.dataset_info user_phase_options = self.nabu_config["phase"] ctf_geom = extract_parameters(user_phase_options["ctf_geometry"]) ctf_advanced_params = extract_parameters(user_phase_options["ctf_advanced_params"]) # z1_vh z1_v = ctf_geom["z1_v"] z1_h = ctf_geom["z1_h"] z1_vh = None if z1_h is None and z1_v is None: # parallel beam z1_vh = None elif (z1_v is None) ^ (z1_h is None): # only one is provided: source-sample distance z1_vh = z1_v or z1_h if z1_h is not None and z1_v is not None: # distance of the vertically focused source (horizontal line) and the horizontaly focused source (vertical line) # for KB mirrors z1_vh = (z1_v, z1_h) # pix_size_det pix_size_det = ctf_geom["detec_pixel_size"] or dataset_info.pixel_size * 1e-6 # wavelength wavelength = 1.23984199e-9 / dataset_info.energy phase_options["ctf_geo_pars"] = { "z1_vh": z1_vh, "z2": phase_options["distance_m"], "pix_size_det": pix_size_det, "wavelength": wavelength, "magnification": bool(ctf_geom["magnification"]), "length_scale": ctf_advanced_params["length_scale"], } phase_options["ctf_lim1"] = ctf_advanced_params["lim1"] phase_options["ctf_lim2"] = ctf_advanced_params["lim2"] phase_options["ctf_normalize_by_mean"] = ctf_advanced_params["normalize_by_mean"] def _configure_save_steps(self, steps_to_save=None): self.steps_to_save = [] self.dump_sinogram = False if steps_to_save is None: steps_to_save = self.nabu_config["pipeline"]["save_steps"] if steps_to_save in (None, ""): return steps_to_save = [s.strip() for s in steps_to_save.split(",")] for step in self.processing_steps: step = step.strip() if step in steps_to_save: self.processing_options[step]["save"] = True self.processing_options[step]["save_steps_file"] = self.get_save_steps_file(step_name=step) # "sinogram" is a special keyword, not explicitly in the processing steps if "sinogram" in steps_to_save: self.dump_sinogram = True self.dump_sinogram_file = self.get_save_steps_file(step_name="sinogram") self.steps_to_save = steps_to_save def _get_dump_file_and_h5_path(self): resume_from = self.resume_from_step process_file = self.get_save_steps_file(step_name=resume_from) if not os.path.isfile(process_file): self.logger.error("Cannot resume processing from step '%s': no such file %s" % (resume_from, process_file)) return None, None h5_entry = self.dataset_info.hdf5_entry or "entry" process_h5_path = posixpath.join(h5_entry, resume_from, "results/data") if not hdf5_entry_exists(process_file, process_h5_path): self.logger.error("Could not find data in %s in file %s" % (process_h5_path, process_file)) process_h5_path = None return process_file, process_h5_path def _configure_resume(self, resume_from=None): self.resume_from_step = None if resume_from is None: resume_from = self.nabu_config["pipeline"]["resume_from_step"] if resume_from in (None, ""): return resume_from = resume_from.strip(" ,;") self.resume_from_step = resume_from processing_steps = self.processing_steps # special case: resume from sinogram if resume_from == "sinogram": # disable up to 'reconstruction', not included if "sino_rings_correction" in processing_steps: # Sinogram destriping is done before building the half tomo sino. # Not sure if this is needed (i.e can we do before building the extended sino ?) # TODO find a more elegant way idx = processing_steps.index("sino_rings_correction") else: idx = processing_steps.index("reconstruction") # elif resume_from in processing_steps: idx = processing_steps.index(resume_from) + 1 # disable up to resume_from, included else: msg = "Cannot resume processing from step '%s': no such step in the current configuration" % resume_from self.logger.error(msg) self.resume_from_step = None return # Get corresponding file and h5 path process_file, process_h5_path = self._get_dump_file_and_h5_path() if process_file is None or process_h5_path is None: self.resume_from_step = None return dump_info = self._check_dump_file(process_file, raise_on_error=False) if dump_info is None: self.logger.error("Cannot resume from step %s: cannot use file %s" % (resume_from, process_file)) self.resume_from_step = None return dump_start_z, dump_end_z = dump_info # Disable steps steps_to_disable = processing_steps[1:idx] self.logger.debug("Disabling steps %s" % str(steps_to_disable)) for step_name in steps_to_disable: processing_steps.remove(step_name) self.processing_options.pop(step_name) # Update configuration self.logger.info("Processing will be resumed from step '%s' using file %s" % (resume_from, process_file)) self._old_read_chunk = self.processing_options["read_chunk"] self.processing_options["read_chunk"] = { "process_file": process_file, "process_h5_path": process_h5_path, "step_name": resume_from, "dump_start_z": dump_start_z, "dump_end_z": dump_end_z, } # Dont dump a step if we resume from this step if resume_from in self.steps_to_save: self.logger.warning( "Processing is resumed from step '%s'. This step won't be dumped to a file" % resume_from ) self.steps_to_save.remove(resume_from) if resume_from == "sinogram": self.dump_sinogram = False else: if resume_from in self.processing_options: # should not happen self.processing_options[resume_from].pop("save") def _check_dump_file(self, process_file, raise_on_error=False): """ Return (start_z, end_z) on success Return None on failure """ # Ensure data in the file correspond to what is currently asked if self.resume_from_step is None: return None # Check dataset shape/start_z/end_z rec_cfg_h5_path = posixpath.join( self.dataset_info.hdf5_entry or "entry", self.resume_from_step, "configuration/processing_options/reconstruction", ) dump_start_z = get_h5_value(process_file, posixpath.join(rec_cfg_h5_path, "start_z")) dump_end_z = get_h5_value(process_file, posixpath.join(rec_cfg_h5_path, "end_z")) if dump_end_z < 0: dump_end_z += self.radio_shape(binning=False)[0] start_z, end_z = ( self.processing_options["reconstruction"]["start_z"], self.processing_options["reconstruction"]["end_z"], ) if not (dump_start_z <= start_z and end_z <= dump_end_z): msg = ( "File %s was built with start_z=%d, end_z=%d but current configuration asks for start_z=%d, end_z=%d" % (process_file, dump_start_z, dump_end_z, start_z, end_z) ) if not raise_on_error: self.logger.error(msg) return None self.logger.fatal(msg) raise ValueError(msg) # Check parameters other than reconstruction filedump_nabu_config = import_h5_to_dict( process_file, posixpath.join(self.dataset_info.hdf5_entry or "entry", self.resume_from_step, "configuration/nabu_config"), ) sections_to_ignore = ["reconstruction", "output", "pipeline"] for section in sections_to_ignore: filedump_nabu_config[section] = self.nabu_config[section] # special case of the "save_steps process" # filedump_nabu_config["pipeline"]["save_steps"] = self.nabu_config["pipeline"]["save_steps"] diff = compare_dicts(filedump_nabu_config, self.nabu_config) if diff is not None: msg = "Nabu configuration in file %s differ from the current one: %s" % (process_file, diff) if not raise_on_error: self.logger.error(msg) return None self.logger.fatal(msg) raise ValueError(msg) # return (dump_start_z, dump_end_z) def get_save_steps_file(self, step_name=None): if self.nabu_config["pipeline"]["steps_file"] not in (None, ""): return self.nabu_config["pipeline"]["steps_file"] nabu_save_options = self.nabu_config["output"] output_dir = nabu_save_options["location"] file_prefix = step_name + "_" + nabu_save_options["file_prefix"] fname = os.path.join(output_dir, file_prefix) + ".hdf5" return fname ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/pipeline/fullfield/reconstruction.py0000644000175000017500000010357200000000000022367 0ustar00pierrepierrefrom os import environ from os.path import join, isfile, basename, dirname from math import ceil import gc from psutil import virtual_memory from silx.io import get_data from silx.io.url import DataUrl from ...utils import check_supported, subdivide_into_overlapping_segment from ...resources.logger import LoggerOrPrint from ...io.writer import merge_hdf5_files, NXProcessWriter from ...cuda.utils import collect_cuda_gpus from ...preproc.phase import compute_paganin_margin from ...misc.histogram import PartialHistogram, add_last_bin, hist_as_2Darray from .chunked import ChunkedPipeline from .chunked_cuda import CudaChunkedPipeline from .computations import estimate_max_chunk_size def variable_idxlen_sort(fname): return int(fname.split("_")[-1].split(".")[0]) class FullFieldReconstructor: """ A reconstructor spawns and manages Pipeline objects, depending on the current chunk/group size. """ _available_backends = ["cuda", "numpy"] _process_name = "reconstruction" default_advanced_options = { "gpu_mem_fraction": 0.9, "cpu_mem_fraction": 0.9, "chunk_size": None, "margin": None, "force_grouped_mode": False, } def __init__(self, process_config, logger=None, backend="cuda", extra_options=None, cuda_options=None): """ Initialize a Reconstructor object. This class is used for managing pipelines. Parameters ---------- process_config: ProcessConfig object Data structure with process configuration logger: Logger, optional logging object backend: str, optional Which backend to use. Available are: "cuda", "numpy". extra_options: dict, optional Dictionary with advanced options. Please see 'Other parameters' below cuda_options: dict, optional Dictionary with cuda options passed to `nabu.cuda.processing.CudaProcessing` Other parameters ----------------- Advanced options can be passed in the 'extra_options' dictionary. These can be: - cpu_mem_fraction: 0.9, - gpu_mem_fraction: 0.9, - chunk_size: None, - margin: None, - force_grouped_mode: False """ self.logger = LoggerOrPrint(logger) self.process_config = process_config self._set_extra_options(extra_options) self._get_reconstruction_range() self._get_resources() self._get_backend(backend, cuda_options) self._compute_margin() self._compute_max_chunk_size() self._get_pipeline_mode() self._build_tasks() self._do_histograms = self.process_config.nabu_config["postproc"]["output_histogram"] self._histogram_merged = False self.pipeline = None self._histogram_merged = False # # Initialization # def _set_extra_options(self, extra_options): self.extra_options = self.default_advanced_options.copy() self.extra_options.update(extra_options or {}) self.gpu_mem_fraction = self.extra_options["gpu_mem_fraction"] self.cpu_mem_fraction = self.extra_options["cpu_mem_fraction"] def _get_reconstruction_range(self): rec_region = self.process_config.rec_region # without binning self.z_min = rec_region["start_z"] # In the user configuration, the upper bound is included: [start_z, end_z]. # In python syntax, the upper bound is not: [start_z, end_z[ self.z_max = rec_region["end_z"] + 1 self.delta_z = self.z_max - self.z_min # Cache some volume info self.n_angles = self.process_config.n_angles(subsampling=False) self.n_z, self.n_x = self.process_config.radio_shape(binning=False) def _get_resources(self): self.resources = {} self._get_gpu() self._get_memory() def _get_memory(self): vm = virtual_memory() self.resources["mem_avail_GB"] = vm.available / 1e9 # Account for other memory constraints. There might be a better way slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0)) if slurm_mem_constraint_MB > 0: self.resources["mem_avail_GB"] = slurm_mem_constraint_MB / 1e3 # self.cpu_mem = self.resources["mem_avail_GB"] * self.cpu_mem_fraction def _get_gpu(self): avail_gpus = collect_cuda_gpus() or {} self.resources["gpus"] = avail_gpus if len(avail_gpus) == 0: return # pick first GPU by default. TODO: handle user's nabu_config["resources"]["gpu_id"] self.resources["gpu_id"] = self._gpu_id = list(avail_gpus.keys())[0] def _get_backend(self, backend, cuda_options): self._pipeline_cls = ChunkedPipeline check_supported(backend, self._available_backends, "backend") if backend == "cuda": self.cuda_options = cuda_options if len(self.resources["gpus"]) == 0: # Not sure if an error should be raised in this case self.logger.error("Could not find any cuda device. Falling back on numpy/CPU backend.") backend = "numpy" else: self.gpu_mem = self.resources["gpus"][self._gpu_id]["memory_GB"] * self.gpu_mem_fraction if backend == "cuda": self._pipeline_cls = CudaChunkedPipeline self.backend = backend def _compute_max_chunk_size(self): """ Compute the maximum number of (partial) radios that can be processed in memory. Ideally, the processing is done by reading N lines of all the projections. This function estimates max_chunk_size = N_max, the maximum number of lines that can be read in all the projections while still fitting the memory. On the other hand, if a "processing margin" is needed (eg. phase retrieval), then we need to read at least N_min = 2 * margin_v + n_slices lines of each image. For large datasets, we have N_min > max_chunk_size, so we can't read lines from all the projections. """ user_chunk_size = self.extra_options["chunk_size"] if user_chunk_size is not None: self.chunk_size = user_chunk_size else: self.cpu_max_chunk_size = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="all", step=5 ) self.chunk_size = self.cpu_max_chunk_size if self.backend == "cuda": self.gpu_max_chunk_size = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="all", step=5, ) self.chunk_size = min(self.gpu_max_chunk_size, self.cpu_max_chunk_size) self.chunk_size = min(self.chunk_size, self.n_z) def _compute_max_group_size(self): """ Compute the maximum number of (partial) images that can be processed in memory """ # # "Group size" (i.e, how many radios can be processed in one pass for the first part of the pipeline) # user_group_size = self.extra_options.get("chunk_size", None) if user_group_size is not None: self.group_size = user_group_size else: self.cpu_max_group_size = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="radios", n_rows=min(2 * self._margin_v + self.delta_z, self.n_z), step=10, ) self.group_size = self.cpu_max_group_size if self.backend == "cuda": self.gpu_max_group_size = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="radios", n_rows=min(2 * self._margin_v + self.delta_z, self.n_z), step=10, ) self.group_size = min(self.gpu_max_group_size, self.cpu_max_group_size) self.group_size = min(self.group_size, self.n_angles) # # "sinos chunk size" (i.e, how many sinograms can be processed in one pass for the second part of the pipeline) # self.cpu_max_chunk_size_sinos = estimate_max_chunk_size( self.cpu_mem, self.process_config, pipeline_part="sinos", step=10, ) if self.backend == "cuda": self.gpu_max_chunk_size_sinos = estimate_max_chunk_size( self.gpu_mem, self.process_config, pipeline_part="sinos", step=5, ) self.chunk_size_sinos = min(self.gpu_max_chunk_size_sinos, self.cpu_max_chunk_size_sinos) self.chunk_size_sinos = min(self.chunk_size_sinos, self.delta_z) def _get_pipeline_mode(self): # "Pipeline mode" means we either process data chunks of type [:, delta_z, :] or [delta_theta, :, :]. # The first case is better in terms of performances and should be preferred. # However, in some cases, it's not possible to use it (eg. high "delta_z" because of margin) chunk_size_for_one_slice = 1 + 2 * self._margin_v # TODO ignore margin when resuming from sinogram ? chunk_is_too_small = False if chunk_size_for_one_slice > self.chunk_size: msg = str( "Margin is %d, so we need to process at least %d detector rows. However, the available memory enables to process only %d rows at once" % (self._margin_v, chunk_size_for_one_slice, self.chunk_size) ) chunk_is_too_small = True if self._margin_v > self.chunk_size // 3: n_slices = max(1, self.chunk_size - (2 * self._margin_v)) n_stages = ceil(self.delta_z / n_slices) if n_stages > 1: msg = str("Margin (%d) is too big for chunk size (%d)" % (self._margin_v, self.chunk_size)) chunk_is_too_small = True force_grouped_mode = self.extra_options.get("force_grouped_mode", False) if force_grouped_mode: msg = "Forcing grouped mode" if self.process_config.processing_options.get("phase", {}).get("method", None) == "CTF": force_grouped_mode = True msg = "CTF phase retrieval needs to process full radios" if self.process_config.processing_options.get("rotate_projections", {}).get("angle", 0) > 15: force_grouped_mode = True msg = "Radios rotation with a large angle needs to process full radios" if self.process_config.resume_from_step == "sinogram" and force_grouped_mode: self.logger.warning("Cannot use grouped-radios processing when resuming from sinogram") force_grouped_mode = False if not (chunk_is_too_small or force_grouped_mode): # Default case (preferred) self._pipeline_mode = "chunked" self.chunk_shape = (self.n_angles, self.chunk_size, self.n_x) else: # Fall-back mode (slower) self.logger.warning(msg) self._pipeline_mode = "grouped" self._compute_max_group_size() self.chunk_shape = (self.group_size, self.delta_z, self.n_x) self.logger.info("Using 'grouped' pipeline mode") # # "Margin" # def _compute_margin(self): user_margin = self.extra_options.get("margin", None) if self.process_config.resume_from_step == "sinogram": self.logger.debug("Margin not needed when resuming from sinogram") margin_v, margin_h = 0, 0 elif user_margin is not None and user_margin > 0: margin_v, margin_h = user_margin, user_margin self.logger.info("Using user-defined margin: %d" % user_margin) else: unsharp_margin = self._compute_unsharp_margin() phase_margin = self._compute_phase_margin() translations_margin = self._compute_translations_margin() cone_margin = self._compute_cone_overlap() # TODO radios rotation/movements margin_v = max(unsharp_margin[0], phase_margin[0], translations_margin[0], cone_margin[0]) margin_h = max(unsharp_margin[1], phase_margin[1], translations_margin[1], cone_margin[1]) if margin_v > 0: self.logger.info("Estimated margin: %d pixels" % margin_v) margin_v = min(margin_v, self.n_z) margin_h = min(margin_h, self.n_x) self._margin = margin_v, margin_h self._margin_v = margin_v def _compute_unsharp_margin(self): if "unsharp_mask" not in self.process_config.processing_steps: return (0, 0) opts = self.process_config.processing_options["unsharp_mask"] sigma = opts["unsharp_sigma"] # nabu uses cutoff = 4 cutoff = 4 gaussian_kernel_size = int(ceil(2 * cutoff * sigma + 1)) self.logger.debug("Unsharp mask margin: %d pixels" % gaussian_kernel_size) return (gaussian_kernel_size, gaussian_kernel_size) def _compute_phase_margin(self): phase_options = self.process_config.processing_options.get("phase", None) if phase_options is None: margin_v, margin_h = (0, 0) elif phase_options["method"] == "paganin": radio_shape = self.process_config.dataset_info.radio_dims[::-1] margin_v, margin_h = compute_paganin_margin( radio_shape, distance=phase_options["distance_m"], energy=phase_options["energy_kev"], delta_beta=phase_options["delta_beta"], pixel_size=phase_options["pixel_size_m"], padding=phase_options["padding_type"], use_rfft=False, # disable fftw here ) elif phase_options["method"] == "CTF": # The whole projection has to be processed! margin_v = ceil( (self.n_z - self.delta_z) / 2 ) # not so elegant. Use a dedicated flag eg. _process_whole_image ? margin_h = 0 # unused for now else: margin_v, margin_h = (0, 0) return (margin_v, margin_h) def _compute_translations_margin(self): translations = self.process_config.dataset_info.translations if translations is None: return (0, 0) margin_h_v = [] for i in range(2): transl = translations[:, i] margin_h_v.append(ceil(max([transl.max(), (-transl).max()]))) self.logger.debug("Maximum vertical displacement: %d pixels" % margin_h_v[1]) return tuple(margin_h_v[::-1]) def _compute_cone_overlap(self): rec_cfg = self.process_config.processing_options.get("reconstruction", {}) rec_method = rec_cfg.get("method", None) if rec_method != "cone": return (0, 0) d1 = rec_cfg["source_sample_dist"] d2 = rec_cfg["sample_detector_dist"] n_z, _ = self.process_config.radio_shape(binning=True) delta_z = self.process_config.rec_delta_z # accounts_for_binning overlap = ceil(delta_z * d2 / (d1 + d2)) # sqrt(2) missing ? max_overlap = ceil(n_z * d2 / (d1 + d2)) # sqrt(2) missing ? return (max_overlap, 0) def _ensure_good_chunk_size_and_margin(self): """ Check that "chunk_size" and "margin" (if any) are a multiple of binning factor. See nabu:!208 After that, all subregion lengths of _build_tasks_chunked() should be multiple of the binning factor, because "delta_z" itself was checked to be a multiple in DatasetValidator._handle_binning() """ bin_z = self.process_config.binning_z if bin_z <= 1: return self.chunk_size -= self.chunk_size % bin_z if self._margin_v > 0 and (self._margin_v % bin_z) > 0: self._margin_v += bin_z - (self._margin_v % bin_z) # i.e margin = ((margin % bin_z) + 1) * bin_z # # Tasks management # def _modify_processconfig_stage_1(self): # Modify the "process_config" object to dump sinograms proc = self.process_config self._old_steps_to_save = proc.steps_to_save.copy() if "sinogram" in proc.steps_to_save: return proc._configure_save_steps(self._old_steps_to_save + ["sinogram"]) def _undo_modify_processconfig_stage_1(self): self.process_config.steps_to_save = self._old_steps_to_save if "sinogram" not in self._old_steps_to_save: self.process_config.dump_sinogram = False def _modify_processconfig_stage_2(self): # Modify the "process_config" object to resume from sinograms proc = self.process_config self._old_resume_from = proc.resume_from_step self._old_proc_steps = proc.processing_steps.copy() self._old_proc_options = proc.processing_options.copy() proc._configure_resume(resume_from="sinogram") def _undo_modify_processconfig_stage_2(self): self.process_config.resume_from_step = self._old_resume_from self.process_config.processing_steps = self._old_proc_steps self.process_config.processing_options = self._old_proc_options def _build_tasks_grouped(self): tasks = [] segments = subdivide_into_overlapping_segment(self.n_angles, self.group_size, 0) for segment in segments: _, start_angle, end_angle, _ = segment z_min = max(self.z_min - self._margin_v, 0) z_max = min(self.z_max + self._margin_v, self.n_z) sub_region = ((start_angle, end_angle), (z_min, z_max), (0, self.chunk_shape[-1])) tasks.append({"sub_region": sub_region, "margin": (self.z_min - z_min, z_max - self.z_max)}) self.tasks = tasks # Build tasks for stage 2 (sinogram processing + reconstruction) segments = subdivide_into_overlapping_segment(self.delta_z, self.chunk_size_sinos, 0) self._sino_tasks = [] for segment in segments: _, start_z, end_z, _ = segment z_min = self.z_min + start_z z_max = min(self.z_min + end_z, self.n_z) sub_region = ((0, self.n_angles), (z_min, z_max), (0, self.chunk_shape[-1])) self._sino_tasks.append({"sub_region": sub_region, "margin": None}) def _build_tasks_chunked(self): margin_v = self._margin_v if self.chunk_size >= self.delta_z and self.z_min == 0 and self.z_max == self.n_z: # In this case we can do everything in a single stage, without margin self.tasks = [ { "sub_region": ((0, self.n_angles), (self.z_min, self.z_max), (0, self.chunk_shape[-1])), "margin": None, } ] return if self.chunk_size - (2 * margin_v) >= self.delta_z: # In this case we can do everything in a single stage n_slices = self.delta_z (margin_up, margin_down) = (min(margin_v, self.z_min), min(margin_v, self.n_z - self.z_max)) self.tasks = [ { "sub_region": ( (0, self.n_angles), (self.z_min - margin_up, self.z_max + margin_down), (0, self.chunk_shape[-1]), ), "margin": ((margin_up, margin_down), (0, 0)), } ] return # In this case there are at least two stages n_slices = self.chunk_size - (2 * margin_v) n_stages = ceil(self.delta_z / n_slices) self.tasks = [] curr_z_min = self.z_min curr_z_max = self.z_min + n_slices for i in range(n_stages): margin_up = min(margin_v, curr_z_min) margin_down = min(margin_v, max(self.n_z - curr_z_max, 0)) if curr_z_max + margin_down >= self.z_max: curr_z_max -= curr_z_max - (self.z_max + 0) margin_down = min(margin_v, max(self.n_z - 1 - curr_z_max, 0)) self.tasks.append( { "sub_region": ( (0, self.n_angles), (curr_z_min - margin_up, curr_z_max + margin_down), (0, self.chunk_shape[-1]), ), "margin": ((margin_up, margin_down), (0, 0)), } ) if curr_z_max == self.z_max: # No need for further tasks break curr_z_min += n_slices curr_z_max += n_slices def _build_tasks(self): if self._pipeline_mode == "grouped": self._build_tasks_grouped() else: self._ensure_good_chunk_size_and_margin() self._build_tasks_chunked() def _print_tasks_chunked(self): for task in self.tasks: margin_up, margin_down = task["margin"][0] s_u, s_d = task["sub_region"][1] print( "Top Margin: [%04d, %04d[ | Slices: [%04d, %04d[ | Bottom Margin: [%04d, %04d[" % (s_u, s_u + margin_up, s_u + margin_up, s_d - margin_down, s_d - margin_down, s_d) ) def _print_tasks(self): for task in self.tasks: margin_up, margin_down = task["phase_margin"][0] s_u, s_d = task["sub_region"] print( "Top Margin: [%04d, %04d[ | Slices: [%04d, %04d[ | Bottom Margin: [%04d, %04d[" # pylint: disable=E1307 % (s_u, s_u + margin_up, s_u + margin_up, s_d - margin_down, s_d - margin_down, s_d) ) def _get_chunk_length(self, task): if self._pipeline_mode == "helical": (start_z, end_z) = task["sub_region"] return end_z - start_z else: (start_angle, end_angle), (start_z, end_z), _ = task["sub_region"] if self._pipeline_mode == "chunked": return end_z - start_z else: return end_angle - start_angle def _give_progress_info(self, task): self.logger.info("Processing sub-volume %s" % (str(task["sub_region"][:-1]))) # # Reconstruction # def _instantiate_pipeline(self, task): self.logger.debug("Creating a new pipeline object") chunk_shape = tuple(s[1] - s[0] for s in task["sub_region"]) args = [self.process_config, chunk_shape] kwargs = {} if self.backend == "cuda": kwargs["cuda_options"] = self.cuda_options kwargs["use_grouped_mode"] = self._pipeline_mode == "grouped" pipeline = self._pipeline_cls(*args, logger=self.logger, margin=task["margin"], **kwargs) self.pipeline = pipeline def _instantiate_pipeline_if_necessary(self, current_task, other_task): """ Instantiate a pipeline only if current_task has a different "delta z" than other_task """ if self.pipeline is None: self._instantiate_pipeline(current_task) return length_cur = self._get_chunk_length(current_task) length_other = self._get_chunk_length(other_task) if length_cur != length_other: self.logger.debug("Destroying pipeline instance and releasing memory") self._destroy_pipeline() self._instantiate_pipeline(current_task) def _destroy_pipeline(self): self.pipeline = None # Not elegant, but for now the only way to release Cuda memory gc.collect() def _reconstruct_chunked(self, tasks=None): self.results = {} self._histograms = {} self._data_dumps = {} tasks = tasks or self.tasks prev_task = tasks[0] for task in tasks: self.logger.info("Processing sub-volume %s" % (str(task["sub_region"]))) self._instantiate_pipeline_if_necessary(task, prev_task) self.pipeline.process_chunk(task["sub_region"]) task_key = self.pipeline.sub_region task_result = self.pipeline.writer.fname self.results[task_key] = task_result if self.pipeline.writer.histogram: self._histograms[task_key] = self.pipeline.writer.histogram_writer.fname if len(self.pipeline.datadump_manager.data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline.datadump_manager.data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task def _reconstruct_grouped(self): self.results = {} # self._histograms = {} self._data_dumps = {} prev_task = self.tasks[0] # Stage 1: radios processing self._modify_processconfig_stage_1() for task in self.tasks: self.logger.info("Processing sub-volume %s" % (str(task["sub_region"]))) self._instantiate_pipeline_if_necessary(task, prev_task) self.pipeline.process_chunk(task["sub_region"]) task_key = self.pipeline.sub_region task_result = self.pipeline.datadump_manager.data_dump["sinogram"].fname self.results[task_key] = task_result if len(self.pipeline.datadump_manager.data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline.datadump_manager.data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task self.merge_data_dumps(axis=0) self._destroy_pipeline() self.logger.info("End of first stage of processing. Will now process sinograms saved on disk") self._undo_modify_processconfig_stage_1() # Stage 2: sinograms processing and reconstruction self._modify_processconfig_stage_2() self._pipeline_mode = "chunked" self._reconstruct_chunked(tasks=self._sino_tasks) self._pipeline_mode = "grouped" self._undo_modify_processconfig_stage_2() def reconstruct(self): if self._pipeline_mode == "chunked": self._reconstruct_chunked() else: self._reconstruct_grouped() # # Writing data # def get_relative_files(self, files=None): out_cfg = self.process_config.nabu_config["output"] if files is None: files = list(self.results.values()) try: files.sort(key=variable_idxlen_sort) except: self.logger.error( "Lexical ordering failed, falling back to default sort - it will fail for more than 10k projections" ) files.sort() local_files = [join(out_cfg["file_prefix"], basename(fname)) for fname in files] return local_files def merge_hdf5_reconstructions( self, output_file=None, prefix=None, files=None, process_name=None, axis=0, merge_histograms=True, output_dir=None, ): """ Merge existing hdf5 files by creating a HDF5 virtual dataset. Parameters ---------- output_file: str, optional Output file name. If not given, the file prefix in section "output" of nabu config will be taken. """ out_cfg = self.process_config.nabu_config["output"] out_dir = output_dir or out_cfg["location"] prefix = prefix or "" # Prevent issue when out_dir is empty, which happens only if dataset/location is a relative path. # TODO this should be prevented earlier if out_dir is None or len(out_dir.strip()) == 0: out_dir = dirname(dirname(self.results[list(self.results.keys())[0]])) # if output_file is None: output_file = join(out_dir, prefix + out_cfg["file_prefix"]) + ".hdf5" if isfile(output_file): msg = str("File %s already exists" % output_file) if out_cfg["overwrite_results"]: msg += ". Overwriting as requested in configuration file" self.logger.warning(msg) else: msg += ". Set overwrite_results to True in [output] to overwrite existing files." self.logger.fatal(msg) raise ValueError(msg) local_files = files if local_files is None: local_files = self.get_relative_files() if local_files == []: self.logger.error("No files to merge") return entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") process_name = process_name or self._process_name h5_path = join(entry, *[process_name, "results", "data"]) # self.logger.info("Merging %ss to %s" % (process_name, output_file)) # # When dumping to disk an intermediate step taking place before cropping the radios, # 'start_z' and 'end_z' found in nabu config have to be augmented with margin_z. # (these values are checked in ProcessConfig._configure_resume()) # patched_start_end_z = False user_rec_config = self.process_config.processing_options["reconstruction"] if ( self._margin_v > 0 and process_name != "reconstruction" and self.process_config.is_before_radios_cropping(process_name) ): patched_start_end_z = True old_start_z = user_rec_config["start_z"] old_end_z = user_rec_config["end_z"] user_rec_config["start_z"] = max(0, old_start_z - self._margin_v) user_rec_config["end_z"] = min(self.n_z, old_end_z + self._margin_v) # merge_hdf5_files( local_files, h5_path, output_file, process_name, output_entry=entry, output_filemode="a", processing_index=0, config={ self._process_name + "_stages": {str(k): v for k, v in zip(self.results.keys(), local_files)}, "nabu_config": self.process_config.nabu_config, "processing_options": self.process_config.processing_options, }, base_dir=out_dir, axis=axis, overwrite=out_cfg["overwrite_results"], ) if merge_histograms: self.merge_histograms(output_file=output_file) if patched_start_end_z: user_rec_config["start_z"] = old_start_z user_rec_config["end_z"] = old_end_z return output_file merge_hdf5_files = merge_hdf5_reconstructions def merge_histograms(self, output_file=None, force_merge=False): """ Merge the partial histograms """ if not (self._do_histograms): return if self._histogram_merged and not (force_merge): return self.logger.info("Merging histograms") masterfile_entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") masterfile_process_name = "histogram" # TODO don't hardcode process name output_entry = masterfile_entry out_cfg = self.process_config.nabu_config["output"] if output_file is None: output_file = ( join(dirname(list(self._histograms.values())[0]), out_cfg["file_prefix"] + "_histogram") + ".hdf5" ) local_files = self.get_relative_files(files=list(self._histograms.values())) # h5_path = join(masterfile_entry, *[masterfile_process_name, "results", "data"]) # try: files = sorted(self._histograms.values(), key=variable_idxlen_sort) except: self.logger.error( "Lexical ordering of histogram failed, falling back to default sort - it will fail for more than 10k projections" ) files = sorted(self._histograms.values()) data_urls = [] for fname in files: url = DataUrl(file_path=fname, data_path=h5_path, data_slice=None, scheme="silx") data_urls.append(url) histograms = [] for data_url in data_urls: h2D = get_data(data_url) histograms.append((h2D[0], add_last_bin(h2D[1]))) histograms_merger = PartialHistogram(method="fixed_bins_number", num_bins=histograms[0][0].size) merged_hist = histograms_merger.merge_histograms(histograms) rec_region = self.process_config.rec_region # Not sure if we should account for binning here. # (note that "rec_region" does not account for binning). # Anyway the histogram has little use when using binning volume_shape = ( rec_region["end_z"] - rec_region["start_z"] + 1, rec_region["end_y"] - rec_region["start_y"] + 1, rec_region["end_x"] - rec_region["start_x"] + 1, ) writer = NXProcessWriter(output_file, entry=output_entry, filemode="a", overwrite=True) writer.write( hist_as_2Darray(merged_hist), "histogram", processing_index=1, config={ "files": local_files, "bins": self.process_config.nabu_config["postproc"]["histogram_bins"], "volume_shape": volume_shape, }, is_frames_stack=False, direct_access=False, ) self._histogram_merged = True def merge_data_dumps(self, axis=1): # Collect in a dict where keys are step names (instead of task keys) dumps = {} for task_key, data_dumps in self._data_dumps.items(): for step_name, fname in data_dumps.items(): fname = join(basename(dirname(fname)), basename(fname)) if step_name not in dumps: dumps[step_name] = [fname] else: dumps[step_name].append(fname) # Merge HDF5 files for step_name, files in dumps.items(): dump_file = self.process_config.get_save_steps_file(step_name=step_name) self.merge_hdf5_files( output_file=dump_file, output_dir=dirname(dump_file), files=files, process_name=step_name, axis=axis, merge_histograms=False, ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4647331 nabu-2023.1.1/nabu/pipeline/helical/0000755000175000017500000000000000000000000016357 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1656056444.0 nabu-2023.1.1/nabu/pipeline/helical/__init__.py0000644000175000017500000000000000000000000020456 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/helical/dataset_validator.py0000644000175000017500000000057600000000000022433 0ustar00pierrepierrefrom ..fullfield.dataset_validator import * class HelicalDatasetValidator(FullFieldDatasetValidator): """Allows more freedom in the choice of the slice indices""" # this in the fullfield base class is instead True _check_also_z = False def _check_slice_indices(self): """Slice indices can be far beyond what fullfield pipeline accepts""" return ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/helical/fbp.py0000644000175000017500000001271400000000000017505 0ustar00pierrepierrefrom ...reconstruction.fbp import * from .filtering import HelicalSinoFilter from ...utils import convert_index class BackprojectorHelical(Backprojector): """This is the Backprojector derived class for helical reconstruction. The modifications are detailed here : * the backprojection is decoupled from the filtering. This allows, in the pipeline, for doing first a filtering using backprojector_object.sino_filter subobject, then calling backprojector_object.backprojection only after reweigthing the result of the filter_sino method of sino_filter subobject. * the angles can be resetted on the fly, and the class can work with a variable number of projection. As a matter of fact, with the helical_chunked_regridded.py pipeline, the reconstruction is done each time with the same set of angles, this because of the regridding mechanism. The feature might return useful in the future if alternative approachs are developed again. * """ def set_custom_angles_and_axis_corrections(self, angles_rad, x_per_proj): """To arbitrarily change angles Parameters ========== angles_rad: array of floats one angle per each projection in radians x_per_proj: array of floats each entry is the axis shift for a projection, in pixels. """ self.n_provided_angles = len(angles_rad) self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32) self._axis_correction[0, : self.n_provided_angles] = -x_per_proj self.angles[: self.n_provided_angles] = angles_rad self._compute_angles_again() self.kern_proj_args[1] = self.n_provided_angles self.kern_proj_args[6] = self.offsets["x"] self.kern_proj_args[7] = self.offsets["y"] def _compute_angles_again(self): """to update angles dependent auxiliary arrays""" self.h_cos[0] = np.cos(self.angles).astype("f") self.h_msin[0] = (-np.sin(self.angles)).astype("f") self._d_msin[:] = self.h_msin[0] self._d_cos[:] = self.h_cos[0] if self._axis_correction is not None: self._d_axcorr[:] = self._axis_correction def _init_filter(self, filter_name): """To use the HelicalSinoFilter which is derived from SinoFilter with a slightly different padding scheme """ self.filter_name = filter_name self.sino_filter = HelicalSinoFilter( self.sino_shape, filter_name=self.filter_name, padding_mode=self.padding_mode, cuda_options={"ctx": self.cuda_processing.ctx}, ) def backprojection(self, sino, output=None): """Redefined here to do backprojection only, compare to the base class method.""" self._d_sino[:] = sino res = self.backproj(self._d_sino, output=output) return res def _init_geometry(self, sino_shape, slice_shape, angles, rot_center, slice_roi): """this is identical to _init_geometry of the base class with the exception that self.extra_options["centered_axis"] is not considered and as a consequence self.offsets is not set here and the one of _set_slice_roi is kept. """ if slice_shape is not None and slice_roi is not None: raise ValueError("slice_shape and slice_roi cannot be used together") self.sino_shape = sino_shape if len(sino_shape) == 2: n_angles, dwidth = sino_shape else: raise ValueError("Expected 2D sinogram") self.dwidth = dwidth self.rot_center = rot_center or (self.dwidth - 1) / 2.0 self._set_slice_shape( slice_shape, ) self.axis_pos = self.rot_center self._set_angles(angles, n_angles) self._set_slice_roi(slice_roi) self._set_axis_corr() def _set_slice_shape(self, slice_shape): """this is identical to the _set_slice_shape ofthe base class with the exception that n_y,n_x default to the largest possible reconstructible slice """ n_y = self.dwidth + abs(self.dwidth - 1 - self.rot_center * 2) n_x = self.dwidth + abs(self.dwidth - 1 - self.rot_center * 2) if slice_shape is not None: if np.isscalar(slice_shape): slice_shape = (slice_shape, slice_shape) n_y, n_x = slice_shape self.n_x = n_x self.n_y = n_y self.slice_shape = (n_y, n_x) def _set_slice_roi(self, slice_roi): """Automatically tune the offset to in all cases.""" self.slice_roi = slice_roi if slice_roi is None: off = -(self.dwidth - 1 - self.rot_center * 2) if off < 0: self.offsets = {"x": off, "y": off} else: self.offsets = {"x": 0, "y": 0} else: start_x, end_x, start_y, end_y = slice_roi # convert negative indices slice_width, _ = self.slice_shape off = min(0, -(self.dwidth - 1 - self.rot_center * 2)) if end_x < start_x: start_x = off end_x = off + slice_width if end_y < start_y: start_y = off end_y = off + slice_width self.slice_shape = (end_y - start_y, end_x - start_x) self.n_x = self.slice_shape[-1] self.n_y = self.slice_shape[-2] self.offsets = {"x": start_x, "y": start_y} ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/pipeline/helical/filtering.py0000644000175000017500000002416100000000000020720 0ustar00pierrepierre# pylint: disable=too-many-arguments from ...reconstruction.filtering import * import math import os # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-function-args class HelicalSinoFilter(SinoFilter): def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, cuda_options=None, ): """Derived from nabu.reconstruction.filtering.SinoFilter It is used by helical_chunked_regridded pipeline. The shape of the processed sino, as a matter of fact which is due to the helical_chunked_regridded.py module which is using the here present class, is always, but not necessarily [nangles, nslices, nhorizontal] with nslices = 1. This because helical_chunked_regridded.py after a first preprocessing phase, always processes slices one by one. In helical_chunked_regridded .py, the call to the filter_sino method here contained is followed by the weight redistribution ( done by another module), which solves the HA problem, and the backprojection. The latter is performed by fbp.py or hbp.py The reason for having this class, derived from nabu.reconstruction.filtering.SinoFilter, is that the padding mechanism here implemented incorporates the padding with the available theta+180 projection on the half-tomo side. """ super().__init__( sino_shape, filter_name=filter_name, padding_mode=padding_mode, extra_options=extra_options, cuda_options=cuda_options, ) self._init_pad_kernel() def _check_array(self, arr): """ This class may work with an arbitrary number of projections. This is a consequence of the first implementation of the helical pipeline. In the first implementation the slices were reconstructed by backprojecting several turns, and the number of useful projections was different from the beginning or end, to the center of the scan. Now in helical_chunked_regridded.py the number of projections is fixed. The only relic left is that the present class may work with an arbitrary number of projections. """ if arr.dtype != np.float32: raise ValueError("Expected data type = numpy.float32") if arr.shape[1:] != self.sino_shape[1:]: raise ValueError("Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape)) def _init_pad_kernel(self): """The four possible padding kernels. The first two, compared to nabu.reconstruction.filtering.SinoFilter can work with an arbitrary number of projection. The latter two implement the padding with the available information from theta+180. """ self.kern_args = (self.d_sino_f, self.d_filter_f) self.kern_args += self.d_sino_f.shape[::-1] self._pad_mirror_edges_kernel = CudaKernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPiiiiii", options=[str("-DMIRROR_EDGES")], ) self._pad_mirror_constant_kernel = CudaKernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPiiiiiiff", options=[str("-DMIRROR_CONSTANT")], ) self._pad_mirror_edges_variable_rot_pos_kernel = CudaKernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPPiiiii", options=[str("-DMIRROR_EDGES_VARIABLE_ROT_POS")], ) self._pad_mirror_constant_variable_rot_pos_kernel = CudaKernel( "padding", filename=get_cuda_srcfile("helical_padding.cu"), signature="PPPiiiiiff", options=[str("-DMIRROR_CONSTANT_VARIABLE_ROT_POS")], ) self.d_mirror_indexes = garray.zeros((self.sino_padded_shape[-2],), np.int32) self.d_variable_rot_pos = garray.zeros((self.sino_padded_shape[-2],), np.int32) self._pad_edges_kernel = CudaKernel( "padding_edge", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiii" ) self._pad_block = (32, 32, 1) self._pad_grid = tuple([updiv(n, p) for n, p in zip(self.sino_padded_shape[::-1], self._pad_block)]) def _pad_sino(self, sino, mirror_indexes=None, rot_center=None): """redefined here to adapt the memory copy to the lenght of the sino argument which, in the general helical case may be varying """ if mirror_indexes is None: self._pad_sino_simple(sino) self.d_mirror_indexes[:] = np.zeros([len(self.d_mirror_indexes)], np.int32) self.d_mirror_indexes[: len(mirror_indexes)] = mirror_indexes.astype(np.int32) if np.isscalar(rot_center): argument_rot_center = int(round(rot_center)) tmp_pad_mirror_edges_kernel = self._pad_mirror_edges_kernel tmp_pad_mirror_constant_kernel = self._pad_mirror_constant_kernel else: self.d_variable_rot_pos[: len(rot_center)] = np.around(rot_center).astype(np.int32) argument_rot_center = self.d_variable_rot_pos tmp_pad_mirror_edges_kernel = self._pad_mirror_edges_variable_rot_pos_kernel tmp_pad_mirror_constant_kernel = self._pad_mirror_constant_variable_rot_pos_kernel self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] if self.padding_mode == "edges": tmp_pad_mirror_edges_kernel( self.d_sino_padded, self.d_mirror_indexes, argument_rot_center, self.dwidth, self.n_angles, self.dwidth_padded, self.pad_left, self.pad_right, grid=self._pad_grid, block=self._pad_block, ) else: tmp_pad_mirror_constant_kernel( self.d_sino_padded, self.d_mirror_indexes, argument_rot_center, self.dwidth, self.n_angles, self.dwidth_padded, self.pad_left, self.pad_right, 0.0, 0.0, grid=self._pad_grid, block=self._pad_block, ) def _pad_sino_simple(self, sino): if self.padding_mode == "edges": self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] self._pad_edges_kernel( self.d_sino_padded, self.dwidth, self.n_angles, self.dwidth_padded, self.n_angles, self.pad_left, self.pad_right, 0, 0, grid=self._pad_grid, block=self._pad_block, ) else: # zeros self.d_sino_padded.fill(0) if self.ndim == 2: self.d_sino_padded[: len(sino), : self.dwidth] = sino[:] else: self.d_sino_padded[: len(sino), :, : self.dwidth] = sino[:] def filter_sino(self, sino, mirror_indexes=None, rot_center=None, output=None, no_output=False): """ Perform the sinogram siltering. redefined here to use also mirror data Parameters ---------- sino: numpy.ndarray or pycuda.gpuarray.GPUArray Input sinogram (2D or 3D) output: numpy.ndarray or pycuda.gpuarray.GPUArray, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) # copy2d/copy3d self._pad_sino(sino, mirror_indexes=mirror_indexes, rot_center=rot_center) # FFT self.fft.fft(self.d_sino_padded, output=self.d_sino_f) # multiply padded sinogram with filter in the Fourier domain self.mult_kernel(*self.kern_args) # TODO tune block size ? # iFFT self.fft.ifft(self.d_sino_f, output=self.d_sino_padded) # return if no_output: return self.d_sino_padded if output is None: res = np.zeros(self.sino_shape, dtype=np.float32) # can't do memcpy2d D->H ? (self.d_sino_padded[:, w]) I have to get() sino_ref = self.d_sino_padded.get() else: res = output sino_ref = self.d_sino_padded if self.ndim == 2: res[:] = sino_ref[:, : self.dwidth] else: res[:] = sino_ref[:, :, : self.dwidth] return res def _calculate_shapes(self, sino_shape): """redefined here without modifications in order to use the here defined get_next_power_of_two from pyhst2""" self.ndim = len(sino_shape) if self.ndim == 2: n_angles, dwidth = sino_shape n_sinos = 1 elif self.ndim == 3: n_sinos, n_angles, dwidth = sino_shape else: raise ValueError("Invalid sinogram number of dimensions") self.sino_shape = sino_shape self.n_angles = n_angles self.dwidth = dwidth # int() is crucial here ! Otherwise some pycuda arguments (ex. memcpy2D) # will not work with numpy.int64 (as for 2018.X) ### the original get_next_power used in nabu gives a lower ram footprint self.dwidth_padded = 2 * int(get_next_power(self.dwidth)) # self.dwidth_padded = 2 * get_next_power_of_two(self.dwidth) self.sino_padded_shape = (n_angles, self.dwidth_padded) if self.ndim == 3: self.sino_padded_shape = (n_sinos,) + self.sino_padded_shape sino_f_shape = list(self.sino_padded_shape) sino_f_shape[-1] = sino_f_shape[-1] // 2 + 1 self.sino_f_shape = tuple(sino_f_shape) # self.pad_left = (self.dwidth_padded - self.dwidth) // 2 self.pad_right = self.dwidth_padded - self.dwidth - self.pad_left def get_next_power_of_two(num_bins): two_power = (int)(math.log((2.0 * num_bins - 1)) / math.log(2.0) + 0.9999) res = 2 ** (two_power + 1) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/helical/gridded_accumulator.py0000644000175000017500000005030000000000000022730 0ustar00pierrepierrefrom ...preproc.flatfield import FlatFieldArrays import numpy as np from scipy import ndimage as nd import math class GriddedAccumulator: def __init__( self, gridded_radios, gridded_weights, diagnostic_radios, diagnostic_weights, diagnostic_angles, dark=None, flat_indexes=None, flats=None, weights=None, double_flat=None, ): """ This class creates, for a selected volume slab, a standard set of radios from an helical dataset. Parameters ========== gridded_radios : 3D np.array this is the stack of new radios which will be resynthetised, by this class, for a selected slab. The object is initialised with this array, and this array will accumulate, during subsequent calls to method extract_preprocess_with_flats, the sum of the transformed contributions obtained from the arguments of the mentioned method (extract_preprocess_with_flats). gridded_weights : 3d np.array same shape as gridded_radios, but it will accumulate the weights, during calls to extract_preprocess_with_flats diagnostic_radios : 3d np.array, a stack composed of two radios each radio must have the same size as a radio of the gridded_radios argument. During the calls to extract_preprocess_with_flats methods, the first radio will collect the transformed data for angle=0 ( and the neighbouring ones according to angular interpolation coefficients) and this only for the first occurring turn. The second radio will be initialised at the second turn, if any. These array are meant to be used to check the translation step over one turn. diagnostic_weights: 3d np.array a stack composed of two radios Same shape as diagnostic_radios. The weigths for diagnostic radios ( will be zero on pixel where no data is available, or where the weight is null) diagnostic_angles : 1D np.array Must have shape==(2,). The two entries will be filled with the angles at which the contributions to diagnostic_radios have been summed. dark = None or 2D np.array must have the shape of the detector ( generally larger that a radio of gridded_radios) If given, the dark will be subtracted from data and flats. flat_indexes: None or a list of integers the projection index corresponding to the flats flats : None or 3D np.array the stack of flats. Each flat must have the shape of the detector (generally larger that a radio of gridded_radios) The flats, if given, are subtracted of the dark, if given, and the result is used to normalise the data. weights : None or 2D np.array If not given each data pixel will be considered with unit weight. If given it must have the same shape as the detector. double_flat = None or 2D np.array If given, the double flat will be applied (division by double_flat) Must have the same shape as the detector. """ self.gridded_radios = gridded_radios self.gridded_weights = gridded_weights self.diagnostic_radios = diagnostic_radios self.diagnostic_weights = diagnostic_weights self.diagnostic_angles = diagnostic_angles self.dark = dark self.flat_indexes = flat_indexes self.flat_indexes_reverse_map = dict( [(global_index, local_index) for (local_index, global_index) in enumerate(flat_indexes)] ) self.flats = flats self.weights = weights self.double_flat = double_flat def extract_preprocess_with_flats( self, subchunk_slice, subchunk_file_indexes, chunk_info, subr_start_end, dtasrc_start_end, data_raw ): """ This functions is meant to be called providing, each time, a subset of the data which are needed to reconstruct a chunk (to reconstruct a slab). When all the necessary data have flown through the subsequent calls to this method, the accumulators are ready. Parameters ========== subchunk_slice: an object of the python class "slice" this slice slices the angular domain which corresponds to the useful projections which are useful for the chunk, and whose informations are contained in the companion argument "chunk_info" Such slicing correspond to the angular subset, for which we are providing data_raw subchunk_file_indexes: a sequence of integers. they correspond to the projection numbers from which the data in data_raw are coming. They are used to interpolate between falt fields chunk_info: an object returned by the get_chunk_info of the SpanStrategy class this object must have the following members, which relates to the wanted chunk angle_index_span: a pair of integers indicating the start and the end of useful angles in the array of all the scan angle self.projection_angles_deg span_v: a pair of two integers indicating the start and end of the span relatively to the lowest value of array self.total_view_heights integer_shift_v: an array, containing for each one of the useful projections of the span, the integer part of vertical shift to be used in cropping, fract_complement_to_integer_shift_v : the fractional remainder for cropping. z_pix_per_proj: an array, containing for each to be used projection of the span the vertical shift x_pix_per_proj: ....the horizontal shift angles_rad : an array, for each useful projection of the chunk the angle in radian subr_start_end: a pair of integers the start height, and the end height, of the slab for which we are collecting the data. The number are given with the same logic as for member span_v of the chunk_info. Without the phase margin, when the phase margin is zero, hey would correspond exactly to the start and end, vertically, of the reconstructed slices. dtasrc_start_end: a pair of integers This number are relative to the detector ( they are detector line indexes). They indicated, vertically, the detector portion the data_raw data correspond to data_raw: np.array 3D the data which correspond to a limited detector stripe and a limited angular subset """ # the object below is going to containing some auxiliary variable that are use to reframe the data. # This object is used to pass in a compact way such informations to different methods. # The informations are extracted from chunk info reframing_infos = self._ReframingInfos( chunk_info, subchunk_slice, subr_start_end, dtasrc_start_end, subchunk_file_indexes ) # give the proper dimensioning to an auxiliary stack which will contain the reframed data extracted # from the treatement of the sub-chunk radios_subset = np.zeros( [data_raw.shape[0], subr_start_end[1] - subr_start_end[0], data_raw.shape[2]], np.float32 ) # ... and in the same way we dimension the container for the associated reframed weights. radios_weights_subset = np.zeros( [data_raw.shape[0], subr_start_end[1] - subr_start_end[0], data_raw.shape[2]], np.float32 ) # extraction of the data self._extract_preprocess_with_flats(data_raw, reframing_infos, chunk_info, radios_subset) if self.weights is not None: # ... and, if required, extraction of the associated weights wdata_read = self.weights.data[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] self._extract_preprocess_with_flats( wdata_read, reframing_infos, chunk_info, radios_weights_subset, it_is_weight=True ) else: radios_weights_subset[:] = 1.0 # and the remaining part is a simple projection over the accumulators, for # the data and for the weights my_angles = chunk_info.angles_rad[subchunk_slice] n_gridded_angles = self.gridded_radios.shape[0] my_i_float = my_angles * (n_gridded_angles / (2 * math.pi)) tmp_i_rounded = np.floor(my_i_float).astype(np.int32) my_epsilon = my_i_float - tmp_i_rounded my_i0 = np.mod(tmp_i_rounded, n_gridded_angles) my_i1 = np.mod(my_i0 + 1, n_gridded_angles) for i0, epsilon, i1, data, weight, original_angle in zip( my_i0, my_epsilon, my_i1, radios_subset, radios_weights_subset, chunk_info.angles_rad[subchunk_slice] ): data_token = data * weight self.gridded_radios[i0] += data_token * (1 - epsilon) self.gridded_radios[i1] += data_token * epsilon self.gridded_weights[i0] += weight * (1 - epsilon) self.gridded_weights[i1] += weight * epsilon if i0 == 0 or i1 == 0: # There is a contribution to the first regridded radio ( the one indexed by 0) # We build two diagnostics for the contributions to this radio. # The first for the first pass (i_diag=0) # The second for the second pass if any (i_diag=1) # To discriminate we introduce # An angular margin beyond which we know that a possible contribution # is coming from another turn safe_angular_margin = 3.14 / 10 for i_diag in range(2): if original_angle < self.diagnostic_angles[i_diag] + safe_angular_margin: # we are searching for the first contributions ( the one at the lowest angle) # for the two diagnostics. With the constraint that the second is at an higher angle # than the first. So if we are here this means that we have found an occurrence with # lower angle and we discard what we could have previously found. self.diagnostic_radios[i_diag][:] = 0 self.diagnostic_weights[i_diag][:] = 0 self.diagnostic_angles[i_diag] = original_angle if abs(original_angle - self.diagnostic_angles[i_diag]) < safe_angular_margin: if i0 == 0: factor = 1 - epsilon else: factor = epsilon self.diagnostic_radios[i_diag] += data_token * factor self.diagnostic_weights[i_diag] += weight * factor break class _ReframingInfos: def __init__(self, chunk_info, subchunk_slice, subr_start_end, dtasrc_start_end, subchunk_file_indexes): self.subchunk_file_indexes = subchunk_file_indexes my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] self.fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] self.x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] subr_start_z, subr_end_z = subr_start_end self.subr_start_z_list = subr_start_z - my_integer_shifts_v self.subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 self.dtasrc_start_z, self.dtasrc_end_z = dtasrc_start_end floating_start_z = self.subr_start_z_list.min() floating_end_z = self.subr_end_z_list.max() self.floating_subregion = None, None, floating_start_z, floating_end_z def _extract_preprocess_with_flats(self, data_raw, reframing_infos, chunk_info, output, it_is_weight=False): if not it_is_weight: if self.dark is not None: data_raw = data_raw - self.dark[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] if self.flats is not None: for i, idx in enumerate(reframing_infos.subchunk_file_indexes): flat = self._get_flat(idx, slice(reframing_infos.dtasrc_start_z, reframing_infos.dtasrc_end_z)) if self.dark is not None: flat = flat - self.dark[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] data_raw[i] = data_raw[i] / flat if self.double_flat is not None: data_raw = data_raw / self.double_flat[reframing_infos.dtasrc_start_z : reframing_infos.dtasrc_end_z] if it_is_weight: # for the weight, the detector weights, depends on the detector portion, # the one corresponding to dtasrc_start_end, and is the same across # the subchunk first index ( the projection index) take_data_from_this = [data_raw] * len(reframing_infos.subr_start_z_list) else: take_data_from_this = data_raw for data_read, list_subr_start_z, list_subr_end_z, fract_shift, x_shift, data_target in zip( take_data_from_this, reframing_infos.subr_start_z_list, reframing_infos.subr_end_z_list, reframing_infos.fract_complement_shifts_v, reframing_infos.x_shifts_list, output, ): _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shift, list_subr_start_z, list_subr_end_z, reframing_infos.dtasrc_start_z, reframing_infos.dtasrc_end_z, x_shift=x_shift, extension_padding=(not it_is_weight), ) def _get_flat(self, idx, slice_y=slice(None, None), slice_x=slice(None, None), dtype=np.float32): prev_next = FlatFieldArrays.get_previous_next_indices(self.flat_indexes, idx) if len(prev_next) == 1: # current index corresponds to an acquired flat flat_data = self.flats[self.flat_indexes_reverse_map[prev_next[0]]][slice_y, slice_x] else: # interpolate prev_idx, next_idx = prev_next flat_data_prev = self.flats[self.flat_indexes_reverse_map[prev_idx]][slice_y, slice_x] flat_data_next = self.flats[self.flat_indexes_reverse_map[next_idx]][slice_y, slice_x] delta = next_idx - prev_idx w1 = 1 - (idx - prev_idx) / delta w2 = 1 - (next_idx - idx) / delta flat_data = w1 * flat_data_prev + w2 * flat_data_next if flat_data.dtype != dtype: flat_data = np.ascontiguousarray(flat_data, dtype=dtype) return flat_data def _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shift, my_subr_start_z, my_subr_end_z, dtasrc_start_z, dtasrc_end_z, x_shift=0.0, extension_padding=True, ): data_read_precisely_shifted = nd.shift(data_read, (-fract_shift, x_shift), order=1, mode="nearest")[:-1] target_central_slicer, dtasrc_central_slicer = overlap_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if None not in [target_central_slicer, dtasrc_central_slicer]: data_target[target_central_slicer] = data_read_precisely_shifted[dtasrc_central_slicer] target_lower_slicer, target_upper_slicer = padding_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if extension_padding: if target_lower_slicer is not None: data_target[target_lower_slicer] = data_read_precisely_shifted[0] if target_upper_slicer is not None: data_target[target_upper_slicer] = data_read_precisely_shifted[-1] else: if target_lower_slicer is not None: data_target[target_lower_slicer] = 1.0e-6 if target_upper_slicer is not None: data_target[target_upper_slicer] = 1.0e-6 def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """determines the useful lines which can be transferred from the dtasrc_start_z:dtasrc_end_z range targeting the range subr_start_z: subr_end_z .................. """ t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z my_start = max(0, dtasrc_start_z - subr_start_z) my_end = min(t_h, dtasrc_end_z - subr_start_z) if my_start >= my_end: return None, None target_central_slicer = slice(my_start, my_end) my_start = max(0, subr_start_z - dtasrc_start_z) my_end = min(s_h, subr_end_z - dtasrc_start_z) assert my_start < my_end, "Overlap_logic logic error" dtasrc_central_slicer = slice(my_start, my_end) return target_central_slicer, dtasrc_central_slicer def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """.......... and the missing ranges which possibly could be obtained by extension padding""" t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z if dtasrc_start_z <= subr_start_z: target_lower_padding = None else: target_lower_padding = slice(0, dtasrc_start_z - subr_start_z) if dtasrc_end_z >= subr_end_z: target_upper_padding = None else: target_upper_padding = slice(dtasrc_end_z - subr_end_z, None) return target_lower_padding, target_upper_padding def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_margin_pix): """Utility function which, given the span_info object, creates the auxiliary collection arrays and initialises the my_z_min, my_z_end variable keeping into account the scan direction and the min_scanwise_z, end_scanwise_z input arguments Parameters ========== span_info: SpanStrategy min_scanwise_z: int non negative number, where zero indicates the first feaseable slice doable scanwise. Indicates the first (scanwise) requested slice to be reconstructed end_scanwise_z: int non negative number, where zero indicates the first feaseable slice doable scanwise. Indicates the end (scanwise) slice which delimity the to be reconstructed requested slab. """ detector_z_start, detector_z_end = (span_info.get_doable_span()).view_heights_minmax if span_info.z_pix_per_proj[-1] > span_info.z_pix_per_proj[0]: my_z_min = detector_z_start + min_scanwise_z my_z_end = detector_z_start + end_scanwise_z else: my_z_min = detector_z_end - (end_scanwise_z - 1) my_z_end = detector_z_end - (min_scanwise_z + 1) # while the raw dataset may have non uniform angular step # the regridded dataset will have a constant step. # We evaluate here below the number of angles for the # regridded dataset, estimating a meaningul angular step representative # of the raw data my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean()) n_gridded_angles = int(round(360.0 / my_angle_step)) radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix # the accumulators gridded_radios = np.zeros([n_gridded_angles, radios_h, span_info.detector_shape_vh[1]], np.float32) gridded_cumulated_weights = np.zeros([n_gridded_angles, radios_h, span_info.detector_shape_vh[1]], np.float32) diagnostic_radios = np.zeros((2,) + gridded_radios.shape[1:], np.float32) diagnostic_weights = np.zeros((2,) + gridded_radios.shape[1:], np.float32) diagnostic_proj_angle = np.zeros([2], "f") gridded_angles_rad = np.arange(n_gridded_angles) * 2 * np.pi / n_gridded_angles gridded_angles_deg = np.rad2deg(gridded_angles_rad) res = type( "on_the_fly_class_for_reconstruction_room_in_gridded_accumulator.py", (object,), { "my_z_min": my_z_min, "my_z_end": my_z_end, "gridded_radios": gridded_radios, "gridded_cumulated_weights": gridded_cumulated_weights, "diagnostic_radios": diagnostic_radios, "diagnostic_weights": diagnostic_weights, "diagnostic_proj_angle": diagnostic_proj_angle, "gridded_angles_rad": gridded_angles_rad, "gridded_angles_deg": gridded_angles_deg, }, ) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/helical/helical_chunked_regridded.py0000644000175000017500000017155300000000000024060 0ustar00pierrepierre# pylint: skip-file from os import path from time import time import numpy as np import math from silx.image.tomography import get_next_power from scipy import ndimage as nd import h5py import silx.io import copy from silx.io.url import DataUrl from ...resources.logger import LoggerOrPrint from ...resources.utils import is_hdf5_extension, extract_parameters from ...io.reader_helical import ChunkReaderHelical, get_hdf5_dataset_shape from ...preproc.flatfield_variable_region import FlatFieldDataVariableRegionUrls as FlatFieldDataHelicalUrls from ...preproc.distortion import DistortionCorrection from ...preproc.shift import VerticalShift from ...preproc.double_flatfield_variable_region import DoubleFlatFieldVariableRegion as DoubleFlatFieldHelical from ...preproc.phase import PaganinPhaseRetrieval from ...reconstruction.sinogram import SinoBuilder, SinoNormalization from ...misc.unsharp import UnsharpMask from ...misc.histogram import PartialHistogram, hist_as_2Darray from ..utils import use_options, pipeline_step from ...resources.utils import extract_parameters from ..detector_distortion_provider import DetectorDistortionProvider from .utils import ( WriterConfiguratorHelical as WriterConfigurator, ) # .utils is the same as ..utils but internally we retouch the key associated to "tiffwriter" of Writers to # point to our class which can write tiff with names indexed by the z height above the sample stage in millimiters from numpy.lib.stride_tricks import sliding_window_view from ...misc.binning import get_binning_function from .helical_utils import find_mirror_indexes from ...preproc.ccd import Log, CCDFilter from . import gridded_accumulator # For now we don't have a plain python/numpy backend for reconstruction Backprojector = None class HelicalChunkedRegriddedPipeline: """ Pipeline for "helical" full or half field tomography. Data is processed by chunks. A chunk consists in K+-1 contiguous lines of all the radios which are read at variable height following the translations """ extra_marge_granularity = 4 """ This offers extra reading space to be able to read the redundant part which might be sligtly larger and or require extra border for interpolation """ FlatFieldClass = FlatFieldDataHelicalUrls DoubleFlatFieldClass = DoubleFlatFieldHelical CCDFilterClass = CCDFilter MLogClass = Log PaganinPhaseRetrievalClass = PaganinPhaseRetrieval UnsharpMaskClass = UnsharpMask VerticalShiftClass = VerticalShift SinoBuilderClass = SinoBuilder FBPClass = Backprojector HBPClass = None HistogramClass = PartialHistogram regular_accumulator = None def __init__( self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None, reading_granularity=10, span_info=None, ): """ Initialize a "HelicalChunked" pipeline. Parameters ---------- process_config: `nabu.resources.processcinfig.ProcessConfig` Process configuration. sub_region: tuple Sub-region to process in the volume for this worker, in the format `(start_x, end_x, start_z, end_z)`. logger: `nabu.app.logger.Logger`, optional Logger class extra_options: dict, optional Advanced extra options. phase_margin: tuple, optional Margin to use when performing phase retrieval, in the form ((up, down), (left, right)). See also the documentation of PaganinPhaseRetrieval. If not provided, no margin is applied. reading_granularity: int The data angular span which needs to be read for a reconstruction is read step by step, reading each time a maximum of reading_granularity radios, and doing the preprocessing till phase retrieval for each of these angular groups Notes ------ Using a `phase_margin` results in a lesser number of reconstructed slices. More specifically, if `phase_margin = (V, H)`, then there will be `chunk_size - 2*V` reconstructed slices (if the sub-region is in the middle of the volume) or `chunk_size - V` reconstructed slices (if the sub-region is on top or bottom of the volume). """ self.span_info = span_info self.reading_granularity = reading_granularity self.logger = LoggerOrPrint(logger) self._set_params(process_config, sub_region, extra_options, phase_margin) self._init_pipeline() def _set_params(self, process_config, sub_region, extra_options, phase_margin): self.process_config = process_config self.dataset_info = self.process_config.dataset_info self.processing_steps = self.process_config.processing_steps.copy() self.processing_options = self.process_config.processing_options sub_region = self._check_subregion(sub_region) self.chunk_size = sub_region[-1] - sub_region[-2] self.radios_buffer = None self._set_detector_distortion_correction() self.set_subregion(sub_region) self._set_phase_margin(phase_margin) self._set_extra_options(extra_options) self._callbacks = {} self._steps_name2component = {} self._steps_component2name = {} self._data_dump = {} self._resume_from_step = None @staticmethod def _check_subregion(sub_region): if len(sub_region) < 4: assert len(sub_region) == 2, " at least start_z and end_z are required in subregion" sub_region = (None, None) + sub_region if None in sub_region[-2:]: raise ValueError("Cannot set z_min or z_max to None") return sub_region def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} advanced_options = {} advanced_options.update(extra_options) self.extra_options = advanced_options def _set_phase_margin(self, phase_margin): if phase_margin is None: phase_margin = ((0, 0), (0, 0)) self._phase_margin_up = phase_margin[0][0] self._phase_margin_down = phase_margin[0][1] self._phase_margin_left = phase_margin[1][0] self._phase_margin_right = phase_margin[1][1] def set_subregion(self, sub_region): """ Set a sub-region to process. Parameters ---------- sub_region: tuple Sub-region to process in the volume, in the format `(start_x, end_x, start_z, end_z)` or `(start_z, end_z)`. """ sub_region = self._check_subregion(sub_region) dz = sub_region[-1] - sub_region[-2] if dz != self.chunk_size: raise ValueError( "Class was initialized for chunk_size = %d but provided sub_region has chunk_size = %d" % (self.chunk_size, dz) ) self.sub_region = sub_region self.z_min = sub_region[-2] self.z_max = sub_region[-1] def _compute_phase_kernel_margin(self): """ Get the "margin" to pass to classes like PaganinPhaseRetrieval. In order to have a good accuracy for filter-based phase retrieval methods, we need to load extra data around the edges of each image. Otherwise, a default padding type is applied. """ if not (self.use_radio_processing_margin): self._phase_margin = None return up_margin = self._phase_margin_up down_margin = self._phase_margin_down # Horizontal margin is not implemented left_margin, right_margin = (0, 0) self._phase_margin = ((up_margin, down_margin), (left_margin, right_margin)) @property def use_radio_processing_margin(self): return ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps) def _get_phase_margin(self): if not (self.use_radio_processing_margin): return ((0, 0), (0, 0)) return self._phase_margin @property def phase_margin(self): """ Return the margin for phase retrieval in the form ((up, down), (left, right)) """ return self._get_phase_margin() def _get_process_name(self, kind="reconstruction"): # In the future, might be something like "reconstruction-" if kind == "reconstruction": return "reconstruction" elif kind == "histogram": return "histogram" return kind def _configure_dump(self, step_name): if step_name not in self.processing_steps: if step_name == "sinogram" and self.process_config._dump_sinogram: fname_full = self.process_config._dump_sinogram_file else: return else: if not self.processing_options[step_name].get("save", False): return fname_full = self.processing_options[step_name]["save_steps_file"] fname, ext = path.splitext(fname_full) dirname, file_prefix = path.split(fname) output_dir = path.join(dirname, file_prefix) file_prefix += str("_%06d" % self._get_image_start_index()) self.logger.info("omitting config in data_dump because of too slow nexus writer ") self._data_dump[step_name] = WriterConfigurator( output_dir, file_prefix, file_format="hdf5", overwrite=True, logger=self.logger, nx_info={ "process_name": step_name, "processing_index": 0, # TODO # "config": {"processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config}, "config": None, "entry": getattr(self.dataset_info.dataset_scanner, "entry", None), }, ) def _configure_data_dumps(self): self.process_config._configure_save_steps() for step_name in self.processing_steps: self._configure_dump(step_name) # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation if self.process_config._dump_sinogram: self._configure_dump("sinogram") # # Callbacks # def register_callback(self, step_name, callback): """ Register a callback for a pipeline processing step. Parameters ---------- step_name: str processing step name callback: callable A function. It will be executed once the processing step `step_name` is finished. The function takes only one argument: the class instance. """ if step_name not in self.processing_steps: raise ValueError("'%s' is not in processing steps %s" % (step_name, self.processing_steps)) if step_name in self._callbacks: self._callbacks[step_name].append(callback) else: self._callbacks[step_name] = [callback] # # Overwritten in inheriting classes # def _get_shape(self, step_name): """ Get the shape to provide to the class corresponding to step_name. """ if step_name == "flatfield": shape = self.radios_subset.shape elif step_name == "double_flatfield": shape = self.radios_subset.shape elif step_name == "phase": shape = self.radios_subset.shape[1:] elif step_name == "ccd_correction": shape = self.gridded_radios.shape[1:] elif step_name == "unsharp_mask": shape = self.radios_subset.shape[1:] elif step_name == "take_log": shape = self.radios.shape elif step_name == "radios_movements": shape = self.radios.shape elif step_name == "sino_normalization": shape = self.radios.shape elif step_name == "sino_normalization_slim": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "one_sino_slim": shape = self.radios.shape[:1] + self.radios.shape[2:] elif step_name == "build_sino": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "reconstruction": shape = self.sino_builder.output_shape[1:] else: raise ValueError("Unknown processing step %s" % step_name) self.logger.debug("Data shape for %s is %s" % (step_name, str(shape))) return shape def _allocate_array(self, shape, dtype, name=None): """this function can be redefined in the derived class which is dedicated to gpu and will return gpu garrays """ return _cpu_allocate_array(shape, dtype, name=name) def _cpu_allocate_array(self, shape, dtype, name=None): """For objects used in the pre-gpu part. They will be always on CPU even in the derived class""" result = np.zeros(shape, dtype=dtype) return result def _allocate_sinobuilder_output(self): return self._cpu_allocate_array(self.sino_builder.output_shape, "f", name="sinos") def _allocate_recs(self, ny, nx): self.n_slices = self.gridded_radios.shape[1] if self.use_radio_processing_margin: self.n_slices -= sum(self.phase_margin[0]) self.recs = self._allocate_array((1, ny, nx), "f", name="recs") self.recs_stack = self._cpu_allocate_array((self.n_slices, ny, nx), "f", name="recs_stack") def _reset_memory(self): pass def _get_read_dump_subregion(self): read_opts = self.processing_options["read_chunk"] if read_opts.get("process_file", None) is None: return None dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"] relative_start_z = self.z_min - dump_start_z relative_end_z = relative_start_z + self.chunk_size # (n_angles, n_z, n_x) subregion = (None, None, relative_start_z, relative_end_z, None, None) return subregion def _check_resume_from_step(self): if self._resume_from_step is None: return read_opts = self.processing_options["read_chunk"] expected_radios_shape = get_hdf5_dataset_shape( read_opts["process_file"], read_opts["process_h5_path"], sub_region=self._get_read_dump_subregion(), ) # TODO check def _init_reader_finalize(self): """ Method called after _init_reader """ self._check_resume_from_step() self._compute_phase_kernel_margin() self._allocate_reduced_gridded_and_subset_radios() def _allocate_reduced_gridded_and_subset_radios(self): shp_h = self.chunk_reader.data.shape[-1] sliding_window_size = self.chunk_size if sliding_window_size % 2 == 0: sliding_window_size += 1 sliding_window_radius = (sliding_window_size - 1) // 2 if sliding_window_radius == 0: n_projs_max = (self.span_info.sunshine_ends - self.span_info.sunshine_starts).max() else: padded_starts = self.span_info.sunshine_starts padded_ends = self.span_info.sunshine_ends padded_starts = np.concatenate( [[padded_starts[0]] * sliding_window_radius, padded_starts, [padded_starts[-1]] * sliding_window_radius] ) starts = sliding_window_view(padded_starts, sliding_window_size).min(axis=-1) padded_ends = np.concatenate( [[padded_ends[0]] * sliding_window_radius, padded_ends, [padded_ends[-1]] * sliding_window_radius] ) ends = sliding_window_view(padded_ends, sliding_window_size).max(axis=-1) n_projs_max = (ends - starts).max() ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin (start_x, end_x, start_z, end_z) = self.sub_region ## and now the gridded ones my_angle_step = abs(np.diff(self.span_info.projection_angles_deg).mean()) self.n_gridded_angles = int(round(360.0 / my_angle_step)) self.my_angles_rad = np.arange(self.n_gridded_angles) * 2 * np.pi / self.n_gridded_angles my_angles_deg = np.rad2deg(self.my_angles_rad) self.mirror_angle_relative_indexes = find_mirror_indexes(my_angles_deg) if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") r_shp_v, r_shp_h = self.whole_radio_shape (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subradio_shape = subr_end_z - subr_start_z, r_shp_h ### these radios are for diagnostic of the translations ( they will be optionally written, for being further used ## by correlation techniques ). Two radios for the first two pass over the first gridded angles self.diagnostic_radios = np.zeros((2,) + subradio_shape, np.float32) self.diagnostic_weights = np.zeros((2,) + subradio_shape, np.float32) self.diagnostic_proj_angle = np.zeros([2], "f") self.diagnostic = { "radios": self.diagnostic_radios, "weights": self.diagnostic_weights, "angles": self.diagnostic_proj_angle, } ## ------- self.gridded_radios = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) self.gridded_cumulated_weights = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) self.radios_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) self.radios_weights_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) self.radios = np.zeros( (self.n_gridded_angles,) + ((end_z - down_margin) - (start_z + up_margin), shp_h), np.float32 ) self.radios_weights = np.zeros_like(self.radios) self.radios_slim = self._allocate_array(self._get_shape("one_sino_slim"), "f", name="radios_slim") def _process_finalize(self): """ Method called once the pipeline has been executed """ pass def _get_slice_start_index(self): return self.z_min + self._phase_margin_up _get_image_start_index = _get_slice_start_index # # Pipeline initialization # def _init_pipeline(self): self._get_size_of_a_raw_radio() self._init_reader() self._init_flatfield() self._init_double_flatfield() self._init_weights_field() self._init_ccd_corrections() self._init_phase() self._init_unsharp() self._init_mlog() self._init_sino_normalization() self._init_sino_builder() self._prepare_reconstruction() self._init_reconstruction() self._init_histogram() self._init_writer() self._configure_data_dumps() self._configure_regular_accumulator() def _set_detector_distortion_correction(self): if self.process_config.nabu_config["preproc"]["detector_distortion_correction"] is None: self.detector_corrector = None else: self.detector_corrector = DetectorDistortionProvider( detector_full_shape_vh=self.process_config.dataset_info.radio_dims[::-1], correction_type=self.process_config.nabu_config["preproc"]["detector_distortion_correction"], options=self.process_config.nabu_config["preproc"]["detector_distortion_correction_options"], ) def _configure_regular_accumulator(self): accumulator_cls = gridded_accumulator.GriddedAccumulator self.regular_accumulator = accumulator_cls( gridded_radios=self.gridded_radios, gridded_weights=self.gridded_cumulated_weights, diagnostic_radios=self.diagnostic_radios, diagnostic_weights=self.diagnostic_weights, diagnostic_angles=self.diagnostic_proj_angle, dark=self.flatfield.get_dark(), flat_indexes=self.flatfield._sorted_flat_indices, flats=self.flatfield.flats_stack, weights=self.weights_field.data, double_flat=self.double_flatfield.data, ) return def _get_size_of_a_raw_radio(self): """Once for all we find the shape of a radio. This information will be used in other parts of the code when allocating bunch of data holders """ if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] here_a_file = next(iter(options["files"].values())) here_a_radio = silx.io.get_data(here_a_file) binning_x, binning_z = self._get_binning() if (binning_z, binning_x) != (1, 1): binning_function = get_binning_function((binning_z, binning_x)) here_a_radio = binning_function(here_a_radio) self.whole_radio_shape = here_a_radio.shape return self.whole_radio_shape @use_options("read_chunk", "chunk_reader") def _init_reader(self): if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] assert options.get("process_file", None) is None, "Resume not yet implemented in helical pipeline" # dummy initialisation, it will be _set_subregion'ed and set_data_buffer'ed in the loops self.chunk_reader = ChunkReaderHelical( options["files"], sub_region=None, # setting of subregion will be already done by calls to set_subregion detector_corrector=self.detector_corrector, convert_float=True, binning=options["binning"], dataset_subsampling=options["dataset_subsampling"], data_buffer=None, pre_allocate=True, ) self._init_reader_finalize() @use_options("flatfield", "flatfield") def _init_flatfield(self, shape=None): if shape is None: shape = self._get_shape("flatfield") options = self.processing_options["flatfield"] distortion_correction = None if options["do_flat_distortion"]: self.logger.info("Flats distortion correction will be applied") estimation_kwargs = {} estimation_kwargs.update(options["flat_distortion_params"]) estimation_kwargs["logger"] = self.logger distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn" ) self.flatfield = self.FlatFieldClass( shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=options["projs_indices"], interpolation="linear", distortion_correction=distortion_correction, radios_srcurrent=options["radios_srcurrent"], flats_srcurrent=options["flats_srcurrent"], detector_corrector=self.detector_corrector, ## every flat will be read at a different heigth ### sub_region=self.sub_region, binning=options["binning"], convert_float=True, ) def _get_binning(self): options = self.processing_options["read_chunk"] binning = options["binning"] if binning is None: return 1, 1 else: return binning def _init_double_flatfield(self): options = self.processing_options["double_flatfield"] binning_x, binning_z = self._get_binning() result_url = None self.double_flatfield = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/double_flatfield/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading double flatfield from %s" % result_url.file_path()) self.double_flatfield = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z, detector_corrector=self.detector_corrector, ) def _init_weights_field(self): options = self.processing_options["double_flatfield"] result_url = None binning_x, binning_z = self.chunk_reader.get_binning() self.weights_field = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/weights_field/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading weights_field from %s" % result_url.file_path()) self.weights_field = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z ) def _init_ccd_corrections(self): if "ccd_correction" not in self.processing_steps: return options = self.processing_options["ccd_correction"] median_clip_thresh = options["median_clip_thresh"] self.ccd_correction = self.CCDFilterClass( self._get_shape("ccd_correction"), median_clip_thresh=median_clip_thresh ) @use_options("phase", "phase_retrieval") def _init_phase(self): options = self.processing_options["phase"] # If unsharp mask follows phase retrieval, then it should be done # before cropping to the "inner part". # Otherwise, crop the data just after phase retrieval. if "unsharp_mask" in self.processing_steps: margin = None else: margin = self._phase_margin self.phase_retrieval = self.PaganinPhaseRetrievalClass( self._get_shape("phase"), distance=options["distance_m"], energy=options["energy_kev"], delta_beta=options["delta_beta"], pixel_size=options["pixel_size_m"], padding=options["padding_type"], margin=margin, fftw_num_threads=True, # TODO tune in advanced params of nabu config file ) if self.phase_retrieval.use_fftw: self.logger.debug( "PaganinPhaseRetrieval using FFTW with %d threads" % self.phase_retrieval.fftw.num_threads ) ##@use_options("unsharp_mask", "unsharp_mask") def _init_unsharp(self): if "unsharp_mask" not in self.processing_steps: self.unsharp_mask = None self.unsharp_sigma = 0.0 self.unsharp_coeff = 0.0 self.unsharp_method = "log" else: options = self.processing_options["unsharp_mask"] self.unsharp_sigma = options["unsharp_sigma"] self.unsharp_coeff = options["unsharp_coeff"] self.unsharp_method = options["unsharp_method"] self.unsharp_mask = self.UnsharpMaskClass( self._get_shape("unsharp_mask"), options["unsharp_sigma"], options["unsharp_coeff"], mode="reflect", method=options["unsharp_method"], ) def _init_mlog(self): options = self.processing_options["take_log"] self.mlog = self.MLogClass( self._get_shape("take_log"), clip_min=options["log_min_clip"], clip_max=options["log_max_clip"] ) @use_options("sino_normalization", "sino_normalization") def _init_sino_normalization(self): options = self.processing_options["sino_normalization"] self.sino_normalization = self.SinoNormalizationClass( kind=options["method"], radios_shape=self._get_shape("sino_normalization_slim"), ) def _init_sino_builder(self): options = self.processing_options["reconstruction"] ## build_sino class disappeared disappeared self.sino_builder = self.SinoBuilderClass( radios_shape=self._get_shape("build_sino"), rot_center=options["rotation_axis_position"], halftomo=False, ) self._sinobuilder_copy = False self._sinobuilder_output = None self.sinos = None # this should be renamed, as it could be confused with _init_reconstruction. What about _get_reconstruction_array ? @use_options("reconstruction", "reconstruction") def _prepare_reconstruction(self): options = self.processing_options["reconstruction"] x_s, x_e = options["start_x"], options["end_x"] y_s, y_e = options["start_y"], options["end_y"] self._rec_roi = (x_s, x_e + 1, y_s, y_e + 1) self._allocate_recs(y_e - y_s + 1, x_e - x_s + 1) @use_options("reconstruction", "reconstruction") def _init_reconstruction(self): options = self.processing_options["reconstruction"] if self.sino_builder is None: raise ValueError("Reconstruction cannot be done without build_sino") if self.FBPClass is None: raise ValueError("No usable FBP module was found") rot_center = options["rotation_axis_position"] start_y, end_y, start_x, end_x = self._rec_roi if self.HBPClass is not None: fan_source_distance_meters = self.process_config.nabu_config["reconstruction"]["fan_source_distance_meters"] self.reconstruction_hbp = self.HBPClass( self._get_shape("one_sino_slim"), slice_shape=(end_y - start_y, end_x - start_x), angles=self.my_angles_rad, rot_center=rot_center, extra_options={"axis_correction": np.zeros(self.radios.shape[0], "f")}, axis_source_meters=fan_source_distance_meters, voxel_size_microns=options["pixel_size_cm"] * 1.0e4, scale_factor=1.0 / options["pixel_size_cm"], ) else: self.reconstruction_hbp = None self.reconstruction = self.FBPClass( self._get_shape("reconstruction"), angles=np.zeros(self.radios.shape[0], "f"), rot_center=rot_center, filter_name=options["fbp_filter_type"], slice_roi=self._rec_roi, # slice_shape = ( end_y-start_y, end_x- start_x ), scale_factor=1.0 / options["pixel_size_cm"], padding_mode=options["padding_type"], extra_options={ "scale_factor": 1.0 / options["pixel_size_cm"], "axis_correction": np.zeros(self.radios.shape[0], "f"), "clip_outer_circle": options["clip_outer_circle"], }, # "padding_mode": options["padding_type"], }, ) my_options = self.process_config.nabu_config["reconstruction"] if my_options["axis_to_the_center"]: x_s, x_ep1, y_s, y_ep1 = self._rec_roi off_x = -int(round((x_s + x_ep1 - 1) / 2.0 - rot_center)) off_y = -int(round((y_s + y_ep1 - 1) / 2.0 - rot_center)) self.reconstruction.offsets = {"x": off_x, "y": off_y} if options["fbp_filter_type"] is None: self.reconstruction.fbp = self.reconstruction.backproj @use_options("histogram", "histogram") def _init_histogram(self): options = self.processing_options["histogram"] self.histogram = self.HistogramClass(method="fixed_bins_number", num_bins=options["histogram_bins"]) @use_options("save", "writer") def _init_writer(self, chunk_info=None): options = self.processing_options["save"] file_prefix = options["file_prefix"] output_dir = path.join(options["location"], file_prefix) nx_info = None self._hdf5_output = is_hdf5_extension(options["file_format"]) if chunk_info is not None: d_v, d_h = self.process_config.dataset_info.radio_dims[::-1] h_rels = self._get_slice_start_index() + np.arange(chunk_info.span_v[1] - chunk_info.span_v[0]) fact_mm = self.process_config.dataset_info.pixel_size * 1.0e-3 heights_mm = ( fact_mm * (-self.span_info.z_pix_per_proj[0] + (d_v - 1) / 2 - h_rels) - self.span_info.z_offset_mm ) else: heights_mm = None if self._hdf5_output: fname_start_index = None file_prefix += str("_%06d" % self._get_slice_start_index()) entry = getattr(self.dataset_info.dataset_scanner, "entry", None) nx_info = { "process_name": self._get_process_name(), "processing_index": 0, "config": { "processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config, }, "entry": entry, } self._histogram_processing_index = nx_info["processing_index"] + 1 elif options["file_format"] in ["tif", "tiff"]: fname_start_index = self._get_slice_start_index() self._histogram_processing_index = 1 self._writer_configurator = WriterConfigurator( output_dir, file_prefix, file_format=options["file_format"], overwrite=options["overwrite"], start_index=fname_start_index, heights_above_stage_mm=heights_mm, logger=self.logger, nx_info=nx_info, write_histogram=("histogram" in self.processing_steps), histogram_entry=getattr(self.dataset_info.dataset_scanner, "entry", "entry"), ) self.writer = self._writer_configurator.writer self._writer_exec_args = self._writer_configurator._writer_exec_args self._writer_exec_kwargs = self._writer_configurator._writer_exec_kwargs self.histogram_writer = self._writer_configurator.get_histogram_writer() def _apply_expand_fact(self, t): if t is not None: t = t * self.chunk_reader.dataset_subsampling return t def _expand_slice(self, subchunk_slice): start, stop, step = subchunk_slice.start, subchunk_slice.stop, subchunk_slice.step if step is None: step = 1 start, stop, step = list(map(self._apply_expand_fact, [start, stop, step])) result_slice = slice(start, stop, step) return result_slice def _extract_preprocess_with_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, output): """Read, and apply dark+ff to, a small angular domain corresponding to the slice argument sub_total_prange_slice without refilling the holes. """ if self.chunk_reader.dataset_subsampling > 1: subsampling_file_slice = self._expand_slice(sub_total_prange_slice) else: subsampling_file_slice = sub_total_prange_slice my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 floating_start_z = subr_start_z_list.min() floating_end_z = subr_end_z_list.max() floating_subregion = None, None, floating_start_z, floating_end_z self._reset_reader_subregion(floating_subregion) self.chunk_reader.load_data(overwrite=True, sub_total_prange_slice=sub_total_prange_slice) my_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] data_raw = self.chunk_reader.data[: len(my_indexes)] if (self.flatfield is not None) or (self.double_flatfield is not None): sub_regions_per_radio = [self.trimmed_floating_subregion] * len(my_indexes) if self.flatfield is not None: self.flatfield.normalize_radios(data_raw, my_indexes, sub_regions_per_radio) if self.double_flatfield is not None: self.double_flatfield.apply_double_flatfield_for_sub_regions(data_raw, sub_regions_per_radio) source_start_x, source_end_x, source_start_z, sources_end_z = self.trimmed_floating_subregion if self.weights_field is not None: data_weight = self.weights_field.data[source_start_z:sources_end_z] else: data_weight = None for data_read, list_subr_start_z, list_subr_end_z, fract_shit, x_shift, data_target in zip( data_raw, subr_start_z_list, subr_end_z_list, fract_complement_shifts_v, x_shifts_list, output ): _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shit, list_subr_start_z, list_subr_end_z, source_start_z, sources_end_z, x_shift=x_shift, ) def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info): my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 self._reset_reader_subregion((None, None, subr_start_z_list.min(), subr_end_z_list.max())) dtasrc_start_x, dtasrc_end_x, dtasrc_start_z, dtasrc_end_z = self.trimmed_floating_subregion self.chunk_reader.load_data(overwrite=True, sub_total_prange_slice=sub_total_prange_slice) if self.chunk_reader.dataset_subsampling > 1: subsampling_file_slice = self._expand_slice(sub_total_prange_slice) else: subsampling_file_slice = sub_total_prange_slice my_subsampled_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] data_raw = self.chunk_reader.data[: len(my_subsampled_indexes)] self.regular_accumulator.extract_preprocess_with_flats( subchunk_slice, my_subsampled_indexes, chunk_info, np.array((subr_start_z, subr_end_z), "i"), np.array((dtasrc_start_z, dtasrc_end_z), "i"), data_raw, ) def binning_expanded(self, region): binning_x, binning_z = self.chunk_reader.get_binning() binnings = [binning_x] * 2 + [binning_z] * 2 res = [None if tok is None else tok * fact for tok, fact in zip(region, binnings)] return res def _reset_reader_subregion(self, floating_subregion): if self._resume_from_step is None: binning_x, binning_z = self.chunk_reader.get_binning() start_x, end_x, start_z, end_z = floating_subregion trimmed_start_z = max(0, start_z) trimmed_end_z = min(self.whole_radio_shape[0], end_z) my_buffer_height = trimmed_end_z - trimmed_start_z if self.radios_buffer is None or my_buffer_height > self.safe_buffer_height: self.safe_buffer_height = end_z - start_z assert ( self.safe_buffer_height >= my_buffer_height ), "This should always be true, if not contact the developer" self.radios_buffer = None self.radios_buffer = np.zeros( (self.reading_granularity + self.extra_marge_granularity,) + (self.safe_buffer_height, self.whole_radio_shape[1]), np.float32, ) self.trimmed_floating_subregion = start_x, end_x, trimmed_start_z, trimmed_end_z self.chunk_reader._set_subregion(self.binning_expanded(self.trimmed_floating_subregion)) self.chunk_reader._init_reader() self.chunk_reader._loaded = False self.chunk_reader.set_data_buffer(self.radios_buffer[:, :my_buffer_height, :], pre_allocate=False) else: message = "Resume not yet implemented in helical pipeline" raise RuntimeError(message) def _ccd_corrections(self, radios=None): if radios is None: radios = self.gridded_radios if hasattr(self.ccd_correction, "median_clip_correction_multiple_images"): self.ccd_correction.median_clip_correction_multiple_images(radios) else: _tmp_radio = self._cpu_allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio") for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] def _retrieve_phase(self): if "unsharp_mask" in self.processing_steps: for i in range(self.gridded_radios.shape[0]): self.gridded_radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) else: for i in range(self.gridded_radios.shape[0]): self.radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) def _nophase_put_to_radios(self, target, source): ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(target.shape[0]): target[i] = source[i][zslice, xslice] def _apply_unsharp(): ((up_margin, down_margin), (left_margin, right_margin)) = self._phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(self.radios.shape[0]): self.radios[i] = self.unsharp_mask.unsharp(self.gridded_radios[i])[zslice, xslice] def _take_log(self): self.mlog.take_logarithm(self.radios) @pipeline_step("sino_normalization", "Normalizing sinograms") def _normalize_sinos(self, radios=None): if radios is None: radios = self.radios sinos = radios.transpose((1, 0, 2)) self.sino_normalization.normalize(sinos) def _dump_sinogram(self, radios=None): if radios is None: radios = self.radios self._dump_data_to_file("sinogram", data=radios) @pipeline_step("sino_builder", "Building sinograms") def _build_sino(self): self.sinos = self.radios_slim def _filter(self): rot_center = self.processing_options["reconstruction"]["rotation_axis_position"] self.reconstruction.sino_filter.filter_sino( self.radios_slim, mirror_indexes=self.mirror_angle_relative_indexes, rot_center=rot_center, output=self.radios_slim, ) def _build_sino(self): self.sinos = self.radios_slim def _reconstruct(self, sinos=None, chunk_info=None, i_slice=0): if sinos is None: sinos = self.sinos use_hbp = self.process_config.nabu_config["reconstruction"]["use_hbp"] if not use_hbp: if i_slice == 0: self.reconstruction.set_custom_angles_and_axis_corrections( self.my_angles_rad, np.zeros_like(self.my_angles_rad) ) self.reconstruction.backprojection(sinos, output=self.recs[0]) self.recs[0].get(self.recs_stack[i_slice]) else: if self.reconstruction_hbp is None: raise ValueError("You requested the hierchical backprojector but the module could not be imported") self.reconstruction_hbp.backprojection(sinos, output=self.recs_stack[i_slice]) @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None, islice=None, num_slices=None, histo_stack=[]): if data is None: data = self.recs if i_slice == 0: pass my_histo = self.histogram.compute_histogram(data.ravel()) histo_stack.append(my_histo) if i_slice == num_slices - 1: self.recs_histogram = self.histogram.merge_histograms(histo_stack) def _write_data(self, data=None, counter=[0]): if data is None: data = self.recs_stack my_kw_args = copy.copy(self._writer_exec_kwargs) if "config" in my_kw_args: self.logger.info("omitting config in writer because of too slow nexus writer ") my_kw_args["config"] = {"test": counter[0]} counter[0] += 1 self.writer.write(data, *self._writer_exec_args, **my_kw_args) self.logger.info("Wrote %s" % self.writer.get_filename()) self._write_histogram() def _write_histogram(self): if "histogram" not in self.processing_steps: return self.logger.info("Saving histogram") self.histogram_writer.write( hist_as_2Darray(self.recs_histogram), self._get_process_name(kind="histogram"), processing_index=self._histogram_processing_index, config={ "file": path.basename(self.writer.get_filename()), "bins": self.processing_options["histogram"]["histogram_bins"], }, ) def _dump_data_to_file(self, step_name, data=None): if step_name not in self._data_dump: return self.logger.info(f"DUMP step_name={step_name}") if data is None: data = self.radios writer = self._data_dump[step_name] self.logger.info("Dumping data to %s" % writer.fname) writer.write_data(data) def balance_weights(self): options = self.processing_options["reconstruction"] rot_center = options["rotation_axis_position"] self.radios_weights[:] = rebalance(self.radios_weights, self.my_angles_rad, rot_center) # When standard scans are incomplete, due to motors errors, some angular range # is missing short of 360 degrees. # The weight accounting correctly deal with it, but still the padding # procedure with theta+180 data may fall on empty data # and this may cause problems, coming from the ramp filter, # in half tomo. # To correct this we complete with what we have at hand from the nearest # non empty data # to_be_filled = [] for i in range(len(self.radios_weights) - 1, 0, -1): if self.radios_weights[i].sum(): break to_be_filled.append(i) for i in to_be_filled: self.radios[i] = self.radios[to_be_filled[-1] - 1] def _post_primary_data_reduction(self, i_slice): """This will be used in the derived class to transfer data to gpu""" self.radios_slim[:] = self.radios[:, i_slice, :] def reset_translation_diagnostics_accumulators(self): self.diagnostic_radios[:] = 0 self.diagnostic_weights[:] = 0 self.diagnostic_proj_angle[1] = (2**30) * 3.14 self.diagnostic_proj_angle[0] = (2**30 - 1) * 3.14 def process_chunk(self, sub_region=None): self._private_process_chunk(sub_region=sub_region) self._process_finalize() def _private_process_chunk(self, sub_region=None): assert sub_region is not None, "sub_region argument is mandatory in helical pipeline" self.set_subregion(sub_region) self.reset_translation_diagnostics_accumulators() # self._allocate_reduced_radios() # self._allocate_reduced_gridded_and_subset_radios() (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region span_v = subr_start_z + self._phase_margin_up, subr_end_z - self._phase_margin_down chunk_info = self.span_info.get_chunk_info(span_v) self._reset_memory() self._init_writer(chunk_info) self._configure_data_dumps() proj_num_start, proj_num_end = chunk_info.angle_index_span n_granularity = self.reading_granularity pnum_start_list = list(np.arange(proj_num_start, proj_num_end, n_granularity)) pnum_end_list = pnum_start_list[1:] + [proj_num_end] my_first_pnum = proj_num_start self.gridded_cumulated_weights[:] = 0 self.gridded_radios[:] = 0 for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list): start_in_chunk = pnum_start - my_first_pnum end_in_chunk = pnum_end - my_first_pnum self._read_data_and_apply_flats( slice(pnum_start, pnum_end), slice(start_in_chunk, end_in_chunk), chunk_info ) self.gridded_radios[:] /= self.gridded_cumulated_weights if "flatfield" in self._data_dump: paganin_margin = self._phase_margin_up if paganin_margin: data_to_dump = self.gridded_radios[:, paganin_margin:-paganin_margin, :] else: data_to_dump = self.gridded_radios self._dump_data_to_file("flatfield", data_to_dump) if self.process_config.nabu_config["pipeline"]["skip_after_flatfield_dump"]: return if "ccd_correction" in self.processing_steps: self._ccd_corrections() if ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps): self._retrieve_phase() if "unsharp_mask" in self.processing_steps: self._apply_unsharp() else: self._nophase_put_to_radios(self.radios, self.gridded_radios) self.logger.info(" LOG ") self._nophase_put_to_radios(self.radios_weights, self.gridded_cumulated_weights) # print( " processing steps ", self.processing_steps ) # ['read_chunk', 'flatfield', 'double_flatfield', 'take_log', 'reconstruction', 'save'] if "take_log" in self.processing_steps: self._take_log() self.logger.info(" BALANCE ") self.balance_weights() num_slices = self.radios.shape[1] self.logger.info(" NORMALIZE") self._normalize_sinos() self._dump_sinogram() if "reconstruction" in self.processing_steps: for i_slice in range(num_slices): self._post_primary_data_reduction(i_slice) # charge on self.radios_slim self._filter() self.apply_weights(i_slice) self._build_sino() self._reconstruct(chunk_info=chunk_info, i_slice=i_slice) self._compute_histogram(islice=i_slice, num_slices=num_slices) self._write_data() def apply_weights(self, i_slice): """radios_slim is on gpu""" n_provided_angles = self.radios_slim.shape[0] for first_angle_index in range(0, n_provided_angles, self.num_weight_radios_per_app): end_angle_index = min(n_provided_angles, first_angle_index + self.num_weight_radios_per_app) self._d_radios_weights[: end_angle_index - first_angle_index].set( self.radios_weights[first_angle_index:end_angle_index, i_slice] ) self.radios_slim[first_angle_index:end_angle_index] *= self._d_radios_weights[ : end_angle_index - first_angle_index ] @classmethod def estimate_required_memory( cls, process_config, reading_granularity=None, chunk_size=None, margin_v=0, span_info=None ): """ Estimate the memory (RAM) needed for a reconstruction. Parameters ----------- process_config: `ProcessConfig` object Data structure with the processing configuration chunk_size: int, optional Size of a "radios chunk", i.e "delta z". A radios chunk is a 3D array of shape (n_angles, chunk_size, n_x) If set to None, then chunk_size = n_z Notes ----- It seems that Cuda does not allow allocating and/or transferring more than 16384 MiB (17.18 GB). If warn_from_GB is not None, then the result is in the form (estimated_memory_GB, warning) where warning is a boolean indicating wheher memory allocation/transfer might be problematic. """ dataset = process_config.dataset_info nabu_config = process_config.nabu_config processing_steps = process_config.processing_steps Nx, Ny = dataset.radio_dims total_memory_needed = 0 # Read data # ---------- # gridded part tmp_angles_deg = np.rad2deg(process_config.processing_options["reconstruction"]["angles"]) tmp_my_angle_step = abs(np.diff(tmp_angles_deg).mean()) my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean()) n_gridded_angles = int(round(360.0 / my_angle_step)) binning_z = nabu_config["dataset"]["binning_z"] projections_subsampling = nabu_config["dataset"]["projections_subsampling"] # the gridded target total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the gridded weights total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the read grain total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) # the preprocessed radios, their weigth and the buffer used for balancing ( total of three buffer of the same size plus mask plus temporary) total_memory_needed += 5 * (Nx * (chunk_size) * n_gridded_angles) * 4 if "flatfield" in processing_steps: # Flat-field is done in-place, but still need to load darks/flats n_darks = len(dataset.darks) n_flats = len(dataset.flats) darks_size = n_darks * Nx * (2 * margin_v + chunk_size) * 2 # uint16 flats_size = n_flats * Nx * (2 * margin_v + chunk_size) * 4 # f32 total_memory_needed += darks_size + flats_size if "ccd_correction" in processing_steps: total_memory_needed += Nx * (2 * margin_v + chunk_size) * 4 # Phase retrieval # --------------- if "phase" in processing_steps: # Phase retrieval is done image-wise, so near in-place, but needs to # allocate some images, fft plans, and so on Nx_p = get_next_power(2 * Nx) Ny_p = get_next_power(2 * (2 * margin_v + chunk_size)) img_size_real = 2 * 4 * Nx_p * Ny_p img_size_cplx = 2 * 8 * ((Nx_p * Ny_p) // 2 + 1) total_memory_needed += 2 * img_size_real + 3 * img_size_cplx # Reconstruction # --------------- reconstructed_volume_size = 0 if "reconstruction" in processing_steps: ## radios_slim is used to process one slice at once, It will be on the gpu ## and cannot be reduced further, therefore no need to estimate it. ## Either it passes or it does not. #### if radios_and_sinos: #### togtal_memory_needed += data_volume_size # radios + sinos rec_config = process_config.processing_options["reconstruction"] Nx_rec = rec_config["end_x"] - rec_config["start_x"] + 1 Ny_rec = rec_config["end_y"] - rec_config["start_y"] + 1 Nz_rec = chunk_size // binning_z ## the volume is used to reconstruct for each chunk reconstructed_volume_size = Nx_rec * Ny_rec * Nz_rec * 4 # float32 total_memory_needed += reconstructed_volume_size return total_memory_needed # target_central_slicer, source_central_slicer = overlap_logic( subr_start_z, subr_end_z, dtasrc_start_z, dtasrcs_end_z ) def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """determines the useful lines which can be transferred from the dtasrc_start_z:dtasrc_end_z range targeting the range subr_start_z: subr_end_z .................. """ t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z my_start = max(0, dtasrc_start_z - subr_start_z) my_end = min(t_h, dtasrc_end_z - subr_start_z) if my_start >= my_end: return None, None target_central_slicer = slice(my_start, my_end) my_start = max(0, subr_start_z - dtasrc_start_z) my_end = min(s_h, subr_end_z - dtasrc_start_z) assert my_start < my_end, "Overlap_logic logic error" dtasrc_central_slicer = slice(my_start, my_end) return target_central_slicer, dtasrc_central_slicer def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """.......... and the missing ranges which possibly could be obtained by extension padding""" t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z if dtasrc_start_z <= subr_start_z: target_lower_padding = None else: target_lower_padding = slice(0, dtasrc_start_z - subr_start_z) if dtasrc_end_z >= subr_end_z: target_upper_padding = None else: target_upper_padding = slice(dtasrc_end_z - subr_end_z, None) return target_lower_padding, target_upper_padding def _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shit, my_subr_start_z, my_subr_end_z, dtasrc_start_z, dtasrc_end_z, x_shift=0.0, extension_padding=True, ): """given a freshly read cube of data, it dispatches every slice to its proper vertical position and proper radio by shifting, cropping, and extending if necessary""" data_read_precisely_shifted = nd.interpolation.shift(data_read, (-fract_shit, x_shift), order=1, mode="nearest")[ :-1 ] target_central_slicer, dtasrc_central_slicer = overlap_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if None not in [target_central_slicer, dtasrc_central_slicer]: data_target[target_central_slicer] = data_read_precisely_shifted[dtasrc_central_slicer] target_lower_slicer, target_upper_slicer = padding_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if extension_padding: if target_lower_slicer is not None: data_target[target_lower_slicer] = data_read_precisely_shifted[0] if target_upper_slicer is not None: data_target[target_upper_slicer] = data_read_precisely_shifted[-1] else: if target_lower_slicer is not None: data_target[target_lower_slicer] = 1.0e-6 if target_upper_slicer is not None: data_target[target_upper_slicer] = 1.0e-6 def shift(arr, shift, fill_value=0.0): """trivial horizontal shift. Contrarily to scipy.ndimage.interpolation.shift, this shift does not cut the tails abruptly, but by interpolation """ result = np.zeros_like(arr) num1 = int(math.floor(shift)) num2 = num1 + 1 partition = shift - num1 for num, factor in zip([num1, num2], [(1 - partition), partition]): if num > 0: result[:, :num] += fill_value * factor result[:, num:] += arr[:, :-num] * factor elif num < 0: result[:, num:] += fill_value * factor result[:, :num] += arr[:, -num:] * factor else: result[:] += arr * factor return result def rebalance(radios_weights, angles, ax_pos): """rebalance the weights, within groups of equivalent (up to multiple of 180), data pixels""" balanced = np.zeros_like(radios_weights) n_span = int(math.ceil(angles[-1] - angles[0]) / np.pi) center = (radios_weights.shape[-1] - 1) / 2 nloop = balanced.shape[0] for i in range(nloop): w_res = balanced[i] angle = angles[i] for i_half_turn in range(-n_span - 1, n_span + 2): if i_half_turn == 0: w_res[:] += radios_weights[i] continue shifted_angle = angle + i_half_turn * np.pi insertion_index = np.searchsorted(angles, shifted_angle) if insertion_index in [0, angles.shape[0]]: if insertion_index == 0: continue else: if shifted_angle > 2 * np.pi: continue myimage = radios_weights[-1] else: partition = (shifted_angle - angles[insertion_index - 1]) / ( angles[insertion_index] - angles[insertion_index - 1] ) myimage = (1.0 - partition) * radios_weights[insertion_index - 1] + partition * radios_weights[ insertion_index ] if i_half_turn % 2 == 0: w_res[:] += myimage else: myimage = np.fliplr(myimage) w_res[:] += shift(myimage, (2 * ax_pos - 2 * center)) mask = np.equal(0, radios_weights) balanced[:] = radios_weights / balanced balanced[mask] = 0 return balanced ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/helical/helical_chunked_regridded_cuda.py0000644000175000017500000000765500000000000025055 0ustar00pierrepierrefrom math import ceil import numpy as np from ...utils import deprecated from ...preproc.flatfield_cuda import CudaFlatFieldDataUrls from ...preproc.shift import VerticalShift from ...preproc.shift_cuda import CudaVerticalShift from ...preproc.phase_cuda import CudaPaganinPhaseRetrieval from ...reconstruction.sinogram_cuda import CudaSinoBuilder, CudaSinoNormalization from ...reconstruction.sinogram import SinoBuilder, SinoNormalization from ...misc.unsharp_cuda import CudaUnsharpMask from ...misc.rotation_cuda import CudaRotation from ...misc.histogram_cuda import CudaPartialHistogram from .fbp import BackprojectorHelical try: from ...reconstruction.hbp import HierarchicalBackprojector # pylint: disable=E0401,E0611 print("Successfully imported hbp") except: HierarchicalBackprojector = None from ...cuda.utils import get_cuda_context, __has_pycuda__, __pycuda_error_msg__, replace_array_memory from ..utils import pipeline_step from .helical_chunked_regridded import HelicalChunkedRegriddedPipeline if __has_pycuda__: import pycuda.gpuarray as garray class CudaHelicalChunkedRegriddedPipeline(HelicalChunkedRegriddedPipeline): """ Cuda backend of HelicalChunkedPipeline """ VerticalShiftClass = CudaVerticalShift SinoBuilderClass = CudaSinoBuilder FBPClass = BackprojectorHelical HBPClass = HierarchicalBackprojector HistogramClass = CudaPartialHistogram SinoNormalizationClass = CudaSinoNormalization def __init__( self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None, cuda_options=None, reading_granularity=10, span_info=None, num_weight_radios_per_app=1000, ): self._init_cuda(cuda_options) super().__init__( process_config, sub_region, logger=logger, extra_options=extra_options, phase_margin=phase_margin, reading_granularity=reading_granularity, span_info=span_info, ) self._register_callbacks() self.num_weight_radios_per_app = num_weight_radios_per_app def _init_cuda(self, cuda_options): if not (__has_pycuda__): raise ImportError(__pycuda_error_msg__) cuda_options = cuda_options or {} self.ctx = get_cuda_context(**cuda_options) self._d_radios = None self._d_radios_weights = None self._d_sinos = None self._d_recs = None def _allocate_array(self, shape, dtype, name=None): name = name or "tmp" # should be mandatory d_name = "_d_" + name d_arr = getattr(self, d_name, None) if d_arr is None: self.logger.debug("Allocating %s: %s" % (name, str(shape))) d_arr = garray.zeros(shape, dtype) setattr(self, d_name, d_arr) return d_arr def _process_finalize(self): pass def _post_primary_data_reduction(self, i_slice): self._allocate_array((self.num_weight_radios_per_app,) + self.radios_slim.shape[1:], "f", name="radios_weights") if self.process_config.nabu_config["reconstruction"]["angular_tolerance_steps"]: self.radios[:, i_slice, :][np.isnan(self.radios[:, i_slice, :])] = 0 self.radios_slim.set(self.radios[:, i_slice, :]) def _register_callbacks(self): pass # # Pipeline execution (class specialization) # @pipeline_step("histogram", "Computing histogram") def _compute_histogram(self, data=None): if data is None: data = self._d_recs self.recs_histogram = self.histogram.compute_histogram(data) def _dump_data_to_file(self, step_name, data=None): if data is None: data = self.radios if step_name not in self._data_dump: return if isinstance(data, garray.GPUArray): data = data.get() super()._dump_data_to_file(step_name, data=data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/helical/helical_reconstruction.py0000644000175000017500000005331000000000000023475 0ustar00pierrepierrefrom os.path import join, isfile, basename, dirname from math import ceil import gc from time import time from psutil import virtual_memory from silx.io import get_data from silx.io.url import DataUrl import numpy as np import math import copy from ...resources.logger import LoggerOrPrint from ...io.writer import merge_hdf5_files, NXProcessWriter from ...cuda.utils import collect_cuda_gpus from ...preproc.phase import compute_paganin_margin from ...misc.histogram import PartialHistogram, add_last_bin, hist_as_2Darray from .span_strategy import SpanStrategy from .helical_chunked_regridded_cuda import CudaHelicalChunkedRegriddedPipeline from ..fullfield.reconstruction import collect_cuda_gpus, variable_idxlen_sort, FullFieldReconstructor avail_gpus = collect_cuda_gpus() or {} class HelicalReconstructorRegridded: """ A class for obtaining a full-volume reconstructions from helical-scan datasets. """ _pipeline_cls = CudaHelicalChunkedRegriddedPipeline _process_name = "reconstruction" _pipeline_mode = "helical" reading_granularity = 100 """ The data angular span which needs to be read for a reconstruction is read step by step, reading each time a maximum of reading_granularity radios, and doing the preprocessing till phase retrieval for each of these angular groups """ def __init__(self, process_config, logger=None, extra_options=None, cuda_options=None): """ Initialize a LocalReconstruction object. This class is used for managing pipelines Parameters ---------- process_config: ProcessConfig object Data structure with process configuration logger: Logger, optional logging object extra_options: dict, optional Dictionary with advanced options. Please see 'Other parameters' below cuda_options: dict, optional Dictionary with cuda options passed to `nabu.cuda.processing.CudaProcessing` Other parameters ----------------- Advanced options can be passed in the 'extra_options' dictionary. These can be: - "gpu_mem_fraction": 0.9, - "cpu_mem_fraction": 0.9, - "use_phase_margin": True, - "max_chunk_size": None, - "phase_margin": None, - "dry_run": 0, """ self.logger = LoggerOrPrint(logger) self.process_config = process_config self._set_extra_options(extra_options) self._get_resources() self._margin_v, self._margin_h = self._compute_phase_margin() self._setup_span_info() self._compute_max_chunk_size() self._get_reconstruction_range() self._build_tasks() self.pipeline = None self.cuda_options = cuda_options def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} advanced_options = { "gpu_mem_fraction": 0.9, "cpu_mem_fraction": 0.9, "use_phase_margin": True, "max_chunk_size": None, "phase_margin": None, "dry_run": 0, } advanced_options.update(extra_options) self.extra_options = advanced_options self.gpu_mem_fraction = self.extra_options["gpu_mem_fraction"] self.cpu_mem_fraction = self.extra_options["cpu_mem_fraction"] self.use_phase_margin = self.extra_options["use_phase_margin"] self.dry_run = self.extra_options["dry_run"] self._do_histograms = self.process_config.nabu_config["postproc"]["output_histogram"] self._histogram_merged = False self._span_info = None def _get_reconstruction_range(self): rec_cfg = self.process_config.nabu_config["reconstruction"] self.z_min = rec_cfg["start_z"] self.z_max = rec_cfg["end_z"] + 1 z_min_mm = rec_cfg["start_z_mm"] z_max_mm = rec_cfg["end_z_mm"] if z_min_mm != 0.0 or z_max_mm != 0.0: z_min_mm += self.z_offset_mm z_max_mm += self.z_offset_mm d_v, d_h = self.process_config.dataset_info.radio_dims[::-1] z_start, z_end = (self._span_info.get_doable_span()).view_heights_minmax z_end += 1 h_s = np.arange(z_start, z_end) fact_mm = self.process_config.dataset_info.pixel_size * 1.0e-3 z_mm_s = fact_mm * (-self._span_info.z_pix_per_proj[0] + (d_v - 1) / 2 - h_s) self.z_min = 0 self.z_max = len(z_mm_s) if z_mm_s[-1] > z_mm_s[0]: for i in range(len(z_mm_s) - 1): if (z_min_mm - z_mm_s[i]) * (z_min_mm - z_mm_s[i + 1]) <= 0: self.z_min = i break for i in range(len(z_mm_s) - 1): if (z_max_mm - z_mm_s[i]) * (z_max_mm - z_mm_s[i]) <= 0: self.z_max = i + 1 break else: for i in range(len(z_mm_s) - 1): if (z_max_mm - z_mm_s[i]) * (z_max_mm - z_mm_s[i + 1]) <= 0: self.z_max = len(z_mm_s) - 2 - i break for i in range(len(z_mm_s) - 1): if (z_min_mm - z_mm_s[i]) * (z_min_mm - z_mm_s[i + 1]) <= 0: self.z_min = len(z_mm_s) - 1 - i break _get_resources = FullFieldReconstructor._get_resources _get_memory = FullFieldReconstructor._get_memory _get_gpu = FullFieldReconstructor._get_gpu _compute_phase_margin = FullFieldReconstructor._compute_phase_margin _compute_unsharp_margin = FullFieldReconstructor._compute_unsharp_margin _print_tasks = FullFieldReconstructor._print_tasks _instantiate_pipeline_if_necessary = FullFieldReconstructor._instantiate_pipeline_if_necessary _destroy_pipeline = FullFieldReconstructor._destroy_pipeline _give_progress_info = FullFieldReconstructor._give_progress_info get_relative_files = FullFieldReconstructor.get_relative_files merge_histograms = FullFieldReconstructor.merge_histograms merge_data_dumps = FullFieldReconstructor.merge_data_dumps _get_chunk_length = FullFieldReconstructor._get_chunk_length # redefined here, and with self, otherwise @static and inheritance gives "takes 1 positional argument but 2 were given" # when called from inside the inherite class def _get_delta_z(self, task): return task["sub_region"][1] - task["sub_region"][0] def _get_task_key(self): """ Get the 'key' (number) associated to the current task/pipeline """ return self.pipeline.sub_region[-2:] # Gpu required memory size does not depend on the number of slices def _compute_max_chunk_size(self): cpu_mem = self.resources["mem_avail_GB"] * self.cpu_mem_fraction user_max_chunk_size = self.extra_options["max_chunk_size"] self.cpu_max_chunk_size = self.estimate_chunk_size( cpu_mem, self.process_config, chunk_step=1, user_max_chunk_size=user_max_chunk_size ) if user_max_chunk_size is not None: self.cpu_max_chunk_size = min(self.cpu_max_chunk_size, user_max_chunk_size) self.user_slices_at_once = self.cpu_max_chunk_size # cannot use the estimate_chunk_size from computations.py beacause it has the estimate_required_memory hard-coded def estimate_chunk_size(self, available_memory_GB, process_config, chunk_step=1, user_max_chunk_size=None): """ Estimate the maximum chunk size given an avaiable amount of memory. Parameters ----------- available_memory_GB: float available memory in Giga Bytes (GB - not GiB !). process_config: ProcessConfig ProcessConfig object """ chunk_size = chunk_step radios_and_sinos = False if ( "reconstruction" in process_config.processing_steps and process_config.processing_options["reconstruction"]["enable_halftomo"] ): radios_and_sinos = True max_dz = process_config.dataset_info.radio_dims[1] chunk_size = chunk_step last_good_chunk_size = chunk_size while True: required_mem = self._pipeline_cls.estimate_required_memory( process_config, chunk_size=chunk_size, reading_granularity=self.reading_granularity, margin_v=self._margin_v, span_info=self._span_info, ) required_mem_GB = required_mem / 1e9 if required_mem_GB > available_memory_GB: break last_good_chunk_size = chunk_size if user_max_chunk_size is not None and chunk_size > user_max_chunk_size: last_good_chunk_size = user_max_chunk_size break chunk_size += chunk_step return last_good_chunk_size # different because of dry_run def _build_tasks(self): if self.dry_run: self.tasks = [] else: self._compute_volume_chunks() # this is very different def _compute_volume_chunks(self): margin_v = self._margin_v # self._margin_far_up = min(margin_v, self.z_min) # self._margin_far_down = min(margin_v, n_z - (self.z_max + 1)) ## It will be the reading process which pads self._margin_far_up = margin_v self._margin_far_down = margin_v # | margin_up | n_slices | margin_down | # |-----------|-----------------|--------------| # |----------------------------------------------------| # delta_z n_slices = self.user_slices_at_once z_start, z_end = (self._span_info.get_doable_span()).view_heights_minmax z_end += 1 if (self.z_min, self.z_max) == (0, 0): self.z_min, self.z_max = z_start, z_end my_z_min = z_start my_z_end = z_end else: if self.z_max <= self.z_min: message = f"""" The input file provide start_z end_z {self.z_min,self.z_max} but it is necessary that start_z < end_z """ raise ValueError(message) if self._span_info.z_pix_per_proj[-1] > self._span_info.z_pix_per_proj[0]: my_z_min = z_start + self.z_min my_z_end = z_start + self.z_max else: my_z_min = z_end - self.z_max my_z_end = z_end - self.z_min my_z_min = max(z_start, my_z_min) my_z_end = min(z_end, my_z_end) print("my_z_min my_z_end ", my_z_min, my_z_end) if my_z_min >= my_z_end: message = f""" The requested vertical span, after translation to absolute doable heights would be {my_z_min, my_z_end} is not doable (start>=end) """ raise ValueError(message) # if my_z_min != self.z_min or my_z_end != self.z_max: # message = f""" The requested vertical span given by self.z_min, self.z_max+1 ={self.z_min, self.z_max} # is not withing the doable span which is {z_start, z_end} # """ # raise ValueError(message) tasks = [] n_stages = ceil((my_z_end - my_z_min) / n_slices) curr_z_min = my_z_min curr_z_max = my_z_min + n_slices for i in range(n_stages): if curr_z_max >= my_z_end: curr_z_max = my_z_end margin_down = margin_v margin_up = margin_v tasks.append( { "sub_region": (curr_z_min - margin_up, curr_z_max + margin_down), "phase_margin": ((margin_up, margin_down), (0, 0)), } ) if curr_z_max == my_z_end: # No need for further tasks break curr_z_min += n_slices curr_z_max += n_slices ## ## ########################################################################################### self.tasks = tasks self.n_slices = n_slices self._print_tasks() def _setup_span_info(self): """We create here an instance of SpanStrategy class for helical scans. This class do all the accounting for the doable slices, giving for each the useful angle, the shifts .. """ # projections_subsampling = self.process_config.dataset_info.projections_subsampling projections_subsampling = self.process_config.nabu_config["dataset"]["projections_subsampling"] radio_shape = self.process_config.dataset_info.radio_dims[::-1] dz_per_proj = self.process_config.nabu_config["reconstruction"]["dz_per_proj"] z_per_proj = self.process_config.dataset_info.z_per_proj dx_per_proj = self.process_config.nabu_config["reconstruction"]["dx_per_proj"] x_per_proj = self.process_config.dataset_info.x_per_proj tot_num_images = len(self.process_config.processing_options["read_chunk"]["files"]) // projections_subsampling if z_per_proj is not None: z_per_proj = np.array(z_per_proj) self.logger.info(" z_per_proj has been explicitely provided") if len(z_per_proj) != tot_num_images: message = f""" The provided array z_per_proj, which has length {len(z_per_proj)} must match in lenght the number of radios which is {tot_num_images} """ raise ValueError(message) else: z_per_proj = self.process_config.dataset_info.z_translation if dz_per_proj is not None: self.logger.info("correcting vertical displacement by provided screw rate dz_per_proj") z_per_proj += np.arange(tot_num_images) * dz_per_proj if x_per_proj is not None: x_per_proj = np.array(x_per_proj) self.logger.info(" x_per_proj has been explicitely provided") if len(x_per_proj) != tot_num_images: message = f""" The provided array x_per_proj, which has length {len(x_per_proj)} must match in lenght the number of radios which is {tot_num_images} """ raise ValueError(message) else: x_per_proj = self.process_config.dataset_info.x_translation if dx_per_proj is not None: self.logger.info("correcting vertical displacement by provided screw rate dx_per_proj") x_per_proj += np.arange(tot_num_images) * dx_per_proj x_per_proj = x_per_proj - x_per_proj[0] self.z_offset_mm = z_per_proj[0] * self.process_config.dataset_info.pixel_size * 1.0e-3 # micron to mm z_per_proj = z_per_proj - z_per_proj[0] binning = self.process_config.nabu_config["dataset"]["binning"] if binning is not None: if np.isscalar(binning): binning = (binning, binning) binning_x, binning_z = binning x_per_proj = x_per_proj / binning_x z_per_proj = z_per_proj / binning_z x_per_proj = projections_subsampling * x_per_proj z_per_proj = projections_subsampling * z_per_proj angles_rad = self.process_config.processing_options["reconstruction"]["angles"] angles_rad = np.unwrap(angles_rad) angles_deg = np.rad2deg(angles_rad) redundancy_angle_deg = self.process_config.nabu_config["reconstruction"]["redundancy_angle_deg"] do_helical_half_tomo = self.process_config.nabu_config["reconstruction"]["helical_halftomo"] self.logger.info("Creating SpanStrategy object for helical ") t0 = time() self._span_info = SpanStrategy( z_offset_mm=self.z_offset_mm, z_pix_per_proj=z_per_proj, x_pix_per_proj=x_per_proj, detector_shape_vh=radio_shape, phase_margin_pix=self._margin_v, projection_angles_deg=angles_deg, pixel_size_mm=self.process_config.dataset_info.pixel_size * 1.0e-3, # micron to mm logger=self.logger, require_redundancy=(redundancy_angle_deg > 0), angular_tolerance_steps=self.process_config.nabu_config["reconstruction"]["angular_tolerance_steps"], ) duration = time() - t0 self.logger.info(f"Creating SpanStrategy object for helical in {duration} seconds") if self.dry_run: info_string = self._span_info.get_informative_string() print(" Informations about the doable vertical span") print(info_string) return def _instantiate_pipeline(self, task): self.logger.debug("Creating a new pipeline object") args = [self.process_config, task["sub_region"]] dz = self._get_delta_z(task) pipeline = self._pipeline_cls( *args, logger=self.logger, phase_margin=task["phase_margin"], reading_granularity=self.reading_granularity, span_info=self._span_info # cuda_options=self.cuda_options ) self.pipeline = pipeline # kept to save diagnostic def _process_task(self, task): self.pipeline.process_chunk(sub_region=task["sub_region"]) key = len(list(self.diagnostic_per_chunk.keys())) self.diagnostic_per_chunk[key] = copy.deepcopy(self.pipeline.diagnostic) # kept for diagnostic and dry run def reconstruct(self): self._print_tasks() self.diagnostic_per_chunk = {} tasks = self.tasks self.results = {} self._histograms = {} self._data_dumps = {} prev_task = None for task in tasks: if prev_task is None: prev_task = task self._give_progress_info(task) self._instantiate_pipeline_if_necessary(task, prev_task) if self.dry_run: info_string = self._span_info.get_informative_string() print(" SPAN_INFO informations ") print(info_string) return self._process_task(task) if self.pipeline.writer is not None: task_key = self._get_task_key() self.results[task_key] = self.pipeline.writer.fname if self.pipeline.histogram_writer is not None: # self._do_histograms self._histograms[task_key] = self.pipeline.histogram_writer.fname if len(self.pipeline._data_dump) > 0: self._data_dumps[task_key] = {} for step_name, writer in self.pipeline._data_dump.items(): self._data_dumps[task_key][step_name] = writer.fname prev_task = task ## kept in order to speed it up by omitting the super slow python writing ## of thousand of unused urls def merge_hdf5_reconstructions( self, output_file=None, prefix=None, files=None, process_name=None, axis=0, merge_histograms=True, output_dir=None, ): """ Merge existing hdf5 files by creating a HDF5 virtual dataset. Parameters ---------- output_file: str, optional Output file name. If not given, the file prefix in section "output" of nabu config will be taken. """ out_cfg = self.process_config.nabu_config["output"] out_dir = output_dir or out_cfg["location"] prefix = prefix or "" # Prevent issue when out_dir is empty, which happens only if dataset/location is a relative path. # TODO this should be prevented earlier if out_dir is None or len(out_dir.strip()) == 0: out_dir = dirname(dirname(self.results[list(self.results.keys())[0]])) # if output_file is None: output_file = join(out_dir, prefix + out_cfg["file_prefix"]) + ".hdf5" if isfile(output_file): msg = str("File %s already exists" % output_file) if out_cfg["overwrite_results"]: msg += ". Overwriting as requested in configuration file" self.logger.warning(msg) else: msg += ". Set overwrite_results to True in [output] to overwrite existing files." self.logger.fatal(msg) raise ValueError(msg) local_files = files if local_files is None: local_files = self.get_relative_files() if local_files == []: self.logger.error("No files to merge") return entry = getattr(self.process_config.dataset_info.dataset_scanner, "entry", "entry") process_name = process_name or self._process_name h5_path = join(entry, *[process_name, "results", "data"]) # self.logger.info("Merging %ss to %s" % (process_name, output_file)) print("omitting config in call to merge_hdf5_files because export2dict too slow") merge_hdf5_files( local_files, h5_path, output_file, process_name, output_entry=entry, output_filemode="a", processing_index=0, config={ self._process_name + "_stages": {str(k): v for k, v in zip(self.results.keys(), local_files)}, "diagnostics": self.diagnostic_per_chunk, }, # config={ # self._process_name + "_stages": {str(k): v for k, v in zip(self.results.keys(), local_files)}, # "nabu_config": self.process_config.nabu_config, # "processing_options": self.process_config.processing_options, # }, base_dir=out_dir, axis=axis, overwrite=out_cfg["overwrite_results"], ) if merge_histograms: self.merge_histograms(output_file=output_file) return output_file merge_hdf5_files = merge_hdf5_reconstructions ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/helical/helical_utils.py0000644000175000017500000000263700000000000021562 0ustar00pierrepierrefrom ...resources.logger import LoggerOrPrint import numpy as np logger = LoggerOrPrint(None) def find_mirror_indexes(angles_deg, tolerance_factor=1.0): """return a list of indexes where the ith elememnt contains the index of the angles_deg array element which has the value the closest to angles_deg[i] + 180. It is used for padding in halftomo. Parameters: ----------- angles_deg: a nd.array of floats tolerance: float if the mirror positions are not within a distance less than tolerance fro; the ideal position a warning is raised """ av_step = abs(np.diff(angles_deg).mean()) tolerance = av_step * tolerance_factor tmp_mirror_angles_deg = angles_deg + 180 mirror_angle_relative_indexes = (abs(abs(np.mod(tmp_mirror_angles_deg[:, None] - angles_deg, 360) - 180))).argmax( axis=-1 ) mirror_values = angles_deg[mirror_angle_relative_indexes] differences = abs(np.mod(mirror_values - angles_deg, 360) - 180) if differences.max() > tolerance: logger.warning( f"""In function find_mirror_indexes the mirror position are far beyon tolerance from ideal position tolerance is {tolerance} given by average step {av_step} and tolerance_factor {tolerance_factor} and the maximum error is {differences.max()} """ ) return mirror_angle_relative_indexes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/helical/nabu_config.py0000644000175000017500000001620300000000000021205 0ustar00pierrepierrefrom ..fullfield.nabu_config import * import copy ## keep the text below for future inclusion in the documentation # start_z, end_z , start_z_mm, end_z_mm # ---------------------------------------------------------------- # By default, all the volume is reconstructed slice by slice, along the axis 'z'. # ** option 1) you can set # start_z_mm=0 # end_z_mm = 0 # # Now, concerning start_z and end_z, Use positive integers, with start_z < end_z # The reconstructed vertical region will be # slice start = first doable + start_z # slice end = first doable + end_z # or less if such range needs to be clipped to the doable one # # As an example start_z= 10, end_z = 20 # for reconstructing 10 slices close to scan start. # NOTE: we are proceeding in the direction of the scan so that, in millimiters, # the start may be above or below the end # To reconstruct the whole doable volume set # # start_z= 0 # end_z =-1 # # ** option 2) using start_z_mm, end_z_mm # Use positive floats, in millimiters. They indicate the height above the sample stage # The values of start_z and end_z are not used in this case help_start_end_z = """ If both start_z_mm and end_z_mm are seto to zero, then start_z and end_z will be effective. otherwhise the slices whose height above the sample stage, in millimiters, between start_z_mm and end_z_mm are reconstructed """ # we need to deepcopy this in order not to mess the original nabu_config of the full-field pipeline nabu_config = copy.deepcopy(nabu_config) nabu_config["preproc"]["processes_file"] = { "default": "", "help": "Path tgo the file where some operations should be stored for later use. By default it is 'xxx_nabu_processes.h5'", "validator": optional_file_location_validator, "type": "required", } nabu_config["preproc"]["double_flatfield_enabled"]["default"] = 1 nabu_config["reconstruction"].update( { "dz_per_proj": { "default": 0, "help": " A positive DZPERPROJ means that the rotation axis is going up. Alternatively the vertical translations, can be given through an array using the variable z_per_proj_file", "validator": float_validator, "type": "optional", }, "z_per_proj_file": { "default": "", "help": "Alternative to dz_per_proj. A file where each line has one value: vertical displacements of the axis. There should be as many values as there are projection images.", "validator": optional_file_location_validator, "type": "optional", }, "dx_per_proj": { "default": 0, "help": " A positive value means that the rotation axis is going on the rigth. Alternatively the horizontal translations, can be given through an array using the variable x_per_proj_file", "validator": float_validator, "type": "optional", }, "x_per_proj_file": { "default": "", "help": "Alternative to dx_per_proj. A file where each line has one value: horizontal displacements of the axis. There should be as many values as there are projection images.", "validator": optional_file_location_validator, "type": "optional", }, "axis_to_the_center": { "default": "1", "help": "Whether to shift start_x and start_y so to have the axis at the center", "validator": boolean_validator, "type": "optional", }, "auto_size": { "default": "1", "help": "Wether to set automatically start_x end_x start_y end_y ", "validator": boolean_validator, "type": "optional", }, "use_hbp": { "default": "0", "help": "Wether to use hbp routine instead of the backprojector from fbp ", "validator": boolean_validator, "type": "optional", }, "fan_source_distance_meters": { "default": 1.0e9, "help": "For HBP, for the description of the fan geometry, the source to axis distance. Defaults to a large value which implies parallel geometry", "validator": float_validator, "type": "optional", }, "start_z_mm": { "default": "0", "help": help_start_end_z, "validator": float_validator, "type": "optional", }, "end_z_mm": { "default": "0", "help": " To determine the reconstructed vertical range: the height in millimiters above the stage below which slices are reconstructed ", "validator": float_validator, "type": "optional", }, "start_z": { "default": "0", "help": "the first slice of the reconstructed range. Numbered going in the direction of the scan and starting with number zero for the first doable slice", "validator": slice_num_validator, "type": "optional", }, "end_z": { "default": "-1", "help": "the " "end" " slice of the reconstructed range. Numbered going in the direction of the scan and starting with number zero for the first doable slice", "validator": slice_num_validator, "type": "optional", }, } ) nabu_config["pipeline"].update( { "skip_after_flatfield_dump": { "default": "0", "help": "When the writing of the flatfielded data is activated, if this option is set, then the phase and reconstruction steps are skipped", "validator": boolean_validator, "type": "optional", }, } ) nabu_config["reconstruction"].update( { "angular_tolerance_steps": { "default": "3.0", "help": "the angular tolerance, an angular width expressed in units of an angular step, which is tolerated in the criteria for deciding if a slice is reconstructable or not", "validator": float_validator, "type": "advanced", }, "redundancy_angle_deg": { "default": "0", "help": "Can be 0,180 or 360. If there are dead detector regions (notably scintillator junction (stripes) which need to be complemented at +-360 for local tomo or +- 180 for conventional tomo. This may have an impact on the doable vertical span (you can check it with the --dry-run 1 option)", "validator": float_validator, "type": "advanced", }, "enable_halftomo": { "default": "0", "help": "nabu-helical applies the same treatment for half-tomo as for full-tomo. Always let this key to zero", "validator": boolean_validator, "type": "advanced", }, "helical_halftomo": { "default": "1", "help": "Wether to consider doable slices those which are contributed by an angular span greater or equal to 360, instead of just 180 or more", "validator": boolean_validator, "type": "advanced", }, } ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956532.0 nabu-2023.1.1/nabu/pipeline/helical/processconfig.py0000644000175000017500000000501300000000000021574 0ustar00pierrepierrefrom .nabu_config import nabu_config, renamed_keys from .dataset_validator import HelicalDatasetValidator from ..fullfield import processconfig as ff_processconfig from ...resources import dataset_analyzer class ProcessConfig(ff_processconfig.ProcessConfig): default_nabu_config = nabu_config config_renamed_keys = renamed_keys _use_horizontal_translations = False def _configure_save_steps(self): self._dump_sinogram = False steps_to_save = self.nabu_config["pipeline"]["save_steps"] if steps_to_save in (None, ""): self.steps_to_save = [] return steps_to_save = [s.strip() for s in steps_to_save.split(",")] for step in self.processing_steps: step = step.strip() if step in steps_to_save: self.processing_options[step]["save"] = True self.processing_options[step]["save_steps_file"] = self.get_save_steps_file(step_name=step) # "sinogram" is a special keyword, not explicitly in the processing steps if "sinogram" in steps_to_save: self._dump_sinogram = True self._dump_sinogram_file = self.get_save_steps_file(step_name="sinogram") self.steps_to_save = steps_to_save def _update_dataset_info_with_user_config(self): super()._update_dataset_info_with_user_config() self._get_translation_file("reconstruction", "z_per_proj_file", "z_per_proj", last_dim=1) self._get_translation_file("reconstruction", "x_per_proj_file", "x_per_proj", last_dim=1) def _get_user_sino_normalization(self): """is called by the base class but it is not used in helical""" pass def _coupled_validation(self): self.logger.debug("Doing coupled validation") self._dataset_validator = HelicalDatasetValidator(self.nabu_config, self.dataset_info) def _browse_dataset(self, dataset_info): """ Browse a dataset and builds a data structure with the relevant information. """ self.logger.debug("Browsing dataset") if dataset_info is not None: self.dataset_info = dataset_info else: self.dataset_info = dataset_analyzer.analyze_dataset( self.nabu_config["dataset"]["location"], extra_options={ "exclude_projections": self.nabu_config["dataset"]["exclude_projections"], "hdf5_entry": self.nabu_config["dataset"]["hdf5_entry"], }, logger=self.logger, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/pipeline/helical/span_strategy.py0000644000175000017500000006072300000000000021624 0ustar00pierrepierreimport math import numpy as np from ...resources.logger import LoggerOrPrint from ...utils import DictToObj class SpanStrategy: def __init__( self, z_pix_per_proj, x_pix_per_proj, detector_shape_vh, phase_margin_pix, projection_angles_deg, require_redundancy=False, pixel_size_mm=0.1, z_offset_mm=0.0, logger=None, angular_tolerance_steps=0.0, ): """ This class does all the accounting for the reconstructible slices, giving for each one the list of the useful angles, of the vertical and horizontal shifts,and more ... Parameters ---------- z_pix_per_proj : array of floats an array of floats with one entry per projection, in pixel units. The values are the vertical displacements of the detector. An decreasing z means that the rotation axis is following the positive direction of the detector vertical axis, which is pointing toward the ground. In the experimental setup, the vertical detector axis is pointing toward the ground. Moreover the values are offsetted so that the first value is zero. The offset value, in millimiters is z_offset_mm and it is the vertical position of the sample stage relatively to the center of the detector. A negative z_offset_mm means that the sample stage is below the detector for the first projection, and this is almost always the case, because the sample is above the sample stage. A z_pix=0 value indicates that the translation-rotation stage "ground" is exactly at the beam height ( supposed to be near the central line of the detector) plus z_offset_mm. A positive z_pix means that the translation stage has been lowered, compared to the first projection, in order to scan higher parts of the sample. ( the sample is above the translation-rotation stage). x_pix_per_proj : array of floats one entry per projection. The horizontal displacement of the detector respect to the rotation axis. A positive x means that the sample shadow on the detector is moving toward the left of the detector. (the detector is going right) detector_shape_vh : a tuple of two ints the vertical and horizontal dimensions phase_margin_pix : int the maximum margin needed for the different kernels (phase, unsharp..) otherwhere in the pipeline projection_angles_deg : array of floats per each projection the rotation angle of the sample in degree. require_redundancy: bool, optional, defaults to False It can be set to True, when there are dead zones in the detector. In this case the minimal required angular span is increased from 360 to 2*360 in order to enforce the necessary redundancy, which allows the correction of the dead zones. The lines which do not satisfy this requirement are not doable. z_offset_mm: float the vertical position of the sample stage relatively to the center of the detector. A negative z_offset_mm means that the sample stage is below the detector for the first projection, and this is almost always the case, because the sample is above the sample stage. pixel_size_mm: float, the pixel size in millimiters this value is used to give results in units of " millimeters above the sample stage" Althougth internally all is calculated in pixel units, it is useful to incorporate such information in the spanstrategy object which will then be able to setup reconstruction informations according to several request formats: be they in sequential number of reconstructible slices, or millimeters above the stage. angular_tolerance_steps: float, defaults to zero the angular tolerance, an angular width expressed in units of an angular step, which is tolerated in the criteria for deciding if a slice is reconstructable or not logger : a logger, optional """ self.logger = LoggerOrPrint(logger) self.require_redundancy = require_redundancy self.z_pix_per_proj = z_pix_per_proj self.x_pix_per_proj = x_pix_per_proj self.detector_shape_vh = detector_shape_vh self.total_num_images = len(self.z_pix_per_proj) self.phase_margin_pix = phase_margin_pix self.pix_size_mm = pixel_size_mm self.z_offset_mm = z_offset_mm self.angular_tolerance_steps = angular_tolerance_steps # internally we use increasing angles, so that in all inequalities, that are used # to check the span, only such case has to be contemplated. # To do so, if needed, we change the sign of the angles. if projection_angles_deg[-1] > projection_angles_deg[0]: self.projection_angles_deg_internally = projection_angles_deg self.angles_are_inverted = False else: self.projection_angles_deg_internally = -projection_angles_deg self.angles_are_inverted = True self.projection_angles_deg = projection_angles_deg if ( len(self.x_pix_per_proj) != self.total_num_images or len(self.projection_angles_deg_internally) != self.total_num_images ): message = f""" all the arguments z_pix_per_proj, x_pix_per_proj and projection_angles_deg must have the same lenght but their lenght were {len(self.z_pix_per_proj) }, {len(self.x_pix_per_proj) }, {len(self.projection_angles_deg_internally) } respectively """ raise ValueError(message) ## informations to be built are initialised to None here below """ For a given slice, the procedure for obtaining the useful angles, is based on the "sunshine" image, The sunshine image has two dimensions, the second one is the projection number while the first runs over the heights of the slices. All the logic is based on this image: when a given pixel of this image is zero, this corresponds to a pair (height,projection) for which there is no contribution of that projection to that slice. """ self.sunshine_image = None """ This will be an array of integer heights. We use a one-to-one correspondance beween these integers and slices in the reconstructed volume. The value of a given item is the vertical coordinate of an horizontal line in in the dectector (or above or below (negative values) ). More precisely the line over which the corresponding slice gets projected at the beginning of the scan ( in other words for the first vertical translation entry of the z_pix_per_proj argument) Illustratively, a view height equal to zero corresponds to a slice which projects on row 0 when translation is given by the first value in self.z_pix_per_proj. Negative integers for the heights are possible too, according to the direction of translations, which may bring above or below the detector. ( Illustratively the values in z_pix_per_proj are alway positive for the simple fact that the sample is above the roto-translation stage and the stage must be lowered in height for the beam to hit the sample. A scan starts always from a positive z_pix translation but the z_pix value may either decrease or increase. In fact for practical reasons, after having done a scan in one direction it is convenient to scan also while coming back after a previous scan) """ self.total_view_heights = None """ A list wich will contain for every reconstructable heigth, listed in self.total_view_heights, and in the same order, a pair of two integer, the first is the first sequential number of the projection for which the height is projected inside the detector, the second integer is the last projection number for which the projection occurs inside the detector. """ self.on_detector_projection_intervals = None """ All like self.on_detector_projection_intervals, but considering also the margins (phase, unsharp, etc). so that the height is projected inside the detector but while keeping a safe phase_margin_pix distance from the upper and lower border of the detector. """ self.within_margin_projection_intervals = None """ This will contain projection number i_pro, the integer value given by ceil(z_pix_per_proj[i_pro]) This array, together with the here below array self.fract_complement_to_integer_shift_v will be used for cropping when the data are collected for a given to be reconstructed chunk. """ self.integer_shift_v = np.zeros([self.total_num_images], "i") """ The fractional vertical shifts are positive floats < 1.0 pixels. This is the fractional part which, added to self.integer_shift_v, gives back z_pix_per_proj. Together with integer_shift_v, this array is meant to be used, by other modules, for cropping when the data are collected for a given to be reconstructed chunk.""" self.fract_complement_to_integer_shift_v = np.zeros([self.total_num_images], "f") self._setup_ephemerides() self._setup_sunshine() def get_doable_span(self): """return an object with two properties: view_heights_minmax: containining minimum and maximum doable height ( detector reference at iproj=0) z_pix_minmax : containing minimum and maximum heights above the roto-translational sample stage """ vertical_profile = self.sunshine_image.sum(axis=1) doable_indexes = np.arange(len(vertical_profile))[vertical_profile > 0] vertical_span = doable_indexes.min(), doable_indexes.max() if not (vertical_profile[vertical_span[0] : vertical_span[1] + 1] > 0).all(): message = """ Something wrong occurred in the span preprocessing. It appears that some intermetiade slices are not doables. The doable span should instead be contiguous. Please signal the problem""" raise RuntimeError(message) view_heights_minmax = self.total_view_heights[list(vertical_span)] hmin, hmax = view_heights_minmax d_v, d_h = self.detector_shape_vh z_min, z_max = (-self.z_pix_per_proj[0] + (d_v - 1) / 2 - hmax, -self.z_pix_per_proj[0] + (d_v - 1) / 2 - hmin) res = { "view_heights_minmax": view_heights_minmax, "z_pix_minmax": (z_min, z_max), "z_mm_minmax": (z_min * self.pix_size_mm, z_max * self.pix_size_mm), } return DictToObj(res) def get_informative_string(self): doable_span_v = self.get_doable_span() if self.z_pix_per_proj[-1] > self.z_pix_per_proj[-1]: direction = "ascending" else: direction = "descending" s = f""" Doable vertical span -------------------- The scan has been performed with an {direction} vertical translation of the rotation axis. The detector vertical axis is up side down. Detector reference system at iproj=0: from vertical view height ... {doable_span_v.view_heights_minmax[0]} up to (included) ... {doable_span_v.view_heights_minmax[1]} The slice that projects to the first line of the first projection corresponds to vertical heigth = 0 In voxels, the vertical doable span measures: {doable_span_v.z_pix_minmax[1] - doable_span_v.z_pix_minmax[0]} And in millimiters above the stage: from vertical height above stage ( mm units) ... {doable_span_v.z_mm_minmax[0] - self.z_offset_mm } up to (included) ... {doable_span_v.z_mm_minmax[1] - self.z_offset_mm } """ return s def get_chunk_info(self, span_v_absolute): """ This method returns an object containing all the informations that are needed to reconstruct the corresponding chunk angle_index_span: a pair of integers indicating the start and the end of useful angles in the array of all the scan angle self.projection_angles_deg span_v: a pair of two integers indicating the start and end of the span relatively to the lowest value of array self.total_view_heights integer_shift_v: an array, containing for each one of the useful projections of the span, the integer part of vertical shift to be used in cropping, fract_complement_to_integer_shift_v : the fractional remainder for cropping. z_pix_per_proj: an array, containing for each to be used projection of the span the vertical shift x_pix_per_proj: ....the horizontal shit angles_rad : an array, for each useful projection of the chunk the angle in radian Parameters: ----------- span_v_absolute: tuple of integers a pair of two integers the first view height ( referred to the detector y axis at iproj=0) the second view height with the first height smaller than the second. """ span_v = (span_v_absolute[0] - self.total_view_heights[0], span_v_absolute[1] - self.total_view_heights[0]) sunshine_subset = self.sunshine_image[span_v[0] : span_v[1]] angular_profile = sunshine_subset.sum(axis=0) angle_indexes = np.arange(len(self.projection_angles_deg_internally))[angular_profile > 0] angle_index_span = angle_indexes.min(), angle_indexes.max() + 1 if not (np.less(0, angular_profile[angle_index_span[0] : angle_index_span[1]]).all()): message = """ Something wrong occurred in the span preprocessing. It appears that some intermediate slices are not doables. The doable span should instead be contiguous. Please signal the problem""" raise RuntimeError(message) chunk_angles_deg = self.projection_angles_deg[angle_indexes] my_slicer = slice(angle_index_span[0], angle_index_span[1]) values = ( angle_index_span, span_v_absolute, self.integer_shift_v[my_slicer], self.fract_complement_to_integer_shift_v[my_slicer], self.z_pix_per_proj[my_slicer], self.x_pix_per_proj[my_slicer], np.deg2rad(chunk_angles_deg) * (1 - 2 * int(self.angles_are_inverted)), ) key_names = ( "angle_index_span", "span_v", "integer_shift_v", "fract_complement_to_integer_shift_v", "z_pix_per_proj", "x_pix_per_proj", "angles_rad", ) return DictToObj(dict(zip(key_names, values))) def _setup_ephemerides(self): """ A function which will set : * self.integer_shift_v * self.fract_complement_to_integer_shift_v * self.total_view_heights * self.on_detector_projection_intervals * self.within_margin_projection_intervals """ for i_pro in range(self.total_num_images): trans_v = self.z_pix_per_proj[i_pro] self.integer_shift_v[i_pro] = math.ceil(trans_v) self.fract_complement_to_integer_shift_v[i_pro] = math.ceil(trans_v) - trans_v ## The two following line initialize the view height, then considering the vertical translation # the filed of view will be expanded total_view_top = self.detector_shape_vh[0] total_view_bottom = 0 total_view_top = max(total_view_top, int(math.ceil(total_view_top + self.z_pix_per_proj.max()))) total_view_bottom = min(total_view_bottom, int(math.floor(total_view_bottom + self.z_pix_per_proj.min()))) self.total_view_heights = np.arange(total_view_bottom, total_view_top + 1) ## where possible only data from within safe phase margin will be considered. (within_margin) ## if it is enough for 360 or more degree. ## If it is not enough we'll complete wih data close to the border, provided that data comes from within detector ## This will be ## the case for the first and last doable slices. self.within_margin_projection_intervals = np.zeros(self.total_view_heights.shape + (2,), "i") self.on_detector_projection_intervals = np.zeros(self.total_view_heights.shape + (2,), "i") self.on_detector_projection_intervals[:, 1] = 0 # empty intervals self.within_margin_projection_intervals[:, 1] = 0 for i_h, height in enumerate(self.total_view_heights): previous_is_inside_detector = False previous_is_inside_margin = False pos_inside_filtered_v = height - self.integer_shift_v is_inside_detector = np.less_equal(0, pos_inside_filtered_v) is_inside_detector *= np.less( pos_inside_filtered_v, self.detector_shape_vh[0] - np.less(0, self.fract_complement_to_integer_shift_v) ) is_inside_margin = np.less_equal(self.phase_margin_pix, pos_inside_filtered_v) is_inside_margin *= np.less( pos_inside_filtered_v, self.detector_shape_vh[0] - self.phase_margin_pix - np.less(0, self.fract_complement_to_integer_shift_v), ) tmp = np.arange(self.total_num_images)[is_inside_detector] if len(tmp): self.on_detector_projection_intervals[i_h, :] = (tmp.min(), tmp.max()) tmp = np.arange(self.total_num_images)[is_inside_detector] if len(tmp): self.within_margin_projection_intervals[i_h, :] = (tmp.min(), tmp.max()) def _setup_sunshine(self): """It prepares * self.sunshine_image an image which for every height, contained in self.total_view_heights, and in the same order, contains the list of factors, one per every projection of the total list. Each factor is a backprojection weight. A non-doable height corresponds to a line full of zeros. A doable height must correspond to a line having one and only one segment of contiguous non zero elements. """ self.sunshine_image = np.zeros(self.total_view_heights.shape + (self.total_num_images,), "f") avg_angular_step_deg = np.diff(self.projection_angles_deg_internally).mean() projection_angles_for_interp = np.concatenate( [ [self.projection_angles_deg_internally[0] - avg_angular_step_deg], self.projection_angles_deg_internally, [self.projection_angles_deg_internally[-1] + avg_angular_step_deg], ] ) data_container_for_interp = np.zeros_like(projection_angles_for_interp) num_angular_periods = math.ceil( (self.projection_angles_deg_internally.max() - self.projection_angles_deg_internally.min()) / 360 ) for i_h, height in enumerate(self.total_view_heights): first_last_on_dect = self.on_detector_projection_intervals[i_h] first_last_within_margin = self.within_margin_projection_intervals[i_h] if first_last_on_dect[1] == 0: # this line never entered in the fov continue angle_on_dect_first_last = self.projection_angles_deg_internally[first_last_on_dect] angle_within_margin_first_last = self.projection_angles_deg_internally[first_last_within_margin] # a mask which is positive for angular positions for which the height i_h gets projected within the detector mask_on_dect = ( np.less_equal(angle_on_dect_first_last[0], self.projection_angles_deg_internally) * np.less_equal(self.projection_angles_deg_internally, angle_on_dect_first_last[1]) ).astype("f") # a mask which is positive for angular positions for which the height i_h gets projected within the margins mask_within_margin = ( np.less_equal(angle_within_margin_first_last[0], self.projection_angles_deg_internally) * np.less_equal(self.projection_angles_deg_internally, angle_within_margin_first_last[1]) ).astype("f") ## create a line which collects contributions from redundant angles detector_collector = np.zeros(self.projection_angles_deg_internally.shape, "f") margin_collector = np.zeros(self.projection_angles_deg_internally.shape, "f") # the following loop tracks, for each projection, the total weight available at the projection angle # The additional weight is coming from redundant angles. # In this sense the sunshine_image implements a first rudimentary reweighting, # which could be in principle used in the full pipeline, althought the regridded pipeline # implements a, better, reweighting o its own. for i_shift in range(-num_angular_periods, num_angular_periods + 1): signus = ( 1 if (self.projection_angles_deg_internally[-1] > self.projection_angles_deg_internally[0]) else -1 ) data_container_for_interp[1:-1] = mask_on_dect detector_collector = detector_collector + mask_on_dect * np.interp( (self.projection_angles_deg_internally + i_shift * 360) * signus, signus * projection_angles_for_interp, data_container_for_interp, left=0, right=0, ) data_container_for_interp[1:-1] = mask_within_margin margin_collector = margin_collector + mask_within_margin * np.interp( (self.projection_angles_deg_internally + i_shift * 360) * signus, signus * projection_angles_for_interp, data_container_for_interp, left=0, right=0, ) detector_shined_angles = self.projection_angles_deg_internally[detector_collector > 0.99] margin_shined_angles = self.projection_angles_deg_internally[margin_collector > 0.99] if not len(detector_shined_angles) > 1: continue avg_step_deg = abs(avg_angular_step_deg) if len(margin_shined_angles): angular_span_safe_margin = ( margin_shined_angles.max() - margin_shined_angles.min() + avg_step_deg * (1.01 + self.angular_tolerance_steps) ) else: angular_span_safe_margin = 0 angular_span_bare_border = ( detector_shined_angles.max() - detector_shined_angles.min() + avg_step_deg * (1.01 + self.angular_tolerance_steps) ) if not self.require_redundancy: if angular_span_safe_margin >= 360: self.sunshine_image[i_h] = margin_collector elif angular_span_bare_border >= 360: self.sunshine_image[i_h] = detector_collector else: redundancy_angle_deg = 360 if angular_span_safe_margin >= 360 and angular_span_safe_margin > 2 * ( redundancy_angle_deg + avg_step_deg ): self.sunshine_image[i_h] = margin_collector elif angular_span_bare_border >= 360 and angular_span_bare_border > 2 * ( redundancy_angle_deg + avg_step_deg ): self.sunshine_image[i_h] = detector_collector sunshine_mask = np.less(0.99, self.sunshine_image) self.sunshine_image[np.array([True]) ^ sunshine_mask] = 1.0 self.sunshine_image[:] = 1 / self.sunshine_image self.sunshine_image[np.array([True]) ^ sunshine_mask] = 0.0 shp = self.sunshine_image.shape X, Y = np.meshgrid(np.arange(shp[1]), np.arange(shp[0])) condition = self.sunshine_image > 0 self.sunshine_starts = X.min(axis=1, initial=shp[1], where=condition) self.sunshine_ends = X.max(axis=1, initial=0, where=condition) self.sunshine_ends[self.sunshine_ends > 0] += 1 ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4647331 nabu-2023.1.1/nabu/pipeline/helical/tests/0000755000175000017500000000000000000000000017521 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1656056444.0 nabu-2023.1.1/nabu/pipeline/helical/tests/__init__.py0000644000175000017500000000000000000000000021620 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/helical/tests/test_accumulator.py0000644000175000017500000001401700000000000023454 0ustar00pierrepierrefrom nabu.pipeline.helical import gridded_accumulator, span_strategy from nabu.testutils import get_data import pytest import numpy as np import os import h5py @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls # This is a helical dataset derived # from "crayon" dataset, using 5 slices and covering 2.5 x 360 angular span # in halftomo, with vertical translations. helical_dataset = get_data("small_sparse_helical_dataset.npz") helical_dataset = dict([item for item in helical_dataset.items()]) # the radios, in the dataset file, are stored by swapping angular and x dimension # so that the fast running dimension runs over the projections. # Due to the sparsity of the dataset, where only an handful of slices # has signal, this gives a much better compression even when the axis is translating # vertically helical_dataset["radios"] = np.array(np.swapaxes(helical_dataset["radios_transposed"], 0, 2)) del helical_dataset["radios_transposed"] # adding members: radios, dark, flats, z_pix_per_proj, x_pix_per_proj, projection_angles_deg, pixel_size_mm, phase_margin_pix, # weigth_field=weigth_field, double_flat dataset_keys = [ "dark", "flats", "z_pix_per_proj", "x_pix_per_proj", "projection_angles_deg", "pixel_size_mm", "phase_margin_pix", "weights_field", "double_flat", "rotation_axis_position", "detector_shape_vh", "result_inset", "radios", ] for key in dataset_keys: setattr(cls, key, helical_dataset[key]) cls.rtol_regridded = 1.0e-6 @pytest.mark.usefixtures("bootstrap") class TestGriddedAccumulator: """ Test the GriddedAccumulator. Rebuilds the sinogram for some selected slices of the crayon dataset """ def test_regridding(self): span_info = span_strategy.SpanStrategy( z_pix_per_proj=self.z_pix_per_proj, x_pix_per_proj=self.x_pix_per_proj, detector_shape_vh=self.detector_shape_vh, phase_margin_pix=self.phase_margin_pix, projection_angles_deg=self.projection_angles_deg, require_redundancy=True, pixel_size_mm=self.pixel_size_mm, logger=None, ) # I would like to reconstruct from feaseable height 15 to feaseable height 18 # relatively to the first doable slice in the vertical translation direction # I get the heights in the detector frame of the first and of the last reconstruction_space = gridded_accumulator.get_reconstruction_space( span_info=span_info, min_scanwise_z=15, end_scanwise_z=18, phase_margin_pix=self.phase_margin_pix ) chunk_info = span_info.get_chunk_info((reconstruction_space.my_z_min, reconstruction_space.my_z_end)) sub_region = ( reconstruction_space.my_z_min - self.phase_margin_pix, reconstruction_space.my_z_end + self.phase_margin_pix, ) ## useful projections proj_num_start, proj_num_end = chunk_info.angle_index_span # the first of the chunk angular range my_first_pnum = proj_num_start self.accumulator = gridded_accumulator.GriddedAccumulator( gridded_radios=reconstruction_space.gridded_radios, gridded_weights=reconstruction_space.gridded_cumulated_weights, diagnostic_radios=reconstruction_space.diagnostic_radios, diagnostic_weights=reconstruction_space.diagnostic_weights, diagnostic_angles=reconstruction_space.diagnostic_proj_angle, dark=self.dark, flat_indexes=[0, 7501], flats=self.flats, weights=self.weights_field, double_flat=self.double_flat, ) # splitting in sub ranges of 100 projections n_granularity = 100 pnum_start_list = list(np.arange(proj_num_start, proj_num_end, n_granularity)) pnum_end_list = pnum_start_list[1:] + [proj_num_end] for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list): start_in_chunk = pnum_start - my_first_pnum end_in_chunk = pnum_end - my_first_pnum self._read_data_and_apply_flats( slice(pnum_start, pnum_end), slice(start_in_chunk, end_in_chunk), chunk_info, sub_region, span_info ) res = reconstruction_space.gridded_radios / reconstruction_space.gridded_cumulated_weights # check only a sub part to avoid further increasing of the file on edna site errmax = np.max(np.abs(res[:200, 1, -500:] - self.result_inset) / np.max(res)) assert errmax < self.rtol_regridded, "Max error is too high" # uncomment this to see # h5py.File("processed_sinogram.h5","w")["sinogram"] = res def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, sub_region, span_info): my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] subr_start_z, subr_end_z = sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 dtasrc_start_z = max(0, subr_start_z_list.min()) dtasrc_end_z = min(span_info.detector_shape_vh[0], subr_end_z_list.max()) data_raw = self.radios[sub_total_prange_slice, slice(dtasrc_start_z, dtasrc_end_z), :] subsampling_file_slice = sub_total_prange_slice # my_subsampled_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] my_subsampled_indexes = (np.arange(10000))[subsampling_file_slice] self.accumulator.extract_preprocess_with_flats( subchunk_slice, my_subsampled_indexes, chunk_info, np.array((subr_start_z, subr_end_z), "i"), np.array((dtasrc_start_z, dtasrc_end_z), "i"), data_raw, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/pipeline/helical/tests/test_pipeline_elements_full.py0000644000175000017500000003435100000000000025663 0ustar00pierrepierrefrom nabu.pipeline.helical import gridded_accumulator, span_strategy from nabu.testutils import get_data, __do_long_tests__ import os import numpy as np import pytest from nabu.preproc.ccd import Log, CCDFilter from nabu.preproc.phase import PaganinPhaseRetrieval from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.pipeline.helical import gridded_accumulator, span_strategy from nabu.pipeline.helical.weight_balancer import WeightBalancer from nabu.pipeline.helical.helical_utils import find_mirror_indexes from nabu.cuda.utils import get_cuda_context, __has_pycuda__, __pycuda_error_msg__, replace_array_memory if __has_pycuda__: import pycuda.gpuarray as garray from nabu.pipeline.helical.fbp import BackprojectorHelical as FBPClass @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls # This is a helical dataset derived # from "crayon" dataset, using 5 slices and covering 2.5 x 360 angular span # in halftomo, with vertical translations. # >>> d=load("small_sparse_helical_dataset.npz") # dd=dict( d.items() ) # >>> dd["median_clip_threshold"] = 0.04 # >>> savez("retouched_test.npz",**dd) helical_dataset = get_data("small_sparse_helical_dataset.npz") helical_dataset = dict(list(helical_dataset.items())) # the radios, in the dataset file, are stored by swapping angular and x dimension # so that the fast running dimension runs over the projections. # Due to the sparsity of the dataset, where only an handful of slices # has signal, this gives a much better compression even when the axis is translating # vertically helical_dataset["radios"] = np.array(np.swapaxes(helical_dataset["radios_transposed"], 0, 2)) del helical_dataset["radios_transposed"] # adding members: radios, dark, flats, z_pix_per_proj, x_pix_per_proj, projection_angles_deg, # pixel_size_mm, phase_margin_pix, # weigth_field=weigth_field, double_flat dataset_keys = [ "dark", "flats", "z_pix_per_proj", "x_pix_per_proj", "projection_angles_deg", "pixel_size_mm", "phase_margin_pix", "weights_field", "double_flat", "rotation_axis_position", "detector_shape_vh", "result_inset", "radios", # further parameters, added on top of the test data which was originally made for gridded_accumulator "median_clip_threshold", "distance_m", "energy_kev", "delta_beta", "pixel_size_m", "padding_type", "phase_margin_for_pag", "rec_reference", "ref_tol", "ref_start", "ref_end", ] # the test dataset is the original one from the accumulator test # plus some patched metadeta information for phase retrieval. # The original dataset had phase_margin_pix and had 3 usefule slices. # Here, to test phase retrieval, we redefine the phase margin to the maximum # that we can do with such a small dataset helical_dataset["phase_margin_pix"] = helical_dataset["phase_margin_for_pag"] for key in dataset_keys: setattr(cls, key, helical_dataset[key]) cls.padding_type = str(cls.padding_type) cls.rotation_axis_position = float(cls.rotation_axis_position) cls.rtol_regridded = 1.0e-6 cls.projection_angles_rad = np.rad2deg(cls.projection_angles_deg) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") @pytest.mark.skipif(not (__has_pycuda__), reason="Needs pycuda for this test") @pytest.mark.usefixtures("bootstrap") class TestGriddedAccumulator: """ Test the GriddedAccumulator. Rebuilds the sinogram for some selected slices of the crayon dataset """ def test_regridding(self): span_info = span_strategy.SpanStrategy( z_pix_per_proj=self.z_pix_per_proj, x_pix_per_proj=self.x_pix_per_proj, detector_shape_vh=self.detector_shape_vh, phase_margin_pix=self.phase_margin_pix, projection_angles_deg=self.projection_angles_deg, require_redundancy=True, pixel_size_mm=self.pixel_size_mm, logger=None, ) # I would like to reconstruct from feaseable height 15 to feaseable height 18 # relatively to the first doable slice in the vertical translation direction # I get the heights in the detector frame of the first and of the last self.reconstruction_space = gridded_accumulator.get_reconstruction_space( span_info=span_info, min_scanwise_z=15, end_scanwise_z=18, phase_margin_pix=self.phase_margin_pix ) chunk_info = span_info.get_chunk_info((self.reconstruction_space.my_z_min, self.reconstruction_space.my_z_end)) sub_region = ( self.reconstruction_space.my_z_min - self.phase_margin_pix, self.reconstruction_space.my_z_end + self.phase_margin_pix, ) # useful projections proj_num_start, proj_num_end = chunk_info.angle_index_span # the first of the chunk angular range my_first_pnum = proj_num_start self.accumulator = gridded_accumulator.GriddedAccumulator( gridded_radios=self.reconstruction_space.gridded_radios, gridded_weights=self.reconstruction_space.gridded_cumulated_weights, diagnostic_radios=self.reconstruction_space.diagnostic_radios, diagnostic_weights=self.reconstruction_space.diagnostic_weights, diagnostic_angles=self.reconstruction_space.diagnostic_proj_angle, dark=self.dark, flat_indexes=[0, 7501], flats=self.flats, weights=self.weights_field, double_flat=self.double_flat, ) # splitting in sub ranges of 100 projections n_granularity = 100 pnum_start_list = list(np.arange(proj_num_start, proj_num_end, n_granularity)) pnum_end_list = pnum_start_list[1:] + [proj_num_end] for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list): start_in_chunk = pnum_start - my_first_pnum end_in_chunk = pnum_end - my_first_pnum self._read_data_and_apply_flats( slice(pnum_start, pnum_end), slice(start_in_chunk, end_in_chunk), chunk_info, sub_region, span_info ) res_flatfielded = self.reconstruction_space.gridded_radios / self.reconstruction_space.gridded_cumulated_weights # but in real pipeline the radio_shape is obtained from the pipeline get_shape utility method self._init_ccd_corrections(res_flatfielded.shape[1:]) # but in the actual pipeline the argument is not given, and the processed stack is the one internally # kept by the pipeline object ( self.gridded_radios in the pipeline ) self._ccd_corrections(res_flatfielded) self._init_phase(res_flatfielded.shape[1:]) processed_radios = self._retrieve_phase(res_flatfielded) self._init_mlog(processed_radios.shape) self._take_log(processed_radios) top_margin = -self.phase_margin_pix if self.phase_margin_pix else None processed_weights = self.reconstruction_space.gridded_cumulated_weights[ :, self.phase_margin_pix : top_margin, : ] self._init_weight_balancer() self._balance_weights(processed_weights) self._init_reconstructor(processed_radios.shape) i_slice = 0 self.d_radios_slim.set(processed_radios[:, i_slice, :]) self._filter() self._apply_weights(i_slice, processed_weights) res = self._reconstruct() test_slicing = slice(self.ref_start, self.ref_end) tested_inset = res[test_slicing, test_slicing] assert np.max(np.abs(tested_inset - self.rec_reference)) < self.ref_tol # uncomment the four following lines to get the slice image # import fabio # edf = fabio.edfimage.edfimage() # edf.data = res # edf.write("reconstructed_slice.edf") # put the test here def _reconstruct(self): axis_corrections = np.zeros_like(self.reconstruction_space.gridded_angles_rad) self.reconstruction.set_custom_angles_and_axis_corrections( self.reconstruction_space.gridded_angles_rad, axis_corrections ) self.reconstruction.backprojection(self.d_radios_slim, output=self.d_rec_res) self.d_rec_res.get(self.rec_res) return self.rec_res def _apply_weights(self, i_slice, weights): """d_radios_slim is on gpu""" n_provided_angles = self.d_radios_slim.shape[0] for first_angle_index in range(0, n_provided_angles, self.num_weight_radios_per_app): end_angle_index = min(n_provided_angles, first_angle_index + self.num_weight_radios_per_app) self._d_radios_weights[: end_angle_index - first_angle_index].set( weights[first_angle_index:end_angle_index, i_slice] ) self.d_radios_slim[first_angle_index:end_angle_index] *= self._d_radios_weights[ : end_angle_index - first_angle_index ] def _filter(self): self.mirror_angle_relative_indexes = find_mirror_indexes(self.reconstruction_space.gridded_angles_deg) self.reconstruction.sino_filter.filter_sino( self.d_radios_slim, mirror_indexes=self.mirror_angle_relative_indexes, rot_center=self.rotation_axis_position, output=self.d_radios_slim, ) def _init_reconstructor(self, processed_radios_shape): one_slice_data_shape = processed_radios_shape[:1] + processed_radios_shape[2:] self.d_radios_slim = garray.zeros(one_slice_data_shape, np.float32) # let's make room for loading chunck of weights without necessarily doubling the memory footprint. # The weights will be used to multiplied the d_radios_slim. # We will proceed by bunches self.num_weight_radios_per_app = 200 self._d_radios_weights = garray.zeros((self.num_weight_radios_per_app,) + one_slice_data_shape[1:], np.float32) pixel_size_cm = self.pixel_size_m * 100 radio_size_h = processed_radios_shape[-1] assert ( 2 * self.rotation_axis_position > radio_size_h ), """The code of this test is adapted for HA on the right. This seems to be a case of HA on the left because self.rotation_axis_position={self.rotation_axis_position} and radio_size_h = {radio_size_h} """ rec_dim = int(round(2 * self.rotation_axis_position)) self.reconstruction = FBPClass( one_slice_data_shape, angles=np.zeros(processed_radios_shape[0], "f"), rot_center=self.rotation_axis_position, filter_name=None, slice_roi=(0, rec_dim, 0, rec_dim), extra_options={ "scale_factor": 1.0 / pixel_size_cm, "axis_correction": np.zeros(processed_radios_shape[0], "f"), "padding_mode": "edge", }, ) self.reconstruction.fbp = self.reconstruction.backproj self.d_rec_res = garray.zeros((rec_dim, rec_dim), np.float32) self.rec_res = np.zeros((rec_dim, rec_dim), np.float32) def _init_weight_balancer(self): self.weight_balancer = WeightBalancer(self.rotation_axis_position, self.reconstruction_space.gridded_angles_rad) def _balance_weights(self, weights): self.weight_balancer.balance_weights(weights) def _retrieve_phase(self, radios): processed_radios = np.zeros( (radios.shape[0],) + (radios.shape[1] - 2 * self.phase_margin_pix,) + (radios.shape[2],), radios.dtype ) for i in range(radios.shape[0]): processed_radios[i] = self.phase_retrieval.apply_filter(radios[i]) return processed_radios def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, sub_region, span_info): my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] subr_start_z, subr_end_z = sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 dtasrc_start_z = max(0, subr_start_z_list.min()) dtasrc_end_z = min(span_info.detector_shape_vh[0], subr_end_z_list.max()) data_raw = self.radios[sub_total_prange_slice, slice(dtasrc_start_z, dtasrc_end_z), :] subsampling_file_slice = sub_total_prange_slice # my_subsampled_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] my_subsampled_indexes = (np.arange(10000))[subsampling_file_slice] self.accumulator.extract_preprocess_with_flats( subchunk_slice, my_subsampled_indexes, chunk_info, np.array((subr_start_z, subr_end_z), "i"), np.array((dtasrc_start_z, dtasrc_end_z), "i"), data_raw, ) def _init_ccd_corrections(self, radio_shape): # but in real pipeline the radio_shape is obtained from the pipeline get_shape utility method self.ccd_correction = CCDFilter(radio_shape, median_clip_thresh=self.median_clip_threshold) def _ccd_corrections(self, radios): _tmp_radio = np.empty_like(radios[0]) for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] def _take_log(self, radios): self.mlog.take_logarithm(radios) def _init_mlog(self, radios_shape): log_shape = radios_shape clip_min = 1.0e-6 clip_max = 1.1 self.mlog = Log(log_shape, clip_min=clip_min, clip_max=clip_max) def _init_phase(self, raw_shape): self.phase_retrieval = PaganinPhaseRetrieval( raw_shape, distance=self.distance_m, energy=self.energy_kev, delta_beta=self.delta_beta, pixel_size=self.pixel_size_m, padding=self.padding_type, margin=((self.phase_margin_pix,) * 2, (0, 0)), use_R2C=True, fftw_num_threads=True, # TODO tune in advanced params of nabu config file ) if self.phase_retrieval.use_fftw: self.logger.debug( "PaganinPhaseRetrieval using FFTW with %d threads" % self.phase_retrieval.fftw.num_threads ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/helical/tests/test_strategy.py0000644000175000017500000000524100000000000022776 0ustar00pierrepierreimport pytest import numpy as np from nabu.testutils import get_data as nabu_get_data from nabu.pipeline.helical.span_strategy import SpanStrategy @pytest.fixture(scope="class") def bootstrap_TestStrategy(request): cls = request.cls cls.abs_tol = 1.0e-6 # from the Paul telephone dataset test_data = nabu_get_data("data_test_strategy.npz") cls.z_pix_per_proj = test_data["z_pix_per_proj"] cls.x_pix_per_proj = test_data["x_pix_per_proj"] cls.detector_shape_vh = test_data["detector_shape_vh"] cls.phase_margin_pix = test_data["phase_margin_pix"] cls.projection_angles_deg = test_data["projection_angles_deg"] cls.require_redundancy = test_data["require_redundancy"] cls.pixel_size_mm = test_data["pixel_size_mm"] cls.result_angle_index_span = test_data["result_angle_index_span"] cls.result_angles_rad = test_data["result_angles_rad"] cls.result_fract_complement_to_integer_shift_v = test_data["result_fract_complement_to_integer_shift_v"] cls.result_integer_shift_v = test_data["result_integer_shift_v"] cls.result_span_v = test_data["result_span_v"] cls.result_x_pix_per_proj = test_data["result_x_pix_per_proj"] cls.result_z_pix_per_proj = test_data["result_z_pix_per_proj"] cls.test_data = test_data @pytest.mark.usefixtures("bootstrap_TestStrategy") class TestStrategy: def test_strategy(self): # the python implementation is slow. so we take only a p[art of the scan limit = 4000 span_info = SpanStrategy( z_pix_per_proj=self.z_pix_per_proj[:limit], x_pix_per_proj=self.x_pix_per_proj[:limit], detector_shape_vh=self.detector_shape_vh, phase_margin_pix=self.phase_margin_pix, projection_angles_deg=self.projection_angles_deg[:limit], pixel_size_mm=self.pixel_size_mm, require_redundancy=self.require_redundancy, ) print(span_info.get_informative_string()) chunk_info = span_info.get_chunk_info(self.result_span_v) for key, val in chunk_info.__dict__.items(): reference = getattr(self, "result_" + key) ref_array = np.array(reference) res_array = np.array(val) if res_array.dtype in [bool, np.int32, np.int64]: message = f" different result for {key} attribute in the chunk_info returned value " assert np.array_equal(res_array, ref_array), message elif res_array.dtype in [np.float32, np.float64]: message = f" different result for {key} attribute in the chunk_info returned value " assert np.all(np.isclose(res_array, ref_array, atol=self.abs_tol)), message ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/helical/utils.py0000644000175000017500000000302300000000000020067 0ustar00pierrepierrefrom ...utils import * from ...io.writer import Writers, NXProcessWriter from ...io.tiffwriter_zmm import TIFFWriter from ...resources.logger import LoggerOrPrint from ...resources.utils import is_hdf5_extension from os import path, mkdir from ...utils import check_supported from ..fallback_utils import WriterConfigurator from ..params import files_formats Writers["tif"] = TIFFWriter Writers["tiff"] = TIFFWriter class WriterConfiguratorHelical(WriterConfigurator): def _get_initial_writer_kwarg(self): if self.file_format in ["tif", "tiff"]: return {"heights_above_stage_mm": self.heights_above_stage_mm} else: return {} def __init__( self, output_dir, file_prefix, file_format="hdf5", overwrite=False, start_index=None, logger=None, nx_info=None, write_histogram=False, histogram_entry="entry", writer_options=None, extra_options=None, heights_above_stage_mm=None, ): self.heights_above_stage_mm = heights_above_stage_mm self.file_format = file_format super().__init__( output_dir, file_prefix, file_format=file_format, overwrite=overwrite, start_index=start_index, logger=logger, nx_info=nx_info, write_histogram=write_histogram, histogram_entry=histogram_entry, writer_options=writer_options, extra_options=extra_options, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/helical/weight_balancer.py0000644000175000017500000000761500000000000022060 0ustar00pierrepierreimport numpy as np import math class WeightBalancer: def __init__(self, rot_center, angles_rad): """This class contains the method for rebalancing the weight prior to backprojection. The weights of halfomo redundacy data ( the central part) are rebalanced. In a pipeline, the weights rebalanced by the method balance_weight, have to be applied to the ramp-filtered data prior to backprojection. As a matter of fact the method balance_weights could be called as a function, but in order to be conformant to Nabu, we create this class and follow the scheme initialisation + application. Parameters ========== rot_center : float the center of rotation in pixel units angles_rad : the angles corresponding to the to be rebalanced projections """ self.rot_center = rot_center self.my_angles_rad = angles_rad def balance_weights(self, radios_weights): """ The parameter radios_weights is a stack having having the same weight as the stack of projections. It is modified in place, correcting the value of overlapping data, so that the sum is always one """ radios_weights[:] = self._rebalance(radios_weights) def _rebalance(self, radios_weights): """rebalance the weights, within groups of equivalent (up to multiple of 180), data pixels""" balanced = np.zeros_like(radios_weights) n_span = int(math.ceil(self.my_angles_rad[-1] - self.my_angles_rad[0]) / np.pi) center = (radios_weights.shape[-1] - 1) / 2 nloop = balanced.shape[0] for i in range(nloop): w_res = balanced[i] angle = self.my_angles_rad[i] for i_half_turn in range(-n_span - 1, n_span + 2): if i_half_turn == 0: w_res[:] += radios_weights[i] continue shifted_angle = angle + i_half_turn * np.pi insertion_index = np.searchsorted(self.my_angles_rad, shifted_angle) if insertion_index in [0, self.my_angles_rad.shape[0]]: if insertion_index == 0: if abs(self.my_angles_rad[0] - shifted_angle) > np.pi / 100: continue myimage = radios_weights[0] else: if abs(self.my_angles_rad[-1] - shifted_angle) > np.pi / 100: continue myimage = radios_weights[-1] else: partition = shifted_angle - self.my_angles_rad[insertion_index - 1] myimage = (1.0 - partition) * radios_weights[insertion_index - 1] + partition * radios_weights[ insertion_index ] if i_half_turn % 2 == 0: w_res[:] += myimage else: myimage = np.fliplr(myimage) w_res[:] += shift(myimage, (2 * self.rot_center - 2 * center)) mask = np.equal(0, radios_weights) balanced[:] = radios_weights / balanced balanced[mask] = 0 return balanced def shift(arr, shift, fill_value=0.0): """trivial horizontal shift. Contrarily to scipy.ndimage.interpolation.shift, this shift does not cut the tails abruptly, but by interpolation """ result = np.zeros_like(arr) num1 = int(math.floor(shift)) num2 = num1 + 1 partition = shift - num1 for num, factor in zip([num1, num2], [(1 - partition), partition]): if num > 0: result[:, :num] += fill_value * factor result[:, num:] += arr[:, :-num] * factor elif num < 0: result[:, num:] += fill_value * factor result[:, :num] += arr[:, -num:] * factor else: result[:] += arr * factor return result ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/params.py0000644000175000017500000000533200000000000016616 0ustar00pierrepierreflatfield_modes = { "true": True, "1": True, "false": False, "0": False, "forced": "force-load", "force-load": "force-load", "force-compute": "force-compute", } phase_retrieval_methods = { "": None, "none": None, "paganin": "paganin", "tie": "paganin", "ctf": "CTF", } unsharp_methods = { "gaussian": "gaussian", "log": "log", "laplacian": "log", "imagej": "imagej", "none": None, "": None, } padding_modes = { "edges": "edge", "edge": "edge", "mirror": "mirror", "zeros": "zeros", "zero": "zeros", } reconstruction_methods = {"fbp": "FBP", "cone": "cone", "conic": "cone", "none": None, "": None} fbp_filters = { "ramlak": "ramlak", "ram-lak": "ramlak", "none": None, "": None, "shepp-logan": "shepp-logan", "cosine": "cosine", "hamming": "hamming", "hann": "hann", "tukey": "tukey", "lanczos": "lanczos", "hilbert": "hilbert", } iterative_methods = { "tv": "TV", "wavelets": "wavelets", "l2": "L2", "ls": "L2", "sirt": "SIRT", "em": "EM", } optim_algorithms = { "chambolle": "chambolle-pock", "chambollepock": "chambolle-pock", "fista": "fista", } files_formats = { "h5": "hdf5", "hdf5": "hdf5", "nexus": "hdf5", "nx": "hdf5", "npy": "npy", "npz": "npz", "tif": "tiff", "tiff": "tiff", "jp2": "jp2", "jp2k": "jp2", "j2k": "jp2", "jpeg2000": "jp2", "edf": "edf", "vol": "vol", } distribution_methods = { "local": "local", "slurm": "slurm", "": "local", "preview": "preview", } log_levels = { "0": "error", "1": "warning", "2": "info", "3": "debug", } sino_normalizations = { "none": None, "": None, "chebyshev": "chebyshev", "subtraction": "subtraction", "division": "division", } cor_methods = { "auto": "centered", "centered": "centered", "global": "global", "sliding-window": "sliding-window", "sliding window": "sliding-window", "growing-window": "growing-window", "growing window": "growing-window", "sino-coarse-to-fine": "sino-coarse-to-fine", "composite-coarse-to-fine": "composite-coarse-to-fine", "near": "composite-coarse-to-fine", } tilt_methods = { "1d-correlation": "1d-correlation", "1dcorrelation": "1d-correlation", "polarfft": "fft-polar", "polar-fft": "fft-polar", "fft-polar": "fft-polar", } rings_methods = { "none": None, "": None, "munch": "munch", } detector_distortion_correction_methods = {"none": None, "": None, "identity": "identity", "map_xz": "map_xz"} radios_rotation_mode = { "none": None, "": None, "chunk": "chunk", "chunks": "chunk", "full": "full", } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/processconfig.py0000644000175000017500000002053300000000000020177 0ustar00pierrepierreimport os from .config import parse_nabu_config_file from ..utils import deprecation_warning, is_writeable from ..resources.logger import Logger, PrinterLogger from .config import validate_config from ..resources.dataset_analyzer import analyze_dataset, _tomoscan_has_nxversion from .estimators import DetectorTiltEstimator class ProcessConfigBase: """ A class for describing the Nabu process configuration. """ # Must be overriden by inheriting class default_nabu_config = None config_renamed_keys = None def __init__( self, conf_fname=None, conf_dict=None, dataset_info=None, checks=True, remove_unused_radios=True, create_logger=False, ): """ Initialize a ProcessConfig class. Parameters ---------- conf_fname: str Path to the nabu configuration file. If provided, the parameters `conf_dict` is ignored. conf_dict: dict A dictionary describing the nabu processing steps. If provided, the parameter `conf_fname` is ignored. dataset_info: DatasetAnalyzer A `DatasetAnalyzer` class instance. checks: bool, optional, default is True Whether to perform checks on configuration and datasets (recommended !) remove_unused_radios: bool, optional, default is True Whether to remove unused radios, i.e radios present in the dataset, but not explicitly listed in the scan metadata. create_logger: str or bool, optional Whether to create a Logger object. Default is False, meaning that the logger object creation is left to the user. If set to True, a Logger object is created, and logs will be written to the file "nabu_dataset_name.log". If set to a string, a Logger object is created, and the logs will be written to the file specified by this string. """ # Step (1a): create 'nabu_config' self._parse_configuration(conf_fname, conf_dict) self._create_logger(create_logger) # Step (1b): create 'dataset_info' self._browse_dataset(dataset_info) # Step (2) self._update_dataset_info_with_user_config() # Step (3): estimate tilt, CoR, ... self._dataset_estimations() # Step (4) self._coupled_validation() # Step (5) self._build_processing_steps() # Step (6) self._configure_save_steps() self._configure_resume() def _create_logger(self, create_logger): if create_logger is False: self.logger = PrinterLogger() return elif create_logger is True: dataset_loc = self.nabu_config["dataset"]["location"] dataset_fname_rel = os.path.basename(dataset_loc) if os.path.isfile(dataset_loc): logger_filename = os.path.join( os.path.abspath(os.getcwd()), os.path.splitext(dataset_fname_rel)[0] + "_nabu.log" ) else: logger_filename = os.path.join(os.path.abspath(os.getcwd()), dataset_fname_rel + "_nabu.log") elif isinstance(create_logger, str): logger_filename = create_logger else: raise ValueError("Expected bool or str for create_logger") if not is_writeable(os.path.dirname(logger_filename)): self.logger = PrinterLogger() self.logger.error("Cannot create logger file %s: no permission to write therein" % logger_filename) else: self.logger = Logger("nabu", level=self.nabu_config["pipeline"]["verbosity"], logfile=logger_filename) def _parse_configuration(self, conf_fname, conf_dict): """ Parse the user configuration and builds a dictionary. Parameters ---------- conf_fname: str Path to the .conf file. Mutually exclusive with 'conf_dict' conf_dict: dict Dictionary with the configuration. Mutually exclusive with 'conf_fname' """ if not ((conf_fname is None) ^ (conf_dict is None)): raise ValueError("You must either provide 'conf_fname' or 'conf_dict'") if conf_fname is not None: if not os.path.isfile(conf_fname): raise ValueError("No such file: %s" % conf_fname) self.conf_fname = conf_fname self.conf_dict = parse_nabu_config_file(conf_fname) else: self.conf_dict = conf_dict if self.default_nabu_config is None or self.config_renamed_keys is None: raise ValueError( "'default_nabu_config' and 'config_renamed_keys' must be specified by classes inheriting from ProcessConfig" ) self.nabu_config = validate_config( self.conf_dict, self.default_nabu_config, self.config_renamed_keys, ) def _browse_dataset(self, dataset_info): """ Browse a dataset and builds a data structure with the relevant information. """ self.logger.debug("Browsing dataset") if dataset_info is not None: self.dataset_info = dataset_info else: extra_options = { "exclude_projections": self.nabu_config["dataset"]["exclude_projections"], "hdf5_entry": self.nabu_config["dataset"]["hdf5_entry"], } if _tomoscan_has_nxversion: # legacy - should become "if True" soon extra_options["nx_version"] = self.nabu_config["dataset"]["nexus_version"] else: self.logger.warning("Cannot use 'nx_version' for browsing dataset: need tomoscan > 0.6.0") self.dataset_info = analyze_dataset( self.nabu_config["dataset"]["location"], extra_options=extra_options, logger=self.logger ) def _update_dataset_info_with_user_config(self): """ Update the 'dataset_info' (DatasetAnalyzer class instance) data structure with options from user configuration. """ raise ValueError("Base class") def _get_rotation_axis_position(self): self.dataset_info.axis_position = self.nabu_config["reconstruction"]["rotation_axis_position"] def _update_rotation_angles(self): raise ValueError("Base class") def _dataset_estimations(self): """ Perform estimation of several parameters like center of rotation and detector tilt angle. """ self.logger.debug("Doing dataset estimations") self._get_tilt() self._get_cor() def _get_cor(self): raise ValueError("Base class") def _get_tilt(self): tilt = self.nabu_config["preproc"]["tilt_correction"] user_rot_projs = self.nabu_config["preproc"]["rotate_projections"] if user_rot_projs is not None and tilt is not None: msg = "=" * 80 + "\n" msg += ( "Both 'detector_tilt' and 'rotate_projections' options were provided. The option 'rotate_projections' will take precedence. This means that the projections will be rotated by %f degrees and the option 'detector_tilt' will be ignored." % user_rot_projs ) msg += "\n" + "=" * 80 self.logger.warning(msg) tilt = user_rot_projs # if isinstance(tilt, str): # auto-tilt self.tilt_estimator = DetectorTiltEstimator( self.dataset_info, logger=self.logger, autotilt_options=self.nabu_config["preproc"]["autotilt_options"] ) tilt = self.tilt_estimator.find_tilt(tilt_method=tilt) self.dataset_info.detector_tilt = tilt def _coupled_validation(self): """ Validate together the dataset information and user configuration. Update 'dataset_info' and 'nabu_config' """ raise ValueError("Base class") def _build_processing_steps(self): """ Build the processing steps, i.e a tuple (steps, options) where - steps is a list of str (list of processing steps names) - options is a dict with processing options """ raise ValueError("Base class") build_processing_steps = _build_processing_steps # COMPAT. def _configure_save_steps(self): raise ValueError("Base class") def _configure_resume(self): raise ValueError("Base class") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/pipeline/utils.py0000644000175000017500000000525200000000000016474 0ustar00pierrepierrefrom ..utils import deprecated_class # # Decorators and callback mechanism # def use_options(step_name, step_attr): def decorator(func): def wrapper(*args, **kwargs): self = args[0] if step_name not in self.processing_steps: self.__setattr__(step_attr, None) return self._steps_name2component[step_name] = step_attr self._steps_component2name[step_attr] = step_name return func(*args, **kwargs) return wrapper return decorator def pipeline_step(step_attr, step_desc): def decorator(func): def wrapper(*args, **kwargs): self = args[0] if self.__getattribute__(step_attr) is None: return self.logger.info(step_desc) res = func(*args, **kwargs) step_name = self._steps_component2name[step_attr] callbacks = self._callbacks.get(step_name, None) if callbacks is not None: for callback in callbacks: callback(self) if self.datadump_manager is not None and step_name in self.datadump_manager.data_dump: self.datadump_manager.dump_data_to_file(step_name, self.radios) return res return wrapper return decorator # # sub-region, shapes, etc # def get_subregion(sub_region, ndim=3): """ Return a "normalized" sub-region in the form ((start_z, end_z), (start_y, end_y), (start_x, end_x)). Parameters ---------- sub_region: tuple A tuple of tuples or tuple of integers. Notes ----- The parameter "sub_region" is normally a tuple of tuples of integers. However it can be more convenient to use tuple of integers. This function will attempt at catching the different cases, but will fail if 'sub_region' contains heterogeneous types (ex. tuples along with int) """ if sub_region is None: res = ((None, None),) elif hasattr(sub_region[0], "__iter__"): if set(map(len, sub_region)) != set([2]): raise ValueError("Expected each tuple to be in the form (start, end)") res = sub_region else: if len(sub_region) % 2: raise ValueError("Expected even number of elements") starts, ends = sub_region[::2], sub_region[1::2] res = tuple([(s, e) for s, e in zip(starts, ends)]) if len(res) != ndim: res += ((None, None),) * (ndim - len(res)) return res # # Writer - moved to pipeline.writer # from .writer import WriterManager WriterConfigurator = deprecated_class("WriterConfigurator moved to nabu.pipeline.writer.WriterManager", do_print=True)( WriterManager ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/pipeline/writer.py0000644000175000017500000002525000000000000016650 0ustar00pierrepierrefrom os import path from pathlib import Path as pathlib_Path from posixpath import join as posixjoin from silx.io.dictdump import dicttonx from tomoscan.esrf import HDF5Volume, TIFFVolume, MultiTIFFVolume, EDFVolume, JP2KVolume from ..resources.logger import LoggerOrPrint from ..io.writer import get_datetime, NXProcessWriter, HSTVolVolume from ..io.utils import convert_dict_values from .. import version as nabu_version from ..resources.utils import is_hdf5_extension # # There are still multiple issues: # - When using HDF5, we still have to do self.file_prefix += str("_%05d" % self.start_index) # because we are writing partial files. This should be done by tomoscan but it's likely incompatible # with its current architecture # - _configure_metadata() is too long, and somehow duplicates what is done in LegacyNXProcessWriter. # - LegacyNXProcessWriter still has to be used for writing 1D data (histogram) # # All in all, tomoscan.esrf.volume does not seem to make things simpler, at least for HDF5. # class WriterManager: """ This class is a wrapper on top of all "writers". It will create the right "writer" with all the necessary options, and the histogram writer. The layout is the following. For "single-file volume" formats (HDF5, big tiff): bigtiff: output_dir/file_prefix.tiff (bigtiff) hdf5: [output_dir/file_prefix/file_prefix_{%05d}.h5] (partial results) output_dir/file_prefix.h5 (master file) For "one file per slice" formats (tiff, jp2, edf) output_dir/file_prefix_%05d.{ext} When saving intermediate steps (eg. sinogram): HDF5 format is always used. So the layout is [output_dir/sinogram_file_prefix/sinogram_file_prefix_%05d.h5] (partial results) output_dir/sinogram_file_prefix.h5 (master file) """ _overwrite_warned = False def __init__( self, output_dir, file_prefix, file_format="hdf5", overwrite=False, start_index=0, logger=None, metadata=None, histogram=False, extra_options=None, ): """ Create a Writer from a set of parameters. Parameters ---------- output_dir: str Directory where the file(s) will be written. file_prefix: str File prefix (without leading path) start_index: int, optional Index to start the files numbering (filename_0123.ext). Default is 0. Ignored for HDF5 extension. logger: nabu.resources.logger.Logger, optional Logger object metadata: dict, optional Metadata, eg. information on various processing steps. For HDF5, it will go to "configuration" histogram: bool, optional Whether to also write a histogram of data. If set to True, it will configure an additional "writer". extra_options: dict, optional Other advanced options to pass to Writer class. """ self.overwrite = overwrite self.start_index = start_index self.logger = LoggerOrPrint(logger) self.histogram = histogram self.extra_options = extra_options or {} self.is_hdf5_output = is_hdf5_extension(file_format) self.is_bigtiff = file_format in ["tiff", "tif"] and any( [self.extra_options.get(opt, False) for opt in ["tiff_single_file", "use_bigtiff"]] ) self.is_vol = file_format == "vol" self.file_prefix = file_prefix self._set_output_dir(output_dir) self._set_file_name(file_format) self._configure_metadata(metadata) # tomoscan.esrf.volume arguments def _get_writer_kwargs_single_frame(): return { "folder": self.output_dir, "volume_basename": self.file_prefix, "start_index": self.start_index, "overwrite": self.overwrite, } def _get_writer_kwargs_multi_frames(): return { "file_path": self.fname, "overwrite": self.overwrite, } if self.is_hdf5_output: writer = HDF5Volume writer_kwargs = _get_writer_kwargs_multi_frames() writer_kwargs.update( { "data_path": posixjoin(self._h5_entry, self._h5_process_name), } ) elif file_format in ["tiff", "tif"]: if self.is_bigtiff: writer = MultiTIFFVolume writer_kwargs = _get_writer_kwargs_multi_frames() writer_kwargs.update({"append": self.extra_options.get("single_output_file_initialized", False)}) else: writer = TIFFVolume writer_kwargs = _get_writer_kwargs_single_frame() writer_kwargs.update( { "folder": self.output_dir, "volume_basename": self.file_prefix, "overwrite": True, } ) elif file_format == "vol": writer = HSTVolVolume writer_kwargs = _get_writer_kwargs_multi_frames() writer_kwargs.update({"append": self.extra_options.get("single_output_file_initialized", False)}) elif file_format == "edf": writer = EDFVolume writer_kwargs = _get_writer_kwargs_single_frame() elif file_format in ["jp2k", "j2k", "jp2", "jp2000", "jpeg2000"]: writer = JP2KVolume writer_kwargs = _get_writer_kwargs_single_frame() else: raise ValueError("Unsupported file format: %s" % file_format) self.writer = writer(**writer_kwargs) self._init_histogram_writer() def _set_output_dir(self, output_dir): # This class is generally used to create partial files, i.e files containing a subset of the processed volume. # In this case, the files containing partial results are stored in a sub-directory with the same file prefix. # Otherwise, everything is put in a single file (for now it's only the case for "big tiff"). self.is_partial_file = not (self.is_bigtiff or self.is_vol) if self.is_partial_file: output_dir = path.join(output_dir, self.file_prefix) self.output_dir = output_dir if path.exists(self.output_dir): if not (path.isdir(self.output_dir)): raise ValueError( "Unable to create directory %s: already exists and is not a directory" % self.output_dir ) else: self.logger.debug("Creating directory %s" % self.output_dir) pathlib_Path(self.output_dir).mkdir(parents=True, exist_ok=True) def _set_file_name(self, file_format): if self.is_hdf5_output: # HDF5Volume() does not have a "start_index" parameter, so we have to handle it here # (HDF5 files are partial files that are eventually merged into a master file, # so they have a _%05d suffix) self.file_prefix += str("_%05d" % self.start_index) self.file_format = file_format self.fname = path.join(self.output_dir, self.file_prefix + "." + file_format) if path.exists(self.fname): err = "File already exists: %s" % self.fname if self.overwrite: if not (self.__class__._overwrite_warned): self.logger.warning(err + ". It will be overwritten as requested in configuration") self.__class__._overwrite_warned = True else: self.logger.fatal(err) raise ValueError(err) def _configure_metadata(self, metadata): metadata = metadata or {} self.metadata = convert_dict_values(metadata, {None: "None"}) self._h5_entry = self.metadata.pop("entry", "entry") if self.is_hdf5_output: self.metadata.update({"@NX_class": "NXcollection"}) # should be done by tomoscan... self._h5_process_name = process_name = self.metadata.pop("process_name", "reconstruction") # Metadata in {entry}/reconstruction. Can be written now. nabu_process_info = { "@NX_class": "NXentry", "@default": f"{process_name}/results", f"{process_name}@NX_class": "NXprocess", f"{process_name}/program": "nabu", f"{process_name}/version": nabu_version, f"{process_name}/date": get_datetime(), f"{process_name}/sequence_index": self.metadata.pop("processing_index", 0), f"{process_name}@default": "results", } dicttonx( nabu_process_info, h5file=self.fname, h5path=self._h5_entry, update_mode="replace", mode="a", ) # Metadata in {entry}/reconstruction/results. Will be written after data. self._h5_results_metadata = { f"{process_name}/results@NX_class": "NXdata", f"{process_name}/results@signal": "data", # TODO "data_path" ? f"{process_name}/results@interpretation": "image", f"{process_name}/results/data@interpretation": "image", } def _init_histogram_writer(self): if not self.histogram: return separate_histogram_file = not (self.is_hdf5_output) if separate_histogram_file: fmode = "w" hist_fname = path.join(self.output_dir, "histogram_%05d.hdf5" % self.start_index) else: fmode = "a" hist_fname = self.fname # Nabu's original NXProcessWriter has to be used here, as histogram is not 3D self.histogram_writer = NXProcessWriter( hist_fname, entry=self._h5_entry, filemode=fmode, overwrite=True, ) def write_histogram(self, data, config=None, processing_index=1): if not (self.histogram): return self.histogram_writer.write( data, "histogram", processing_index=processing_index, config=config, is_frames_stack=False, direct_access=False, ) def _write_metadata(self): self.writer.metadata = self.metadata self.writer.save_metadata() if self.is_hdf5_output: dicttonx( self._h5_results_metadata, h5file=self.fname, h5path=self._h5_entry, update_mode="replace", mode="a", ) def write_data(self, data): self.writer.data = data self.writer.save() self._write_metadata() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4647331 nabu-2023.1.1/nabu/pipeline/xrdct/0000755000175000017500000000000000000000000016102 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1628752049.0 nabu-2023.1.1/nabu/pipeline/xrdct/__init__.py0000644000175000017500000000000000000000000020201 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4647331 nabu-2023.1.1/nabu/preproc/0000755000175000017500000000000000000000000014623 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/__init__.py0000644000175000017500000000043400000000000016735 0ustar00pierrepierrefrom .ccd import CCDFilter, Log from .ctf import CTFPhaseRetrieval from .distortion import DistortionCorrection from .double_flatfield import DoubleFlatField from .flatfield import FlatField, FlatFieldDataUrls from .phase import PaganinPhaseRetrieval from .shift import VerticalShift ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/alignment.py0000644000175000017500000000057300000000000017160 0ustar00pierrepierre# Backward compat. from ..estimation.alignment import AlignmentBase from ..estimation.cor import ( CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationGrowingWindow, CenterOfRotationSlidingWindow, ) from ..estimation.translation import DetectorTranslationAlongBeam from ..estimation.focus import CameraFocus from ..estimation.tilt import CameraTilt ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/ccd.py0000644000175000017500000001203000000000000015722 0ustar00pierrepierreimport numpy as np from ..utils import check_supported from scipy.ndimage import median_filter class CCDFilter: """ Filtering applied on radios. """ _supported_ccd_corrections = ["median_clip"] def __init__( self, radios_shape: tuple, correction_type: str = "median_clip", median_clip_thresh: float = 0.1, abs_diff=False, preserve_borders=False, ): """ Initialize a CCDCorrection instance. Parameters ----------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. correction_type: str Correction type for radios ("median_clip", "sigma_clip", ...) median_clip_thresh: float, optional Threshold for the median clipping method. abs_diff: boolean by default False: the correction is triggered when img - median > threshold. If equals True: correction is triggered for abs(img-media) > threshold preserve borders: boolean by default False: If equals True: the borders (width=1) are not modified. Notes ------ A CCD correction is a process (usually filtering) taking place in the radios space. Available filters: - median_clip: if the value of the current pixel exceeds the median of adjacent pixels (a 3x3 neighborhood) more than a threshold, then this pixel value is set to the median value. """ self._set_radios_shape(radios_shape) check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode") self.correction_type = correction_type self.median_clip_thresh = median_clip_thresh self.abs_diff = abs_diff self.preserve_borders = preserve_borders def _set_radios_shape(self, radios_shape): if len(radios_shape) == 2: self.radios_shape = (1,) + radios_shape elif len(radios_shape) == 3: self.radios_shape = radios_shape else: raise ValueError("Expected radios to have 2 or 3 dimensions") n_radios, n_z, n_x = self.radios_shape self.n_radios = n_radios self.n_angles = n_radios self.shape = (n_z, n_x) @staticmethod def median_filter(img): """ Perform a median filtering on an image. """ return median_filter(img, (3, 3), mode="reflect") def median_clip_mask(self, img, return_medians=False): """ Compute a mask indicating whether a pixel is valid or not, according to the median-clip method. Parameters ---------- img: numpy.ndarray Input image return_medians: bool, optional Whether to return the median values additionally to the mask. """ median_values = self.median_filter(img) if not self.abs_diff: invalid_mask = img >= median_values + self.median_clip_thresh else: invalid_mask = abs(img - median_values) > self.median_clip_thresh if return_medians: return invalid_mask, median_values else: return invalid_mask def median_clip_correction(self, radio, output=None): """ Compute the median clip correction on one image. Parameters ---------- radios: numpy.ndarray, optional A radio image. output: numpy.ndarray, optional Output array """ assert radio.shape == self.shape if output is None: output = np.copy(radio) else: output[:] = radio[:] invalid_mask, medians = self.median_clip_mask(radio, return_medians=True) if self.preserve_borders: fixed_border = np.array(radio[[0, 0, -1, -1], [0, -1, 0, -1]]) output[invalid_mask] = medians[invalid_mask] if self.preserve_borders: output[[0, 0, -1, -1], [0, -1, 0, -1]] = fixed_border return output class Log: """ Helper class to take -log(radios) Parameters ----------- clip_min: float, optional Before taking the logarithm, the values are clipped to this minimum. clip_max: float, optional Before taking the logarithm, the values are clipped to this maximum. """ def __init__(self, radios_shape, clip_min=None, clip_max=None): self.radios_shape = radios_shape self.clip_min = clip_min self.clip_max = clip_max def take_logarithm(self, radios): """ Take the negative logarithm of a radios chunk. Processing is done in-place ! Parameters ----------- radios: array Radios chunk. """ if (self.clip_min is not None) or (self.clip_max is not None): np.clip(radios, self.clip_min, self.clip_max, out=radios) np.log(radios, out=radios) else: np.log(radios, out=radios) radios[:] *= -1 return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/ccd_cuda.py0000644000175000017500000001365600000000000016735 0ustar00pierrepierrefrom typing import Union import numpy as np from ..preproc.ccd import CCDFilter, Log from ..cuda.kernel import CudaKernel from ..cuda.medfilt import MedianFilter from ..utils import get_cuda_srcfile, updiv, deprecated_class # COMPAT. from .flatfield_cuda import ( CudaFlatField as CudaFlatfield_, CudaFlatFieldArrays as CudaFlatFieldArrays_, CudaFlatFieldDataUrls as CudaFlatFieldDataUrls_, ) FlatField = deprecated_class( "preproc.ccd_cuda.CudaFlatField was moved to preproc.flatfield_cuda.CudaFlatField", do_print=True )(CudaFlatfield_) FlatFieldArrays = deprecated_class( "preproc.ccd_cuda.CudaFlatFieldArrays was moved to preproc.flatfield_cuda.CudaFlatFieldArrays", do_print=True )(CudaFlatFieldArrays_) FlatFieldDataUrls = deprecated_class( "preproc.ccd_cuda.CudaFlatFieldDataUrls was moved to preproc.flatfield_cuda.CudaFlatFieldDataUrls", do_print=True )(CudaFlatFieldDataUrls_) # class CudaCCDFilter(CCDFilter): def __init__( self, radios_shape: tuple, correction_type: str = "median_clip", median_clip_thresh: float = 0.1, abs_diff=False, cuda_options: Union[dict, None] = None, ): """ Initialize a CudaCCDCorrection instance. Please refer to the documentation of CCDCorrection. """ super().__init__( radios_shape, correction_type=correction_type, median_clip_thresh=median_clip_thresh, ) self._set_cuda_options(cuda_options) self.cuda_median_filter = None if correction_type == "median_clip": self.cuda_median_filter = MedianFilter( self.shape, footprint=(3, 3), mode="reflect", threshold=median_clip_thresh, abs_diff=abs_diff, cuda_options={ "device_id": self.cuda_options["device_id"], "ctx": self.cuda_options["ctx"], "cleanup_at_exit": self.cuda_options["cleanup_at_exit"], }, ) def _set_cuda_options(self, user_cuda_options): self.cuda_options = {"device_id": None, "ctx": None, "cleanup_at_exit": None} if user_cuda_options is None: user_cuda_options = {} self.cuda_options.update(user_cuda_options) def median_clip_correction(self, radio, output=None): """ Compute the median clip correction on one image. Parameters ---------- radio: pycuda.gpuarray A radio image output: pycuda.gpuarray, optional Output data. """ assert radio.shape == self.shape return self.cuda_median_filter.medfilt2(radio, output=output) CudaCCDCorrection = deprecated_class("CudaCCDCorrection is replaced with CudaCCDFilter", do_print=True)(CudaCCDFilter) class CudaLog(Log): """ Helper class to take -log(radios) """ def __init__(self, radios_shape, clip_min=None, clip_max=None): """ Initialize a Log processing. Parameters ----------- radios_shape: tuple The shape of 3D radios stack. clip_min: float, optional Data smaller than this value is replaced by this value. clip_max: float, optional. Data bigger than this value is replaced by this value. """ super().__init__(radios_shape, clip_min=clip_min, clip_max=clip_max) self._init_kernels() def _init_kernels(self): self._do_clip_min = int(self.clip_min is not None) self._do_clip_max = int(self.clip_max is not None) self.clip_min = np.float32(self.clip_min or 0) self.clip_max = np.float32(self.clip_max or 1) self._nlog_srcfile = get_cuda_srcfile("ElementOp.cu") nz, ny, nx = self.radios_shape self._nx = np.int32(nx) self._ny = np.int32(ny) self._nz = np.int32(nz) self._nthreadsperblock = (16, 16, 4) # TODO tune ? self._nblocks = tuple([updiv(n, p) for n, p in zip([nx, ny, nz], self._nthreadsperblock)]) self.nlog_kernel = CudaKernel( "nlog", filename=self._nlog_srcfile, signature="Piiiff", options=[ "-DDO_CLIP_MIN=%d" % self._do_clip_min, "-DDO_CLIP_MAX=%d" % self._do_clip_max, ], ) def take_logarithm(self, radios, clip_min=None, clip_max=None): """ Take the negative logarithm of a radios chunk. Parameters ----------- radios: `pycuda.gpuarray.GPUArray` Radios chunk If not provided, a new GPU array is created. clip_min: float, optional Before taking the logarithm, the values are clipped to this minimum. clip_max: float, optional Before taking the logarithm, the values are clipped to this maximum. """ clip_min = clip_min or self.clip_min clip_max = clip_max or self.clip_max if radios.flags.c_contiguous: self.nlog_kernel( radios, self._nx, self._ny, self._nz, clip_min, clip_max, grid=self._nblocks, block=self._nthreadsperblock, ) else: # map-like operations cannot be directly applied on 3D arrays # that are not C-contiguous. We have to process image per image. # TODO it's even worse when each single frame is not C-contiguous. For now this case is not handled nz = np.int32(1) nthreadsperblock = (32, 32, 1) nblocks = tuple([updiv(n, p) for n, p in zip([int(self._nx), int(self._ny), int(nz)], nthreadsperblock)]) for i in range(radios.shape[0]): self.nlog_kernel( radios[i], self._nx, self._ny, nz, clip_min, clip_max, grid=nblocks, block=nthreadsperblock ) return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/preproc/ctf.py0000644000175000017500000003505300000000000015757 0ustar00pierrepierreimport math import numpy as np from ..resources.logger import LoggerOrPrint from ..misc import fourier_filters from ..misc.padding import pad_interpolate, recut from ..utils import get_num_threads class GeoPars: """ A class to describe the geometry of a phase contrast radiography with a source obtained by a focussing system, possibly astigmatic, which is at distance z1_vh from the sample. The detector is at z2 from the sample """ def __init__( self, z1_vh=None, z2=None, pix_size_det=None, wavelength=None, magnification=True, length_scale=10.0e-6, logger=None, ): """ Parameters ---------- z1_vh : None, a float, or a sequence of two floats the source sample distance (meters), if None the parallel beam is assumed. If two floats are given then they are taken as the distance of the vertically focused source (horizontal line) and the horizontaly focused source (vertical line) for KB mirrors. z2 : float the sample detector distance (meters). pix_size_det: float pixel size (meters) wavelength: float beam wave length (meters). magnification: boolean defaults to True if false no magnification is considered length_scale: float rescaling length scale, meant to avoid having too big or too small numbers. defaults to 10.0e-6 logger: Logger, optional A logging object """ self.logger = LoggerOrPrint(logger) if z1_vh is None: self.z1_vh = None else: if hasattr(type(z1_vh), "__iter__"): self.z1_vh = np.array(z1_vh) else: self.z1_vh = np.array([z1_vh, z1_vh]) self.z2 = z2 self.magnification = magnification self.pix_size_det = pix_size_det if self.magnification and self.z1_vh is not None: self.M_vh = (self.z1_vh + self.z2) / self.z1_vh else: self.M_vh = np.array([1, 1]) self.logger.debug("Magnification : h ({}) ; v ({}) ".format(self.M_vh[1], self.M_vh[0])) self.length_scale = length_scale self.wavelength = wavelength self.maxM = self.M_vh.max() self.pix_size_rec = self.pix_size_det / self.maxM # we bring everything to highest magnification which_unit = int(np.sum(np.array([self.pix_size_rec > small for small in [1.0e-6, 1.0e-7]]).astype(np.int32))) self.pixelsize_string = [ "{:.1f} nm".format(self.pix_size_rec * 1e9), "{:.3f} um".format(self.pix_size_rec * 1e6), "{:.1f} um".format(self.pix_size_rec * 1e6), ][which_unit] if self.magnification: self.logger.debug( "All images are resampled to smallest pixelsize: {}".format(self.pixelsize_string), ) else: self.logger.debug("Pixelsize images: {}".format(self.pixelsize_string)) class CTFPhaseRetrieval: """ This class implements the CTF formula of [1] in its regularised form which avoids the zeros of unregularised_filter_denominator (unreg_filter_denom is the so here named denominator/delta_beta of equation 8). References ----------- [1] B. Yu, L. Weber, A. Pacureanu, M. Langer, C. Olivier, P. Cloetens, and F. Peyrin, "Evaluation of phase retrieval approaches in magnified X-ray phase nano computerized tomography applied to bone tissue", Optics Express, Vol 26, No 9, 11110-11124 (2018) """ def __init__( self, shape, geo_pars, delta_beta, padded_shape="auto", padding_mode="reflect", translation_vh=None, normalize_by_mean=False, lim1=1.0e-5, lim2=0.2, use_rfft=False, fftw_num_threads=None, logger=None, ): """ Initialize a Contrast Transfer Function phase retrieval. Parameters ---------- geo_pars: GeoPars the geometry description delta_beta : float the delta/beta ratio padded_shape: str or tuple, optional Padded image shape, in the form (num_rows, num_columns) i.e (vertical, horizontal). By default, it is twice the image shape. padding_mode: str Padding mode. It must be valid for the numpy.pad function translation_vh: array, optional Shift in the form (y, x). It is used to perform a translation of the image before applying the CTF filter. normalize_by_mean: bool Whether to divide the (padded) image with its mean before applying the CTF filter. lim1: float >0 the regulariser strenght at low frequencies lim2: float >0 the regulariser strenght at high frequencies use_rfft: bool, optional Whether to use real-to-complex (R2C) FFT instead of usual complex-to-complex (C2C). fftw_num_threads: bool or None or int, optional If False is passed: don't use FFTW. If None is passed: use all available threads. If a number is provided: number of threads to use for FFTW. You can pass a negative number to use N - fftw_num_threads cores. logger: optional a logger object """ self.logger = LoggerOrPrint(logger) if not isinstance(geo_pars, GeoPars): raise ValueError("Expected GeoPars instance for 'geo_pars' parameter") self.geo_pars = geo_pars self._calc_shape(shape, padded_shape, padding_mode) self.delta_beta = delta_beta self.lim = None self.lim1 = lim1 self.lim2 = lim2 self.normalize_by_mean = normalize_by_mean self.translation_vh = translation_vh self._setup_fft(use_rfft, fftw_num_threads) self._get_ctf_filter() def _calc_shape(self, shape, padded_shape, padding_mode): if np.isscalar(shape): shape = (shape, shape) else: assert len(shape) == 2 self.shape = shape if padded_shape is None or padded_shape is False: padded_shape = self.shape # no padding elif isinstance(padded_shape, (tuple, list, np.ndarray)): pass elif padded_shape == "auto": padded_shape = (2 * self.shape[0], 2 * self.shape[1]) self.shape_padded = tuple(padded_shape) self.padding_mode = padding_mode def _setup_fft(self, use_rfft, fftw_num_threads): self.use_rfft = use_rfft self._fft_func = np.fft.rfft2 if use_rfft else np.fft.fft2 self._ifft_func = np.fft.irfft2 if use_rfft else np.fft.ifft2 self.use_fftw = False if fftw_num_threads is False: return fftw_num_threads = get_num_threads(fftw_num_threads) if self.use_rfft and (fftw_num_threads > 0): # importing silx.math.fft creates opencl contexts all over the place # because of the silx.opencl.ocl singleton. # So, import silx as late as possible from silx.math.fft.fftw import FFTW, __have_fftw__ if __have_fftw__: self.use_fftw = True self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads) self._fft_func = self.fftw.fft self._ifft_func = self.fftw.ifft def _get_ctf_filter(self): """ The parameter "length_scale" was mentioned, in the octave code, as a rescaling length scale, which is meant to avoid having too big or too small numbers. From the mathematical point of view, it is infact completely trasparent: its action is on fsamplex, fsampley and betash, betasv. But these latters ( beta's) are multiplied by the formers (fsample's) so that "length_scale" mathematically disappears, however in case of simple precision float the exponent of a float number ranges from -38 to 38, and one could approach it as an example by taking the square of a very small number ( 1.0e-19), and losing significant bits in he mantissa or getting zero, or the square of a big number, thus generting inf Althought the values involved in our x-ray regimes seems safe, with respect to these problems, this length_scale parameters does not hurt. """ padded_img_shape = self.shape_padded fsample_vh = np.array( [ self.geo_pars.length_scale / self.geo_pars.pix_size_rec, self.geo_pars.length_scale / self.geo_pars.pix_size_rec, ] ) if not self.use_rfft: ff_index_vh = list(map(np.fft.fftfreq, padded_img_shape)) else: ff_index_vh = [np.fft.fftfreq(padded_img_shape[0]), np.fft.rfftfreq(padded_img_shape[1])] # if padded_img_shape[1]%2 == 0 : # change to holotomo_slave indexing (by a transparent 2pi shift) # ff_index_x[ ff_index_x == -0.5 ] = +0.5 # if padded_img_shape[0]%2 == 0 : # change to holotomo_slave indexing (by a transparent 2pi shift) # ff_index_y[ ff_index_y == -0.5 ] = +0.5 frequencies_vh = np.array( np.meshgrid(ff_index_vh[0] * fsample_vh[0], ff_index_vh[1] * fsample_vh[1], indexing="ij") ) frequencies_squared_vh = frequencies_vh * frequencies_vh """ --------------- fresnelnumbers and forward propagators ------------------- In the limit of parallel beam, z1_h and z1_v would be infinite so that the here below distances becomes z2 which is sample-detector distance. """ if self.geo_pars.z1_vh is not None: distances_vh = (self.geo_pars.z1_vh * self.geo_pars.z2) / (self.geo_pars.z1_vh + self.geo_pars.z2) else: distances_vh = np.array([self.geo_pars.z2, self.geo_pars.z2]) """ Citing David Paganin (2002) : The intensity I_{R_1}(r⊥,z) at a distance z of a weakly refracting object illuminated by a point source at distance R1 behind the said object, is related to the intensity I∞(r⊥,z), which would result from normally incident collimated illumination of the same object, by (Pogany et al., 1997): I_{R_1}(r⊥,z) = 1/{M^2} I∞(r⊥/M , z/M) where M is the magnification. This explains the effective distance formula expressed by distancesh, distancesv above. ------------------------------------------------------------------------------ """ lambda_dist_vh = self.geo_pars.wavelength * distances_vh / (self.geo_pars.length_scale**2) """ --------------------------------------------------------------------------------------- -> cut_v at first maximum of ctf largest distance In the paraxial expansion of the Fresnel propagator, the phase is equal to 1/2 K_{parallel}^2 wavelength * Distance /2 / pi When this is equal to a multiple of 2 pi, the effect of propagation disappears and we have a singularity in equation 8. The sampling in the reciprocal space is done with a step length of 2*pi*fsamplex,y (note: fsamples are the plain inverse of pixel size) The first singularity occurs at a frequence number K_{parallel}/2/pi/ fsamplex for K_{parallel}^2 = 2 * (2pi)^2 /( wavelength * distance ) which would correspond to K_{parallel}/2/pi/ fsamplex = sqrt( 2/wavelength*distance) / fsample Question:( why the factor 2 appear at the denominator in the square root below?) Answer : the below defined cut corresponds to the first maximum of the denominator, before arriving to the first zero. In this way the regularisation is already at ( almost) full strenght on the first pole.ation ## is already at ( almost) full strenght on the first pole. """ self.cut_v = math.sqrt(1.0 / 2 / lambda_dist_vh[0]) / fsample_vh[0] self.cut_v = min(self.cut_v, 0.5) self.logger.debug("Normalized cut-off = {:5.3f}".format(self.cut_v)) self.r = fourier_filters.get_lowpass_filter( padded_img_shape, cutoff_par=( 0.5 / (self.cut_v + 1.0 / padded_img_shape[0]), 0.01 * padded_img_shape[0] / (1 + self.cut_v * padded_img_shape[0]), ), use_rfft=self.use_rfft, ) self.r /= self.r[0, 0] self.lim = self.lim1 * self.r + self.lim2 * (1 - self.r) # more methods exist in the original code, and they are initialized starting from here # (ht_app1, ht_app2... ht_app7) fresnel_phase = ( np.pi * lambda_dist_vh[1] * frequencies_squared_vh[1] + np.pi * lambda_dist_vh[0] * frequencies_squared_vh[0] ) if self.delta_beta: unreg_filter_denom = np.sin(fresnel_phase) + (1.0 / self.delta_beta) * np.cos(fresnel_phase) else: unreg_filter_denom = np.sin(fresnel_phase) self.unreg_filter_denom = unreg_filter_denom.astype(np.float32) self._ctf_filter_denom = (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim).astype(np.complex64) def _apply_filter(self, img): img_f = self._fft_func(img) img_f *= self.unreg_filter_denom unreg_filter_denom_0_mean = self.unreg_filter_denom[0, 0] nf, mf = img.shape # here it is assumed that the average of img is 1 and the DC component is removed img_f[0, 0] -= nf * mf * unreg_filter_denom_0_mean ## formula 8, with regularisation to stay at a safe distance from the poles img_f /= self._ctf_filter_denom ph = self._ifft_func(img_f).real return ph def retrieve_phase(self, img, output=None): """ Apply the CTF filter to retrieve the phase. Parameters ---------- img: np.ndarray Projection image. It must have been already flat-fielded. Returns -------- ph: numpy.ndarray Phase image """ padded_img = pad_interpolate( img, self.shape_padded, translation_vh=self.translation_vh, padding_mode=self.padding_mode ) if self.normalize_by_mean: padded_img /= padded_img.mean() phase_img = self._apply_filter(padded_img) res = recut(phase_img, img.shape) if output is not None: output[:, :] = res[:, :] return output return res __call__ = retrieve_phase CtfFilter = CTFPhaseRetrieval ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/preproc/ctf_cuda.py0000644000175000017500000001155400000000000016753 0ustar00pierrepierreimport numpy as np from pycuda import gpuarray as garray from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile from ..cuda.processing import CudaProcessing from ..cuda.kernel import CudaKernel from ..cuda.padding import CudaPadding from .phase_cuda import CudaPaganinPhaseRetrieval from .ctf import CTFPhaseRetrieval # TODO: # - better padding scheme (for now 2*shape) # - rework inheritance scheme ? (base class SingleDistancePhaseRetrieval and its cuda counterpart) class CudaCTFPhaseRetrieval(CTFPhaseRetrieval): """ Cuda back-end of CTFPhaseRetrieval """ def __init__( self, shape, geo_pars, delta_beta, padded_shape="auto", padding_mode="reflect", translation_vh=None, normalize_by_mean=False, lim1=1.0e-5, lim2=0.2, use_rfft=True, fftw_num_threads=None, logger=None, cuda_options=None, ): """ Initialize a CudaCTFPhaseRetrieval. Parameters ---------- shape: tuple Shape of the images to process padding_mode: str Padding mode. Default is "reflect". Other parameters ----------------- Please refer to CTFPhaseRetrieval documentation. """ if not use_rfft: raise ValueError("Only use_rfft=True is supported") self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, geo_pars, delta_beta, padded_shape=padded_shape, padding_mode=padding_mode, translation_vh=translation_vh, normalize_by_mean=normalize_by_mean, lim1=lim1, lim2=lim2, logger=logger, use_rfft=True, fftw_num_threads=False, ) self._init_ctf_filter() self._init_cuda_padding() self._init_fft() self._init_mult_kernel() def _init_ctf_filter(self): self._mean_scale_factor = self.unreg_filter_denom[0, 0] * np.prod(self.shape_padded) self._d_filter_num = garray.to_gpu(self.unreg_filter_denom).astype("f") self._d_filter_denom = garray.to_gpu( (1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f") ) def _init_cuda_padding(self): pad_width = calc_padding_lengths(self.shape, self.shape_padded) # Custom coordinate transform to get directly FFT layout R, C = np.indices(self.shape, dtype=np.int32) coords_R = np.roll( np.pad(R, pad_width, mode=self.padding_mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1) ) coords_C = np.roll( np.pad(C, pad_width, mode=self.padding_mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1) ) self.cuda_padding = CudaPadding( self.shape, (coords_R, coords_C), mode=self.padding_mode, # propagate cuda options ? ) def _init_fft(self): # Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import from silx.math.fft.cufft import CUFFT self.cufft = CUFFT(template=np.zeros(self.shape_padded, dtype="f")) self.d_radio_padded = self.cufft.data_in self.d_radio_f = self.cufft.data_out def _init_mult_kernel(self): self.cpxmult_kernel = CudaKernel( "CTF_kernel", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPPfii", ) Nx = np.int32(self.shape_padded[1] // 2 + 1) Ny = np.int32(self.shape_padded[0]) self._cpxmult_kernel_args = [ self.d_radio_f, self._d_filter_num, self._d_filter_denom, np.float32(self._mean_scale_factor), Nx, Ny, ] blk = (32, 32, 1) grd = (updiv(Nx, blk[0]), updiv(Ny, blk[1])) self._cpxmult_kernel_kwargs = {"grid": grd, "block": blk} set_input = CudaPaganinPhaseRetrieval.set_input def retrieve_phase(self, image, output=None): """ Perform padding on an image. Please see the documentation of CTFPhaseRetrieval.retrieve_phase(). """ self.set_input(image) self.cuda_padding.pad(image, output=self.d_radio_padded) if self.normalize_by_mean: m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded) self.d_radio_padded /= m self.cufft.fft(self.d_radio_padded, output=self.d_radio_f) self.cpxmult_kernel(*self._cpxmult_kernel_args, **self._cpxmult_kernel_kwargs) self.cufft.ifft(self.d_radio_f, output=self.d_radio_padded) if output is None: output = self.cuda_processing.allocate_array("d_output", self.shape) output[:, :] = self.d_radio_padded[: self.shape[0], : self.shape[1]] return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/distortion.py0000644000175000017500000000635500000000000017404 0ustar00pierrepierreimport numpy as np from scipy.interpolate import RegularGridInterpolator from ..utils import check_supported from ..estimation.distortion import estimate_flat_distortion def correct_distortion_interpn(image, coords, bounds_error=False, fill_value=None): """ Correct image distortion with scipy.interpolate.interpn. Parameters ---------- image: array Distorted image coords: array Coordinates of the distortion correction to apply, with the shape (Ny, Nx, 2) """ foo = RegularGridInterpolator( (np.arange(image.shape[0]), np.arange(image.shape[1])), image, bounds_error=bounds_error, method="linear", fill_value=fill_value, ) return foo(coords) class DistortionCorrection: """ A class for estimating and correcting image distortion. """ estimation_methods = { "fft-correlation": estimate_flat_distortion, } correction_methods = { "interpn": correct_distortion_interpn, } def __init__( self, estimation_method="fft-correlation", estimation_kwargs=None, correction_method="interpn", correction_kwargs=None, ): """ Initialize a DistortionCorrection object. Parameters ----------- estimation_method: str Name of the method to use for estimating the distortion estimation_kwargs: dict, optional Named arguments to pass to the estimation method, in the form of a dictionary. correction_method: str Name of the method to use for correcting the distortion correction_kwargs: dict, optional Named arguments to pass to the correction method, in the form of a dictionary. """ self._set_estimator(estimation_method, estimation_kwargs) self._set_corrector(correction_method, correction_kwargs) def _set_estimator(self, estimation_method, estimation_kwargs): check_supported(estimation_method, self.estimation_methods.keys(), "estimation method") self.estimator = self.estimation_methods[estimation_method] self._estimator_kwargs = estimation_kwargs or {} def _set_corrector(self, correction_method, correction_kwargs): check_supported(correction_method, self.correction_methods.keys(), "correction method") self.corrector = self.correction_methods[correction_method] self._corrector_kwargs = correction_kwargs or {} def estimate_distortion(self, image, reference_image): return self.estimator(image, reference_image, **self._estimator_kwargs) estimate = estimate_distortion def correct_distortion(self, image, coords): image_corrected = self.corrector(image, coords, **self._corrector_kwargs) fill_value = self._corrector_kwargs.get("fill_value", None) if fill_value is not None and np.isnan(fill_value): mask = np.isnan(image_corrected) image_corrected[mask] = image[mask] return image_corrected correct = correct_distortion def estimate_and_correct(self, image, reference_image): coords = self.estimate_distortion(image, reference_image) image_corrected = self.correct_distortion(image, coords) return image_corrected ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/double_flatfield.py0000644000175000017500000001727000000000000020470 0ustar00pierrepierrefrom os import path import numpy as np from scipy.ndimage import gaussian_filter from silx.io.url import DataUrl from ..utils import check_supported, check_shape, get_2D_3D_shape from ..io.reader import Readers from ..io.writer import Writers class DoubleFlatField: _default_h5_path = "/entry/double_flatfield/results" _small = 1e-7 def __init__( self, shape, result_url=None, sub_region=None, detector_corrector=None, input_is_mlog=True, output_is_mlog=False, average_is_on_log=False, sigma_filter=None, filter_mode="reflect", ): """ Init double flat field by summing a series of urls and considering the same subregion of them. Parameters ---------- shape: tuple Expected shape of radios chunk to process result_url: url, optional where the double-flatfield is stored after being computed, and possibly read (instead of re-computed) before processing the images. sub_region: tuple, optional If provided, this must be a tuple in the form (start_x, end_x, start_y, end_y). Each image will be cropped to this region. This is used to specify a chunk of files. Each of the parameters can be None, in this case the default start and end are taken in each dimension. input_is_mlog: boolean, default True the input is considred as minus logarithm of normalised radios output_is_mlog: boolean, default True the output is considred as minus logarithm of normalised radios average_is_on_log : boolean, False the minus logarithm of the data is averaged the clipping value that is applied prior to the logarithm sigma_filter: optional if given a high pass filter is applied by signal -gaussian_filter(signal,sigma,filter_mode) filter_mode: optional, default 'reflect' the padding scheme applied a the borders ( same as scipy.ndimage.filtrs.gaussian_filter) """ self.radios_shape = get_2D_3D_shape(shape) self.n_angles = self.radios_shape[0] self.shape = self.radios_shape[1:] self._init_filedump(result_url, sub_region, detector_corrector) self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode) self._computed = False def _get_reader_writer_class(self): ext = path.splitext(self.result_url.file_path())[-1].replace(".", "") check_supported(ext, list(Writers.keys()), "file format") self._writer_cls = Writers[ext] self._reader_cls = Readers[ext] def _load_dff_dump(self): res = self.reader.get_data(self.result_url) if self.detector_corrector is not None: if res.ndim == 2: res = self.detector_corrector.transform(res) else: for i in range(res.shape[0]): res[i] = self.detector_corrector.transform(res[i]) if res.ndim == 3 and res.shape[0] == 1: res = res.reshape(res.shape[1], res.shape[2]) if res.shape != self.shape: raise ValueError( "Data in %s has shape %s, but expected %s" % (self.result_url.file_path(), str(res.shape), str(self.shape)) ) return res def _init_filedump(self, result_url, sub_region, detector_corrector=None): if isinstance(result_url, str): result_url = DataUrl(file_path=result_url, data_path=self._default_h5_path) self.sub_region = sub_region self.detector_corrector = detector_corrector self.result_url = result_url self.writer = None self.reader = None if self.result_url is None: return self._get_reader_writer_class() if path.exists(result_url.file_path()): if detector_corrector is None: adapted_subregion = sub_region else: adapted_subregion = self.detector_corrector.get_adapted_subregion(sub_region) self.reader = self._reader_cls(sub_region=adapted_subregion) else: self.writer = self._writer_cls(self.result_url.file_path()) def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode): self.input_is_mlog = input_is_mlog self.output_is_mlog = output_is_mlog self.average_is_on_log = average_is_on_log self.sigma_filter = sigma_filter if self.sigma_filter is not None and abs(float(self.sigma_filter)) < 1e-4: self.sigma_filter = None self.filter_mode = filter_mode proc = lambda x, o: np.copyto(o, x) if self.input_is_mlog: if not self.average_is_on_log: proc = lambda x, o: np.exp(-x, out=o) else: if self.average_is_on_log: proc = lambda x, o: -np.log(x, out=o) postproc = lambda x: x if self.output_is_mlog: if not self.average_is_on_log: postproc = lambda x: -np.log(x) else: if self.average_is_on_log: postproc = lambda x: np.exp(-x) self.proc = proc self.postproc = postproc def compute_double_flatfield(self, radios, recompute=False): """ Read the radios and generate the "double flat field" by averaging and possibly other processing. Parameters ---------- radios: array Input radios chunk. recompute: bool, optional Whether to recompute the double flatfield if already computed. """ if self._computed and not (recompute): return self.doubleflatfield # pylint: disable=E0203 acc = np.zeros(radios[0].shape, "f") tmpdat = np.zeros(radios[0].shape, "f") for ima in radios: self.proc(ima, tmpdat) acc += tmpdat acc /= radios.shape[0] if self.sigma_filter is not None: acc = acc - gaussian_filter(acc, self.sigma_filter, mode=self.filter_mode) self.doubleflatfield = self.postproc(acc) # Handle small values to avoid issues when dividing self.doubleflatfield[np.abs(self.doubleflatfield) < self._small] = 1.0 self.doubleflatfield = self.doubleflatfield.astype("f") if self.writer is not None: self.writer.write(self.doubleflatfield, "double_flatfield") self._computed = True return self.doubleflatfield def get_double_flatfield(self, radios=None, compute=False): """ Get the double flat field or a subregion of it. Parameters ---------- radios: array, optional Input radios chunk compute: bool, optional Whether to compute the double flatfield anyway even if a dump file exists. """ if self.reader is None: if radios is None: raise ValueError("result_url was not provided. Please provide 'radios' to this function") return self.compute_double_flatfield(radios) if radios is not None and compute: res = self.compute_double_flatfield(radios) else: res = self._load_dff_dump() self._computed = True return res def apply_double_flatfield(self, radios): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ check_shape(radios.shape, self.radios_shape, "radios") dff = self.get_double_flatfield(radios=radios) for i in range(self.n_angles): radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/double_flatfield_cuda.py0000644000175000017500000001376400000000000021470 0ustar00pierrepierrefrom .double_flatfield import DoubleFlatField from ..utils import check_shape from ..cuda.utils import __has_pycuda__ from ..cuda.processing import CudaProcessing from ..misc.unsharp_cuda import CudaUnsharpMask if __has_pycuda__: import pycuda.gpuarray as garray import pycuda.cumath as cumath class CudaDoubleFlatField(DoubleFlatField): def __init__( self, shape, result_url=None, sub_region=None, detector_corrector=None, input_is_mlog=True, output_is_mlog=False, average_is_on_log=False, sigma_filter=None, filter_mode="reflect", cuda_options=None, ): """ Init double flat field with Cuda backend. """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, result_url=result_url, sub_region=sub_region, detector_corrector=detector_corrector, input_is_mlog=input_is_mlog, output_is_mlog=output_is_mlog, average_is_on_log=average_is_on_log, sigma_filter=sigma_filter, filter_mode=filter_mode, ) self._init_gaussian_filter() def _init_gaussian_filter(self): if self.sigma_filter is None: return self._unsharp_mask = CudaUnsharpMask(self.shape, self.sigma_filter, -1.0, mode=self.filter_mode, method="log") @staticmethod def _proc_copy(x, o): o[:] = x[:] return o @staticmethod def _proc_expm(x, o): o[:] = x[:] o[:] *= -1 cumath.exp(o, out=o) return o @staticmethod def _proc_mlog(x, o, min_clip=None): if min_clip is not None: garray.maximum(x, min_clip, out=o) cumath.log(o, out=o) else: cumath.log(x, out=o) o *= -1 return o def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode): self.input_is_mlog = input_is_mlog self.output_is_mlog = output_is_mlog self.average_is_on_log = average_is_on_log self.sigma_filter = sigma_filter if self.sigma_filter is not None and abs(float(self.sigma_filter)) < 1e-4: self.sigma_filter = None self.filter_mode = filter_mode # proc = lambda x,o: np.copyto(o, x) proc = self._proc_copy if self.input_is_mlog: if not self.average_is_on_log: # proc = lambda x,o: np.exp(-x, out=o) proc = self._proc_expm else: if self.average_is_on_log: # proc = lambda x,o: -np.log(x, out=o) proc = self._proc_mlog # postproc = lambda x: x postproc = self._proc_copy if self.output_is_mlog: if not self.average_is_on_log: # postproc = lambda x: -np.log(x) postproc = self._proc_mlog else: if self.average_is_on_log: # postproc = lambda x: np.exp(-x) postproc = self._proc_expm self.proc = proc self.postproc = postproc def compute_double_flatfield(self, radios, recompute=False): """ Read the radios and generate the "double flat field" by averaging and possibly other processing. Parameters ---------- radios: array Input radios chunk. recompute: bool, optional Whether to recompute the double flatfield if already computed. """ if not (isinstance(radios, garray.GPUArray)): raise ValueError("Expected pycuda.gpuarray.GPUArray for radios") if self._computed and not (recompute): return self.doubleflatfield acc = garray.zeros(radios[0].shape, "f") tmpdat = garray.zeros(radios[0].shape, "f") for i in range(radios.shape[0]): self.proc(radios[i], tmpdat) acc += tmpdat acc /= radios.shape[0] if self.sigma_filter is not None: # acc = acc - gaussian_filter(acc, self.sigma_filter, mode=self.filter_mode) self._unsharp_mask.unsharp(acc, tmpdat) acc[:] = tmpdat[:] self.postproc(acc, tmpdat) self.doubleflatfield = tmpdat # Handle small values to avoid issues when dividing # self.doubleflatfield[np.abs(self.doubleflatfield) < self._small] = 1. cumath.fabs(self.doubleflatfield, out=acc) acc -= self._small # acc = abs(doubleflatfield) - _small garray.if_positive(acc, self.doubleflatfield, garray.zeros_like(acc) + self._small, out=self.doubleflatfield) if self.writer is not None: self.writer.write(self.doubleflatfield.get(), "double_flatfield") self._computed = True return self.doubleflatfield def get_double_flatfield(self, radios=None, compute=False): """ Get the double flat field or a subregion of it. Parameters ---------- radios: array, optional Input radios chunk compute: bool, optional Whether to compute the double flatfield anyway even if a dump file exists. """ if self.reader is None: if radios is None: raise ValueError("result_url was not provided. Please provide 'radios' to this function") return self.compute_double_flatfield(radios) if radios is not None and compute: res = self.compute_double_flatfield(radios) else: res = self._load_dff_dump() res = garray.to_gpu(res) self._computed = True return res def apply_double_flatfield(self, radios): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ check_shape(radios.shape, self.radios_shape, "radios") dff = self.get_double_flatfield(radios=radios) for i in range(self.n_angles): radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/double_flatfield_variable_region.py0000644000175000017500000000434700000000000023701 0ustar00pierrepierrefrom .double_flatfield import ( DoubleFlatField, DoubleFlatField, check_shape, get_2D_3D_shape, ) from ..misc.binning import get_binning_function class DoubleFlatFieldVariableRegion(DoubleFlatField): def __init__( self, shape, result_url=None, binning_x=None, binning_z=None, detector_corrector=None, ): """This class provides the division by the double flat field. At variance with the standard class, it store as member the whole field, and performs the division by the proper region according to the positionings of the processed radios which is passed by the argument array sub_regions_per_radio to the method apply_double_flatfield_for_sub_regions """ self.radios_shape = get_2D_3D_shape(shape) self.n_angles = self.radios_shape[0] self.shape = self.radios_shape[1:] self._init_filedump(result_url, None, detector_corrector) data = self._load_dff_full_dump() if (binning_z, binning_x) != (1, 1): print(" (binning_z, binning_x) ", (binning_z, binning_x)) binning_function = get_binning_function((binning_z, binning_x)) if binning_function is None: raise NotImplementedError(f"Binning factor for {(binning_z, binning_x)} is not implemented yet") self.data = binning_function(data) else: self.data = data def _load_dff_full_dump(self): res = self.reader.get_data(self.result_url) if self.detector_corrector is not None: self.detector_corrector.set_full_transformation() res = self.detector_corrector.transform(res, do_full=True) return res def apply_double_flatfield_for_sub_regions(self, radios, sub_regions_per_radio): """ Apply the "double flatfield" filter on a chunk of radios. The processing is done in-place ! """ my_double_ff = self.data for i in range(radios.shape[0]): s_x, e_x, s_y, e_y = sub_regions_per_radio[i] dff = my_double_ff[s_y:e_y, s_x:e_x] check_shape(radios[i].shape, dff.shape, "radios") radios[i] /= dff return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/preproc/flatfield.py0000644000175000017500000004540200000000000017134 0ustar00pierrepierrefrom multiprocessing.pool import ThreadPool from bisect import bisect_left import numpy as np from ..io.reader import load_images_from_dataurl_dict from ..utils import check_supported, get_num_threads class FlatFieldArrays: """ A class for flat-field normalization """ # the variable below will be True for the derived class # which is taylored for to helical case _full_shape = False _supported_interpolations = ["linear", "nearest"] def __init__( self, radios_shape: tuple, flats, darks, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, n_threads=None, ): """ Initialize a flat-field normalization process. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where each key is the flat index, and the value is a numpy.ndarray of the flat image. darks: dict Dictionary where each key is the dark index, and the value is a numpy.ndarray of the dark image. radios_indices: array of int, optional Array containing the radios indices in the scan. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. radios_srcurrent: array, optional Array with the same shape as radios_indices. Each item contains the synchrotron electric current. If not None, normalization with current is applied. Please refer to "Notes" for more information on this normalization. flats_srcurrent: array, optional Array with the same length as "flats". Each item is a measurement of the synchrotron electric current for the corresponding flat. The items must be ordered in the same order as the flats indices (`flats.keys()`). This parameter must be used along with 'radios_srcurrent'. Please refer to "Notes" for more information on this normalization. n_threads: int or None, optional Number of threads to use for flat-field correction. Default is to use half the threads. Important ---------- `flats` and `darks` are expected to be a dictionary with integer keys (the flats/darks indices) and numpy array values. You can use the following helper functions: `nabu.io.reader.load_images_from_dataurl_dict` and `nabu.io.utils.create_dict_of_indices` Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. The normalization with synchrotron electric current is done as follows. Let s = sr/sr_max denote the ratio between current and maximum current, D be the dark-current frame, and X' be the normalized frame. Then: srcurrent_normalization(X) = X' = (X - D)/s + D flatfield_normalization(X') = (X' - D)/(F' - D) = (X - D) / (F - D) * sF/sX So current normalization boils down to a scalar multiplication after flat-field. """ if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape if radios_indices is not None: radios_shape = (len(radios_indices),) + radios_shape[1:] self._set_parameters(radios_shape, radios_indices, interpolation, nan_value) self._set_flats_and_darks(flats, darks) self._precompute_flats_indices_weights() self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent) self.distortion_correction = distortion_correction self.n_threads = min(1, get_num_threads(n_threads) // 2) def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value): self._set_radios_shape(radios_shape) if radios_indices is None: radios_indices = np.arange(0, self.n_radios, dtype=np.int32) else: radios_indices = np.array(radios_indices, dtype=np.int32) self._check_radios_and_indices_congruence(radios_indices) self.radios_indices = radios_indices self.interpolation = interpolation check_supported(interpolation, self._supported_interpolations, "Interpolation mode") self.nan_value = nan_value self._radios_idx_to_pos = dict(zip(self.radios_indices, np.arange(self.radios_indices.size))) def _set_radios_shape(self, radios_shape): if len(radios_shape) == 2: self.radios_shape = (1,) + radios_shape elif len(radios_shape) == 3: self.radios_shape = radios_shape else: raise ValueError("Expected radios to have 2 or 3 dimensions") n_radios, n_z, n_x = self.radios_shape self.n_radios = n_radios self.n_angles = n_radios self.shape = (n_z, n_x) def _set_flats_and_darks(self, flats, darks): self._check_frames(flats, "flats", 1, 9999) self.n_flats = len(flats) self.flats = flats self._sorted_flat_indices = sorted(self.flats.keys()) if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape self.shape = flats[self._sorted_flat_indices[0]].shape self._flat2arrayidx = dict(zip(self._sorted_flat_indices, np.arange(self.n_flats))) self.flats_arr = np.zeros((self.n_flats,) + self.shape, "f") for i, idx in enumerate(self._sorted_flat_indices): self.flats_arr[i] = self.flats[idx] self._check_frames(darks, "darks", 1, 1) self.darks = darks self.n_darks = len(darks) self._sorted_dark_indices = sorted(self.darks.keys()) self._dark = None def _check_frames(self, frames, frames_type, min_frames_required, max_frames_supported): n_frames = len(frames) if n_frames < min_frames_required: raise ValueError("Need at least %d %s" % (min_frames_required, frames_type)) if n_frames > max_frames_supported: raise ValueError( "Flat-fielding with more than %d %s is not supported" % (max_frames_supported, frames_type) ) self._check_frame_shape(frames, frames_type) def _check_frame_shape(self, frames, frames_type): for frame_idx, frame in frames.items(): if frame.shape != self.shape: raise ValueError( "Invalid shape for %s %s: expected %s, but got %s" % (frames_type, frame_idx, str(self.shape), str(frame.shape)) ) def _check_radios_and_indices_congruence(self, radios_indices): if radios_indices.size != self.n_radios: raise ValueError( "Expected radios_indices to have length %s = n_radios, but got length %d" % (self.n_radios, radios_indices.size) ) def _precompute_flats_indices_weights(self): """ Build two arrays: "indices" and "weights". These arrays contain pre-computed information so that the interpolated flat is obtained with flat_interpolated = weight_prev * flat_prev + weight_next * flat_next where weight_prev, weight_next = weights[2*i], weights[2*i+1] idx_prev, idx_next = indices[2*i], indices[2*i+1] flat_prev, flat_next = flats[idx_prev], flats[idx_next] In words: - If a projection has an index between two flats, the equivalent flat is a linear interpolation between "previous flat" and "next flat". - If a projection has the same index as a flat, only this flat is used for normalization (this case normally never occurs, but it's handled in the code) """ def _interp_linear(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next delta = next_idx - prev_idx w1 = 1 - (idx - prev_idx) / delta w2 = 1 - (next_idx - idx) / delta weights = (w1, w2) f_idx = (self._flat2arrayidx[prev_idx], self._flat2arrayidx[next_idx]) return f_idx, weights def _interp_nearest(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next idx_to_take = prev_idx if abs(idx - prev_idx) < abs(idx - next_idx) else next_idx weights = (1, 0) f_idx = (self._flat2arrayidx[idx_to_take], -1) return f_idx, weights self.flats_idx = np.zeros((self.n_radios, 2), dtype=np.int32) self.flats_weights = np.zeros((self.n_radios, 2), dtype=np.float32) for i, idx in enumerate(self.radios_indices): prev_next = self.get_previous_next_indices(self._sorted_flat_indices, idx) if self.interpolation == "nearest": f_idx, weights = _interp_nearest(idx, prev_next) elif self.interpolation == "linear": f_idx, weights = _interp_linear(idx, prev_next) self.flats_idx[i] = f_idx self.flats_weights[i] = weights # pylint: disable=E1307 def _configure_srcurrent_normalization(self, radios_srcurrent, flats_srcurrent): self.normalize_srcurrent = False if radios_srcurrent is None or flats_srcurrent is None: return radios_srcurrent = np.array(radios_srcurrent) if radios_srcurrent.size != self.n_radios: raise ValueError( "Expected 'radios_srcurrent' to have %d elements but got %d" % (self.n_radios, radios_srcurrent.size) ) flats_srcurrent = np.array(flats_srcurrent) if flats_srcurrent.size != self.n_flats: raise ValueError( "Expected 'flats_srcurrent' to have %d elements but got %d" % (self.n_flats, flats_srcurrent.size) ) self.normalize_srcurrent = True self.radios_srcurrent = radios_srcurrent self.flats_srcurrent = flats_srcurrent self.srcurrent_ratios = np.zeros(self.n_radios, "f") # Flats SRCurrent is obtained with "nearest" interp, to emulate an already-done flats SR current normalization for i, radio_idx in enumerate(self.radios_indices): flat_idx = self.get_nearest_index(self._sorted_flat_indices, radio_idx) flat_srcurrent = self.flats_srcurrent[self._flat2arrayidx[flat_idx]] self.srcurrent_ratios[i] = flat_srcurrent / self.radios_srcurrent[i] @staticmethod def get_previous_next_indices(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr): # outside range return (arr[-1],) if arr[pos] == idx: return (idx,) if pos == 0: return (arr[0],) return arr[pos - 1], arr[pos] @staticmethod def get_nearest_index(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr) or arr[pos] == idx: return arr[-1] return arr[pos - 1] if idx - arr[pos - 1] < arr[pos] - idx else arr[pos] @staticmethod def interp(pos, indices, weights, array, slice_y=slice(None, None), slice_x=slice(None, None)): """ Interpolate between two values. The interpolator consists in pre-computed arrays such that prev, next = indices[pos] w1, w2 = weights[pos] interpolated_value = w1 * array[prev] + w2 * array[next] """ prev_idx = indices[pos, 0] next_idx = indices[pos, 1] if slice_y != slice(None, None) or slice_x != slice(None, None): w1 = weights[pos, 0][slice_y, slice_x] w2 = weights[pos, 1][slice_y, slice_x] else: w1 = weights[pos, 0] w2 = weights[pos, 1] if next_idx == -1: val = array[prev_idx] else: val = w1 * array[prev_idx] + w2 * array[next_idx] return val def get_flat(self, pos, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None)): flat = self.interp(pos, self.flats_idx, self.flats_weights, self.flats_arr, slice_y=slice_y, slice_x=slice_x) if flat.dtype != dtype: flat = np.ascontiguousarray(flat, dtype=dtype) return flat def get_dark(self): if self._dark is None: first_dark_idx = self._sorted_dark_indices[0] dark = np.ascontiguousarray(self.darks[first_dark_idx], dtype=np.float32) self._dark = dark return self._dark def remove_invalid_values(self, img): if self.nan_value is None: return invalid_mask = np.logical_not(np.isfinite(img)) img[invalid_mask] = self.nan_value def normalize_radios(self, radios): """ Apply a flat-field normalization, with the current parameters, to a stack of radios. The processing is done in-place, meaning that the radios content is overwritten. Parameters ----------- radios: numpy.ndarray Radios chunk """ do_flats_distortion_correction = self.distortion_correction is not None dark = self.get_dark() def apply_flatfield(i): radio_data = radios[i] radio_data -= dark flat = self.get_flat(i) flat = flat - dark if do_flats_distortion_correction: flat = self.distortion_correction.estimate_and_correct(flat, radio_data) np.divide(radio_data, flat, out=radio_data) self.remove_invalid_values(radio_data) if self.n_threads > 2: with ThreadPool(self.n_threads) as tp: tp.map(apply_flatfield, range(self.n_radios)) else: for i in range(self.n_radios): apply_flatfield(i) if self.normalize_srcurrent: radios *= self.srcurrent_ratios[:, np.newaxis, np.newaxis] return radios def normalize_single_radio( self, radio, radio_idx, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None) ): """ Apply a flat-field normalization to a single projection image. """ dark = self.get_dark()[slice_y, slice_x] radio -= dark radio_pos = self._radios_idx_to_pos[radio_idx] flat = self.get_flat(radio_pos, dtype=dtype, slice_y=slice_y, slice_x=slice_x) flat = flat - dark if self.distortion_correction is not None: flat = self.distortion_correction.estimate_and_correct(flat, radio) radio /= flat if self.normalize_srcurrent: radio *= self.srcurrent_ratios[radio_pos] self.remove_invalid_values(radio) return radio FlatField = FlatFieldArrays class FlatFieldDataUrls(FlatField): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, **chunk_reader_kwargs, ): """ Initialize a flat-field normalization process with DataUrls. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where the key is the flat index, and the value is a silx.io.DataUrl pointing to the flat. darks: dict Dictionary where the key is the dark index, and the value is a silx.io.DataUrl pointing to the dark. radios_indices: array, optional Array containing the radios indices. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. Other Parameters ---------------- The other named parameters are passed to ChunkReader(). Please read its documentation for more information. Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. """ flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/flatfield_cuda.py0000644000175000017500000001247600000000000020135 0ustar00pierrepierrefrom typing import Union import numpy as np import pycuda.gpuarray as garray from ..preproc.flatfield import FlatFieldArrays from ..cuda.kernel import CudaKernel from ..utils import get_cuda_srcfile from ..io.reader import load_images_from_dataurl_dict class CudaFlatFieldArrays(FlatFieldArrays): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options: Union[dict, None] = None, ): """ Initialize a flat-field normalization CUDA process. Please read the documentation of nabu.preproc.flatfield.FlatField for help on the parameters. """ # if distortion_correction is not None: raise NotImplementedError("Flats distortion correction is not implemented with the Cuda backend") # super().__init__( radios_shape, flats, darks, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, nan_value=nan_value, ) self._set_cuda_options(cuda_options) self._init_cuda_kernels() self._load_flats_and_darks_on_gpu() def _set_cuda_options(self, user_cuda_options): self.cuda_options = {"device_id": None, "ctx": None, "cleanup_at_exit": None} if user_cuda_options is None: user_cuda_options = {} self.cuda_options.update(user_cuda_options) def _init_cuda_kernels(self): # TODO if self.interpolation != "linear": raise ValueError("Interpolation other than linar is not yet implemented in the cuda back-end") # self._cuda_fname = get_cuda_srcfile("flatfield.cu") options = [ "-DN_FLATS=%d" % self.n_flats, "-DN_DARKS=%d" % self.n_darks, ] if self.nan_value is not None: options.append("-DNAN_VALUE=%f" % self.nan_value) self.cuda_kernel = CudaKernel( "flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options ) self._nx = np.int32(self.shape[1]) self._ny = np.int32(self.shape[0]) def _load_flats_and_darks_on_gpu(self): # Flats self.d_flats = garray.zeros((self.n_flats,) + self.shape, np.float32) for i, flat_idx in enumerate(self._sorted_flat_indices): self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32)) # Darks self.d_darks = garray.zeros((self.n_darks,) + self.shape, np.float32) for i, dark_idx in enumerate(self._sorted_dark_indices): self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32)) self.d_darks_indices = garray.to_gpu(np.array(self._sorted_dark_indices, dtype=np.int32)) # Indices self.d_flats_indices = garray.to_gpu(self.flats_idx) self.d_flats_weights = garray.to_gpu(self.flats_weights) def normalize_radios(self, radios): """ Apply a flat-field correction, with the current parameters, to a stack of radios. Parameters ----------- radios_shape: `pycuda.gpuarray.GPUArray` Radios chunk. """ if not (isinstance(radios, garray.GPUArray)): raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios))) if radios.dtype != np.float32: raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype)) if radios.shape != self.radios_shape: raise ValueError("Expected radios shape = %s but got %s" % (str(self.radios_shape), str(radios.shape))) self.cuda_kernel( radios, self.d_flats, self.d_darks, self._nx, self._ny, np.int32(self.n_radios), self.d_flats_indices, self.d_flats_weights, ) if self.normalize_srcurrent: for i in range(self.n_radios): radios[i] *= self.srcurrent_ratios[i] return radios CudaFlatField = CudaFlatFieldArrays class CudaFlatFieldDataUrls(CudaFlatField): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options: Union[dict, None] = None, **chunk_reader_kwargs, ): flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, cuda_options=cuda_options, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679996432.0 nabu-2023.1.1/nabu/preproc/flatfield_variable_region.py0000644000175000017500000000612300000000000022341 0ustar00pierrepierreimport numpy as np from .flatfield import FlatFieldArrays, load_images_from_dataurl_dict, check_supported class FlatFieldArraysVariableRegion(FlatFieldArrays): _full_shape = True def _check_frame_shape(self, frames, frames_type): # in helical the flat is the whole one and its shape does not necesseraly match the smaller frames. # Therefore no check is done to allow this. pass def _check_radios_and_indices_congruence(self, radios_indices): """At variance with parent class, preprocesing is done with on a fraction of the radios, whose length may vary. So we dont enforce here that the lenght is always the same """ pass def _normalize_radios(self, radios, sub_indexes, sub_regions_per_radio): """ Apply a flat-field normalization, with the current parameters, to a stack of radios. The processing is done in-place, meaning that the radios content is overwritten. """ if len(sub_regions_per_radio) != len(sub_indexes): message = f""" The length of sub_regions_per_radio,which is {len(sub_regions_per_radio)} , does not correspond to the length of sub_indexes which is {len(sub_indexes)} """ raise ValueError(message) do_flats_distortion_correction = self.distortion_correction is not None whole_dark = self.get_dark() for i, (idx, sub_r) in enumerate(zip(sub_indexes, sub_regions_per_radio)): start_x, end_x, start_y, end_y = sub_r slice_x = slice(start_x, end_x) slice_y = slice(start_y, end_y) self.normalize_single_radio(radios[i], idx, dtype=np.float32, slice_y=slice_y, slice_x=slice_x) return radios class FlatFieldDataVariableRegionUrls(FlatFieldArraysVariableRegion): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, **chunk_reader_kwargs, ): flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) _flats_indexes = list(flats_arrays_dict.keys()) _flats_indexes.sort() self.flats_indexes = np.array(_flats_indexes) self.flats_stack = np.array([flats_arrays_dict[i] for i in self.flats_indexes], "f") flats_arrays_dict = dict([[indx, flat] for indx, flat in zip(self.flats_indexes, self.flats_stack)]) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, ) self._sorted_flat_indices = np.array(self._sorted_flat_indices, "i") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/preproc/phase.py0000644000175000017500000004027400000000000016304 0ustar00pierrepierrefrom math import pi from bisect import bisect import numpy as np from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning # from .ctf import CTFPhaseRetrieval # def lmicron_to_db(Lmicron, energy, distance): """ Utility to convert the "Lmicron" parameter of PyHST to a value of delta/beta. Parameters ----------- Lmicron: float Length in microns, values of the parameter "PAGANIN_Lmicron" in PyHST2 parameter file. energy: float Energy in keV. distance: float Sample-detector distance in microns Notes -------- The conversion is done using the formula .. math:: L^2 = \\pi \\lambda D \\frac{\\delta}{\\beta} """ L2 = Lmicron**2 wavelength = 1.23984199e-3 / energy return L2 / (pi * wavelength * distance) class PaganinPhaseRetrieval: available_padding_modes = ["zeros", "mean", "edge", "symmetric", "reflect"] powers = generate_powers() def __init__( self, shape, distance=0.5, energy=20, delta_beta=250.0, pixel_size=1e-6, padding="edge", margin=None, use_rfft=True, use_R2C=None, fftw_num_threads=None, ): """ Paganin Phase Retrieval for an infinitely distant point source. Formula (10) in [1]. Parameters ---------- shape: int or tuple Shape of each radio, in the format (num_rows, num_columns), i.e (size_vertical, size_horizontal). If an integer is provided, the shape is assumed to be square. distance : float, optional Propagation distance in meters. energy : float, optional Energy in keV. delta_beta: float, optional delta/beta ratio, where n = (1 - delta) + i*beta is the complex refractive index of the sample. pixel_size : float, optional Detector pixel size in meters. Default is 1e-6 (one micron) padding : str, optional Padding method. Available are "zeros", "mean", "edge", "sym", "reflect". Default is "edge". Please refer to the "Padding" section below for more details. margin: tuple, optional The user may provide integers values U, D, L, R as a tuple under the form ((U, D), (L, R)) (same syntax as numpy.pad()). The resulting filtered radio will have a size equal to (size_vertic - U - D, size_horiz - L - R). These values serve to create a "margin" for the filtering process, where U, D, L R are the margin of the Up, Down, Left and Right part, respectively. The filtering is done on a subset of the input radio. The subset size is (Nrows - U - D, Ncols - R - L). The margins is used to do the padding for the rest of the padded array. For example in one dimension, where ``padding="edge"``:: <------------------------------ padded_size ---------------------------> [padding=edge | padding=data | radio data | padding=data | padding=edge] <------ N2 ---><----- L -----><- (N-L-R)--><----- R -----><----- N2 ---> Some or all the values U, D, L, R can be 0. In this case, the padding of the parts related to the zero values will fall back to the one of "padding" parameter. For example, if padding="edge" and L, R are 0, then the left and right parts will be padded with the edges, while the Up and Down parts will be padded using the the user-provided margins of the radio, and the final data will have shape (Nrows - U - D, Ncols). Some or all the values U, D, L, R can be the string "auto". In this case, the values of U, D, L, R are automatically computed as a function of the Paganin filter width. use_rfft: bool, optional Whether to use Real-to-Complex (R2C) transform instead of standard Complex-to-Complex transform, providing better performances use_R2C: bool, optional DEPRECATED, use use_rfft instead fftw_num_threads: bool or None or int, optional Whether to use FFTW for speeding up FFT. Default is to use all available threads. You can pass a negative number to use N - fftw_num_threads cores. Important ---------- Mind the units! Distance and pixel size are in meters, and energy is in keV. Notes ------ **Padding methods** The phase retrieval is a convolution done in Fourier domain using FFT, so the Fourier transform size has to be at least twice the size of the original data. Mathematically, the data should be padded with zeros before being Fourier transformed. However, in practice, this can lead to artefacts at the edges (Gibbs effect) if the data does not go to zero at the edges. Apart from applying an apodization (Hamming, Blackman, etc), a common strategy to avoid these artefacts is to pad the data. In tomography reconstruction, this is usually done by replicating the last(s) value(s) of the edges ; but one can think of other methods: - "zeros": the data is simply padded with zeros. - "mean": the upper side of extended data is padded with the mean of the first row, the lower side with the mean of the last row, etc. - "edge": the data is padded by replicating the edges. This is the default mode. - "sym": the data is padded by mirroring the data with respect to its edges. See ``numpy.pad()``. - "reflect": the data is padded by reflecting the data with respect to its edges, including the edges. See ``numpy.pad()``. **Formulas** The radio is divided, in the Fourier domain, by the original "Paganin filter" `[1]`. .. math:: F = 1 + \\frac{\\delta}{\\beta} \\lambda D \\pi |k|^2 where k is the wave vector. References ----------- [1] D. Paganin Et Al, "Simultaneous phase and amplitude extraction from a single defocused image of a homogeneous object", Journal of Microscopy, Vol 206, Part 1, 2002 """ self._init_parameters(distance, energy, pixel_size, delta_beta, padding) self._calc_shape(shape, margin) # COMPAT. if use_R2C is not None: deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c") # - self._get_fft(use_rfft, fftw_num_threads) self.compute_filter() def _init_parameters(self, distance, energy, pixel_size, delta_beta, padding): self.distance_cm = distance * 1e2 self.distance_micron = distance * 1e6 self.energy_kev = energy self.pixel_size_micron = pixel_size * 1e6 self.delta_beta = delta_beta self.wavelength_micron = 1.23984199e-3 / self.energy_kev self.padding = padding self.padding_methods = { "zeros": self._pad_zeros, "mean": self._pad_mean, "edge": self._pad_edge, "symmetric": self._pad_sym, "reflect": self._pad_reflect, } def _get_fft(self, use_rfft, fftw_num_threads): self.use_rfft = use_rfft self.use_R2C = use_rfft # Compat. fftw_num_threads = get_num_threads(fftw_num_threads) if self.use_rfft: self.fft_func = np.fft.rfft2 self.ifft_func = np.fft.irfft2 else: self.fft_func = np.fft.fft2 self.ifft_func = np.fft.ifft2 self.use_fftw = False if self.use_rfft and (fftw_num_threads > 0): # importing silx.math.fft creates opencl contexts all over the place # because of the silx.opencl.ocl singleton. # So, import silx as late as possible from silx.math.fft.fftw import FFTW, __have_fftw__ if __have_fftw__: self.use_fftw = True self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads) self.fft_func = self.fftw.fft self.ifft_func = self.fftw.ifft def _calc_shape(self, shape, margin): if np.isscalar(shape): shape = (shape, shape) else: assert len(shape) == 2 self.shape = shape self._set_margin_value(margin) self._calc_padded_shape() def _set_margin_value(self, margin): self.margin = margin if margin is None: self.shape_inner = self.shape self.use_margin = False self.margin = ((0, 0), (0, 0)) return self.use_margin = True try: ((U, D), (L, R)) = margin except ValueError: raise ValueError("Expected margin in the format ((U, D), (L, R))") for val in [U, D, L, R]: if isinstance(val, str) and val != "auto": raise ValueError("Expected either an integer, or 'auto'") if int(val) != val or val < 0: raise ValueError("Expected positive integers for margin values") self.shape_inner = (self.shape[0] - U - D, self.shape[1] - L - R) def _calc_padded_shape(self): """ Compute the padded shape. If margin = 0, length_padded = next_power(2*length). Otherwise : length_padded = next_power(2*(length - margins)) Principle ---------- <--------------------- nx_p ---------------------> | | original data | | < -- Pl - ><-- L -->< -- nx --><-- R --><-- Pr --> <----------- nx0 -----------> Pl, Pr : left/right padding length L, R : left/right margin nx : length of inner data (and length of final result) nx0 : length of original data nx_p : total length of padded data """ n_y, n_x = self.shape_inner n_y0, n_x0 = self.shape n_y_p = self._get_next_power(max(2 * n_y, n_y0)) n_x_p = self._get_next_power(max(2 * n_x, n_x0)) self.shape_padded = (n_y_p, n_x_p) self.data_padded = np.zeros((n_y_p, n_x_p), dtype=np.float64) ((U, D), (L, R)) = self.margin n_y0, n_x0 = self.shape self.pad_top_len = (n_y_p - n_y0) // 2 self.pad_bottom_len = n_y_p - n_y0 - self.pad_top_len self.pad_left_len = (n_x_p - n_x0) // 2 self.pad_right_len = n_x_p - n_x0 - self.pad_left_len def _get_next_power(self, n): """ Given a number, get the closest (upper) number p such that p is a power of 2, 3, 5 and 7. """ idx = bisect(self.powers, n) if self.powers[idx - 1] == n: return n return self.powers[idx] def compute_filter(self): nyp, nxp = self.shape_padded fftfreq = np.fft.rfftfreq if self.use_rfft else np.fft.fftfreq fy = np.fft.fftfreq(nyp, d=self.pixel_size_micron) fx = fftfreq(nxp, d=self.pixel_size_micron) self._coords_grid = np.add.outer(fy**2, fx**2) # k2 = self._coords_grid D = self.distance_micron L = self.wavelength_micron db = self.delta_beta self.paganin_filter = 1.0 / (1 + db * L * D * pi * k2) def pad_with_values(self, data, top_val=0, bottom_val=0, left_val=0, right_val=0): """ Pad the data into `self.padded_data` with values. Parameters ---------- data: numpy.ndarray data (radio) top_val: float or numpy.ndarray, optional Value(s) to fill the top of the padded data with. bottom_val: float or numpy.ndarray, optional Value(s) to fill the bottom of the padded data with. left_val: float or numpy.ndarray, optional Value(s) to fill the left of the padded data with. right_val: float or numpy.ndarray, optional Value(s) to fill the right of the padded data with. """ self.data_padded.fill(0) Pu, Pd = self.pad_top_len, self.pad_bottom_len Pl, Pr = self.pad_left_len, self.pad_right_len self.data_padded[:Pu, :] = top_val self.data_padded[-Pd:, :] = bottom_val self.data_padded[:, :Pl] = left_val self.data_padded[:, -Pr:] = right_val self.data_padded[Pu:-Pd, Pl:-Pr] = data # Transform the data to the FFT layout self.data_padded = np.roll(self.data_padded, (-Pu, -Pl), axis=(0, 1)) def _pad_zeros(self, data): return self.pad_with_values(data, top_val=0, bottom_val=0, left_val=0, right_val=0) def _pad_mean(self, data): """ Pad the data at each border with a different constant value. The value depends on the padding size: - On the left, value = mean(first data column) - On the right, value = mean(last data column) - On the top, value = mean(first data row) - On the bottom, value = mean(last data row) """ return self.pad_with_values( data, top_val=np.mean(data[0, :]), bottom_val=np.mean(data[-1, :]), left_val=np.mean(data[:, 0]), right_val=np.mean(data[:, -1]), ) def _pad_numpy(self, data, mode): data_padded = np.pad( data, ((self.pad_top_len, self.pad_bottom_len), (self.pad_left_len, self.pad_right_len)), mode=mode ) # Transform the data to the FFT layout Pu, Pl = self.pad_top_len, self.pad_left_len return np.roll(data_padded, (-Pu, -Pl), axis=(0, 1)) def _pad_edge(self, data): self.data_padded = self._pad_numpy(data, mode="edge") def _pad_sym(self, data): self.data_padded = self._pad_numpy(data, mode="symmetric") def _pad_reflect(self, data): self.data_padded = self._pad_numpy(data, mode="reflect") def pad_data(self, data, padding_method=None): padding_method = padding_method or self.padding check_supported(padding_method, self.available_padding_modes, "padding mode") if padding_method not in self.padding_methods: raise ValueError( "Unknown padding method %s. Available are: %s" % (padding_method, str(list(self.padding_methods.keys()))) ) pad_func = self.padding_methods[padding_method] pad_func(data) return self.data_padded def apply_filter(self, radio, padding_method=None, output=None): self.pad_data(radio, padding_method=padding_method) radio_f = self.fft_func(self.data_padded) radio_f *= self.paganin_filter radio_filtered = self.ifft_func(radio_f).real s0, s1 = self.shape_inner ((U, _), (L, _)) = self.margin if output is None: return radio_filtered[U : U + s0, L : L + s1] else: output[:, :] = radio_filtered[U : U + s0, L : L + s1] return output def lmicron_to_db(self, Lmicron): """ Utility to convert the "Lmicron" parameter of PyHST to a value of delta/beta. Please see the doc of nabu.preproc.phase.lmicron_to_db() """ return lmicron_to_db(Lmicron, self.energy_kev, self.distance_micron) __call__ = apply_filter retrieve_phase = apply_filter def compute_paganin_margin(shape, cutoff=1e3, **pag_kwargs): """ Compute the convolution margin to use when calling PaganinPhaseRetrieval class. Parameters ----------- shape: tuple Detector shape in the form (n_z, n_x) """ P = PaganinPhaseRetrieval(shape, **pag_kwargs) ifft_func = np.fft.irfft2 if P.use_rfft else np.fft.ifft2 conv_kernel = ifft_func(P.paganin_filter) vmax = conv_kernel[0, 0] v_margin = get_decay(conv_kernel[:, 0], cutoff=cutoff, vmax=vmax) h_margin = get_decay(conv_kernel[0, :], cutoff=cutoff, vmax=vmax) # If the Paganin filter is very narrow, then the corresponding convolution # kernel is constant, and np.argmax() gives 0 (when it should give the max value) if v_margin == 0: v_margin = shape[0] if h_margin == 0: h_margin = shape[1] return v_margin, h_margin ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/preproc/phase_cuda.py0000644000175000017500000001137700000000000017302 0ustar00pierrepierreimport numpy as np from math import sqrt, pi from ..utils import updiv, get_cuda_srcfile, _sizeof, check_supported from .phase import PaganinPhaseRetrieval import pycuda.driver as cuda from pycuda import gpuarray as garray from ..cuda.processing import CudaProcessing from ..cuda.kernel import CudaKernel class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval): supported_paddings = ["zeros", "constant", "edge"] def __init__( self, shape, distance=0.5, energy=20, delta_beta=250.0, pixel_size=1e-6, padding="edge", margin=None, cuda_options=None, fftw_num_threads=None, ): """ Please refer to the documentation of nabu.preproc.phase.PaganinPhaseRetrieval """ padding = self._check_padding(padding) self.cuda_processing = CudaProcessing(**(cuda_options or {})) super().__init__( shape, distance=distance, energy=energy, delta_beta=delta_beta, pixel_size=pixel_size, padding=padding, margin=margin, use_rfft=True, fftw_num_threads=None, ) self._init_gpu_arrays() self._init_fft() self._init_padding_kernel() self._init_mult_kernel() def _check_padding(self, padding): check_supported(padding, self.supported_paddings, "padding") if padding == "zeros": padding = "constant" return padding def _init_gpu_arrays(self): self.d_paganin_filter = garray.to_gpu(np.ascontiguousarray(self.paganin_filter, dtype=np.float32)) # overwrite parent method, don't initialize any FFT plan def _get_fft(self, use_rfft, fftw_num_threads): self.use_rfft = use_rfft self.use_fftw = False def _init_fft(self): # Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import from silx.math.fft.cufft import CUFFT # self.cufft = CUFFT(template=self.data_padded.astype("f")) self.d_radio_padded = self.cufft.data_in self.d_radio_f = self.cufft.data_out def _init_padding_kernel(self): kern_signature = {"constant": "Piiiiiiiiffff", "edge": "Piiiiiiii"} self.padding_kernel = CudaKernel( "padding_%s" % self.padding, filename=get_cuda_srcfile("padding.cu"), signature=kern_signature[self.padding], ) Ny, Nx = self.shape Nyp, Nxp = self.shape_padded self.padding_kernel_args = [ self.d_radio_padded, Nx, Ny, Nxp, Nyp, self.pad_left_len, self.pad_right_len, self.pad_top_len, self.pad_bottom_len, ] # TODO configurable constant values if self.padding == "constant": self.padding_kernel_args.extend([0, 0, 0, 0]) def _init_mult_kernel(self): self.cpxmult_kernel = CudaKernel( "inplace_complexreal_mul_2Dby2D", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPii", ) self.cpxmult_kernel_args = [ self.d_radio_f, self.d_paganin_filter, self.shape_padded[1] // 2 + 1, self.shape_padded[0], ] def set_input(self, data): assert data.shape == self.shape assert data.dtype == np.float32 # Rectangular memcopy # TODO profile, and if needed include this copy in the padding kernel if isinstance(data, np.ndarray) or isinstance(data, garray.GPUArray): self.d_radio_padded[: self.shape[0], : self.shape[1]] = data[:, :] elif isinstance(data, cuda.DeviceAllocation): # TODO manual memcpy2D raise NotImplementedError("pycuda buffers are not supported yet") else: raise ValueError("Expected either numpy array, pycuda array or pycuda buffer") def get_output(self, output): s0, s1 = self.shape_inner ((U, _), (L, _)) = self.margin if output is None: # copy D2H return self.d_radio_padded[U : U + s0, L : L + s1].get() assert output.shape == self.shape_inner assert output.dtype == np.float32 output[:, :] = self.d_radio_padded[U : U + s0, L : L + s1] return output def apply_filter(self, radio, output=None): self.set_input(radio) self.padding_kernel(*self.padding_kernel_args) self.cufft.fft(self.d_radio_padded, output=self.d_radio_f) self.cpxmult_kernel(*self.cpxmult_kernel_args) self.cufft.ifft(self.d_radio_f, output=self.d_radio_padded) return self.get_output(output) __call__ = apply_filter retrieve_phase = apply_filter ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/shift.py0000644000175000017500000000654000000000000016317 0ustar00pierrepierrefrom math import floor import numpy as np class VerticalShift: def __init__(self, radios_shape, shifts): """ This class is used when a vertical translation (along the tomography rotation axis) occurred. These translations are meant "per projection" and can be due either to mechanical errors, or can be applied purposefully with known motor movements to smear rings artefacts. The object is initialised with an array of shifts: one shift for each projection. A positive shifts means that the axis has moved in the positive Z direction. The interpolation is done taking for a pixel (y,x) the pixel found at (y+shft,x) in the recorded images. The method apply_vertical_shifts performs the correctionson the radios. Parameters ---------- radios_shape: tuple Shape of the radios chunk, in the form (n_radios, n_y, n_x) shifts: sequence of floats one shift for each projection Notes ------ During the acquisition, there might be other translations, each of them orthogonal to the rotation axis. - A "horizontal" translation in the detector plane: this is handled directly in the Backprojection operation. - A translation along the beam direction: this one is of no concern for parallel-beam geometry """ self.radios_shape = radios_shape self.shifts = shifts self._init_interp_coefficients() def _init_interp_coefficients(self): self.interp_infos = [] for s in self.shifts: s0 = int(floor(s)) f = s - s0 self.interp_infos.append([s0, f]) def _check(self, radios, iangles): assert np.min(iangles) >= 0 assert np.max(iangles) < len(self.interp_infos) assert len(iangles) == radios.shape[0] def apply_vertical_shifts(self, radios, iangles, output=None): """ Parameters ---------- radios: a sequence of np.array The input radios. If the optional parameter is not given, they are modified in-place iangles: a sequence of integers Must have the same lenght as radios. It contains the index at which the shift is found in `self.shifts` given by `shifts` argument in the initialisation of the object. output: a sequence of np.array, optional If given, it will be modified to contain the shifted radios. Must be of the same shape of `radios`. """ self._check(radios, iangles) newradio = np.zeros_like(radios[0]) for radio, ia in zip(radios, iangles): newradio[:] = 0 S0, f = self.interp_infos[ia] s0 = S0 if s0 > 0: newradio[:-s0] = radio[s0:] * (1 - f) elif s0 == 0: newradio[:] = radio[s0:] * (1 - f) else: newradio[-s0:] = radio[:s0] * (1 - f) s0 = S0 + 1 if s0 > 0: newradio[:-s0] += radio[s0:] * f elif s0 == 0: newradio[:] += radio[s0:] * f else: newradio[-s0:] += radio[:s0] * f if output is None: radios[ia] = newradio else: output[ia] = newradio ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/shift_cuda.py0000644000175000017500000000465500000000000017320 0ustar00pierrepierreimport numpy as np from math import floor from .shift import VerticalShift from ..cuda.utils import __has_pycuda__ if __has_pycuda__: import pycuda.gpuarray as garray class CudaVerticalShift(VerticalShift): def __init__(self, radios_shape, shifts): """ Vertical Shifter, Cuda backend. """ super().__init__(radios_shape, shifts) self._init_cuda_arrays() def _init_cuda_arrays(self): interp_infos_arr = np.zeros((len(self.interp_infos), 2), "f") self._d_interp_infos = garray.to_gpu(interp_infos_arr) self._d_radio_tmp = garray.zeros(self.radios_shape[1:], "f") def apply_vertical_shifts(self, radios, iangles, output=None): """ Parameters ---------- radios: 3D pycuda.gpuarray.GPUArray The input radios. If the optional parameter is not given, they are modified in-place iangles: a sequence of integers Must have the same lenght as radios. It contains the index at which the shift is found in `self.shifts` given by `shifts` argument in the initialisation of the object. output: 3D pycuda.gpuarray.GPUArray, optional If given, it will be modified to contain the shifted radios. Must be of the same shape of `radios`. """ self._check(radios, iangles) n_z = self.radios_shape[1] for ia in iangles: radio = radios[ia] self._d_radio_tmp.fill(0) S0, f = self.interp_infos[ia] s0 = S0 if s0 > 0: self._d_radio_tmp[:-s0] = radio[s0:] self._d_radio_tmp[:-s0] *= 1 - f elif s0 == 0: self._d_radio_tmp[:] = radio[s0:] self._d_radio_tmp[:] *= 1 - f else: self._d_radio_tmp[-s0:] = radio[:s0] self._d_radio_tmp[-s0:] *= 1 - f s0 = S0 + 1 f = np.float32(f) # "radios[] * f" is out of place but 2D if s0 > 0: if s0 < n_z: self._d_radio_tmp[:-s0] += radio[s0:] * f elif s0 == 0: self._d_radio_tmp[:] += radio[s0:] * f else: self._d_radio_tmp[-s0:] += radio[:s0] * f if output is None: radios[ia, :, :] = self._d_radio_tmp[:] else: output[ia, :, :] = self._d_radio_tmp[:] ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/preproc/tests/0000755000175000017500000000000000000000000015765 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/preproc/tests/__init__.py0000644000175000017500000000000100000000000020065 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/tests/test_ccd_corr.py0000644000175000017500000000431100000000000021153 0ustar00pierrepierreimport pytest import numpy as np from nabu.utils import median2 as nabu_median_filter from nabu.testutils import get_data from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.preproc.ccd import CCDFilter if __has_pycuda__: import pycuda.gpuarray as garray from nabu.preproc.ccd_cuda import CudaCCDFilter @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.data = get_data("mri_proj_astra.npz")["data"] cls.data /= cls.data.max() cls.put_hotspots_in_data() if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.usefixtures("bootstrap") class TestCCDFilter: @classmethod def put_hotspots_in_data(cls): # Put 5 hot spots in the data # (row, column, deviation from median) cls.hotspots = [(50, 51, 0.04), (151, 150, 0.08), (202, 303, 0.12), (322, 203, 0.14)] cls.threshold = 0.1 # parameterize ? data_medfilt = nabu_median_filter(cls.data) for r, c, deviation_from_median in cls.hotspots: cls.data[r, c] = data_medfilt[r, c] + deviation_from_median def check_detected_hotspots_locations(self, res): diff = self.data - res rows, cols = np.where(diff > 0) hotspots_arr = np.array(self.hotspots) M = hotspots_arr[:, -1] > self.threshold hotspots_rows = hotspots_arr[M, 0] hotspots_cols = hotspots_arr[M, 1] assert np.allclose(hotspots_rows, rows) assert np.allclose(hotspots_cols, cols) def test_median_clip(self): ccd_filter = CCDFilter(self.data.shape, median_clip_thresh=self.threshold) res = np.zeros_like(self.data) res = ccd_filter.median_clip_correction(self.data, output=res) self.check_detected_hotspots_locations(res) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_median_clip(self): d_radios = garray.to_gpu(self.data) cuda_ccd_correction = CudaCCDFilter(d_radios.shape, median_clip_thresh=self.threshold) d_out = garray.zeros_like(d_radios) cuda_ccd_correction.median_clip_correction(d_radios, output=d_out) res = d_out.get() self.check_detected_hotspots_locations(res) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/tests/test_ctf.py0000644000175000017500000002317500000000000020162 0ustar00pierrepierreimport pytest import numpy as np import scipy.interpolate from scipy.misc import ascent from nabu.testutils import get_data as nabu_get_data from nabu.testutils import __do_long_tests__ from nabu.preproc.flatfield import FlatFieldArrays from nabu.preproc.ccd import CCDFilter from nabu.preproc import ctf from nabu.estimation.distortion import estimate_flat_distortion from nabu.misc.filters import correct_spikes from nabu.preproc.distortion import DistortionCorrection from nabu.cuda.utils import __has_pycuda__, __has_cufft__, get_cuda_context if __has_pycuda__ and __has_cufft__: from nabu.preproc.ctf_cuda import CudaCTFPhaseRetrieval import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap_TestCtf(request): cls = request.cls cls.abs_tol = 1.0e-4 test_data = nabu_get_data("ctf_tests_data_all_pars.npz") cls.rand_disp_vh = test_data["rh"] ## the dimension number 1 is over holotomo distances, so far our filter is for one distance only cls.rand_disp_vh.shape = [cls.rand_disp_vh.shape[0], cls.rand_disp_vh.shape[2]] cls.dark = test_data["dark"] cls.flats = [test_data["ref0"], test_data["ref1"]] cls.im = test_data["im"] cls.ipro = int(test_data["ipro"]) cls.expected_result = test_data["result"] cls.ref_plain = test_data["ref_plain_float_flat"] cls.flats_n = test_data["refns"] cls.img_shape_vh = test_data["img_shape_vh"] cls.padded_img_shape_vh = test_data["padded_img_shape_vh"] cls.z1_vh = test_data["z1_vh"] cls.z2 = test_data["z2"] cls.pix_size_det = test_data["pix_size_det"] cls.length_scale = test_data["length_scale"] cls.wavelength = test_data["wave_length"] cls.remove_spikes_threshold = test_data["remove_spikes_threshold"] cls.delta_beta = 27 @pytest.mark.usefixtures("bootstrap_TestCtf") class TestCtf: def check_result(self, res, ref, error_message): diff = np.abs(res - ref) diff[diff > np.percentile(diff, 99)] = 0 assert diff.max() < self.abs_tol * (np.abs(ref).mean()), error_message def test_ctf_id16_way(self): """test the ctf phase retrieval. The cft filter, of the CtfFilter class is iniitalised with the geomety informations contained in geo_pars object of the GeoPars class. The geometry encompass the case of astigmatic wavefront with a vertical and horisontal sources which are at distance z1_vh[0] and z1_vh[1] from the object. In the case of parllel geometry put z1_vh[0]= z1_vh[1] = R where R is a large value ( meters). SI unit system is used. But the same results shudl be obtained with any homogenuous choice of the distance units. The img_shape is the shape of the images which will be processed. padded_img_shape is an intermediate shape which needs to be larger that the img_shape to avoid border effect due to convolutions. Length scale is an internal parameters which should not affect in anyway the result unless there are serious numerical problems involving very small lenghts. You can safely let the default value. """ geo_pars = ctf.GeoPars( z1_vh=self.z1_vh, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) flats = FlatFieldArrays( [1200] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark} ) my_flat = flats.get_flat(self.ipro) my_img = self.im - self.dark my_flat = my_flat - self.dark new_coordinates = estimate_flat_distortion( my_flat, my_img, tile_size=100, interpolation_kind="cubic", padding_mode="edge", correction_spike_threshold=3, ) interpolator = scipy.interpolate.RegularGridInterpolator( (np.arange(my_flat.shape[0]), np.arange(my_flat.shape[1])), my_flat, bounds_error=False, method="linear", fill_value=None, ) my_flat = interpolator(new_coordinates) my_img = my_img / my_flat my_img = correct_spikes(my_img, self.remove_spikes_threshold) my_shift = self.rand_disp_vh[:, self.ipro] ctf_filter = ctf.CtfFilter( self.dark.shape, geo_pars, self.delta_beta, padded_shape=self.padded_img_shape_vh, translation_vh=my_shift, normalize_by_mean=True, lim1=1.0e-5, lim2=0.2, ) phase = ctf_filter.retrieve_phase(my_img) self.check_result( phase, self.expected_result, "retrieved phase and reference result differ beyond the accepted tolerance" ) @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1") def test_ctf_id16_class(self): geo_pars = ctf.GeoPars( z1_vh=self.z1_vh, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs={ "tile_size": 100, "interpolation_kind": "cubic", "padding_mode": "edge", "correction_spike_threshold": 3.0, }, correction_method="interpn", correction_kwargs={"fill_value": None}, ) flats = FlatFieldArrays( [1200] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark}, distortion_correction=distortion_correction, ) # The "correct_spikes" function is numerically unstable (comparison with a float threshold). # If float32 is used for the image, one spike is detected while it is not in the previous test # (although the max difference between the inputs is about 1e-8). # We use float64 data type for the image to make tests pass. img = self.im.astype(np.float64) flats.normalize_single_radio(img, self.ipro) img = correct_spikes(img, self.remove_spikes_threshold) shift = self.rand_disp_vh[:, self.ipro] ctf_filter = ctf.CtfFilter( img.shape, geo_pars, self.delta_beta, padded_shape=self.padded_img_shape_vh, translation_vh=shift, normalize_by_mean=True, lim1=1.0e-5, lim2=0.2, ) phase = ctf_filter.retrieve_phase(img) message = "retrieved phase and reference result differ beyond the accepted tolerance" assert np.abs(phase - self.expected_result).max() < self.abs_tol * ( np.abs(self.expected_result).mean() ), message def test_ctf_plain_way(self): geo_pars = ctf.GeoPars( z1_vh=None, z2=self.z2, pix_size_det=self.pix_size_det, length_scale=self.length_scale, wavelength=self.wavelength, ) flatfielder = FlatFieldArrays( [1] + list(self.img_shape_vh), {0: self.flats[0], 1200: self.flats[1]}, {0: self.dark}, radios_indices=[self.ipro], ) spikes_corrector = CCDFilter( self.dark.shape, median_clip_thresh=self.remove_spikes_threshold, abs_diff=True, preserve_borders=True ) img = self.im.astype("f") img = flatfielder.normalize_radios(np.array([img]))[0] img = spikes_corrector.median_clip_correction(img) ctf_args = [img.shape, geo_pars, self.delta_beta] ctf_kwargs = {"padded_shape": self.padded_img_shape_vh, "normalize_by_mean": True, "lim1": 1.0e-5, "lim2": 0.2} ctf_filter = ctf.CtfFilter(*ctf_args, **ctf_kwargs) phase = ctf_filter.retrieve_phase(img) self.check_result(phase, self.ref_plain, "Something wrong with CtfFilter") # Test R2C ctf_numpy = ctf.CtfFilter(*ctf_args, **ctf_kwargs, use_rfft=True) phase_r2c = ctf_numpy.retrieve_phase(img) self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-R2C") # Test FFTW ctf_fftw = ctf.CtfFilter(*ctf_args, **ctf_kwargs, use_rfft=True, fftw_num_threads=-1) if ctf_fftw.use_rfft: phase_fftw = ctf_fftw.retrieve_phase(img) self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFTW") @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and scikit-cuda") def test_cuda_ctf(self): data = ascent().astype("f") delta_beta = 50.0 energy_kev = 22.0 distance_m = 1.0 pix_size_m = 0.1e-6 geo_pars = ctf.GeoPars(z2=distance_m, pix_size_det=pix_size_m, wavelength=1.23984199e-9 / energy_kev) ctx = get_cuda_context() for normalize in [True, False]: ctf_filter = ctf.CTFPhaseRetrieval( data.shape, geo_pars, delta_beta=delta_beta, normalize_by_mean=normalize, use_rfft=True ) cuda_ctf_filter = CudaCTFPhaseRetrieval( data.shape, geo_pars, delta_beta=delta_beta, use_rfft=True, normalize_by_mean=normalize, ) ref = ctf_filter.retrieve_phase(data) d_data = garray.to_gpu(data) res = cuda_ctf_filter.retrieve_phase(d_data).get() err_max = np.max(np.abs(res - ref)) assert err_max < 1e-2, "Something wrong with retrieve_phase(normalize_by_mean=%s)" % (str(normalize)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/preproc/tests/test_double_flatfield.py0000644000175000017500000000505100000000000022663 0ustar00pierrepierreimport os.path as path from math import exp import tempfile import numpy as np import pytest from silx.io.url import DataUrl from tomoscan.esrf.mock import MockHDF5 from nabu.io.reader import HDF5Reader from nabu.preproc.double_flatfield import DoubleFlatField from nabu.preproc.double_flatfield_cuda import CudaDoubleFlatField, __has_pycuda__ if __has_pycuda__: import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tmpdir = tempfile.TemporaryDirectory() dname = cls.tmpdir.name cls.dname = dname radios = MockHDF5( path.join(dname, "tmp"), 10, n_ini_proj=10, dim=100, n_refs=1, scene="increasing value", ).scan.projections reader = HDF5Reader() cls.radios = [] Rkeys = list(radios.keys()) for k in Rkeys: dataurl = radios[k] data = reader.get_data(dataurl) cls.radios.append(data) cls.radios = np.array(cls.radios) cls.ff_dump_url = DataUrl( file_path=path.join(cls.dname, "dff.h5"), data_path="/entry/double_flatfield/results/data" ) cls.ff_cuda_dump_url = DataUrl( file_path=path.join(cls.dname, "dff_cuda.h5"), data_path="/entry/double_flatfield/results/data" ) golden = 0 for i in range(10): golden += exp(-i) cls.golden = golden / 10 cls.tol = 1e-4 @pytest.mark.usefixtures("bootstrap") class TestDoubleFlatField: def test_dff_numpy(self): dff = DoubleFlatField(self.radios.shape, result_url=self.ff_dump_url) mydf = dff.get_double_flatfield(radios=self.radios) assert path.isfile(dff.result_url.file_path()) dff2 = DoubleFlatField(self.radios.shape, result_url=self.ff_dump_url) mydf2 = dff2.get_double_flatfield(radios=self.radios) assert np.max(np.abs(mydf2 - mydf)) < self.tol assert np.max(np.abs(mydf - self.golden)) < self.tol @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for double flatfield with cuda backend") def test_dff_cuda(self): dff = CudaDoubleFlatField(self.radios.shape, result_url=self.ff_cuda_dump_url) d_radios = garray.to_gpu(self.radios) mydf = dff.get_double_flatfield(radios=d_radios).get() assert path.isfile(dff.result_url.file_path()) dff2 = CudaDoubleFlatField(self.radios.shape, result_url=self.ff_cuda_dump_url) mydf2 = dff2.get_double_flatfield(radios=d_radios).get() assert np.max(np.abs(mydf2 - mydf)) < self.tol assert np.max(np.abs(mydf - self.golden)) < self.tol ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/preproc/tests/test_flatfield.py0000644000175000017500000005571200000000000021342 0ustar00pierrepierrefrom tempfile import mkdtemp import os import numpy as np import pytest from silx.io.url import DataUrl from silx.io import get_data from silx.io.dictdump import dicttoh5 from nabu.cuda.utils import get_cuda_context, __has_pycuda__ from nabu.preproc.flatfield import FlatField, FlatFieldDataUrls if __has_pycuda__: import pycuda.gpuarray as garray from nabu.preproc.flatfield_cuda import CudaFlatFieldDataUrls, CudaFlatField # Flats values should be O(k) so that linear interpolation between flats gives exact results flatfield_tests_cases = { "simple_nearest_interp": { "image_shape": (100, 512), "radios_values": np.arange(10) + 1, "radios_indices": None, "flats_values": [0.5], "flats_indices": [1], "darks_values": [1], "darks_indices": [0], "expected_result": np.arange(0, -2 * 10, -2), }, "two_flats_no_radios_indices": { "image_shape": (100, 512), "radios_values": np.arange(10) + 1, "radios_indices": None, "flats_values": [2, 11], "flats_indices": [0, 9], "darks_values": [1], "darks_indices": [0], "expected_result": np.arange(10) / (np.arange(10) + 1), }, "two_flats_with_radios_indices": { # Type D F R R R F # IDX 0 1 2 3 4 5 # Value 1 4 9 16 25 8 # F_interp 5 6 7 "image_shape": (16, 17), # R_k = (k + 1)**2 "radios_values": [9, 16, 25], "radios_indices": [2, 3, 4], # F_k = k+3 "flats_values": [4, 8], "flats_indices": [1, 5], # D_k = 1 "darks_values": [1], "darks_indices": [0], # Expected normalization result: N_k = k "expected_result": [2, 3, 4], }, "three_flats_srcurrent": { # Type D F R R | R F R R F # IDX 0 1 2 3 | 4 5 6 7 8 # Value 1 4 9 16 | 25 8 46 67 14 # F_interp 5 6 | 7 10 12 # srCurrent 20 10 10 | 16 32 16 8 16 "image_shape": (16, 17), "radios_values": [9, 16, 25, 46, 67], "radios_indices": [2, 3, 4, 6, 7], "flats_values": [4, 8, 14], "flats_indices": [1, 5, 8], "darks_values": [1], "darks_indices": [0], # "expected_result": [2, 3, 4, 5, 6], # without SR normalization "expected_result": [4, 6, 8, 10, 12], # sr_flat/sr_radio = 2 "radios_srcurrent": [10, 16, 16, 16, 8], # "flats_srcurrent": [20, 18, 6, 4, 2], "flats_srcurrent": [20, 32, 16], }, } def generate_test_flatfield_generalized( image_shape, radios_indices, radios_values, flats_indices, flats_values, darks_indices, darks_values, h5_fname, dtype=np.uint16, ): """ Parameters ----------- image_shape: tuple of int shape of each image. radios_indices: array of int Indices where radios are found in the dataset. radios_values: array of scalars Value for each radio image. Length must be equal to `radios_shape[0]`. flats_indices: array of int Indices where flats are found in the dataset. flats_values: array of scalars Values of flat images. Length must be equal to `len(flats_indices)` darks_indices: array of int Indices where darks are found in the dataset. darks_values: array of scalars Values of dark images. Length must be equal to `len(darks_indices)` Returns ------- radios: numpy.ndarray 3D array with raw radios darks: dict of DataUrls Dictionary where each key is the dark indice, and value is a DataUrl flats: dict of DataUrls Dictionary where each key is the flat indice, and value is a DataUrl """ tempdir = mkdtemp(prefix="nabu_") testffname = os.path.join(tempdir, h5_fname) # Radios radios = np.zeros((len(radios_values),) + image_shape, dtype="f") n_radios = radios.shape[0] for i in range(n_radios): radios[i].fill(radios_values[i]) img_shape = radios.shape[1:] # Flats flats = {} flats_urls = {} for i, flat_idx in enumerate(flats_indices): flats["flats_%06d" % flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i] flats_urls[flat_idx] = DataUrl( file_path=testffname, data_path=str("/flats/flats_%06d" % flat_idx), scheme="silx" ) # Darks darks = {} darks_urls = {} for i, dark_idx in enumerate(darks_indices): darks["darks_%06d" % dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i] darks_urls[dark_idx] = DataUrl( file_path=testffname, data_path=str("/darks/darks_%06d" % dark_idx), scheme="silx" ) dicttoh5(flats, testffname, h5path="/flats", mode="w") dicttoh5(darks, testffname, h5path="/darks", mode="a") return radios, flats_urls, darks_urls @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.tmp_files = [] cls.tmp_dirs = [] cls.n_radios = 10 cls.n_z = 100 cls.n_x = 512 if __has_pycuda__: cls.ctx = get_cuda_context() yield # Tear-down for fname in cls.tmp_files: os.remove(fname) for dname in cls.tmp_dirs: os.rmdir(dname) @pytest.mark.usefixtures("bootstrap") class TestFlatField: def get_test_elements(self, case_name): config = flatfield_tests_cases[case_name] radios_stack, flats_url, darks_url = generate_test_flatfield_generalized( config["image_shape"], config["radios_indices"], config["radios_values"], config["flats_indices"], config["flats_values"], config["darks_indices"], config["darks_values"], "test_ff.h5", ) fname = flats_url[list(flats_url.keys())[0]].file_path() self.tmp_files.append(fname) self.tmp_dirs.append(os.path.dirname(fname)) return radios_stack, flats_url, darks_url, config @staticmethod def check_normalized_radios(radios_corr, expected_values): # must be the same value everywhere in the radio std = np.std(np.std(radios_corr, axis=-1), axis=-1) assert np.max(np.abs(std)) < 1e-7 # radios values must be 0, -2, -4, ... assert np.allclose(radios_corr[:, 0, 0], expected_values) def test_flatfield_simple(self): """ Test the flat-field normalization on a radios stack with 1 dark and 1 flat. (I - D)/(F - D) where I = (1, 2, ...), D = 1, F = 0.5 = (0, -2, -4, -6, ...) """ radios_stack, flats_url, darks_url, config = self.get_test_elements("simple_nearest_interp") flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios(radios_corr, config["expected_result"]) def test_flatfield_simple_subregion(self): """ Same as test_flatfield_simple, but in a vertical subregion of the radios. """ radios_stack, flats_url, darks_url, config = self.get_test_elements("simple_nearest_interp") end_z = 51 radios_chunk = np.copy(radios_stack[:, :end_z, :]) # we only have a chunk in memory. Instantiate the class with the # corresponding subregion to only load the relevant part of dark/flat flatfield = FlatFieldDataUrls( radios_chunk.shape, flats_url, darks_url, sub_region=(None, None, None, end_z), # start_x, end_x, start_z, end_z ) radios_corr = flatfield.normalize_radios(radios_chunk) self.check_normalized_radios(radios_corr, config["expected_result"]) def test_flatfield_linear_interp(self): """ Test flat-field normalization with 1 dark and 2 flats, with linear interpolation between flats. I = 1 2 3 4 5 6 7 8 9 10 D = 1 (one dark) F = 2 11 (two flats) F_i = 2 3 4 5 6 7 8 9 10 11 (interpolated flats) R = 0 .5 .66 .75 .8 .83 .86 = (I-D)/(F-D) = (I-1)/I """ radios_stack, flats_url, darks_url, config = self.get_test_elements("two_flats_no_radios_indices") flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios(radios_corr, config["expected_result"]) # Test 2: one of the flats is not at the beginning/end # I = 1 2 3 4 5 6 7 8 9 10 # F = 2 11 # F_i = 2 3.8 5.6 7.4 9.2 11 11 11 11 11 # R = 0 .357 .435 .469 .488 .5 .6 .7 .8 .9 flats_url = flats_url.copy() flats_url[5] = flats_url[9] flats_url.pop(9) flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url) radios_corr = flatfield.normalize_radios(np.copy(radios_stack)) self.check_normalized_radios( radios_corr, [0.0, 0.35714286, 0.43478261, 0.46875, 0.48780488, 0.5, 0.6, 0.7, 0.8, 0.9] ) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): """ Test the flat-field with cuda back-end. """ radios_stack, flats_url, darks_url, config = self.get_test_elements("two_flats_no_radios_indices") d_radios = garray.to_gpu(radios_stack.astype("f")) cuda_flatfield = CudaFlatFieldDataUrls( d_radios.shape, flats_url, darks_url, ) cuda_flatfield.normalize_radios(d_radios) radios_corr = d_radios.get() self.check_normalized_radios(radios_corr, config["expected_result"]) # Linear interpolation, two flats, one dark def test_twoflats_simple(self): radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices") FF = FlatFieldDataUrls(radios.shape, flats, darks, radios_indices=config["radios_indices"]) FF.normalize_radios(radios) self.check_normalized_radios(radios, config["expected_result"]) def _setup_numerical_issue(self): radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices") # Retrieve the actual data for radios/darks/flats to use FlatField instead of FlatFieldDataUrl. # Create a setting yielding "0/0": one pixel such that flat==dark and radio==dark for flat_idx, flat_url in flats.items(): flats[flat_idx] = get_data(flat_url) flats[flat_idx][0, 0] = 99 for dark_idx, dark_url in darks.items(): darks[dark_idx] = get_data(dark_url) darks[dark_idx][0, 0] = 99 radios[:, 0, 0] = 99 return radios, flats, darks, config def _check_numerical_issue(self, radios, expected_result, nan_value=None): if nan_value is None: assert np.alltrue(np.logical_not(np.isfinite(radios[:, 0, 0]))), "First pixel should be nan or inf" radios[:, 0, 0] = radios[:, 1, 1] self.check_normalized_radios(radios, expected_result) else: assert np.all(np.isfinite(radios)), "No inf/nan value should be there" assert np.allclose(radios[:, 0, 0], nan_value, atol=1e-7), ( "Handled NaN should have nan_value=%f" % nan_value ) radios[:, 0, 0] = radios[:, 1, 1] self.check_normalized_radios(radios, expected_result) def test_twoflats_numerical_issue(self): """ Same as above, but for the first radio: I==Dark and Flat==Dark For this radio, nan is replaced with 1. """ radios, flats, darks, config = self._setup_numerical_issue() radios0 = radios.copy() # FlatField without NaN handling yields NaN and raises RuntimeWarning FF_no_nan_handling = FlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None ) with pytest.warns(RuntimeWarning): FF_no_nan_handling.normalize_radios(radios) self._check_numerical_issue(radios, config["expected_result"], None) # FlatField with NaN handling nan_value = 50 radios = radios0.copy() FF_with_nan_handling = FlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=nan_value ) with pytest.warns(RuntimeWarning): FF_with_nan_handling.normalize_radios(radios) self._check_numerical_issue(radios, config["expected_result"], nan_value) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_twoflats_numerical_issue(self): """ Same as above, with the Cuda backend """ radios, flats, darks, config = self._setup_numerical_issue() radios0 = radios.copy() d_radios = garray.to_gpu(radios) FF_no_nan_handling = CudaFlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None ) # In a cuda kernel, no one can hear you scream FF_no_nan_handling.normalize_radios(d_radios) radios = d_radios.get() self._check_numerical_issue(radios, config["expected_result"], None) # FlatField with NaN handling nan_value = 50 d_radios.set(radios0) FF_with_nan_handling = CudaFlatField( radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=nan_value ) FF_with_nan_handling.normalize_radios(d_radios) radios = d_radios.get() self._check_numerical_issue(radios, config["expected_result"], nan_value) def test_srcurrent(self): radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent") FF = FlatFieldDataUrls( radios.shape, flats, darks, radios_indices=config["radios_indices"], radios_srcurrent=config["radios_srcurrent"], flats_srcurrent=config["flats_srcurrent"], ) radios_corr = FF.normalize_radios(np.copy(radios)) self.check_normalized_radios(radios_corr, config["expected_result"]) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_srcurrent_cuda(self): radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent") d_radios = garray.to_gpu(radios) FF = CudaFlatFieldDataUrls( radios.shape, flats, darks, radios_indices=config["radios_indices"], radios_srcurrent=config["radios_srcurrent"], flats_srcurrent=config["flats_srcurrent"], ) FF.normalize_radios(d_radios) radios_corr = d_radios.get() self.check_normalized_radios(radios_corr, config["expected_result"]) # This test should be closer to the ESRF standard setting. # There are 2 flats, one dark, 4000 radios. # dark : indice=0 value=10 # flat1 : indice=1 value=4202 # flat2 : indice=2102 value=2101 # # The projections have the following indices: # j: 0 1 1998 1999 2000 2001 3999 # idx: [102, 103, ..., 2100, 2101, 2203, 2204, ..., 4200, 4201, 4202] # Notice the gap in the middle. # # The linear interpolation is # flat_i = (n2 - i)/(n2 - n1)*flat_1 + (i - n1)/(n2 - n1)*flat_2 # where n1 and n2 are the indices of flat_1 and flat_2 respectively. # With the above values, we have flat_i = 4203 - i. # # The projections values are dark + i*(flat_i - dark), # so that the normalization norm_i = (proj_i - dark)/(flat_i - dark) gives # # idx 102 103 104 ... # flat 4101 4102 4103 ... # norm 102 103 104 ... # class FlatFieldTestDataset: # Parameters shp = (27, 32) n1 = 1 # flat indice 1 n2 = 2102 # flat indice 2 dark_val = 10 darks = {0: np.zeros(shp, "f") + dark_val} flats = {n1: np.zeros(shp, "f") + (n2 - 1) * 2, n2: np.zeros(shp, "f") + n2 - 1} projs_idx = list(range(102, 2102)) + list(range(2203, 4203)) # gap in the middle def __init__(self): self._generate_projections() self._dump_to_h5() self._generate_dataurls() def get_flat_idx(self, proj_idx): flats_idx = sorted(list(self.flats.keys())) if proj_idx <= flats_idx[0]: return (flats_idx[0],) elif proj_idx > flats_idx[0] and proj_idx < flats_idx[1]: return flats_idx else: return (flats_idx[1],) def get_flat(self, idx): flatidx = self.get_flat_idx(idx) if len(flatidx) == 1: flat = self.flats[flatidx[0]] else: nf1, nf2 = flatidx w1 = (nf2 - idx) / (nf2 - nf1) flat = w1 * self.flats[nf1] + (1 - w1) * self.flats[nf2] return flat def _generate_projections(self): self.projs_data = np.zeros((len(self.projs_idx),) + self.shp, "f") self.projs = {} for i, proj_idx in enumerate(self.projs_idx): flat = self.get_flat(proj_idx) proj_val = self.dark_val + proj_idx * (flat[0, 0] - self.dark_val) self.projs[str(proj_idx)] = np.zeros(self.shp, "f") + proj_val self.projs_data[i] = self.projs[str(proj_idx)] def _dump_to_h5(self): self.tempdir = mkdtemp(prefix="nabu_") self.fname = os.path.join(self.tempdir, "projs_flats.h5") dicttoh5( { "projs": self.projs, "flats": {str(k): v for k, v in self.flats.items()}, "darks": {str(k): v for k, v in self.darks.items()}, }, h5file=self.fname, ) def _generate_dataurls(self): self.flats_urls = {} for idx in self.flats.keys(): self.flats_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/flats/%d" % idx) self.darks_urls = {} for idx in self.darks.keys(): self.darks_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/darks/0") @pytest.fixture(scope="class") def bootstraph5(request): cls = request.cls cls.dataset = FlatFieldTestDataset() n1, n2 = cls.dataset.n1, cls.dataset.n2 # Interpolation function cls._weight1 = lambda i: (n2 - i) / (n2 - n1) cls.tol = 5e-4 cls.tol_std = 1e-3 yield # tear-down os.remove(cls.dataset.fname) os.rmdir(cls.dataset.tempdir) @pytest.mark.usefixtures("bootstraph5") class TestFlatFieldH5: def check_normalization(self, projs): # Check that each projection is filled with the same values std_projs = np.std(projs, axis=(-2, -1)) assert np.max(np.abs(std_projs)) < self.tol_std # Check that the normalized radios are equal to 102, 103, 104, ... errs = projs[:, 0, 0] - self.dataset.projs_idx assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization" def test_flatfield(self): flatfield = FlatFieldDataUrls( self.dataset.projs_data.shape, self.dataset.flats_urls, self.dataset.darks_urls, radios_indices=self.dataset.projs_idx, interpolation="linear", ) projs = np.copy(self.dataset.projs_data) flatfield.normalize_radios(projs) self.check_normalization(projs) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): d_projs = garray.to_gpu(self.dataset.projs_data) cuda_flatfield = CudaFlatFieldDataUrls( self.dataset.projs_data.shape, self.dataset.flats_urls, self.dataset.darks_urls, radios_indices=self.dataset.projs_idx, ) cuda_flatfield.normalize_radios(d_projs) projs = d_projs.get() self.check_normalization(projs) # # Another test with more than two flats. # # Here we have # # F_i = i + 2 # R_i = i*(F_i - 1) + 1 # N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i # def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname): radios = np.zeros((n_radios,) + radio_shape, "f") dark_data = np.ones(radios.shape[1:], "f") tempdir = mkdtemp(prefix="nabu_") testffname = os.path.join(tempdir, h5_fname) flats = {} flats_urls = {} # F_i = i + 2 # R_i = i*(F_i - 1) + 1 # N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i for i in range(n_radios): f_i = i + 2 if (i % flat_interval) == 0: flats["flats_%06d" % i] = np.zeros(radio_shape, "f") + f_i flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%06d" % i), scheme="silx") radios[i] = i * (f_i - 1) + 1 dark = {"dark_0000": dark_data} dicttoh5(flats, testffname, h5path="/flats", mode="w") dicttoh5(dark, testffname, h5path="/dark", mode="a") dark_url = {0: DataUrl(file_path=testffname, data_path="/dark/dark_0000", scheme="silx")} return radios, flats_urls, dark_url @pytest.fixture(scope="class") def bootstrap_multiflats(request): cls = request.cls n_radios = 50 radio_shape = (20, 21) cls.flat_interval = 11 h5_fname = "testff.h5" radios, flats, dark = generate_test_flatfield(n_radios, radio_shape, cls.flat_interval, h5_fname) cls.radios = radios cls.flats_urls = flats cls.darks_urls = dark cls.expected_results = np.arange(n_radios) cls.tol = 5e-4 cls.tol_std = 1e-4 yield # tear down os.remove(dark[0].file_path()) os.rmdir(os.path.dirname(dark[0].file_path())) @pytest.mark.usefixtures("bootstrap_multiflats") class TestFlatFieldMultiFlat: def check_normalization(self, projs): # Check that each projection is filled with the same values std_projs = np.std(projs, axis=(-2, -1)) assert np.max(np.abs(std_projs)) < self.tol_std # Check that the normalized radios are equal to 0, 1, 2, ... stop = (projs.shape[0] // self.flat_interval) * self.flat_interval errs = projs[:stop, 0, 0] - self.expected_results[:stop] assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization" def test_flatfield(self): flatfield = FlatFieldDataUrls(self.radios.shape, self.flats_urls, self.darks_urls, interpolation="linear") projs = np.copy(self.radios) flatfield.normalize_radios(projs) print(projs[:, 0, 0]) self.check_normalization(projs) @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_flatfield(self): d_projs = garray.to_gpu(self.radios) cuda_flatfield = CudaFlatFieldDataUrls( self.radios.shape, self.flats_urls, self.darks_urls, ) cuda_flatfield.normalize_radios(d_projs) projs = d_projs.get() self.check_normalization(projs) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/tests/test_paganin.py0000644000175000017500000000420600000000000021015 0ustar00pierrepierreimport pytest import numpy as np from nabu.preproc.phase import PaganinPhaseRetrieval from nabu.testutils import get_data from nabu.thirdparty.tomopy_phase import retrieve_phase from nabu.cuda.utils import __has_pycuda__, __has_cufft__ if __has_pycuda__: from nabu.preproc.phase_cuda import CudaPaganinPhaseRetrieval scenarios = [ { "distance": 1, "energy": 35, "delta_beta": 1e1, "margin": ((50, 50), (0, 0)), } ] @pytest.fixture(scope="class", params=scenarios) def bootstrap(request): cls = request.cls cls.paganin_config = request.param cls.data = get_data("mri_proj_astra.npz")["data"] cls.rtol = 1.1e-6 cls.rtol_pag = 5e-3 cls.paganin = PaganinPhaseRetrieval(cls.data.shape, **cls.paganin_config) @pytest.mark.usefixtures("bootstrap") class TestPaganin: """ Test the Paganin phase retrieval. The reference implementation is tomopy. """ def crop_to_margin(self, data): s0, s1 = self.paganin.shape_inner ((U, _), (L, _)) = self.paganin.margin return data[U : U + s0, L : L + s1] def test_paganin(self): data_tomopy = np.atleast_3d(np.copy(self.data)).T res_tomopy = retrieve_phase( data_tomopy, pixel_size=self.paganin.pixel_size_micron * 1e-4, dist=self.paganin.distance_cm, energy=self.paganin.energy_kev, alpha=1.0 / (4 * 3.141592**2 * self.paganin.delta_beta), ) res_tomopy = self.crop_to_margin(res_tomopy[0].T) res = self.paganin.apply_filter(self.data) errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy)) assert errmax < self.rtol_pag, "Max error is too high" @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and scikit-cuda for this test") def test_gpu_paganin(self): gpu_paganin = CudaPaganinPhaseRetrieval(self.data.shape, **self.paganin_config) ref = self.paganin.apply_filter(self.data) res = gpu_paganin.apply_filter(self.data) errmax = np.max(np.abs((res - ref) / np.max(ref))) assert errmax < self.rtol, "Max error is too high" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/preproc/tests/test_vshift.py0000644000175000017500000000435400000000000020707 0ustar00pierrepierreimport pytest import numpy as np from scipy.ndimage import shift as ndshift from nabu.preproc.shift import VerticalShift from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: from nabu.preproc.shift_cuda import CudaVerticalShift, garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls data = np.zeros([13, 11], "f") slope = 100 + np.arange(13) data[:] = slope[:, None] cls.radios = np.array([data] * 17) cls.shifts = 0.3 + np.arange(17) cls.indexes = range(17) # given the shifts and the radios we build the golden reference golden = [] for iradio in range(17): projection_number = cls.indexes[iradio] my_shift = cls.shifts[projection_number] padded_radio = np.concatenate( [cls.radios[iradio], np.zeros([1, 11], "f")], axis=0 ) # needs padding because ndshifs does not work as expected shifted_padded_radio = ndshift(padded_radio, [-my_shift, 0], mode="constant", cval=0.0, order=1).astype("f") shifted_radio = shifted_padded_radio[:-1] golden.append(shifted_radio) cls.golden = np.array(golden) cls.tol = 1e-5 if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.usefixtures("bootstrap") class TestVerticalShift: def test_vshift(self): radios = self.radios.copy() new_radios = np.zeros_like(radios) Shifter = VerticalShift(radios.shape, self.shifts) Shifter.apply_vertical_shifts(radios, self.indexes, output=new_radios) assert abs(new_radios - self.golden).max() < self.tol Shifter.apply_vertical_shifts(radios, self.indexes) assert abs(radios - self.golden).max() < self.tol @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test") def test_cuda_vshift(self): d_radios = garray.to_gpu(self.radios) d_out = garray.zeros_like(d_radios) Shifter = CudaVerticalShift(d_radios.shape, self.shifts) Shifter.apply_vertical_shifts(d_radios, self.indexes, output=d_out) assert abs(d_out.get() - self.golden).max() < self.tol Shifter.apply_vertical_shifts(d_radios, self.indexes) assert abs(d_radios.get() - self.golden).max() < self.tol ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/reconstruction/0000755000175000017500000000000000000000000016232 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/__init__.py0000644000175000017500000000024300000000000020342 0ustar00pierrepierrefrom .reconstructor import Reconstructor from .rings import MunchDeringer, munchetal_filter from .sinogram import SinoBuilder, convert_halftomo, SinoNormalization ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/cone.py0000644000175000017500000002546100000000000017540 0ustar00pierrepierreimport numpy as np try: import astra __have_astra__ = True except ImportError: __have_astra__ = False from ..cuda.processing import CudaProcessing class ConebeamReconstructor: """ A reconstructor for cone-beam geometry using the astra toolbox. """ def __init__( self, sinos_shape, source_origin_dist, origin_detector_dist, angles=None, volume_shape=None, rot_center=None, relative_z_position=None, pixel_size=None, cuda_options=None, ): """ Initialize a cone beam reconstructor. This reconstructor works on slabs of data, meaning that one partial volume is obtained from one stack of sinograms. To reconstruct a full volume, the reconstructor must be called on a series of sinograms stacks, with an updated "relative_z_position" each time. Parameters ----------- sinos_shape: tuple Shape of the sinograms stack, in the form (n_sinos, n_angles, prj_width) source_origin_dist: float Distance, in pixel units, between the beam source (cone apex) and the "origin". The origin is defined as the center of the sample origin_detector_dist: float Distance, in pixel units, between the center of the sample and the detector. angles: array, optional Rotation angles in radians. If provided, its length should be equal to sinos_shape[1]. volume_shape: tuple of int, optional Shape of the output volume slab, in the form (n_z, n_y, n_x). If not provided, the output volume slab shape is (sinos_shape[0], sinos_shape[2], sinos_shape[2]). rot_center: float, optional Rotation axis position. Default is `(detector_width - 1)/2.0` relative_z_position: float, optional Position of the central slice of the slab, with respect to the full stack of slices. By default it is set to zero, meaning that the current slab is assumed in the middle of the stack pixel_size: float or tuple, optional Size of the pixel. Possible options: - Nothing is provided (default): in this case, all lengths are normalized with respect to the pixel size, i.e 'source_origin_dist' and 'origin_detector_dist' should be expressed in pixels (and 'pixel_size' is set to 1). - A scalar number is provided: in this case it is the spacing between two pixels (in each dimension) - A tuple is provided: in this case it is the spacing between two pixels in both dimensions, vertically then horizontally, i.e (detector_spacing_y, detector_spacing_x) Notes ------ This reconstructor is using the astra toolbox [1]. Therefore the implementation uses Astra's reference frame, which is centered on the sample (source and detector move around the sample). For more information see Fig. 2 of paper [1]. To define the cone-beam geometry, two distances are needed: - Source-origin distance (hereby d1) - Origin-detector distance (hereby d2) The magnification at distance d2 is m = 1+d2/d1, so given a detector pixel size p_s, the sample voxel size is p_s/m. To make things simpler, this class internally uses a different (but equivalent) geometry: - d2 is set to zero, meaning that the detector is (virtually) moved to the center of the sample - The detector is "re-scaled" to have a pixel size equal to the voxel size (p_s/m) Having the detector in the same plane as the sample center simplifies things when it comes to slab-wise reconstruction: defining a volume slab (in terms of z_min, z_max) is equivalent to define the detector bounds, like in parallel geometry. References ----------- [1] Aarle, Wim & Palenstijn, Willem & Cant, Jeroen & Janssens, Eline & Bleichrodt, Folkert & Dabravolski, Andrei & De Beenhouwer, Jan & Batenburg, Kees & Sijbers, Jan. (2016). Fast and flexible X-ray tomography using the ASTRA toolbox. Optics Express. 24. 25129-25147. 10.1364/OE.24.025129. """ self._init_cuda(cuda_options) self._init_geometry( sinos_shape, source_origin_dist, origin_detector_dist, pixel_size, angles, volume_shape, rot_center, relative_z_position, ) self._alg_id = None self._vol_id = None self._proj_id = None def _init_cuda(self, cuda_options): cuda_options = cuda_options or {} self.cuda = CudaProcessing(**cuda_options) def _set_sino_shape(self, sinos_shape): if len(sinos_shape) != 3: raise ValueError("Expected a 3D shape") self.sinos_shape = sinos_shape self.n_sinos, self.n_angles, self.prj_width = sinos_shape def _set_pixel_size(self, pixel_size): if pixel_size is None: det_spacing_y = det_spacing_x = 1 elif np.iterable(pixel_size): det_spacing_y, det_spacing_x = pixel_size else: # assuming scalar det_spacing_y = det_spacing_x = pixel_size self._det_spacing_y = det_spacing_y self._det_spacing_x = det_spacing_x def _init_geometry( self, sinos_shape, source_origin_dist, origin_detector_dist, pixel_size, angles, volume_shape, rot_center, relative_z_position, ): self._set_sino_shape(sinos_shape) if angles is None: self.angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=True) else: self.angles = angles if volume_shape is None: volume_shape = (self.sinos_shape[0], self.sinos_shape[2], self.sinos_shape[2]) self.volume_shape = volume_shape self.n_z, self.n_y, self.n_x = self.volume_shape self.source_origin_dist = source_origin_dist self.origin_detector_dist = origin_detector_dist self.magnification = 1 + origin_detector_dist / source_origin_dist self.vol_geom = astra.create_vol_geom(self.n_y, self.n_x, self.n_z) self.vol_shape = astra.geom_size(self.vol_geom) self._cor_shift = 0.0 self.rot_center = rot_center if rot_center is not None: self._cor_shift = (self.prj_width - 1) / 2.0 - rot_center self._set_pixel_size(pixel_size) self._create_astra_proj_geometry(relative_z_position) def _create_astra_proj_geometry(self, relative_z_position): # This object has to be re-created each time, because once the modifications below are done, # it is no more a "cone" geometry but a "cone_vec" geometry, and cannot be updated subsequently # (see astra/functions.py:271) self.proj_geom = astra.create_proj_geom( "cone", self._det_spacing_x, self._det_spacing_y, self.n_sinos, self.prj_width, self.angles, self.source_origin_dist, self.origin_detector_dist, ) self.relative_z_position = relative_z_position or 0.0 self.proj_geom = astra.geom_postalignment(self.proj_geom, (self._cor_shift, 0)) # Component 2 of vecs is the z coordinate of the source, component 5 is the z component of the detector position # We need to re-create the same inclination of the cone beam, thus we need to keep the inclination of the two z positions. # The detector is centered on the rotation axis, thus moving it up or down, just moves it out of the reconstruction volume. # We can bring back the detector in the correct volume position, by applying a rigid translation of both the detector and the source. # The translation is exactly the amount that brought the detector up or down, but in the opposite direction. vecs = self.proj_geom["Vectors"] vecs[:, 2] = -self.relative_z_position def _set_output(self, volume): if volume is not None: self.cuda.check_array(volume, self.vol_shape) self.cuda.set_array("output", volume, self.vol_shape) if volume is None: self.cuda.allocate_array("output", self.vol_shape) d_volume = self.cuda.get_array("output") z, y, x = d_volume.shape self._vol_link = astra.data3d.GPULink(d_volume.ptr, x, y, z, d_volume.strides[-2]) self._vol_id = astra.data3d.link("-vol", self.vol_geom, self._vol_link) def _set_input(self, sinos): self.cuda.check_array(sinos, self.sinos_shape) self.cuda.set_array("sinos", sinos, sinos.shape) # self.cuda.sinos is now a GPU array # TODO don't create new link/proj_id if ptr is the same ? # But it seems Astra modifies the input sinogram while doing FDK, so this might be not relevant d_sinos = self.cuda.get_array("sinos") # self._proj_data_link = astra.data3d.GPULink(d_sinos.ptr, self.prj_width, self.n_angles, self.n_z, sinos.strides[-2]) self._proj_data_link = astra.data3d.GPULink( d_sinos.ptr, self.prj_width, self.n_angles, self.n_sinos, sinos.strides[-2] ) self._proj_id = astra.data3d.link("-sino", self.proj_geom, self._proj_data_link) def _update_reconstruction(self): cfg = astra.astra_dict("FDK_CUDA") cfg["ReconstructionDataId"] = self._vol_id cfg["ProjectionDataId"] = self._proj_id # TODO more options "eg. filter" ? if self._alg_id is not None: astra.algorithm.delete(self._alg_id) self._alg_id = astra.algorithm.create(cfg) def reconstruct(self, sinos, output=None, relative_z_position=None): """ sinos: numpy.ndarray or pycuda.gpuarray Sinograms, with shape (n_sinograms, n_angles, width) output: pycuda.gpuarray, optional Output array. If not provided, a new numpy array is returned relative_z_position: int, optional Position of the central slice of the slab, with respect to the full stack of slices. By default it is set to zero, meaning that the current slab is assumed in the middle of the stack """ self._create_astra_proj_geometry(relative_z_position) self._set_input(sinos) self._set_output(output) self._update_reconstruction() astra.algorithm.run(self._alg_id) result = self.cuda.get_array("output") if output is None: result = result.get() self.cuda.recover_arrays_references(["sinos", "output"]) return result def __del__(self): if self._alg_id is not None: astra.algorithm.delete(self._alg_id) if self._vol_id is not None: astra.data3d.delete(self._vol_id) if self._proj_id is not None: astra.data3d.delete(self._proj_id) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/reconstruction/fbp.py0000644000175000017500000003611400000000000017360 0ustar00pierrepierreimport numpy as np from math import sqrt, pi from ..utils import updiv, get_cuda_srcfile, _sizeof, nextpow2, convert_index, deprecation_warning from ..cuda.utils import copy_array from ..cuda.processing import CudaProcessing from ..cuda.kernel import CudaKernel from .filtering import SinoFilter import pycuda.driver as cuda from pycuda import gpuarray as garray class Backprojector: """ Cuda Backprojector. """ default_padding_mode = "zeros" cuda_fname = "backproj.cu" cuda_kernel_name = "backproj" default_extra_options = { "padding_mode": None, "axis_correction": None, "centered_axis": False, "clip_outer_circle": False, "scale_factor": None, "filter_cutoff": 1.0, } def __init__( self, sino_shape, slice_shape=None, angles=None, rot_center=None, padding_mode=None, filter_name=None, slice_roi=None, scale_factor=None, extra_options=None, cuda_options=None, ): """ Initialize a Cuda Backprojector. Parameters ----------- sino_shape: tuple Shape of the sinogram, in the form `(n_angles, detector_width)` (for backprojecting one sinogram) or `(n_sinos, n_angles, detector_width)`. slice_shape: int or tuple, optional Shape of the slice. By default, the slice shape is (n_x, n_x) where `n_x = detector_width` angles: array-like, optional Rotation anles in radians. By default, angles are equispaced between [0, pi[. rot_center: float, optional Rotation axis position. Default is `(detector_width - 1)/2.0` padding_mode: str, optional Padding mode when filtering the sinogram. Can be "zeros" (default) or "edges". filter_name: str, optional Name of the filter for filtered-backprojection. slice_roi: tuple, optional. Whether to backproject in a restricted area. If set, it must be in the form (start_x, end_x, start_y, end_y). `end_x` and `end_y` are non inclusive ! For example if the detector has 2048 pixels horizontally, then you can choose `start_x=0` and `end_x=2048`. If one of the value is set to None, it is replaced with a default value (0 for start, n_x and n_y for end) scale_factor: float, optional Scaling factor for backprojection. For example, to get the linear absorption coefficient in 1/cm, this factor has to be set as the pixel size in cm. DEPRECATED - please use this parameter in "extra_options" extra_options: dict, optional Advanced extra options. See the "Extra options" section for more information. cuda_options: dict, optional Cuda options passed to the CudaProcessing class. Other parameters ----------------- extra_options: dict, optional Dictionary with a set of advanced options. The default are the following: - "padding_mode": "zeros" Padding mode when filtering the sinogram. Can be "zeros" or "edges". DEPRECATED - please use "padding_mode" directly in parameters. - "axis_correction": None Whether to set a correction for the rotation axis. If set, this should be an array with as many elements as the number of angles. This is useful when there is an horizontal displacement of the rotation axis. - centered_axis: bool Whether to "center" the slice on the rotation axis position. If set to True, then the reconstructed region is centered on the rotation axis. - scale_factor: float Scaling factor for backprojection. For example, to get the linear absorption coefficient in 1/cm, this factor has to be set as the pixel size in cm. - clip_outer_circle: False Whether to set to zero the pixels outside the reconstruction mask - filter_cutoff: float Cut-off frequency usef for Fourier filter. Default is 1.0 """ self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._configure_extra_options(scale_factor, padding_mode, extra_options=extra_options) self._init_geometry(sino_shape, slice_shape, angles, rot_center, slice_roi) self._init_filter(filter_name) self._allocate_memory() self._compute_angles() self._compile_kernels() self._bind_textures() def _configure_extra_options(self, scale_factor, padding_mode, extra_options=None): extra_options = extra_options or {} # compat. scale_factor = None if scale_factor is not None: deprecation_warning( "Please use the parameter 'scale_factor' in the 'extra_options' dict", do_print=True, func_name="fbp_scale_factor", ) scale_factor = extra_options.get("scale_factor", None) or scale_factor or 1.0 # if "padding_mode" in extra_options: deprecation_warning( "Please use 'padding_mode' directly in Backprojector arguments, not in 'extra_options'", do_print=True, func_name="fbp_padding_mode", ) # self._backproj_scale_factor = scale_factor self._axis_array = None self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options) self.padding_mode = padding_mode or self.extra_options["padding_mode"] or self.default_padding_mode self._axis_array = self.extra_options["axis_correction"] def _init_geometry(self, sino_shape, slice_shape, angles, rot_center, slice_roi): if slice_shape is not None and slice_roi is not None: raise ValueError("slice_shape and slice_roi cannot be used together") self.sino_shape = sino_shape if len(sino_shape) == 2: n_angles, dwidth = sino_shape else: raise ValueError("Expected 2D sinogram") self.dwidth = dwidth self._set_slice_shape(slice_shape) self.rot_center = rot_center or (self.dwidth - 1) / 2.0 self.axis_pos = self.rot_center self._set_angles(angles, n_angles) self._set_slice_roi(slice_roi) # # offset = start - move # move = 0 if not(centered_axis) else start + (n-1)/2. - c if self.extra_options["centered_axis"]: self.offsets = { "x": round(self.rot_center - (self.n_x - 1) / 2.0), "y": round(self.rot_center - (self.n_y - 1) / 2.0), } # self._set_axis_corr() def _set_slice_shape(self, slice_shape): n_y = self.dwidth n_x = self.dwidth if slice_shape is not None: if np.isscalar(slice_shape): slice_shape = (slice_shape, slice_shape) n_y, n_x = slice_shape self.n_x = n_x self.n_y = n_y self.slice_shape = (n_y, n_x) def _set_angles(self, angles, n_angles): self.n_angles = n_angles if angles is None: angles = n_angles if np.isscalar(angles): angles = np.linspace(0, np.pi, angles, False) else: assert len(angles) == self.n_angles self.angles = angles def _set_slice_roi(self, slice_roi): self.offsets = {"x": 0, "y": 0} self.slice_roi = slice_roi if slice_roi is None: return start_x, end_x, start_y, end_y = slice_roi # convert negative indices dwidth = self.dwidth start_x = convert_index(start_x, dwidth, 0) start_y = convert_index(start_y, dwidth, 0) end_x = convert_index(end_x, dwidth, dwidth) end_y = convert_index(end_y, dwidth, dwidth) self.slice_shape = (end_y - start_y, end_x - start_x) self.n_x = self.slice_shape[-1] self.n_y = self.slice_shape[-2] self.offsets = {"x": start_x, "y": start_y} def _allocate_memory(self): self._d_sino_cua = cuda.np_to_array(np.zeros(self.sino_shape, "f"), "C") # 1D textures are not supported in pycuda self.h_msin = np.zeros((1, self.n_angles), "f") self.h_cos = np.zeros((1, self.n_angles), "f") self._d_sino = garray.zeros(self.sino_shape, "f") self.cuda_processing.init_arrays_to_none(["_d_slice"]) def _compute_angles(self): self.h_cos[0] = np.cos(self.angles).astype("f") self.h_msin[0] = (-np.sin(self.angles)).astype("f") self._d_msin = garray.to_gpu(self.h_msin[0]) self._d_cos = garray.to_gpu(self.h_cos[0]) if self._axis_correction is not None: self._d_axcorr = garray.to_gpu(self._axis_correction) def _set_axis_corr(self): axcorr = self.extra_options["axis_correction"] self._axis_correction = axcorr if axcorr is None: return if len(axcorr) != self.n_angles: raise ValueError("Expected %d angles but got %d" % (self.n_angles, len(axcorr))) self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32) self._axis_correction[0, :] = axcorr[:] # pylint: disable=E1136 def _init_filter(self, filter_name): self.filter_name = filter_name self.sino_filter = SinoFilter( self.sino_shape, filter_name=self.filter_name, padding_mode=self.padding_mode, extra_options={"cutoff": self.extra_options.get("filter_cutoff", 1.0)}, cuda_options={"ctx": self.cuda_processing.ctx}, ) def _get_kernel_signature(self): kern_full_sig = list("PiifiiiiPPPf") if self._axis_correction is None: kern_full_sig[10] = "" return "".join(kern_full_sig) def _get_kernel_options(self): tex_name = "tex_projections" sourcemodule_options = [] # We use blocks of 16*16 (see why in kernel doc), and one thread # handles 2 pixels per dimension. block = (16, 16, 1) # The Cuda kernel is optimized for 16x16 threads blocks # If one of the dimensions is smaller than 16, it has to be addapted if self.n_x < 16 or self.n_y < 16: tpb_x = min(int(nextpow2(self.n_x)), 16) tpb_y = min(int(nextpow2(self.n_y)), 16) block = (tpb_x, tpb_y, 1) sourcemodule_options.append("-DSHARED_SIZE=%d" % (tpb_x * tpb_y)) grid = (updiv(updiv(self.n_x, block[0]), 2), updiv(updiv(self.n_y, block[1]), 2)) if self.extra_options["clip_outer_circle"]: sourcemodule_options.append("-DCLIP_OUTER_CIRCLE") shared_size = int(np.prod(block)) * 2 if self._axis_correction is not None: sourcemodule_options.append("-DDO_AXIS_CORRECTION") shared_size += int(np.prod(block)) shared_size *= 4 # sizeof(float32) self._kernel_options = { "file_name": get_cuda_srcfile(self.cuda_fname), "kernel_name": self.cuda_kernel_name, "kernel_signature": self._get_kernel_signature(), "texture_name": tex_name, "sourcemodule_options": sourcemodule_options, "grid": grid, "block": block, "shared_size": shared_size, } def _compile_kernels(self): self._get_kernel_options() kern_opts = self._kernel_options # Configure backprojector self.gpu_projector = CudaKernel( kern_opts["kernel_name"], filename=kern_opts["file_name"], options=kern_opts["sourcemodule_options"] ) self.texref_proj = self.gpu_projector.module.get_texref(kern_opts["texture_name"]) self.texref_proj.set_filter_mode(cuda.filter_mode.LINEAR) self.gpu_projector.prepare(kern_opts["kernel_signature"], [self.texref_proj]) # Prepare kernel arguments self.kern_proj_args = [ None, # output d_slice holder self.n_angles, self.dwidth, self.axis_pos, self.n_x, self.n_y, self.offsets["x"], self.offsets["y"], self._d_cos, self._d_msin, self._backproj_scale_factor, ] if self._axis_correction is not None: self.kern_proj_args.insert(-1, self._d_axcorr) self.kern_proj_kwargs = { "grid": kern_opts["grid"], "block": kern_opts["block"], "shared_size": kern_opts["shared_size"], } def _bind_textures(self): self.texref_proj.set_array(self._d_sino_cua) def _set_output(self, output, check=False): if output is None: self.cuda_processing.allocate_array("_d_slice", self.slice_shape, dtype=np.float32) return self.cuda_processing._d_slice if check: assert output.dtype == np.float32 assert output.shape == self.slice_shape, "Expected output shape %s but got %s" % ( self.slice_shape, output.shape, ) if isinstance(output, garray.GPUArray): return output.gpudata else: # pycuda.driver.DeviceAllocation ? return output def backproj(self, sino, output=None, do_checks=True): copy_array(self._d_sino_cua, sino, check=do_checks) d_slice = self._set_output(output, check=do_checks) self.kern_proj_args[0] = d_slice self.gpu_projector(*self.kern_proj_args, **self.kern_proj_kwargs) if output is not None: return output else: return self.cuda_processing._d_slice.get() def filtered_backprojection(self, sino, output=None): self.sino_filter(sino, output=self._d_sino) return self.backproj(self._d_sino, output=output) fbp = filtered_backprojection # shorthand class PolarBackprojector(Backprojector): """ Cuda Backprojector with output in polar coordinates. """ cuda_fname = "backproj_polar.cu" cuda_kernel_name = "backproj_polar" # patch parent method: force slice_shape to (n_angles, n_x) def _set_angles(self, angles, n_angles): Backprojector._set_angles(self, angles, n_angles) self.slice_shape = (self.n_angles, self.n_x) # patch parent method: def _set_slice_roi(self, slice_roi): if slice_roi is not None: raise ValueError("slice_roi is not supported with this class") Backprojector._set_slice_roi(self, slice_roi) # patch parent method: don't do the 4X compute-workload optimization for this kernel def _get_kernel_options(self): Backprojector._get_kernel_options(self) block = self._kernel_options["block"] self._kernel_options["grid"] = (updiv(self.n_x, block[0]), updiv(self.n_y, block[1])) # patch parent method: update kernel args def _compile_kernels(self): n_y = self.n_y self.n_y = self.n_angles Backprojector._compile_kernels(self) self.n_y = n_y ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/reconstruction/fbp_opencl.py0000644000175000017500000000707400000000000020723 0ustar00pierrepierreimport numpy as np from silx.opencl.backprojection import Backprojection from ..utils import deprecation_warning # Compatibility layer Nabu/silx class Backprojector: def __init__( self, sino_shape, slice_shape=None, angles=None, rot_center=None, filter_name=None, padding_mode=None, slice_roi=None, scale_factor=None, ctx=None, devicetype="all", platformid=None, deviceid=None, profile=False, extra_options=None, ): if slice_roi and ( slice_roi[0] > 0 or slice_roi[2] > 0 or slice_roi[1] < sino_shape[1] or slice_roi[3] < sino_shape[1] ): raise ValueError("Not implemented yet in the OpenCL back-end") self._configure_extra_options(extra_options, padding_mode) self._get_scale_factor(scale_factor) self.backprojector = Backprojection( sino_shape, slice_shape=slice_shape, axis_position=rot_center, # angles=angles, filter_name=filter_name, ctx=ctx, devicetype=devicetype, platformid=platformid, deviceid=deviceid, profile=profile, extra_options=self._silx_fbp_extra_options, ) def _configure_extra_options(self, extra_options, padding_mode): self.extra_options = extra_options or {} self._silx_fbp_extra_options = {} if "padding_mode" in self.extra_options: deprecation_warning( "Please use 'padding_mode' directly in Backprojector arguments, not in 'extra_options'", do_print=True, func_name="ocl_fbp_padding_mode", ) if self.extra_options.get("clip_outer_circle", False) or self.extra_options.get("center_slice", False): raise NotImplementedError() if padding_mode is not None: self._silx_fbp_extra_options["padding_mode"] = padding_mode self._silx_fbp_extra_options["cutoff"] = self.extra_options.get("fbp_filter_cutoff", 1.0) def _get_scale_factor(self, scale_factor): if scale_factor is not None: deprecation_warning( "Please use the 'scale_factor' parameter in extra_options", func_name="ocl_fbp_scale_factor" ) self.scale_factor = scale_factor or self.extra_options.get("scale_factor", None) # scale_factor is not implemented in the opencl code def _fbp_with_scale_factor(self, sino, output=None): return self.backprojector.filtered_backprojection(sino * self.scale_factor, output=output) def _fbp(self, sino, output=None): return self.backprojector.filtered_backprojection(sino, output=output) def filtered_backprojection(self, sino, output=None): input_sino = sino # TODO scale_factor is not implemented in the silx opencl code # This makes a copy of the input array if self.scale_factor is not None: input_sino = sino * self.scale_factor # if output is None or isinstance(output, np.ndarray): res = self.backprojector.filtered_backprojection(input_sino) if output is not None: output[:] = res[:] return output return res else: # assuming pyopencl array return self.backprojector.filtered_backprojection(input_sino, output=output) fbp = filtered_backprojection def backproj(self, *args, **kwargs): # TODO scale_factor ? return self.backprojector.backprojection(*args, **kwargs) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/reconstruction/filtering.py0000644000175000017500000001750400000000000020576 0ustar00pierrepierrefrom math import pi import pycuda.gpuarray as garray import numpy as np from silx.image.tomography import compute_fourier_filter, get_next_power from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing from ..utils import get_cuda_srcfile, check_supported, updiv class SinoFilter: available_padding_modes = ["zeros", "edges"] def __init__( self, sino_shape, filter_name=None, padding_mode="zeros", extra_options=None, cuda_options=None, ): """ Build a sinogram filter process. """ self.cuda = CudaProcessing(**(cuda_options or {})) self._init_extra_options(extra_options) self._calculate_shapes(sino_shape) self._init_fft() self._allocate_memory(sino_shape) self._compute_filter(filter_name) self._set_padding_mode(padding_mode) self._init_kernels() def _init_extra_options(self, extra_options): self.extra_options = { "cutoff": 1.0, } if extra_options is not None: self.extra_options.update(extra_options) def _set_padding_mode(self, padding_mode): # Compat. if padding_mode == "edge": padding_mode = "edges" # check_supported(padding_mode, self.available_padding_modes, "padding mode") self.padding_mode = padding_mode def _calculate_shapes(self, sino_shape): self.ndim = len(sino_shape) if self.ndim == 2: n_angles, dwidth = sino_shape n_sinos = 1 elif self.ndim == 3: n_sinos, n_angles, dwidth = sino_shape else: raise ValueError("Invalid sinogram number of dimensions") self.sino_shape = sino_shape self.n_angles = n_angles self.dwidth = dwidth # int() is crucial here ! Otherwise some pycuda arguments (ex. memcpy2D) # will not work with numpy.int64 (as for 2018.X) self.dwidth_padded = int(get_next_power(2 * self.dwidth)) self.sino_padded_shape = (n_angles, self.dwidth_padded) if self.ndim == 3: self.sino_padded_shape = (n_sinos,) + self.sino_padded_shape sino_f_shape = list(self.sino_padded_shape) sino_f_shape[-1] = sino_f_shape[-1] // 2 + 1 self.sino_f_shape = tuple(sino_f_shape) # self.pad_left = (self.dwidth_padded - self.dwidth) // 2 self.pad_right = self.dwidth_padded - self.dwidth - self.pad_left def _init_fft(self): # Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import from silx.math.fft.cufft import CUFFT # self.fft = CUFFT( self.sino_padded_shape, dtype=np.float32, axes=(-1,), ) def _allocate_memory(self, sino_shape): self.d_filter_f = garray.zeros((self.sino_f_shape[-1],), np.complex64) self.d_sino_padded = self.fft.data_in self.d_sino_f = self.fft.data_out def set_filter(self, h_filt, normalize=True): """ Set a filter for sinogram filtering. :param h_filt: Filter. Each line of the sinogram will be filtered with this filter. It has to be the Real-to-Complex Fourier Transform of some real filter, padded to 2*sinogram_width. :param normalize: Whether to normalize the filter with pi/num_angles. """ if h_filt.size != self.sino_f_shape[-1]: raise ValueError( """ Invalid filter size: expected %d, got %d. Please check that the filter is the Fourier R2C transform of some real 1D filter. """ % (self.sino_f_shape[-1], h_filt.size) ) if not (np.iscomplexobj(h_filt)): print("Warning: expected a complex Fourier filter") self.filter_f = h_filt if normalize: self.filter_f *= pi / self.n_angles self.filter_f = self.filter_f.astype(np.complex64) self.d_filter_f[:] = self.filter_f[:] def _compute_filter(self, filter_name): self.filter_name = filter_name or "ram-lak" # TODO add this one into silx if self.filter_name == "hilbert": freqs = np.fft.fftfreq(self.dwidth_padded) filter_f = 1.0 / (2 * pi * 1j) * np.sign(freqs) # else: filter_f = compute_fourier_filter( self.dwidth_padded, self.filter_name, cutoff=self.extra_options["cutoff"], ) filter_f = filter_f[: self.dwidth_padded // 2 + 1] # R2C self.set_filter(filter_f, normalize=True) def _init_kernels(self): fname = get_cuda_srcfile("ElementOp.cu") if self.ndim == 2: kernel_name = "inplace_complex_mul_2Dby1D" kernel_sig = "PPii" else: kernel_name = "inplace_complex_mul_3Dby1D" kernel_sig = "PPiii" self.mult_kernel = CudaKernel(kernel_name, filename=fname, signature=kernel_sig) self.kern_args = (self.d_sino_f, self.d_filter_f) self.kern_args += self.d_sino_f.shape[::-1] self._pad_edges_kernel = CudaKernel( "padding_edge", filename=get_cuda_srcfile("padding.cu"), signature="Piiiiiiii" ) self._pad_block = (32, 32, 1) self._pad_grid = tuple([updiv(n, p) for n, p in zip(self.sino_padded_shape[::-1], self._pad_block)]) def _check_array(self, arr): if arr.dtype != np.float32: raise ValueError("Expected data type = numpy.float32") if arr.shape != self.sino_shape: raise ValueError("Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape)) def _pad_sino(self, sino): if self.padding_mode == "edges": self.d_sino_padded[:, : self.dwidth] = sino[:] self._pad_edges_kernel( self.d_sino_padded, self.dwidth, self.n_angles, self.dwidth_padded, self.n_angles, self.pad_left, self.pad_right, 0, 0, grid=self._pad_grid, block=self._pad_block, ) else: # zeros self.d_sino_padded.fill(0) if self.ndim == 2: self.d_sino_padded[:, : self.dwidth] = sino[:] else: self.d_sino_padded[:, :, : self.dwidth] = sino[:] def filter_sino(self, sino, output=None, no_output=False): """ Perform the sinogram siltering. Parameters ---------- sino: numpy.ndarray or pycuda.gpuarray.GPUArray Input sinogram (2D or 3D) output: numpy.ndarray or pycuda.gpuarray.GPUArray, optional Output array. no_output: bool, optional If set to True, no copy is be done. The resulting data lies in self.d_sino_padded. """ self._check_array(sino) # copy2d/copy3d self._pad_sino(sino) # FFT self.fft.fft(self.d_sino_padded, output=self.d_sino_f) # multiply padded sinogram with filter in the Fourier domain self.mult_kernel(*self.kern_args) # TODO tune block size ? # iFFT self.fft.ifft(self.d_sino_f, output=self.d_sino_padded) # return if no_output: return self.d_sino_padded if output is None: res = np.zeros(self.sino_shape, dtype=np.float32) # can't do memcpy2d D->H ? (self.d_sino_padded[:, w]) I have to get() sino_ref = self.d_sino_padded.get() else: res = output sino_ref = self.d_sino_padded if self.ndim == 2: res[:] = sino_ref[:, : self.dwidth] else: res[:] = sino_ref[:, :, : self.dwidth] return res __call__ = filter_sino ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/reconstructor.py0000644000175000017500000001620300000000000021522 0ustar00pierrepierreimport numpy as np from ..utils import convert_index class Reconstructor: """ Abstract base class for reconstructors. A `Reconstructor` is a helper to reconstruct slices in arbitrary directions (not only usual "horizontal slices") in parallel-beam tomography. Current limitations: - Limitation to the three main axes - One instance of Reconstructor can only reconstruct successive slices Typical scenarios examples: - "I want to reconstruct several slices along 'z'", where `z` is the vertical axis. In this case, we reconstruct "horizontal slices" in planes perpendicular to the rotation axis. - "I want to reconstruct slices along 'y'". Here `y` is an axis perpendicular to `z`, i.e we reconstruct "vertical slices". A `Reconstructor` is tied to the set of slices to reconstruct (axis and orientation). Once defined, it cannot be changed ; i.e another class has to be instantiated to reconstruct slices in other axes/indices. The volume geometry conventions are defined below:: __________ / /| / / | z / / | ^ /_________/ | | | | | | y | | / | / | | / | / | | / | / |__________|/ |/ ---------- > x The axis `z` parallel to the rotation axis. The usual parallel-beam tomography setting reconstructs slices along `z`, i.e in planes parallel to (x, y). """ def __init__(self, shape, indices_range, axis="z", vol_type="sinograms", slices_roi=None): """ Initialize a reconstructor. Parameters ----------- shape: tuple Shape of the stack of sinograms or projections. indices_range: tuple Range of indices to reconstruct, in the form (start, end). As the standard Python behavior, the upper bound is not included. For example, to reconstruct 100 slices (numbered from 0 to 99), then you can provide (0, 100) or (0, None). Providing (0, 99) or (0, -1) will omit the last slice. axis: str Axis along which the slices are reconstructed. This axis is orthogonal to the slices planes. This parameter can be either "x", "y", or "z". Default is "z" (reconstruct slices perpendicular to the rotation axis). vol_type: str, optional Whether the parameter `shape` describes a volume of sinograms or projections. The two are the same except that axes 0 and 1 are swapped. Can be "sinograms" (default) or "projections". slices_roi: tuple, optional Define a Region Of Interest to reconstruct a part of each slice. By default, the whole slice is reconstructed for each slice index. This parameter is in the form `(start_u, end_u, start_v, end_v)`, where `u` and `v` are horizontal and vertical axes on the reconstructed slice respectively, regardless of its orientation. If one of the values is set to None, it will be replaced by the corresponding default value. Examples --------- To reconstruct the first two horizontal slices, i.e along `z`: `R = Reconstructor(vol_shape, [0, 1])` To reconstruct vertical slices 0-100 along the `y` axis: `R = Reconstructor(vol_shape, (0, 100), axis="y")` """ self._set_shape(shape, vol_type) self._set_axis(axis) self._set_indices(indices_range) self._configure_geometry(slices_roi) def _set_shape(self, shape, vol_type): if "sinogram" in vol_type.lower(): self.vol_type = "sinograms" elif "projection" in vol_type.lower(): self.vol_type = "projections" else: raise ValueError("vol_type can be either 'sinograms' or 'projections'") if len(shape) != 3: raise ValueError("Expected a 3D array description, but shape does not have 3 dims") self.shape = shape if self.vol_type == "sinograms": n_z, n_a, n_x = shape else: n_a, n_z, n_x = shape self.sinos_shape = (n_z, n_a, n_x) self.projs_shape = (n_a, n_z, n_x) self.data_shape = self.sinos_shape if self.vol_type == "sinograms" else self.projs_shape self.n_a = n_a self.n_x = n_x self.n_y = n_x # square slice by default self.n_z = n_z self._n = {"x": self.n_x, "y": self.n_y, "z": self.n_z} def _set_axis(self, axis): if axis.lower() not in ["x", "y", "z"]: raise ValueError("axis can be either 'x', 'y' or 'z' (got %s)" % axis) self.axis = axis def _set_indices(self, indices_range): start, end = indices_range npix = self._n[self.axis] start = convert_index(start, npix, 0) end = convert_index(end, npix, npix) self.indices_range = (start, end) self._idx_start = start self._idx_end = end self.indices = np.arange(start, end) def _configure_geometry(self, slices_roi): self.slices_roi = slices_roi or (None, None, None, None) start_u, end_u, start_v, end_v = self.slices_roi uv_to_xyz = { "z": ("x", "y"), # reconstruct along z: u = x, v = y "y": ("y", "z"), # reconstruct along y: u = y, v = z "x": ("y", "z"), # reconstruct along x: u = y, v = z } rotated_axes = uv_to_xyz[self.axis] u_max = self._n[rotated_axes[0]] v_max = self._n[rotated_axes[1]] start_u = convert_index(start_u, u_max, 0) end_u = convert_index(end_u, u_max, u_max) start_v = convert_index(start_v, v_max, 0) end_v = convert_index(end_v, v_max, v_max) self.slices_roi = (start_u, end_u, start_v, end_v) if self.axis == "z": self.backprojector_roi = self.slices_roi start_z, end_z = self._idx_start, self._idx_end if self.axis == "y": self.backprojector_roi = (start_u, end_u, self._idx_start, self._idx_end) start_z, end_z = start_v, end_v if self.axis == "x": self.backprojector_roi = (self._idx_start, self._idx_end, start_u, end_u) start_z, end_z = start_v, end_v self._z_indices = np.arange(start_z, end_z) self.output_shape = ( self._z_indices.size, self.backprojector_roi[3] - self.backprojector_roi[2], self.backprojector_roi[1] - self.backprojector_roi[0], ) def _check_data(self, data): if data.shape != self.data_shape: raise ValueError( "Invalid data shape: expected %s shape %s, but got %s" % (self.vol_type, self.data_shape, data.shape) ) if data.dtype != np.float32: raise ValueError("Expected float32 data type") def reconstruct(self): raise ValueError("Base class") __call__ = reconstruct ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/reconstructor_cuda.py0000644000175000017500000000333700000000000022522 0ustar00pierrepierreimport numpy as np import pycuda.gpuarray as garray from ..cuda.kernel import CudaKernel from ..cuda.processing import CudaProcessing from .fbp import Backprojector from .reconstructor import Reconstructor class CudaReconstructor(Reconstructor): def __init__(self, shape, indices, axis="z", vol_type="sinograms", slices_roi=None, **backprojector_options): Reconstructor.__init__(self, shape, indices, axis=axis, vol_type=vol_type, slices_roi=slices_roi) self._init_backprojector(**backprojector_options) def _init_backprojector(self, **backprojector_options): self.backprojector = Backprojector( self.sinos_shape[1:], slice_roi=self.backprojector_roi, **backprojector_options ) def reconstruct(self, data, output=None): """ Reconstruct from sinograms or projections. """ self._check_data(data) B = self.backprojector if output is None: output = np.zeros(self.output_shape, "f") new_output = True else: assert output.shape == self.output_shape, str( "Expected output_shape = %s, got %s" % (str(self.output_shape), str(output.shape)) ) assert output.dtype == np.float32 new_output = False def reconstruct_fbp(data, output, i, i0): if self.vol_type == "sinograms": current_sino = data[i] else: current_sino = data[:, i, :] if new_output: output[i0] = B.fbp(current_sino) else: B.fbp(current_sino, output=output[i0]) for i0, i in enumerate(self._z_indices): reconstruct_fbp(data, output, i, i0) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/rings.py0000644000175000017500000000454500000000000017736 0ustar00pierrepierrefrom ..thirdparty.pore3d_deringer_munch import munchetal_filter from .sinogram import SinoBuilder class MunchDeringer: def __init__(self, sigma, sinos_shape=None, levels=None, wname="db15"): """ Initialize a "Munch Et Al" sinogram deringer. See References for more information. Parameters ----------- sigma: float Standard deviation of the damping parameter. The higher value of sigma, the more important the filtering effect on the rings. levels: int, optional Number of wavelets decomposition levels. By default (None), the maximum number of decomposition levels is used. wname: str, optional Default is "db15" (Daubechies, 15 vanishing moments) sinos_shape: tuple, optional Shape of the sinogram (or sinograms stack). References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ self._get_shapes(sinos_shape=sinos_shape) self.sigma = sigma self.levels = levels self.wname = wname self._check_can_use_wavelets() _get_shapes = SinoBuilder._get_shapes def _check_can_use_wavelets(self): if munchetal_filter is None: raise ValueError("Need pywavelets to use this class") def _destripe_2D(self, sino, output): res = munchetal_filter(sino, self.levels, self.sigma, wname=self.wname) output[:] = res return output def remove_rings(self, sinos, output=None): """ Main function to performs rings artefacts removal on sinogram(s). CAUTION: this function defaults to in-place processing, meaning that the sinogram(s) you pass will be overwritten. Parameters ---------- sinos: numpy.ndarray Sinogram or stack of sinograms. output: numpy.ndarray, optional Output array. If set to None (default), the output overwrites the input. """ if output is None: output = sinos if sinos.ndim == 2: return self._destripe_2D(sinos, output) n_sinos = sinos.shape[0] for i in range(n_sinos): self._destripe_2D(sinos[i], output[i]) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1678380095.0 nabu-2023.1.1/nabu/reconstruction/rings_cuda.py0000644000175000017500000001230400000000000020722 0ustar00pierrepierreimport numpy as np import pycuda.gpuarray as garray from ..utils import get_cuda_srcfile from ..cuda.processing import CudaProcessing from ..cuda.kernel import CudaKernel from .rings import MunchDeringer try: from pycudwt import Wavelets __have_pycudwt__ = True except ImportError: __have_pycudwt__ = False try: from skcuda.fft import Plan from skcuda.fft import fft as cufft from skcuda.fft import ifft as cuifft __have_skcuda__ = True except Exception as exc: # We have to catch this very broad exception, because # skcuda.cublas.cublasError cannot be evaluated without error when no cuda GPU is found __have_skcuda__ = False class CudaMunchDeringer(MunchDeringer): def __init__(self, sigma, sinos_shape=None, levels=None, wname="db15", cuda_options=None): """ Initialize a "Munch Et Al" sinogram deringer with the Cuda backend. See References for more information. Parameters ----------- sigma: float Standard deviation of the damping parameter. The higher value of sigma, the more important the filtering effect on the rings. levels: int, optional Number of wavelets decomposition levels. By default (None), the maximum number of decomposition levels is used. wname: str, optional Default is "db15" (Daubechies, 15 vanishing moments) sinos_shape: tuple, optional Shape of the sinogram (or sinograms stack). References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ super().__init__(sigma, levels=levels, wname=wname, sinos_shape=sinos_shape) self._check_can_use_wavelets() self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_pycudwt() self._init_fft() self._setup_fw_kernel() def _check_can_use_wavelets(self): if not (__have_pycudwt__ and __have_skcuda__): raise ValueError("Needs pycudwt and scikit-cuda to use this class") def _init_fft(self): self._fft_plans = {} for level, d_vcoeff in self._d_vertical_coeffs.items(): n_angles, dwidth = d_vcoeff.shape # Batched vertical 1D FFT - need advanced data layout # http://docs.nvidia.com/cuda/cufft/#advanced-data-layout p_f = Plan( (n_angles,), np.float32, np.complex64, batch=dwidth, inembed=np.int32([0]), istride=dwidth, idist=1, onembed=np.int32([0]), ostride=dwidth, odist=1, ) p_i = Plan( (n_angles,), np.complex64, np.float32, batch=dwidth, inembed=np.int32([0]), istride=dwidth, idist=1, onembed=np.int32([0]), ostride=dwidth, odist=1, ) self._fft_plans[level] = {"forward": p_f, "inverse": p_i} def _init_pycudwt(self): if self.levels is None: self.levels = 100 # will be clipped by pycudwt self.sino_shape = self.sinos_shape[1:] self.cudwt = Wavelets(np.zeros(self.sino_shape, "f"), self.wname, self.levels) self.levels = self.cudwt.levels # Access memory allocated by "pypwt" from pycuda self._d_sino = garray.empty(self.sino_shape, np.float32, gpudata=self.cudwt.image_int_ptr()) self._get_vertical_coeffs() def _get_vertical_coeffs(self): self._d_vertical_coeffs = {} self._d_sino_f = {} # Transfer the (0-memset) coefficients in order to get all the shapes coeffs = self.cudwt.coeffs for i in range(self.cudwt.levels): shape = coeffs[i + 1][1].shape self._d_vertical_coeffs[i + 1] = garray.empty( shape, np.float32, gpudata=self.cudwt.coeff_int_ptr(3 * i + 2) ) self._d_sino_f[i + 1] = garray.zeros((shape[0] // 2 + 1, shape[1]), dtype=np.complex64) def _setup_fw_kernel(self): self._fw_kernel = CudaKernel( "kern_fourierwavelets", filename=get_cuda_srcfile("fourier_wavelets.cu"), signature="Piif", ) def _destripe_2D(self, d_sino, output): # set the "image" for DWT (memcpy D2D) self._d_sino.set(d_sino) # perform forward DWT self.cudwt.forward() for i in range(self.cudwt.levels): level = i + 1 d_coeffs = self._d_vertical_coeffs[level] d_sino_f = self._d_sino_f[level] Ny, Nx = d_coeffs.shape # Batched FFT along axis 0 cufft(d_coeffs, d_sino_f, self._fft_plans[level]["forward"]) # Dampen the wavelets coefficients self._fw_kernel(d_sino_f, Nx, Ny, self.sigma) # IFFT cuifft(d_sino_f, d_coeffs, self._fft_plans[level]["inverse"]) # Finally, inverse DWT self.cudwt.inverse() output.set(self._d_sino) return output ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/reconstruction/sinogram.py0000644000175000017500000003470000000000000020427 0ustar00pierrepierreimport numpy as np from scipy.interpolate import interp1d from ..utils import get_2D_3D_shape, check_supported, deprecated_class, deprecated class SinoBuilder: """ A class to build sinograms. """ def __init__( self, sinos_shape=None, radios_shape=None, rot_center=None, halftomo=False, angles=None, interpolate=False ): """ Initialize a SinoBuilder instance. Parameters ---------- sinos_shape: tuple of int Shape of the stack of sinograms, in the form `(n_z, n_angles, n_x)`. If not provided, it is derived from `radios_shape`. radios_shape: tuple of int Shape of the chunk of radios, in the form `(n_angles, n_z, n_x)`. If not provided, it is derived from `sinos_shape`. rot_center: int or array Rotation axis position. A scalar indicates the same rotation axis position for all the projections. halftomo: bool Whether "half tomography" is enabled. Default is False. interpolate: bool, optional Only used if halftomo=True. Whether to re-grid the second part of sinograms to match projection k with projection k + n_a/2. This forces each pair of projection (k, k + n_a/2) to be separated by exactly 180 degrees. angles: array, optional Rotation angles (in radians). Used and required only when halftomo and interpolate are True. """ self._get_shapes(sinos_shape, radios_shape) self.set_rot_center(rot_center) self._configure_halftomo(halftomo, interpolate, angles) def _get_shapes(self, sinos_shape, radios_shape=None): if (sinos_shape is None) and (radios_shape is None): raise ValueError("Need to provide sinos_shape and/or radios_shape") if sinos_shape is None: n_a, n_z, n_x = get_2D_3D_shape(radios_shape) sinos_shape = (n_z, n_a, n_x) elif len(sinos_shape) == 2: sinos_shape = (1,) + sinos_shape if radios_shape is None: n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) radios_shape = (n_a, n_z, n_x) elif len(radios_shape) == 2: radios_shape = (1,) + radios_shape self.sinos_shape = sinos_shape self.radios_shape = radios_shape n_a, n_z, n_x = radios_shape self.n_angles = n_a self.n_z = n_z self.n_x = n_x def set_rot_center(self, rot_center): """ Set the rotation axis position for the current radios/sinos stack. rot_center: int or array Rotation axis position. A scalar indicates the same rotation axis position for all the projections. """ if rot_center is None: rot_center = (self.n_x - 1) / 2.0 if not (np.isscalar(rot_center)): rot_center = np.array(rot_center) if rot_center.size != self.n_angles: raise ValueError( "Expected rot_center to have %d elements but got %d" % (self.n_angles, rot_center.size) ) self.rot_center = rot_center def _configure_halftomo(self, halftomo, interpolate, angles): self.halftomo = halftomo self.interpolate = interpolate self.angles = angles self._halftomo_flip = False if not self.halftomo: return if interpolate and (angles is None): raise ValueError("The parameter 'angles' has to be provided when using halftomo=True and interpolate=True") self.extended_sino_width = get_extended_sinogram_width(self.n_x, self.rot_center) # If CoR is on the left: "flip" the logic if self.rot_center < (self.n_x - 1) / 2: self.rot_center = self.n_x - 1 - self.rot_center self._halftomo_flip = True # if abs(self.rot_center - ((self.n_x - 1) / 2.0)) < 1: # which tol ? raise ValueError("Half tomography: incompatible rotation axis position: %.2f" % self.rot_center) self.sinos_halftomo_shape = (self.n_z, (self.n_angles + 1) // 2, self.extended_sino_width) def _check_array_shape(self, array, kind="radio"): expected_shape = self.radios_shape if "radio" in kind else self.sinos_shape assert array.shape == expected_shape, "Expected radios shape %s, but got %s" % (expected_shape, array.shape) @property def output_shape(self): """ Get the output sinograms shape. """ if self.halftomo: return self.sinos_halftomo_shape return self.sinos_shape # # 2D # def _get_sino_simple(self, radios, i): return radios[:, i, :] # view def _get_sino_halftomo(self, radios, i, output=None): # TODO output is ignored for now sino = radios[:, i, :] if self.interpolate: match_half_sinos_parts(sino, self.angles) elif self.n_angles & 1: # Odd number of projections - add one line in the end sino = np.vstack([sino, np.zeros_like(sino[-1])]) if self._halftomo_flip: sino = sino[:, ::-1] if self.rot_center > self.n_x: # (hopefully rare) case where CoR is outside FoV result = _convert_halftomo_right(sino, self.extended_sino_width) else: # Standard case result = convert_halftomo(sino, self.extended_sino_width) if self._halftomo_flip: result = result[:, ::-1] return result def get_sino(self, radios, i, output=None): """ The the sinogram at a given index. Parameters ---------- radios: array 3D array with shape (n_z, n_angles, n_x) i: int Sinogram index Returns ------- sino: array Two dimensional array with shape (n_angles2, n_x2) where the dimensions are determined by the current settings. """ if self.halftomo: return self._get_sino_halftomo(radios, i, output=None) else: return self._get_sino_simple(radios, i) # # 3D # def _get_sinos_simple(self, radios, output=None): res = np.rollaxis(radios, 1, 0) # view if output is not None: output[...] = res[...] # copy return output return res def _get_sinos_halftomo(self, radios, output=None): n_a, n_z, n_x = radios.shape if output is None: output = np.zeros(self.sinos_halftomo_shape, dtype=np.float32) elif output.shape != self.output_shape: raise ValueError("Expected output shape to be %s, but got %s" % (str(output.shape), str(self.output_shape))) for i in range(n_z): output[i] = self._get_sino_halftomo(radios, i) return output def get_sinos(self, radios, output=None): if self.halftomo: return self._get_sinos_halftomo(radios, output=output) else: return self._get_sinos_simple(radios, output=output) @deprecated("Use get_sino() or get_sinos() instead", do_print=True) def radios_to_sinos(self, radios, output=None, copy=False): """ DEPRECATED. Use get_sinos() or get_sino() instead. """ return self.get_sinos(radios, output=output) SinoProcessing = deprecated_class("'SinoProcessing' was renamed 'SinoBuilder'", do_print=True)(SinoBuilder) def convert_halftomo(sino, extended_width, transition_width=None): """ Converts a sinogram into a sinogram with extended FOV with the "half tomography" setting. """ assert sino.ndim == 2 assert (sino.shape[0] % 2) == 0 na, nx = sino.shape na2 = na // 2 r = extended_width // 2 d = transition_width or nx - r res = np.zeros((na2, 2 * r), dtype="f") sino1 = sino[:na2, :] sino2 = sino[na2:, ::-1] res[:, : r - d] = sino1[:, : r - d] # w1 = np.linspace(0, 1, 2 * d, endpoint=True) res[:, r - d : r + d] = (1 - w1) * sino1[:, r - d :] + w1 * sino2[:, 0 : 2 * d] # res[:, r + d :] = sino2[:, 2 * d :] return res # This function can have a cuda counterpart, see test_interpolation.py def match_half_sinos_parts(sino, angles, output=None): """ Modifies the lower part of the half-acquisition sinogram so that each projection pair is separated by exactly 180 degrees. This means that `new_sino[k]` and `new_sino[k + n_angles//2]` will be 180 degrees apart. Parameters ---------- sino: numpy.ndarray Two dimensional array with the sinogram in the form (n_angles, n_x) angles: numpy.ndarray One dimensional array with the rotation angles. output: numpy.array, optional Output sinogram. By default, the array 'sino' is modified in-place. Notes ----- This function assumes that the angles are in an increasing order. """ n_a = angles.size n_a_2 = n_a // 2 sino_part1 = sino[:n_a_2, :] sino_part2 = sino[n_a_2:, :] angles = np.rad2deg(angles) # more numerically stable ? angles_1 = angles[:n_a_2] angles_2 = angles[n_a_2:] angles_2_target = angles_1 + 180.0 interpolator = interp1d(angles_2, sino_part2, axis=0, kind="linear", copy=False, fill_value="extrapolate") if output is None: output = sino else: output[:n_a_2, :] = sino[:n_a_2, :] output[n_a_2:, :] = interpolator(angles_2_target) return output # EXPERIMENTAL def _convert_halftomo_right(sino, extended_width): """ Converts a sinogram into a sinogram with extended FOV with the "half tomography" setting, with a CoR outside the image support. """ assert sino.ndim == 2 na, nx = sino.shape assert (na % 2) == 0 rotation_axis_position = extended_width // 2 assert rotation_axis_position > nx sino2 = np.pad(sino, ((0, 0), (0, rotation_axis_position - nx)), mode="reflect") return convert_halftomo(sino2, extended_width) def get_extended_sinogram_width(sino_width, rotation_axis_position): """ Compute the width (in pixels) of the extended sinogram for half-acquisition setting. """ if rotation_axis_position < (sino_width - 1) / 2.0: # CoR is on left side - flip to fallback on standard case rotation_axis_position = int(sino_width - 1 - rotation_axis_position) # TODO it would be more accurate to use int(round(2 * CoR)) # or perhaps 2*int(round(CoR))) to have a multiple of 2 # For now it is 2*int(CoR), so we can lose up to two pixels on the edges return 2 * int(rotation_axis_position) class SinoNormalization: """ A class for sinogram normalization utilities. """ kinds = [ "chebyshev", "subtraction", "division", ] operations = {"subtraction": np.subtract, "division": np.divide} def __init__(self, kind="chebyshev", sinos_shape=None, radios_shape=None, normalization_array=None): """ Initialize a SinoNormalization class. Parameters ----------- kind: str, optional Normalization type. They can be the following: - chebyshev: Each sinogram line is estimated by a Chebyshev polynomial of degree 2. This estimation is then subtracted from the sinogram. - subtraction: Each sinogram is subtracted with a user-provided array. The array can be 1D (angle-independent) and 2D (angle-dependent) - division: same as previously, but with a division operation. Default is "chebyshev" sinos_shape: tuple, optional Shape of the sinogram or sinogram stack. Either this parameter or 'radios_shape' has to be provided. radios_shape: tuple, optional Shape of the projections or projections stack. Either this parameter or 'sinos_shape' has to be provided. normalization_array: numpy.ndarray, optional Normalization array when kind='subtraction' or kind='division'. """ self._get_shapes(sinos_shape, radios_shape) self._set_kind(kind, normalization_array) _get_shapes = SinoBuilder._get_shapes def _set_kind(self, kind, normalization_array): check_supported(kind, self.kinds, "sinogram normalization kind") self.normalization_kind = kind self._normalization_instance_method = self._normalize_chebyshev # default if kind in ["subtraction", "division"]: if not isinstance(normalization_array, np.ndarray): raise ValueError( "Expected 'normalization_array' to be provided as a numpy array for normalization kind='%s'" % kind ) if normalization_array.shape[-1] != self.sinos_shape[-1]: n_a, n_x = self.sinos_shape[-2:] raise ValueError("Expected normalization_array to have shape (%d, %d) or (%d, )" % (n_a, n_x, n_x)) self.norm_operation = self.operations[kind] self._normalization_instance_method = self._normalize_op self.normalization_array = normalization_array # # Chebyshev normalization # def _normalize_chebyshev_2D(self, sino): output = sino # inplace Nr, Nc = sino.shape J = np.arange(Nc) x = 2.0 * (J + 0.5 - Nc / 2) / Nc sum0 = Nc f2 = 3.0 * x * x - 1.0 sum1 = (x**2).sum() sum2 = (f2**2).sum() for i in range(Nr): ff0 = sino[i, :].sum() ff1 = (x * sino[i, :]).sum() ff2 = (f2 * sino[i, :]).sum() output[i, :] = sino[i, :] - (ff0 / sum0 + ff1 * x / sum1 + ff2 * f2 / sum2) return output def _normalize_chebyshev_3D(self, sino): for i in range(sino.shape[0]): self._normalize_chebyshev_2D(sino[i]) return sino def _normalize_chebyshev(self, sino): if sino.ndim == 2: self._normalize_chebyshev_2D(sino) else: self._normalize_chebyshev_3D(sino) return sino # # Array subtraction/division # def _normalize_op(self, sino): if sino.ndim == 2: self.norm_operation(sino, self.normalization_array, out=sino) else: for i in range(sino.shape[0]): self.norm_operation(sino[i], self.normalization_array, out=sino[i]) return sino # # Dispatch # def normalize(self, sino): """ Normalize a sinogram or stack of sinogram. The process is done in-place, meaning that the sinogram content is overwritten. """ return self._normalization_instance_method(sino) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/reconstruction/sinogram_cuda.py0000644000175000017500000002210500000000000021417 0ustar00pierrepierreimport numpy as np import pycuda.gpuarray as garray from ..cuda.kernel import CudaKernel from ..utils import get_cuda_srcfile, updiv, deprecated_class from .sinogram import SinoBuilder, SinoNormalization from .sinogram import _convert_halftomo_right # FIXME Temporary patch from ..cuda.processing import CudaProcessing class CudaSinoBuilder(SinoBuilder): def __init__( self, sinos_shape=None, radios_shape=None, rot_center=None, halftomo=False, angles=None, cuda_options=None ): """ Initialize a CudaSinoBuilder instance. Please see the documentation of nabu.reconstruction.sinogram.Builder and nabu.cuda.processing.CudaProcessing. """ super().__init__( sinos_shape=sinos_shape, radios_shape=radios_shape, rot_center=rot_center, halftomo=halftomo, angles=angles ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_halftomo() def _init_cuda_halftomo(self): if not (self.halftomo): return kernel_name = "halftomo_kernel" self.halftomo_kernel = CudaKernel( kernel_name, get_cuda_srcfile("halftomo.cu"), signature="PPPiii", ) blk = (32, 32, 1) # tune ? self._halftomo_blksize = blk self._halftomo_gridsize = (updiv(self.extended_sino_width, blk[0]), updiv((self.n_angles + 1) // 2, blk[1]), 1) d = self.n_x - self.extended_sino_width // 2 # will have to be adapted for varying axis pos self.halftomo_weights = np.linspace(0, 1, 2 * abs(d), endpoint=True, dtype="f") self.d_halftomo_weights = garray.to_gpu(self.halftomo_weights) # Allocate one single sinogram (kernel needs c-contiguous array). # If odd number of angles: repeat last angle. self.d_sino = garray.zeros((self.n_angles + (self.n_angles & 1), self.n_x), "f") self.h_sino = self.d_sino.get() # self.cuda_processing.init_arrays_to_none(["d_output"]) if self._halftomo_flip: self.xflip_kernel = CudaKernel("reverse2D_x", get_cuda_srcfile("ElementOp.cu"), signature="Pii") blk = (32, 32, 1) self._xflip_blksize = blk self._xflip_gridsize_1 = (updiv(self.n_x, blk[0]), updiv(self.n_angles, blk[1]), 1) self._xflip_gridsize_2 = self._halftomo_gridsize # # 2D # def _get_sino_halftomo(self, radios, i, output=None): if output is None: output = self.cuda_processing.allocate_array("d_output", self.output_shape[1:]) elif output.shape != self.output_shape[1:]: raise ValueError("Expected output to have shape %s but got %s" % (self.output_shape[1:], output.shape)) d_sino = self.d_sino n_a, n_z, n_x = radios.shape d_sino[:n_a] = radios[:, i, :] if self.n_angles & 1: d_sino[-1, :].fill(0) if self._halftomo_flip: self.xflip_kernel(d_sino, n_x, n_a, grid=self._xflip_gridsize_1, block=self._xflip_blksize) # Sometimes CoR is set well outside the FoV. Not supported by cuda backend for now. # TODO/FIXME: TEMPORARY PATCH, waiting for cuda implementation if self.rot_center > self.n_x: d_sino.get(self.h_sino) # copy D2H res = _convert_halftomo_right(self.h_sino, self.extended_sino_width) output.set(res) # copy H2D # else: self.halftomo_kernel( d_sino, output, self.d_halftomo_weights, n_a, n_x, self.extended_sino_width // 2, grid=self._halftomo_gridsize, block=self._halftomo_blksize, ) if self._halftomo_flip: self.xflip_kernel( output, self.extended_sino_width, (n_a + 1) // 2, grid=self._xflip_gridsize_2, block=self._xflip_blksize ) return output # # 3D # def _get_sinos_simple(self, radios, output=None): if output is None: return radios.transpose(axes=(1, 0, 2)) # view else: # why can't I do a discontig single copy ? for i in range(radios.shape[1]): output[i] = radios[:, i, :] return output def _get_sinos_halftomo(self, radios, output=None): if output is None: output = garray.zeros(self.output_shape, "f") elif output.shape != self.output_shape: raise ValueError("Expected output to have shape %s but got %s" % (self.output_shape, output.shape)) for i in range(self.n_z): self._get_sino_halftomo(radios, i, output=output[i]) return output CudaSinoProcessing = deprecated_class("'CudaSinoProcessing' was renamed 'CudaSinoBuilder'", do_print=True)( CudaSinoBuilder ) class CudaSinoNormalization(SinoNormalization): def __init__( self, kind="chebyshev", sinos_shape=None, radios_shape=None, normalization_array=None, cuda_options=None ): super().__init__( kind=kind, sinos_shape=sinos_shape, radios_shape=radios_shape, normalization_array=normalization_array ) self._get_shapes(sinos_shape, radios_shape) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_normalization() _get_shapes = SinoBuilder._get_shapes # # Chebyshev normalization # def _init_cuda_normalization(self): self._d_tmp = garray.zeros(self.sinos_shape[-2:], "f") if self.normalization_kind == "chebyshev": self._chebyshev_kernel = CudaKernel( "normalize_chebyshev", filename=get_cuda_srcfile("normalization.cu"), signature="Piii", ) self._chebyshev_kernel_args = [np.int32(self.n_x), np.int32(self.n_angles), np.int32(self.n_z)] blk = (1, 64, 16) # TODO tune ? self._chebyshev_kernel_kwargs = { "block": blk, "grid": (1, int(updiv(self.n_angles, blk[1])), int(updiv(self.n_z, blk[2]))), } elif self.normalization_array is not None: normalization_array = self.normalization_array # If normalization_array is 1D, make a 2D array by repeating the line if normalization_array.ndim == 1: normalization_array = np.tile(normalization_array, (self.n_angles, 1)) self._d_normalization_array = garray.to_gpu(normalization_array.astype("f")) if self.normalization_kind == "subtraction": generic_op_val = 1 elif self.normalization_kind == "division": generic_op_val = 3 self._norm_kernel = CudaKernel( "inplace_generic_op_2Dby2D", filename=get_cuda_srcfile("ElementOp.cu"), signature="PPii", options=["-DGENERIC_OP=%d" % generic_op_val], ) self._norm_kernel_args = [self._d_normalization_array, np.int32(self.n_angles), np.int32(self.n_x)] blk = (32, 32, 1) self._norm_kernel_kwargs = { "block": blk, "grid": (int(updiv(self.n_angles, blk[0])), int(updiv(self.n_x, blk[1])), 1), } def _normalize_chebyshev(self, sinos): if sinos.flags.c_contiguous: self._chebyshev_kernel(sinos, *self._chebyshev_kernel_args, **self._chebyshev_kernel_kwargs) else: # This kernel seems to have an issue on arrays that are not C-contiguous. # We have to process image per image. nz = np.int32(1) nthreadsperblock = (1, 32, 1) # TODO tune nblocks = (1, int(updiv(self.n_angles, nthreadsperblock[1])), 1) for i in range(sinos.shape[0]): self._d_tmp[:] = sinos[i][:] self._chebyshev_kernel( self._d_tmp, np.int32(self.n_x), np.int32(self.n_angles), np.int32(1), grid=nblocks, block=nthreadsperblock, ) sinos[i][:] = self._d_tmp[:] return sinos # # Array subtraction/division # def _normalize_op(self, sino): if sino.ndim == 2: # Things can go wrong if "sino" is a non-contiguous 2D array # But this should be handled outside this function, as the processing is in-place self._norm_kernel(sino, *self._norm_kernel_args, **self._norm_kernel_kwargs) else: if sino.flags.forc: # Contiguous 3D array. But pycuda wants the same shape for both operands. for i in range(sino.shape[0]): self._norm_kernel(sino[i], *self._norm_kernel_args, **self._norm_kernel_kwargs) else: # Non-contiguous 2D array. Make a temp. copy for i in range(sino.shape[0]): self._d_tmp[:] = sino[i][:] self._norm_kernel(self._d_tmp, *self._norm_kernel_args, **self._norm_kernel_kwargs) sino[i][:] = self._d_tmp[:] return sino ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/reconstruction/tests/0000755000175000017500000000000000000000000017374 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/reconstruction/tests/__init__.py0000644000175000017500000000000100000000000021474 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/tests/test_cone.py0000644000175000017500000002704500000000000021741 0ustar00pierrepierreimport pytest import numpy as np from scipy.ndimage import gaussian_filter from nabu.utils import subdivide_into_overlapping_segment try: import astra __has_astra__ = True except ImportError: __has_astra__ = False from nabu.cuda.utils import __has_pycuda__, get_cuda_context if __has_pycuda__: from nabu.reconstruction.cone import ConebeamReconstructor if __has_astra__: from astra.extrautils import clipCircle @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.vol_shape = (128, 126, 126) cls.n_angles = 180 cls.prj_width = 192 # detector larger than the sample cls.src_orig_dist = 1000 cls.orig_det_dist = 100 cls.volume, cls.cone_data = generate_hollow_cube_cone_sinograms( cls.vol_shape, cls.n_angles, cls.src_orig_dist, cls.orig_det_dist, prj_width=cls.prj_width ) if __has_pycuda__: cls.ctx = get_cuda_context() @pytest.mark.skipif(not (__has_pycuda__ and __has_astra__), reason="Need pycuda and astra for this test") @pytest.mark.usefixtures("bootstrap") class TestCone: def _create_cone_reconstructor(self, relative_z_position=None): return ConebeamReconstructor( self.cone_data.shape, self.src_orig_dist, self.orig_det_dist, relative_z_position=relative_z_position, volume_shape=self.volume.shape, cuda_options={"ctx": self.ctx}, ) def test_simple_cone_reconstruction(self): C = self._create_cone_reconstructor() res = C.reconstruct(self.cone_data) delta = np.abs(res - self.volume) # Can we do better ? We already had to lowpass-filter the volume! # First/last slices are OK assert np.max(delta[:8]) < 1e-5 assert np.max(delta[-8:]) < 1e-5 # Middle region has a relatively low error assert np.max(delta[40:-40]) < 0.11 # Transition zones between "zero" and "cube" has a large error assert np.max(delta[10:25]) < 0.2 assert np.max(delta[-25:-10]) < 0.2 # End of transition zones have a smaller error assert np.max(delta[25:40]) < 0.125 assert np.max(delta[-40:-25]) < 0.125 def test_against_explicit_astra_calls(self): C = self._create_cone_reconstructor() res = C.reconstruct(self.cone_data) # # Check that ConebeamReconstructor is consistent with these calls to astra # # "vol_geom" shape layout is (y, x, z). But here this geometry is used for the reconstruction # (i.e sinogram -> volume)and not for projection (volume -> sinograms). # So we assume a square slice. Mind that this is a particular case. vol_geom = astra.create_vol_geom(self.vol_shape[2], self.vol_shape[2], self.vol_shape[0]) angles = np.linspace(0, 2 * np.pi, self.n_angles, True) proj_geom = astra.create_proj_geom( "cone", 1.0, 1.0, self.cone_data.shape[0], self.prj_width, angles, self.src_orig_dist, self.orig_det_dist, ) sino_id = astra.data3d.create("-sino", proj_geom, data=self.cone_data) rec_id = astra.data3d.create("-vol", vol_geom) cfg = astra.astra_dict("FDK_CUDA") cfg["ReconstructionDataId"] = rec_id cfg["ProjectionDataId"] = sino_id alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id) res_astra = astra.data3d.get(rec_id) # housekeeping astra.algorithm.delete(alg_id) astra.data3d.delete(rec_id) astra.data3d.delete(sino_id) assert ( np.max(np.abs(res - res_astra)) < 5e-4 ), "ConebeamReconstructor results are inconsistent with plain calls to astra" def test_projection_full_vs_partial(self): """ In the ideal case, all the data volume (and reconstruction) fits in memory. In practice this is rarely the case, so we have to reconstruct the volume slabs by slabs. The slabs should be slightly overlapping to avoid "stitching" artefacts at the edges. """ # Astra seems to duplicate the projection data, even if all GPU memory is handled externally # Let's try with (n_z * n_y * n_x + 2 * n_a * n_z * n_x) * 4 < mem_limit # 256^3 seems OK with n_a = 200 (180 MB) n_z = n_y = n_x = 256 n_a = 200 src_orig_dist = 1000 orig_det_dist = 100 volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist ) C_full = ConebeamReconstructor(cone_data.shape, src_orig_dist, orig_det_dist, cuda_options={"ctx": self.ctx}) vol_geom = astra.create_vol_geom(n_y, n_x, n_z) proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, n_z, n_x, C_full.angles, src_orig_dist, orig_det_dist) proj_id, projs_full_geom = astra.create_sino3d_gpu(volume, proj_geom, vol_geom) astra.data3d.delete(proj_id) # Do the same slab-by-slab inner_slab_size = 64 overlap = 16 slab_size = inner_slab_size + overlap * 2 slabs = subdivide_into_overlapping_segment(n_z, slab_size, overlap) projs_partial_geom = np.zeros_like(projs_full_geom) for slab in slabs: z_min, z_inner_min, z_inner_max, z_max = slab rel_z_pos = (z_min + z_max) / 2 - n_z / 2 subvolume = volume[z_min:z_max, :, :] C = ConebeamReconstructor( (z_max - z_min, n_a, n_x), src_orig_dist, orig_det_dist, relative_z_position=rel_z_pos, cuda_options={"ctx": self.ctx}, ) proj_id, projs = astra.create_sino3d_gpu(subvolume, C.proj_geom, C.vol_geom) astra.data3d.delete(proj_id) projs_partial_geom[z_inner_min:z_inner_max] = projs[z_inner_min - z_min : z_inner_max - z_min] error_profile = [ np.max(np.abs(proj_partial - proj_full)) for proj_partial, proj_full in zip(projs_partial_geom, projs_full_geom) ] assert np.all(np.isclose(error_profile, 0.0, atol=0.0375)), "Mismatch between full-cone and slab geometries" def test_cone_reconstruction_magnified_vs_demagnified(self): """ This will only test the astra toolbox. When reconstructing a volume from cone-beam data, the volume "should" have a smaller shape than the projection data shape (because of cone magnification). But astra provides the same results when backprojecting on a "de-magnified grid" and the original grid shape. """ n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 100 magnification = 1 + orig_det_dist / src_orig_dist angles = np.linspace(0, 2 * np.pi, n_a, True) volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) rec_original_grid = astra_cone_beam_reconstruction( cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=False ) rec_reduced_grid = astra_cone_beam_reconstruction( cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=True ) m_z = (n_z - int(n_z / magnification)) // 2 m_y = (n_y - int(n_y / magnification)) // 2 m_x = (n_x - int(n_x / magnification)) // 2 assert np.allclose(rec_original_grid[m_z:-m_z, m_y:-m_y, m_x:-m_x], rec_reduced_grid) def test_reconstruction_full_vs_partial(self): n_z = n_y = n_x = 256 n_a = 500 src_orig_dist = 1000 orig_det_dist = 100 angles = np.linspace(0, 2 * np.pi, n_a, True) volume, cone_data = generate_hollow_cube_cone_sinograms( vol_shape=(n_z, n_y, n_x), n_angles=n_a, src_orig_dist=src_orig_dist, orig_det_dist=orig_det_dist, apply_filter=False, ) rec_full_volume = astra_cone_beam_reconstruction(cone_data, angles, src_orig_dist, orig_det_dist) rec_partial = np.zeros_like(rec_full_volume) inner_slab_size = 64 overlap = 18 slab_size = inner_slab_size + overlap * 2 slabs = subdivide_into_overlapping_segment(n_z, slab_size, overlap) for slab in slabs: z_min, z_inner_min, z_inner_max, z_max = slab m1, m2 = z_inner_min - z_min, z_max - z_inner_max C = ConebeamReconstructor((z_max - z_min, n_a, n_x), src_orig_dist, orig_det_dist) rec = C.reconstruct( cone_data[z_min:z_max], relative_z_position=((z_min + z_max) / 2) - n_z / 2, # (z_min + z_max)/2. ) rec_partial[z_inner_min:z_inner_max] = rec[m1 : (-m2) or None] # Compare volumes in inner circle for i in range(n_z): clipCircle(rec_partial[i]) clipCircle(rec_full_volume[i]) diff = np.abs(rec_partial - rec_full_volume) err_max_profile = np.max(diff, axis=(-1, -2)) err_median_profile = np.median(diff, axis=(-1, -2)) assert np.max(err_max_profile) < 2e-3 assert np.max(err_median_profile) < 5e-6 def generate_hollow_cube_cone_sinograms( vol_shape, n_angles, src_orig_dist, orig_det_dist, prj_width=None, apply_filter=True ): # Adapted from Astra toolbox python samples n_z, n_y, n_x = vol_shape vol_geom = astra.create_vol_geom(n_y, n_x, n_z) prj_width = prj_width or n_x prj_height = n_z angles = np.linspace(0, 2 * np.pi, n_angles, True) proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, prj_width, prj_width, angles, src_orig_dist, orig_det_dist) magnification = 1 + orig_det_dist / src_orig_dist # hollow cube cube = np.zeros(astra.geom_size(vol_geom), dtype="f") d = int(min(n_x, n_y) / 2 * (1 - np.sqrt(2) / 2)) cube[20:-20, d:-d, d:-d] = 1 cube[40:-40, d + 20 : -(d + 20), d + 20 : -(d + 20)] = 0 # d = int(min(n_x, n_y) / 2 * (1 - np.sqrt(2) / 2) * magnification) # d1 = d + 10 # d2 = d + 20 # cube[40:-40, d1:-d1, d1:-d1] = 1 # cube[60:-60, d2 : -d2, d2 : -d2] = 0 # High-frequencies yield cannot be accurately retrieved if apply_filter: cube = gaussian_filter(cube, (1.0, 1.0, 1.0)) proj_id, proj_data = astra.create_sino3d_gpu(cube, proj_geom, vol_geom) astra.data3d.delete(proj_id) # (n_z, n_angles, n_x) return cube, proj_data def astra_cone_beam_reconstruction(cone_data, angles, src_orig_dist, orig_det_dist, demagnify_volume=False): """ Handy (but data-inefficient) function to reconstruct data from cone-beam geometry """ n_z, n_a, n_x = cone_data.shape proj_geom = astra.create_proj_geom("cone", 1.0, 1.0, n_z, n_x, angles, src_orig_dist, orig_det_dist) sino_id = astra.data3d.create("-sino", proj_geom, data=cone_data) m = 1 + orig_det_dist / src_orig_dist if demagnify_volume else 1.0 n_z_vol, n_y_vol, n_x_vol = int(n_z / m), int(n_x / m), int(n_x / m) vol_geom = astra.create_vol_geom(n_y_vol, n_x_vol, n_z_vol) rec_id = astra.data3d.create("-vol", vol_geom) cfg = astra.astra_dict("FDK_CUDA") cfg["ReconstructionDataId"] = rec_id cfg["ProjectionDataId"] = sino_id alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id) rec = astra.data3d.get(rec_id) astra.data3d.delete(sino_id) astra.data3d.delete(rec_id) astra.algorithm.delete(alg_id) return rec ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/tests/test_deringer.py0000644000175000017500000000622500000000000022611 0ustar00pierrepierreimport numpy as np import pytest from nabu.utils import clip_circle from nabu.testutils import get_data, compare_arrays from nabu.reconstruction.rings import MunchDeringer from nabu.thirdparty.pore3d_deringer_munch import munchetal_filter from nabu.cuda.utils import __has_pycuda__, get_cuda_context __have_gpuderinger__ = False if __has_pycuda__: import pycuda.gpuarray as garray from nabu.reconstruction.rings_cuda import CudaMunchDeringer, __have_pycudwt__, __have_skcuda__ if __have_pycudwt__ and __have_skcuda__: __have_gpuderinger__ = True @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino = get_data("mri_sino500.npz")["data"] cls.tol = 5e-3 cls.rings = {150: 0.5, -150: 0.5} cls.fw_levels = 4 cls.fw_sigma = 1.0 cls.fw_wname = "db5" if __have_gpuderinger__: cls.ctx = get_cuda_context(cleanup_at_exit=False) yield if __have_gpuderinger__: cls.ctx.pop() @pytest.mark.usefixtures("bootstrap") class TestMunchDeringer: @staticmethod def add_stripes_to_sino(sino, rings_desc): """ Create a new sinogram by adding synthetic stripes to an existing one. Parameters ---------- sino: array-like Sinogram. rings_desc: dict Dictionary describing the stripes locations and intensity. The location is an integer in [0, N[ where N is the number of columns. The intensity is a float: percentage of the current column mean value. """ sino_out = np.copy(sino) for loc, intensity in rings_desc.items(): sino_out[:, loc] += sino[:, loc].mean() * intensity return sino_out @pytest.mark.skipif(munchetal_filter is None, reason="Need PyWavelets for this test") def test_munch_deringer(self): deringer = MunchDeringer(self.fw_sigma, levels=self.fw_levels, wname=self.fw_wname, sinos_shape=self.sino.shape) sino = self.add_stripes_to_sino(self.sino, self.rings) # Reference destriping with pore3d "munchetal_filter" ref = munchetal_filter(sino, self.fw_levels, self.fw_sigma, wname=self.fw_wname) # Wrapping with DeRinger res = np.zeros((1,) + sino.shape, dtype=np.float32) deringer.remove_rings(sino, output=res) err_max = np.max(np.abs(res[0] - ref)) assert err_max < self.tol, "Max error is too high" @pytest.mark.skipif( not (__have_gpuderinger__) or munchetal_filter is None, reason="Need pycuda, pycudwt and scikit-cuda for this test", ) def test_cuda_munch_deringer(self): sino = self.add_stripes_to_sino(self.sino, self.rings) deringer = CudaMunchDeringer( self.fw_sigma, levels=self.fw_levels, wname=self.fw_wname, sinos_shape=self.sino.shape, cuda_options={"ctx": self.ctx}, ) d_sino = garray.to_gpu(sino) deringer.remove_rings(d_sino) res = d_sino.get() ref = munchetal_filter(sino, self.fw_levels, self.fw_sigma, wname=self.fw_wname) err_max = np.max(np.abs(res - ref)) assert err_max < 1e-1, "Max error is too high" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/reconstruction/tests/test_fbp.py0000644000175000017500000001450700000000000021563 0ustar00pierrepierreimport numpy as np import pytest from scipy.ndimage import shift from nabu.pipeline.params import fbp_filters from nabu.utils import clip_circle from nabu.testutils import get_data from nabu.cuda.utils import __has_pycuda__, __has_cufft__ if __has_pycuda__: from nabu.reconstruction.fbp import Backprojector @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino_512 = get_data("mri_sino500.npz")["data"] cls.ref_512 = get_data("mri_rec_astra.npz")["data"] cls.sino_511 = cls.sino_512[:, :-1] cls.tol = 5e-2 @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and scikit-cuda for this test") @pytest.mark.usefixtures("bootstrap") class TestFBP: @staticmethod def clip_to_inner_circle(img, radius_factor=0.99): radius = int(radius_factor * max(img.shape) / 2) return clip_circle(img, radius=radius) def test_fbp_512(self): """ Simple test of a FBP on a 512x512 slice """ B = Backprojector((500, 512)) res = B.fbp(self.sino_512) delta_clipped = self.clip_to_inner_circle(res - self.ref_512) err_max = np.max(np.abs(delta_clipped)) assert err_max < self.tol, "Max error is too high" def test_fbp_511(self): """ Test FBP of a 511x511 slice where the rotation axis is at (512-1)/2.0 """ B = Backprojector((500, 511), rot_center=255.5) res = B.fbp(self.sino_511) ref = self.ref_512[:-1, :-1] delta_clipped = self.clip_to_inner_circle(res - ref) err_max = np.max(np.abs(delta_clipped)) assert err_max < self.tol, "Max error is too high" def test_fbp_roi(self): """ Test FBP in region of interest """ sino = self.sino_511 B0 = Backprojector(sino.shape, rot_center=255.5) ref = B0.fbp(sino) def backproject_roi(roi, reference): B = Backprojector(sino.shape, rot_center=255.5, slice_roi=roi) res = B.fbp(sino) err_max = np.max(np.abs(res - ref)) return err_max cases = { # Test 1: use slice_roi=(0, -1, 0, -1), i.e plain FBP of whole slice 1: [(0, None, 0, None), ref], # Test 2: horizontal strip 2: [(0, None, 50, 100), ref[50:100, :]], # Test 3: vertical strip 3: [(60, 65, 0, None), ref[:, 60:65]], # Test 4: rectangular inner ROI 4: [(157, 162, 260, -10), ref[260:-10, 157:162]], } for roi, ref in cases.values(): err_max = backproject_roi(roi, ref) assert err_max < self.tol, str("backproject_roi: max error is too high for ROI=%s" % str(roi)) def test_fbp_axis_corr(self): """ Test the "axis correction" feature """ sino = self.sino_512 # Create a sinogram with a drift in the rotation axis def create_drifted_sino(sino, drifts): out = np.zeros_like(sino) for i in range(sino.shape[0]): out[i] = shift(sino[i], drifts[i]) return out drifts = np.linspace(0, 20, sino.shape[0]) sino = create_drifted_sino(sino, drifts) B = Backprojector(sino.shape, extra_options={"axis_correction": drifts}) res = B.fbp(sino) delta_clipped = clip_circle(res - self.ref_512, radius=200) err_max = np.max(np.abs(delta_clipped)) # Max error is relatively high, migh be due to interpolation of scipy shift in sinogram assert err_max < 10.0, "Max error is too high" def test_fbp_clip_circle(self): """ Test the "clip outer circle" parameter in (extra options) """ sino = self.sino_512 for rot_center in [None, sino.shape[1] / 2.0 - 10, sino.shape[1] / 2.0 + 15]: B = Backprojector(sino.shape, rot_center=rot_center, extra_options={"clip_outer_circle": True}) res = B.fbp(sino) B0 = Backprojector(sino.shape, rot_center=rot_center, extra_options={"clip_outer_circle": False}) res_noclip = B0.fbp(sino) ref = self.clip_to_inner_circle(res_noclip, radius_factor=1) err_max = np.max(np.abs(res - ref)) assert err_max < 1e-5, "Max error is too high" def test_fbp_centered_axis(self): """ Test the "centered_axis" parameter (in extra options) """ sino = np.pad(self.sino_512, ((0, 0), (100, 0))) rot_center = (self.sino_512.shape[1] - 1) / 2.0 + 100 B0 = Backprojector(self.sino_512.shape) ref = B0.fbp(self.sino_512) # Check that "centered_axis" worked B = Backprojector(sino.shape, rot_center=rot_center, extra_options={"centered_axis": True}) res = B.fbp(sino) # The outside region (outer circle) is different as "res" is a wider slice diff = self.clip_to_inner_circle(res[50:-50, 50:-50] - ref) err_max = np.max(np.abs(diff)) assert err_max < 5e-2, "centered_axis without clip_circle: something wrong" # Check that "clip_outer_circle" works when used jointly with "centered_axis" B = Backprojector( sino.shape, rot_center=rot_center, extra_options={ "centered_axis": True, "clip_outer_circle": True, }, ) res2 = B.fbp(sino) diff = res2 - self.clip_to_inner_circle(res, radius_factor=1) err_max = np.max(np.abs(diff)) assert err_max < 1e-5, "centered_axis with clip_circle: something wrong" def test_fbp_filters(self): for filter_name in set(fbp_filters.values()): if filter_name in [None, "ramlak"]: continue fbp = Backprojector(self.sino_512.shape, filter_name=filter_name) res = fbp.fbp(self.sino_512) # not sure what to check in this case def test_differentiated_backprojection(self): # test Hilbert + DBP sino_diff = np.diff(self.sino_512, axis=1, prepend=0).astype("f") # Need to translate the axis a little bit, because of non-centered differentiation. # prepend -> +0.5 ; append -> -0.5 fbp = Backprojector(sino_diff.shape, filter_name="hilbert", rot_center=255.5 + 0.5) rec = fbp.fbp(sino_diff) # Looks good, but all frequencies are not recovered. Use a metric like SSIM or FRC ? ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/tests/test_halftomo.py0000644000175000017500000001552300000000000022624 0ustar00pierrepierreimport os import numpy as np import pytest import h5py try: from algotom.prep.conversion import convert_sinogram_360_to_180 __has_algotom__ = True except ImportError: __has_algotom__ = False from nabu.testutils import compare_arrays, utilstest from nabu.reconstruction.sinogram import SinoBuilder from nabu.cuda.utils import __has_pycuda__ if __has_pycuda__: from nabu.cuda.utils import get_cuda_context import pycuda.gpuarray as garray from nabu.reconstruction.sinogram_cuda import CudaSinoBuilder @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls sino, sino_ref, cor = get_data_h5("halftomo_new.h5") cls.sino = sino cls.radios = convert_sino_to_radios_stack(sino) cls.sino_ref = sino_ref cls.rot_center = cor cls.tol = 5e-3 if __has_pycuda__: cls.ctx = get_cuda_context() def convert_sino_to_radios_stack(sino): return np.moveaxis(np.tile(sino, (1, 1, 1)), 1, 0) def get_data_h5(*dataset_path): dataset_relpath = os.path.join(*dataset_path) dataset_path = utilstest.getfile(dataset_relpath) with h5py.File(dataset_path, "r") as hf: sino = hf["entry/radio/results/data"][()] sino_extended_ref = hf["entry/sino/results/data"][()] cor = hf["entry/sino/configuration/configuration/rotation_axis_position"][()] return sino, sino_extended_ref, cor def generate_halftomo_sinogram_algotom(sino, cor): """ Generate the 180 degrees sinogram with algotom. """ n_angles, dwidth = sino.shape # If the sinogram has an even number of projections n_a that are exactly matched, # then the resulting sinogram should have n_a//2 angles. # # 0 0 180 # 1 1 181 # 2 2 182 # ... -- convert --> ... ... # 180 179 360-1 # 181 # ... # 360-1 # # Yet for some reason algotom yields n_a//2 + 1 angles, because the "180 degrees" # radio is used in both "sino_top" and "sino_bottom". # Thus we have to cheat a little bit and use an odd number of projections. if n_angles % 2 == 0: sino = np.vstack([sino, sino[-1]]) # could even be zeros for the last line # In nabu we use the following overlap width. # algotom default is 2 * (dwidth - cor - 1), which yields 2 extra pixels overlap_width = 2 * (dwidth - int(cor)) sino_halftomo, new_cor = convert_sinogram_360_to_180( sino, (overlap_width, 1), norm=False # 1 means that CoR is on the right ) if n_angles % 2 == 0: sino_halftomo = sino_halftomo[:-1, :] return sino_halftomo @pytest.mark.usefixtures("bootstrap") class TestHalftomo: def _build_sinos(self, radios, output=None, backend="python"): sinobuilder_cls = CudaSinoBuilder if backend == "cuda" else SinoBuilder sino_builder = sinobuilder_cls(radios_shape=radios.shape, rot_center=self.rot_center, halftomo=True) sinos_halftomo = sino_builder.get_sinos(radios, output=output) return sinos_halftomo def _check_result(self, sino, test_description): _, err = compare_arrays(sino, self.sino_ref, self.tol, return_residual=True) assert err < self.tol, "Something wrong with %s" % test_description def test_halftomo(self): sinos_halftomo = self._build_sinos(self.radios, backend="python") self._check_result(sinos_halftomo[0], "SinoBuilder.get_sinos, halftomo=True") @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") def test_cuda_halftomo(self): d_radios = garray.to_gpu(self.radios) d_sinos = garray.zeros((1,) + self.sino_ref.shape, "f") self._build_sinos(d_radios, output=d_sinos, backend="cuda") self._check_result(d_sinos.get()[0], "CudaSinoBuilder.get_sinos, halftomo=True") def _get_sino_with_odd_nprojs(self): n_a = self.sino.shape[0] # dummy line inserted in the middle, # so that result should match reference sinogram with an even number of angles sino_odd = np.vstack([self.sino[: n_a // 2], self.sino[-1], self.sino[n_a // 2 :]]) # dummy line return sino_odd def test_halftomo_odd(self): sino_odd = self._get_sino_with_odd_nprojs() radios = convert_sino_to_radios_stack(sino_odd) assert radios.shape[0] & 1, "Radios must have a odd number of angles" sinos = self._build_sinos(radios, backend="python") sino_halftomo = sinos[0][:-1, :] self._check_result(sino_halftomo, "SinoBuilder.get_sinos, halftomo=True, odd number of projs") @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") def test_cuda_halftomo_odd(self): sino_odd = self._get_sino_with_odd_nprojs() radios = convert_sino_to_radios_stack(sino_odd) assert radios.shape[0] & 1, "Radios must have a odd number of angles" d_radios = garray.to_gpu(radios) d_out = garray.zeros((1, self.sino_ref.shape[0] + 1, self.sino_ref.shape[1]), "f") self._build_sinos(d_radios, output=d_out, backend="cuda") sino_halftomo = d_out.get()[0][:-1, :] self._check_result(sino_halftomo, "CudaSinoBuilder.get_sinos, halftomo=True, odd number of projs") @staticmethod def _flip_array(arr): if arr.ndim == 2: return np.fliplr(arr) res = np.zeros_like(arr) for i in range(arr.shape[0]): res[i] = np.fliplr(arr[i]) return res def test_halftomo_left(self): na, nz, nx = self.radios.shape left_cor = nx - 1 - self.rot_center radios = self._flip_array(self.radios) sino_builder = SinoBuilder(radios_shape=radios.shape, rot_center=left_cor, halftomo=True) sinos_halftomo = sino_builder.get_sinos(radios) _, err = compare_arrays( sinos_halftomo[0], self._flip_array(self.sino_ref), self.tol, return_residual=True, ) assert err < self.tol, "Something wrong with SinoBuilder.radios_to_sino, halftomo=True" @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for this test") def test_cuda_halftomo_left(self): na, nz, nx = self.radios.shape left_cor = nx - 1 - self.rot_center radios = self._flip_array(self.radios) sino_processing = CudaSinoBuilder(radios_shape=radios.shape, rot_center=left_cor, halftomo=True) d_radios = garray.to_gpu(radios) d_sinos = garray.zeros(sino_processing.sinos_halftomo_shape, "f") sino_processing.get_sinos(d_radios, output=d_sinos) sino_halftomo = d_sinos.get()[0] _, err = compare_arrays(sino_halftomo, self._flip_array(self.sino_ref), self.tol, return_residual=True) assert err < self.tol, "Something wrong with SinoBuilder.radios_to_sino, halftomo=True" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/tests/test_reconstructor.py0000644000175000017500000000622500000000000023726 0ustar00pierrepierreimport numpy as np import pytest from nabu.testutils import ( get_big_data, __big_testdata_dir__, compare_arrays, generate_tests_scenarios, __do_long_tests__, ) from nabu.cuda.utils import __has_pycuda__, __has_cufft__, get_cuda_context __has_cuda_fbp__ = __has_cufft__ and __has_pycuda__ if __has_cuda_fbp__: from nabu.reconstruction.reconstructor_cuda import CudaReconstructor from nabu.reconstruction.fbp import Backprojector as CudaBackprojector import pycuda.gpuarray as garray scenarios = generate_tests_scenarios( { "axis": ["z", "y", "x"], "vol_type": ["sinograms", "projections"], "indices": [(300, 310)], # reconstruct 10 slices "slices_roi": [None, (None, None, 200, 400), (250, 300, None, None), (120, 330, 301, 512)], } ) @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.projs = get_big_data("MRI512_projs.npy") if __has_cuda_fbp__: cls.ctx = get_cuda_context() cls.d_projs = garray.to_gpu(cls.projs) cls.ref = None cls.tol = 5e-2 @pytest.mark.skipif( __big_testdata_dir__ is None or not (__do_long_tests__), reason="need environment variable NABU_BIGDATA_DIR and NABU_LONG_TESTS=1", ) @pytest.mark.usefixtures("bootstrap") class TestReconstructor: @pytest.mark.skipif(not (__has_cuda_fbp__), reason="need pycuda and scikit-cuda") @pytest.mark.parametrize("config", scenarios) def test_cuda_reconstructor(self, config): data = self.projs d_data = self.d_projs if config["vol_type"] == "sinograms": data = np.moveaxis(self.projs, 1, 0) d_data = self.d_projs.transpose(axes=(1, 0, 2)) # view reconstructor = CudaReconstructor( data.shape, config["indices"], axis=config["axis"], vol_type=config["vol_type"], slices_roi=config["slices_roi"], ) res = reconstructor.reconstruct(d_data) ref = self.get_ref() ref = self.crop_array(ref, config) err_max = np.max(np.abs(res - ref)) assert err_max < self.tol, "something wrong with reconstructor, config = %s" % str(config) def get_ref(self): if self.ref is not None: return self.ref if __has_cuda_fbp__: fbp_cls = CudaBackprojector ref = np.zeros((512, 512, 512), "f") fbp = fbp_cls((self.projs.shape[0], self.projs.shape[-1])) for i in range(512): ref[i] = fbp.fbp(self.d_projs[:, i, :]) self.ref = ref return self.ref @staticmethod def crop_array(arr, config): indices = config["indices"] axis = config["axis"] slices_roi = config["slices_roi"] or (None, None, None, None) i_slice = slice(*indices) u_slice = slice(*slices_roi[:2]) v_slice = slice(*slices_roi[-2:]) if axis == "z": z_slice, y_slice, x_slice = i_slice, v_slice, u_slice if axis == "y": z_slice, y_slice, x_slice = v_slice, i_slice, u_slice if axis == "x": z_slice, y_slice, x_slice = v_slice, u_slice, i_slice return arr[z_slice, y_slice, x_slice] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/reconstruction/tests/test_sino_normalization.py0000644000175000017500000000715400000000000024732 0ustar00pierrepierreimport os.path as path import numpy as np import pytest from nabu.testutils import get_data from nabu.cuda.utils import __has_pycuda__ from nabu.reconstruction.sinogram import SinoNormalization if __has_pycuda__: from nabu.reconstruction.sinogram_cuda import CudaSinoNormalization import pycuda.gpuarray as garray @pytest.fixture(scope="class") def bootstrap(request): cls = request.cls cls.sino = get_data("sino_refill.npy") cls.tol = 1e-7 cls.norm_array_1D = np.arange(cls.sino.shape[-1]) + 1 cls.norm_array_2D = np.arange(cls.sino.size).reshape(cls.sino.shape) + 1 @pytest.mark.usefixtures("bootstrap") class TestSinoNormalization: def test_sino_normalization(self): sino_proc = SinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) sino = self.sino.copy() sino_proc.normalize(sino) @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for sinogram normalization with cuda backend") def test_sino_normalization_cuda(self): sino_proc = SinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) sino = self.sino.copy() ref = sino_proc.normalize(sino) cuda_sino_proc = CudaSinoNormalization(kind="chebyshev", sinos_shape=self.sino.shape) d_sino = garray.to_gpu(self.sino) cuda_sino_proc.normalize(d_sino) res = d_sino.get() assert np.max(np.abs(res - ref)) < self.tol def get_normalization_reference_result(self, op, normalization_arr): # Perform explicit operations to compare with numpy.divide, numpy.subtract, etc if op == "subtraction": ref = self.sino - normalization_arr elif op == "division": ref = self.sino / normalization_arr return ref def test_sino_array_subtraction_and_division(self): with pytest.raises(ValueError): SinoNormalization(kind="subtraction", sinos_shape=self.sino.shape) def compare_normalizations(normalization_arr, op): sino_normalization = SinoNormalization( kind=op, sinos_shape=self.sino.shape, normalization_array=normalization_arr ) sino = self.sino.copy() sino_normalization.normalize(sino) ref = self.get_normalization_reference_result(op, normalization_arr) assert np.allclose(sino, ref), "operation=%s, normalization_array dims=%d" % (op, normalization_arr.ndim) compare_normalizations(self.norm_array_1D, "subtraction") compare_normalizations(self.norm_array_1D, "division") compare_normalizations(self.norm_array_2D, "subtraction") compare_normalizations(self.norm_array_2D, "division") @pytest.mark.skipif(not (__has_pycuda__), reason="Need pycuda for sinogram normalization with cuda backend") def test_sino_array_subtraction_cuda(self): with pytest.raises(ValueError): CudaSinoNormalization(kind="subtraction", sinos_shape=self.sino.shape) def compare_normalizations(normalization_arr, op): sino_normalization = CudaSinoNormalization( kind=op, sinos_shape=self.sino.shape, normalization_array=normalization_arr ) sino = garray.to_gpu(self.sino) sino_normalization.normalize(sino) ref = self.get_normalization_reference_result(op, normalization_arr) assert np.allclose(sino.get(), ref) compare_normalizations(self.norm_array_1D, "subtraction") compare_normalizations(self.norm_array_2D, "subtraction") compare_normalizations(self.norm_array_1D, "division") compare_normalizations(self.norm_array_2D, "division") ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/resources/0000755000175000017500000000000000000000000015163 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/resources/__init__.py0000644000175000017500000000000000000000000017262 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/resources/cli/0000755000175000017500000000000000000000000015732 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/resources/cli/__init__.py0000644000175000017500000000000000000000000020031 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1633071556.0 nabu-2023.1.1/nabu/resources/cor.py0000644000175000017500000000025200000000000016317 0ustar00pierrepierrefrom ..pipeline.estimators import CORFinder as ECorFinder # This class has moved # The future location will be nabu.pipeline.estimators.CORFinder CORFinder = ECorFinder ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682590465.0 nabu-2023.1.1/nabu/resources/dataset_analyzer.py0000644000175000017500000003164300000000000021076 0ustar00pierrepierreimport os import numpy as np from silx.io import get_data from silx.io.url import DataUrl from tomoscan.esrf.scan.edfscan import EDFTomoScan from tomoscan.esrf.scan.hdf5scan import HDF5TomoScan from ..utils import check_supported from ..io.utils import get_compacted_dataslices from .utils import is_hdf5_extension, get_values_from_file from .logger import LoggerOrPrint # Wait for next tomoscan release to ship "nexus_version" from packaging.version import parse as parse_version from tomoscan.version import version as tomoscan_version _tomoscan_has_nxversion = parse_version(tomoscan_version) > parse_version("0.6.0") # class DatasetAnalyzer: _scanner = None kind = "none" """ Base class for datasets analyzers. """ def __init__(self, location, extra_options=None, logger=None): """ Initialize a Dataset analyzer. Parameters ---------- location: str Dataset location (directory or file name) extra_options: dict, optional Extra options on how to interpret the dataset. logger: logging object, optional Logger. If not set, messages will just be printed in stdout. """ self.logger = LoggerOrPrint(logger) self.location = location self._set_extra_options(extra_options) self._get_excluded_projections() self._set_default_dataset_values() self._init_dataset_scan() self._finish_init() def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} # COMPAT. advanced_options = { "force_flatfield": False, "output_dir": None, "exclude_projections": None, "hdf5_entry": None, } if _tomoscan_has_nxversion: advanced_options["nx_version"] = 1.0 # -- advanced_options.update(extra_options) self.extra_options = advanced_options def _get_excluded_projections(self): excluded_projs = self.extra_options["exclude_projections"] if excluded_projs is None: return projs_idx = get_values_from_file(excluded_projs, any_size=True).astype(np.int32).tolist() self.logger.info("Ignoring projections: %s" % (str(projs_idx))) self.extra_options["exclude_projections"] = projs_idx def _init_dataset_scan(self, **kwargs): if self._scanner is None: raise ValueError("Base class") if self._scanner is HDF5TomoScan: if self.extra_options.get("hdf5_entry", None) is not None: kwargs["entry"] = self.extra_options["hdf5_entry"] if self.extra_options.get("nx_version", None) is not None: kwargs["nx_version"] = self.extra_options["nx_version"] if self._scanner is EDFTomoScan: # Assume 1 frame per file (otherwise too long to open each file) kwargs["n_frames"] = 1 self.dataset_scanner = self._scanner( # pylint: disable=E1102 self.location, ignore_projections=self.extra_options["exclude_projections"], **kwargs ) self.projections = self.dataset_scanner.projections self.flats = self.dataset_scanner.flats self.darks = self.dataset_scanner.darks self.n_angles = len(self.dataset_scanner.projections) self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2) self._radio_dims_notbinned = self.radio_dims # COMPAT def _finish_init(self): pass def _set_default_dataset_values(self): self._detector_tilt = None self.translations = None self.ctf_translations = None self.axis_position = None self._rotation_angles = None self.z_per_proj = None self.x_per_proj = None self._energy = None self._pixel_size = None self._distance = None self.flats_srcurrent = None @property def energy(self): """ Return the energy in kev. """ if self._energy is None: self._energy = self.dataset_scanner.energy return self._energy @energy.setter def energy(self, val): self._energy = val @property def distance(self): """ Return the sample-detector distance in meters. """ if self._distance is None: self._distance = abs(self.dataset_scanner.distance) return self._distance @distance.setter def distance(self, val): self._distance = val @property def pixel_size(self): """ Return the pixel size in microns. """ # TODO X and Y pixel size if self._pixel_size is None: self._pixel_size = self.dataset_scanner.pixel_size * 1e6 return self._pixel_size @pixel_size.setter def pixel_size(self, val): self._pixel_size = val def _get_rotation_angles(self): return self._rotation_angles # None by default @property def rotation_angles(self): """ Return the rotation angles in radians. """ return self._get_rotation_angles() @rotation_angles.setter def rotation_angles(self, angles): self._rotation_angles = angles def _is_halftomo(self): return None # base class @property def is_halftomo(self): """ Indicates whether the current dataset was performed with half acquisition. """ return self._is_halftomo() @property def detector_tilt(self): """ Return the detector tilt in degrees """ return self._detector_tilt @detector_tilt.setter def detector_tilt(self, tilt): self._detector_tilt = tilt @property def projections_srcurrent(self): """ Return the synchrotron electric current for each projection. """ srcurrent = self.dataset_scanner.electric_current if srcurrent is None or len(srcurrent) == 0: return None srcurrent_all = np.array(srcurrent) projections_indices = np.array(sorted(self.projections.keys())) if np.any(projections_indices >= len(srcurrent_all)): self.logger.error("Something wrong with SRCurrent: not enough values!") return None return srcurrent_all[projections_indices].astype("f") def check_defined_attribute(self, name, error_msg=None): """ Utility function to check that a given attribute is defined. """ if getattr(self, name, None) is None: raise ValueError(error_msg or str("No information on %s was found in the dataset" % name)) class EDFDatasetAnalyzer(DatasetAnalyzer): """ EDF Dataset analyzer for legacy ESRF acquisitions """ _scanner = EDFTomoScan kind = "edf" def _finish_init(self): self.remove_unused_radios() def remove_unused_radios(self): """ Remove "unused" radios. This is used for legacy ESRF scans. """ # Extraneous projections are assumed to be on the end projs_indices = sorted(self.projections.keys()) used_radios_range = range(projs_indices[0], len(self.projections)) radios_not_used = [] for idx in self.projections.keys(): if idx not in used_radios_range: radios_not_used.append(idx) for idx in radios_not_used: self.projections.pop(idx) return radios_not_used def _get_flats_darks(self): return @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset. Not applicable for EDF (return None) """ return None def _is_halftomo(self): return None def _get_rotation_angles(self): if self._rotation_angles is None: scan_range = self.dataset_scanner.scan_range if scan_range is not None: fullturn = abs(scan_range - 360) < abs(scan_range - 180) angles = np.linspace(0, scan_range, num=self.dataset_scanner.tomo_n, endpoint=fullturn, dtype="f") self._rotation_angles = np.deg2rad(angles) return self._rotation_angles class HDF5DatasetAnalyzer(DatasetAnalyzer): """ HDF5 dataset analyzer """ _scanner = HDF5TomoScan kind = "hdf5" @property def z_translation(self): raw_data = np.array(self.dataset_scanner.z_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size @property def x_translation(self): raw_data = np.array(self.dataset_scanner.x_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size def _get_rotation_angles(self): if self._rotation_angles is None: angles = np.array(self.dataset_scanner.rotation_angle) projs_idx = np.array(list(self.projections.keys())) angles = angles[projs_idx] self._rotation_angles = np.deg2rad(angles) return self._rotation_angles def _get_dataset_hdf5_url(self): if len(self.projections) > 0: frames_to_take = self.projections elif len(self.flats) > 0: frames_to_take = self.flats elif len(self.darks) > 0: frames_to_take = self.darks else: raise ValueError("No projections, no flats and no darks ?!") first_proj_idx = sorted(frames_to_take.keys())[0] first_proj_url = frames_to_take[first_proj_idx] return DataUrl( file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), data_slice=None, scheme="silx" ) @property def dataset_hdf5_url(self): return self._get_dataset_hdf5_url() @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset """ return self.dataset_scanner.entry def _is_halftomo(self): try: is_halftomo = self.dataset_scanner.field_of_view.value.lower() == "half" except: is_halftomo = None return is_halftomo def get_data_slices(self, what): """ Return indices in the data volume where images correspond to a given kind. Parameters ---------- what: str Which keys to get. Can be "projections", "flats", "darks" Returns -------- slices: list of slice A list where each item is a slice. """ check_supported(what, ["projections", "flats", "darks"], "image type") images = getattr(self, what) # dict # we can't directly use set() on slice() object (unhashable). Use tuples slices = set() for du in get_compacted_dataslices(images).values(): if du.data_slice() is not None: s = (du.data_slice().start, du.data_slice().stop) else: s = None slices.add(s) slices_list = [slice(item[0], item[1]) if item is not None else None for item in list(slices)] return slices_list def analyze_dataset(dataset_path, extra_options=None, logger=None): if not (os.path.isdir(dataset_path)): if not (os.path.isfile(dataset_path)): raise ValueError("Error: %s no such file or directory" % dataset_path) if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))): raise ValueError("Error: expected a HDF5 file") dataset_analyzer_class = HDF5DatasetAnalyzer else: # directory -> assuming EDF dataset_analyzer_class = EDFDatasetAnalyzer dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger) return dataset_structure def get_0_180_radios(dataset_info, return_indices=False): """ Get the radios at 0 degres and 180 degrees. Parameters ---------- dataset_info: `DatasetAnalyzer` instance Data structure with the dataset information return_indices: bool, optional Whether to return radios indices along with the radios array. Returns ------- res: array or tuple If return_indices is True, return a tuple (radios, indices). Otherwise, return an array with the radios. """ radios_indices = [] radios_indices = sorted(dataset_info.projections.keys()) angles = dataset_info.rotation_angles angles = angles - angles.min() i_0 = np.argmin(np.abs(angles)) i_180 = np.argmin(np.abs(angles - np.pi)) _min_indices = [i_0, i_180] radios_indices = [radios_indices[i_0], radios_indices[i_180]] n_radios = 2 radios = np.zeros((n_radios,) + dataset_info.radio_dims[::-1], "f") for i in range(n_radios): radio_idx = radios_indices[i] radios[i] = get_data(dataset_info.projections[radio_idx]).astype("f") if return_indices: return radios, radios_indices else: return radios ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/resources/gpu.py0000644000175000017500000001310500000000000016330 0ustar00pierrepierre""" gpu.py: general-purpose utilities for GPU """ from ..utils import check_supported try: from pycuda.driver import Device as CudaDevice from pycuda.driver import device_attribute as dev_attrs __has_pycuda__ = True except ImportError: __has_pycuda__ = False CudaDevice = type(None) try: from pyopencl import Device as CLDevice __has_pyopencl__ = True except ImportError: CLDevice = type(None) __has_pyopencl__ = False # # silx.opencl.common.Device cannot be supported as long as # silx.opencl instantiates the "ocl" singleton in __init__, # leaving opencl contexts all over the place in some cases # class GPUDescription: """ Simple description of a Graphical Processing Unit. This class is designed to be simple to understand, and to be serializable for being used by dask.distributed. """ def __init__(self, device, vendor=None, device_id=None): """ Create a description from a device. Parameters ---------- device: `pycuda.driver.Device` or `pyopencl.Device` Class describing a GPU device. """ is_cuda_device = isinstance(device, CudaDevice) is_cl_device = isinstance(device, CLDevice) if is_cuda_device: self._init_from_cuda_device(device) elif is_cl_device: self._init_from_cl_device(device) self._set_other_attrs(vendor, device_id) else: raise ValueError("Expected `pycuda.driver.Device` or `pyopencl.Device`") def _init_from_cuda_device(self, device): self._dict = { "type": "cuda", "name": device.name(), "memory_GB": device.total_memory() / 1e9, "compute_capability": device.compute_capability(), "device_id": device.get_attribute(dev_attrs.MULTI_GPU_BOARD_GROUP_ID), } def _init_from_cl_device(self, device): self._dict = { "type": "opencl", "name": device.name, "memory_GB": device.global_mem_size / 1e9, "vendor": device.vendor, } def _set_other_attrs(self, vendor, device_id): if vendor is not None: self._dict["vendor"] = vendor if device_id is not None: self._dict["device_id"] = device_id # device ID for OpenCL (!= platform ID) def _dict_to_self(self): for key, val in self._dict.items(): setattr(self, key, val) def get_dict(self): return self._dict GPU_PICK_METHODS = ["cuda", "auto"] def pick_gpus(method, cuda_gpus, opencl_platforms, n_gpus): check_supported(method, GPU_PICK_METHODS, "GPU picking method") if method == "cuda": return pick_gpus_nvidia(cuda_gpus, n_gpus) elif method == "auto": return pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus) else: return [] # TODO Fix this function, it is broken: # - returns something when n_gpus = 0 # - POCL increments device_id by 1 ! def pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus): """ Pick `n_gpus` devices with the best available driver. This function browse the visible Cuda GPUs and Opencl platforms to pick the GPUs with the best driver. A worker might see several implementations of a GPU driver. For example with Nvidia hardware, we can see: - The Cuda implementation (nvidia-cuda-toolkit) - OpenCL implementation by Nvidia (nvidia-opencl-icd) - OpenCL implementation by Portable OpenCL Parameters ---------- cuda_gpu: dict Dictionary where each key is an ID, and the value is a dictionary describing some attributes of the GPU (obtained with `GPUDescription`) opencl_platforms: dict Dictionary where each key is the platform name, and the value is a list of dictionary descriptions. n_gpus: int Number of GPUs to pick. """ def gpu_equal(gpu1, gpu2): # TODO find a better test ? # Some information are not always available depending on the opencl vendor ! return (gpu1["device_id"] == gpu2["device_id"]) and (gpu1["name"] == gpu2["name"]) def is_in_gpus(avail_gpus, query_gpu): for gpu in avail_gpus: if gpu_equal(gpu, query_gpu): return True return False # If some Nvidia hardware is visible, add it without question. # In the case we don't want it, we should either re-run resources discovery # with `try_cuda=False`, or mask individual devices with CUDA_VISIBLE_DEVICES. chosen_gpus = list(cuda_gpus.values()) if len(chosen_gpus) >= n_gpus: return chosen_gpus for platform, gpus in opencl_platforms.items(): for gpu_id, gpu in gpus.items(): if not (is_in_gpus(chosen_gpus, gpu)): # TODO prioritize some OpenCL implementations ? chosen_gpus.append(gpu) if len(chosen_gpus) < n_gpus: raise ValueError("Not enough GPUs: could only collect %d/%d" % (len(chosen_gpus), n_gpus)) return chosen_gpus def pick_gpus_nvidia(cuda_gpus, n_gpus): """ Pick one or more Nvidia GPUs. """ if len(cuda_gpus) < n_gpus: raise ValueError("Not enough Nvidia GPU: requested %d, but can get only %d" % (n_gpus, len(cuda_gpus))) # Sort GPUs by computing capabilities, pick the "best" ones gpus_cc = [] for gpu_id, gpu in cuda_gpus.items(): cc = gpu["compute_capability"] gpus_cc.append((gpu_id, cc[0] + 0.1 * cc[1])) gpus_cc_sorted = sorted(gpus_cc, key=lambda x: x[1], reverse=True) res = [] for i in range(n_gpus): res.append(cuda_gpus[gpus_cc_sorted[i][0]]) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/resources/logger.py0000644000175000017500000000725600000000000017026 0ustar00pierrepierreimport logging import logging.config class Logger(object): def __init__(self, loggername, level="DEBUG", logfile="logger.log", console=True): """ Configure a Logger object. Parameters ----------- loggername: str Logger name. level: str, optional Logging level. Can be "debug", "info", "warning", "error", "critical". Default is "debug". logfile: str, optional File where the logs are written. If set to None, the logs are not written in a file. Default is "logger.log". console: bool, optional If set to True, the logs are (also) written in stdout/stderr. Default is True. """ self.loggername = loggername self.level = level self.logfile = logfile self.console = console self._configure_logger() def _configure_logger(self): conf = self._get_default_config_dict() for handler in conf["handlers"].keys(): conf["handlers"][handler]["level"] = self.level.upper() conf["loggers"][self.loggername]["level"] = self.level.upper() if not (self.console): conf["loggers"][self.loggername]["handlers"].remove("console") self.config = conf logging.config.dictConfig(conf) self.logger = logging.getLogger(self.loggername) def _get_default_config_dict(self): conf = { "version": 1, "formatters": { "default": {"format": "%(asctime)s - %(levelname)s - %(message)s", "datefmt": "%d-%m-%Y %H:%M:%S"}, "console": {"format": "%(message)s"}, }, "handlers": { "console": { "level": "DEBUG", "class": "logging.StreamHandler", "formatter": "console", "stream": "ext://sys.stdout", }, "file": { "level": "DEBUG", "class": "logging.FileHandler", "formatter": "default", "filename": self.logfile, }, }, "loggers": { self.loggername: { "level": "DEBUG", "handlers": ["console", "file"], # This logger inherits from root logger - dont propagate ! "propagate": False, } }, "disable_existing_loggers": False, } return conf def info(self, msg): return self.logger.info(msg) def debug(self, msg): return self.logger.debug(msg) def warning(self, msg): return self.logger.warning(msg) warn = warning def error(self, msg): return self.logger.error(msg) def fatal(self, msg): return self.logger.fatal(msg) def critical(self, msg): return self.logger.critical(msg) def LoggerOrPrint(logger): """ Logger that is either a "true" logger object, or a fall-back to "print". """ if logger is None: return PrinterLogger() return logger class PrinterLogger(object): def __init__(self): methods = [ "debug", "warn", "warning", "info", "error", "fatal", "critical", ] for method in methods: self.__setattr__(method, print) LogLevel = { "notset": logging.NOTSET, "debug": logging.DEBUG, "info": logging.INFO, "warn": logging.WARN, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, "fatal": logging.FATAL, } ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/resources/nxflatfield.py0000644000175000017500000002176400000000000020047 0ustar00pierrepierreimport os import numpy as np from silx.io.url import DataUrl from tomoscan.io import HDF5File from tomoscan.esrf.scan.hdf5scan import HDF5TomoScan from ..utils import check_supported, is_writeable def get_frame_possible_urls(dataset_info, user_dir, output_dir, frame_type): """ Return a list with the possible location of reduced dark/flat frames. Parameters ---------- dataset_info: DatasetAnalyzer object DatasetAnalyzer object: data structure containing information on the parsed dataset user_dir: str or None User-provided directory location for the reduced frames. output_dir: str or None Output processing directory frame_type: str Frame type, can be "flats" or "darks". """ check_supported(frame_type, ["flats", "darks"], "frame type") h5scan = dataset_info.dataset_scanner # tomoscan object if frame_type == "flats": dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0] else: dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0] def make_dataurl(dirname): # The template formatting should be done by tomoscan in principle, but this complicates logging. rel_file_path = dataurl_default_template.file_path().format( scan_prefix=dataset_info.dataset_scanner.get_dataset_basename() ) return DataUrl( file_path=os.path.join(dirname, rel_file_path), data_path=dataurl_default_template.data_path(), data_slice=dataurl_default_template.data_slice(), # not sure if needed scheme="silx", ) urls = {"user": None, "dataset": None, "output": None} if user_dir is not None: urls["user"] = make_dataurl(user_dir) # tomoscan.esrf.scan.hdf5scan.REDUCED_{DARKS|FLATS}_DATAURLS.file_path() is a relative path # Create a absolute path instead urls["dataset"] = make_dataurl(os.path.dirname(h5scan.master_file)) if output_dir is not None: urls["output"] = make_dataurl(output_dir) return urls def get_metadata_url(url, frame_type): """ Return the url of the metadata stored alongside flats/darks """ check_supported(frame_type, ["flats", "darks"], "frame type") template_url = getattr(HDF5TomoScan, "REDUCED_%s_METADATAURLS" % frame_type.upper())[0] return DataUrl( file_path=url.file_path(), data_path=template_url.data_path(), scheme="silx", ) def tomoscan_load_reduced_frames(dataset_info, frame_type, url): tomoscan_method = getattr(dataset_info.dataset_scanner, "load_reduced_%s" % frame_type) return tomoscan_method( inputs_urls=[url], return_as_url=True, return_info=True, metadata_input_urls=[get_metadata_url(url, frame_type)], ) def tomoscan_save_reduced_frames(dataset_info, frame_type, url, frames, info): tomoscan_method = getattr(dataset_info.dataset_scanner, "save_reduced_%s" % frame_type) kwargs = {"%s_infos" % frame_type: info} return tomoscan_method( frames, output_urls=[url], metadata_output_urls=[get_metadata_url(url, frame_type)], **kwargs ) # pylint: disable=E1136 def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=None, darks_flats_dir=None): """ Update a DatasetAnalyzer object with reduced flats/darks (hereafter "reduced frames"). How the reduced frames are loaded/computed/saved will depend on the "flatfield_mode" parameter. The principle is the following: (1) Attempt at loading already-computed reduced frames (XXX_darks.h5 and XXX_flats.h5): - First check files in the user-defined directory 'darks_flats_dir' - Then try to load from files located alongside the .nx dataset (dataset directory) - Then try to load from output_dir, if provided (2) If loading fails, or flatfield_mode == "force_compute", compute the reduced frames. (3) Save these reduced frames - Save in darks_flats_dir, if provided by user - Otherwise, save in the data directory (next to the .nx file), if write access OK - Otherwise, save in output directory """ if flatfield_mode is False: return logger = dataset_info.logger frames_types = ["darks", "flats"] reduced_frames_urls = {} for frame_type in frames_types: reduced_frames_urls[frame_type] = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir, frame_type) reduced_frames = dict.fromkeys(frames_types, None) # # Try to load frames # def load_reduced_frame(url, frame_type, frames_loaded, reduced_frames): if frames_loaded[frame_type]: return frames, info = tomoscan_load_reduced_frames(dataset_info, frame_type, url) if frames not in (None, {}): dataset_info.logger.info("Loaded %s from %s" % (frame_type, url.file_path())) frames_loaded[frame_type] = True reduced_frames[frame_type] = frames, info else: msg = "Could not load %s from %s" % (frame_type, url.file_path()) logger.error(msg) frames_loaded = dict.fromkeys(frames_types, False) if flatfield_mode != "force-compute": for load_from in ["user", "dataset", "output"]: # in that order for frame_type in frames_types: url = reduced_frames_urls[frame_type][load_from] if url is None: continue # cannot load from this source (eg. undefined folder) load_reduced_frame(url, frame_type, frames_loaded, reduced_frames) if all(frames_loaded.values()): break if not all(frames_loaded.values()) and flatfield_mode == "force-load": raise ValueError("Could not load darks/flats (using 'force-load')") # # COMPAT. Keep DataUrl - won't be needed in future versions when pipeline will use FlatField # instead of FlatFieldDataUrl frames_urls = reduced_frames.copy() # # Compute reduced frames, if needed # if reduced_frames["flats"] is None: reduced_frames["flats"] = dataset_info.dataset_scanner.compute_reduced_flats(return_info=True) if reduced_frames["darks"] is None: reduced_frames["darks"] = dataset_info.dataset_scanner.compute_reduced_darks(return_info=True) if reduced_frames["darks"][0] == {} or reduced_frames["flats"][0] == {}: raise ValueError( "Could not get any reduced flat/dark. This probably means that no already-reduced flats/darks were found and that the dataset itself does not have any flat/dark" ) # # Save reduced frames # def save_reduced_frame(url, frame_type, frames_saved): frames, info = reduced_frames[frame_type] tomoscan_save_reduced_frames(dataset_info, frame_type, url, frames, info) dataset_info.logger.info("Saved reduced %s to %s" % (frame_type, url.file_path())) frames_saved[frame_type] = True frames_saved = dict.fromkeys(frames_types, False) if not all(frames_loaded.values()): for save_to in ["user", "dataset", "output"]: # in that order for frame_type in frames_types: if frames_loaded[frame_type]: continue # already loaded url = reduced_frames_urls[frame_type][save_to] if url is None: continue # cannot load from this source (eg. undefined folder) if not is_writeable(os.path.dirname(url.file_path())): continue save_reduced_frame(url, frame_type, frames_saved) # COMPAT. if frames_urls[frame_type] is None: frames_urls[frame_type] = tomoscan_load_reduced_frames(dataset_info, frame_type, url) # if all(frames_saved.values()): break dataset_info.flats = frames_urls["flats"][0] # reduced_frames["flats"] # in future versions dataset_info.flats_srcurrent = frames_urls["flats"][1].machine_electric_current # This is an extra check to avoid having more than 1 (reduced) dark. # FlatField only works with exactly 1 (reduced) dark (having more than 1 series of darks makes little sense) # This is normally prevented by tomoscan HDF5FramesReducer, but let's add this extra check darks_ = frames_urls["darks"][0] # reduced_frames["darks"] # in future versions if len(darks_) > 1: dark_idx = sorted(darks_.keys())[0] dataset_info.logger.error("Found more that one series of darks. Keeping only the first one") darks_ = {dark_idx: darks_[dark_idx]} # dataset_info.darks = darks_ dataset_info.darks_srcurrent = frames_urls["darks"][1].machine_electric_current # tomoscan "compute_reduced_XX" is quite slow. If needed, here is an alternative implementation def my_reduce_flats(di): res = {} with HDF5File(di.dataset_hdf5_url.file_path(), "r") as f: for data_slice in di.get_data_slices("flats"): data = f[di.dataset_hdf5_url.data_path()][data_slice.start : data_slice.stop] res[data_slice.start] = np.median(data, axis=0) return res ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/resources/templates/0000755000175000017500000000000000000000000017161 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1669280031.0 nabu-2023.1.1/nabu/resources/templates/__init__.py0000644000175000017500000000000000000000000021260 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/resources/templates/bm05_pag.conf0000644000175000017500000000061200000000000021421 0ustar00pierrepierre# # ESRF BM05 phase contrast # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 take_logarithm = 1 [phase] method = paganin delta_beta = 100 [reconstruction] rotation_axis_position = sliding-window enable_halftomo = auto clip_outer_circle = 1 centered_axis = 1 [output] location = ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/resources/templates/id16_ctf.conf0000644000175000017500000000151200000000000021426 0ustar00pierrepierre# # ESRF ID16 single-distance CTF # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 flat_distortion_correction_enabled = 0 flat_distortion_params = tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None take_logarithm = 0 ccd_filter_enabled = 1 ccd_filter_threshold = 0.04 [phase] method = ctf delta_beta = 80 ctf_geometry = z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True ctf_advanced_params = length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True [reconstruction] rotation_axis_position = global cor_options = translation_movements_file = enable_halftomo = 0 clip_outer_circle = 1 centered_axis = 1 [output] location = file_format = tiff tiff_single_file = 1 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/resources/templates/id16_holo.conf0000644000175000017500000000140500000000000021614 0ustar00pierrepierre# # ESRF ID16 holo-tomography # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 0 flat_distortion_correction_enabled = 0 flat_distortion_params = tile_size=100; interpolation_kind='linear'; padding_mode='edge'; correction_spike_threshold=None take_logarithm = 0 [phase] method = none ctf_geometry = z1_v=None; z1_h=None; detec_pixel_size=None; magnification=True ctf_advanced_params = length_scale=1e-5; lim1=1e-5; lim2=0.2; normalize_by_mean=True [reconstruction] rotation_axis_position = global cor_options = translation_movements_file = enable_halftomo = 0 clip_outer_circle = 1 centered_axis = 1 [output] location = file_format = tiff tiff_single_file = 1 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1669280031.0 nabu-2023.1.1/nabu/resources/templates/id19_pag.conf0000644000175000017500000000060000000000000021421 0ustar00pierrepierre# # ESRF ID19 phase contrast # # # Write here your custom configuration as a python dictionary. # Any parameter not present will take the default value. # [preproc] flatfield = 1 take_logarithm = 1 [phase] method = paganin delta_beta = 100 [reconstruction] rotation_axis_position = sliding-window translation_movements_file = enable_halftomo = auto [output] location = ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/resources/tests/0000755000175000017500000000000000000000000016325 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1599550945.0 nabu-2023.1.1/nabu/resources/tests/__init__.py0000644000175000017500000000000000000000000020424 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/resources/tests/test_nxflatfield.py0000644000175000017500000001031200000000000022233 0ustar00pierrepierrefrom os import path from tempfile import mkdtemp from shutil import rmtree import pytest import numpy as np from silx.io import get_data from tomoscan.esrf.scan.hdf5scan import ImageKey from nabu.testutils import generate_nx_dataset from nabu.resources.nxflatfield import update_dataset_info_flats_darks from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer test_nxflatfield_scenarios = [ { "name": "simple", "flats_pos": [slice(1, 6)], "darks_pos": [slice(0, 1)], "output_dir": None, }, { "name": "simple_with_save", "flats_pos": [slice(1, 6)], "darks_pos": [slice(0, 1)], "output_dir": None, }, { "name": "multiple_with_save", "flats_pos": [slice(0, 10), slice(30, 40)], "darks_pos": [slice(95, 100)], "output_dir": path.join("{tempdir}", "output_reduced"), }, ] # parametrize with fixture and "params=" will launch a new class for each scenario. # the attributes set to "cls" will remain for all the tests done in this class # with the current scenario. @pytest.fixture(scope="class", params=test_nxflatfield_scenarios) def bootstrap(request): cls = request.cls cls.n_projs = 265 cls.params = request.param cls.tempdir = mkdtemp(prefix="nabu_") yield rmtree(cls.tempdir) @pytest.mark.usefixtures("bootstrap") class TestNXFlatField: _reduction_func = {"flats": np.median, "darks": np.mean} def get_nx_filename(self): return path.join(self.tempdir, "dataset_" + self.params["name"] + ".nx") def get_image_key(self): keys = np.zeros(self.n_projs, np.int32) for what, val in [("flats_pos", ImageKey.FLAT_FIELD.value), ("darks_pos", ImageKey.DARK_FIELD.value)]: for pos in self.params[what]: keys[pos.start : pos.stop] = val return keys @staticmethod def check_image_key(dataset_info, frame_type, expected_slices): data_slices = dataset_info.get_data_slices(frame_type) assert set(map(str, data_slices)) == set(map(str, expected_slices)) def test_nxflatfield(self): dataset_fname = self.get_nx_filename() image_key = self.get_image_key() generate_nx_dataset(dataset_fname, image_key) dataset_info = HDF5DatasetAnalyzer(dataset_fname) # When parsing a "raw" dataset, flats/darks are a series of images. # dataset_info.flats is a dictionary where each key is the index of the frame. # For example dataset_info.flats.keys() = [10, 11, 12, 13, ..., 19, 1200, 1201, 1202, ..., 1219] for frame_type in ["darks", "flats"]: self.check_image_key(dataset_info, frame_type, self.params[frame_type + "_pos"]) output_dir = self.params.get("output_dir", None) if output_dir is not None: output_dir = output_dir.format(tempdir=self.tempdir) update_dataset_info_flats_darks(dataset_info, True, output_dir=output_dir) # After reduction (median/mean), the flats/darks are located in another file. # median(series_1) goes to entry/flats/idx1, mean(series_2) goes to entry/flats/idx2, etc. assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"]) assert set(dataset_info.darks.keys()) == set(s.start for s in self.params["darks_pos"]) # Check that the computations were correct # Loads the entire volume in memory ! So keep the data volume small for the tests data_volume = get_data(dataset_info.dataset_hdf5_url) expected_flats = {} for s in self.params["flats_pos"]: expected_flats[s.start] = self._reduction_func["flats"](data_volume[s.start : s.stop], axis=0) expected_darks = {} for s in self.params["darks_pos"]: expected_darks[s.start] = self._reduction_func["darks"](data_volume[s.start : s.stop], axis=0) flats = {idx: get_data(dataset_info.flats[idx]) for idx in dataset_info.flats.keys()} for idx in flats.keys(): assert np.allclose(flats[idx], expected_flats[idx]) darks = {idx: get_data(dataset_info.darks[idx]) for idx in dataset_info.darks.keys()} for idx in darks.keys(): assert np.allclose(darks[idx], expected_darks[idx]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/resources/tests/test_units.py0000644000175000017500000000415300000000000021103 0ustar00pierrepierreimport pytest from nabu.utils import compare_dicts from nabu.resources.utils import get_quantities_and_units class TestUnits: expected_results = { "distance = 1 m ; pixel_size = 2.0 um": {"distance": 1.0, "pixel_size": 2e-6}, "distance = 1 m ; pixel_size = 2.6 micrometer": {"distance": 1.0, "pixel_size": 2.6e-6}, "distance = 10 m ; pixel_size = 2e-6 m": {"distance": 10, "pixel_size": 2e-6}, "distance = .5 m ; pixel_size = 2.6e-4 centimeter": {"distance": 0.5, "pixel_size": 2.6e-6}, "distance = 10 cm ; pixel_size = 2.5 micrometer ; energy = 1 ev": { "distance": 0.1, "pixel_size": 2.5e-6, "energy": 1.0e-3, }, "distance = 10 cm ; pixel_size = 9.0e-3 millimeter ; energy = 19 kev": { "distance": 0.1, "pixel_size": 9e-6, "energy": 19.0, }, } expected_failures = { # typo ("ke" instead of "kev") "distance = 10 cm ; energy = 10 ke": ValueError("Cannot convert: ke"), # No units "distance = 10 ; energy = 10 kev": ValueError("not enough values to unpack (expected 2, got 1)"), # Unit not separated by space "distance = 10m; energy = 10 kev": ValueError("not enough values to unpack (expected 2, got 1)"), # Invalid separator "distance = 10 m, energy = 10 kev": ValueError("too many values to unpack (expected 2)"), } def test_conversion(self): for test_str, expected_result in self.expected_results.items(): res = get_quantities_and_units(test_str) err_msg = str( "Something wrong with quantities/units extraction from '%s': expected %s, got %s" % (test_str, str(expected_result), str(res)) ) assert compare_dicts(res, expected_result) is None, err_msg def test_valid_input(self): for test_str, expected_failure in self.expected_failures.items(): with pytest.raises(type(expected_failure)) as e_info: get_quantities_and_units(test_str) assert e_info.value.args[0] == str(expected_failure) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/resources/utils.py0000644000175000017500000001255500000000000016705 0ustar00pierrepierrefrom ast import literal_eval import numpy as np from psutil import virtual_memory, cpu_count from tomoscan.unitsystem.metricsystem import MetricSystem def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False): """ Read a text file and scan the values inside. This function expects one value per line, or values separated with a separator defined with the `sep` parameter. Parameters ---------- fname: str Path of the text file n_values: int, optional If set to a value, this function will check that it scans exactly this number of values. Ignored if `shape` is provided shape: tuple, optional Generalization of n_values for higher dimensions. sep: str, optional Separator between values. Default is white space any_size: bool, optional If set to True, then the parameters 'n_values' and 'shape' are ignored. Returns -------- arr: numpy.ndarray An array containing the values scanned from the text file """ if not (any_size) and not ((n_values is not None) ^ (shape is not None)): raise ValueError("Please provide either n_values or shape") arr = np.loadtxt(fname, ndmin=1) if (n_values is not None) and (arr.shape[0] != n_values): if any_size: arr = arr[:n_values] else: raise ValueError("Expected %d values, but could get %d values" % (n_values, arr.shape[0])) if (shape is not None) and (arr.shape != shape): if any_size: arr = arr[: shape[0], : shape[1]] # TODO handle more dimensions ? else: raise ValueError("Expected shape %s, but got shape %s" % (shape, arr.shape)) return arr def get_memory_per_node(max_mem, is_percentage=True): """ Get the available memory per node in GB. Parameters ---------- max_mem: float If is_percentage is False, then number is interpreted as an absolute number in GigaBytes. Otherwise, it should be a number between 0 and 100 and is interpreted as a percentage. is_percentage: bool A boolean indicating whether the parameter max_mem is to be interpreted as a percentage of available system memory. """ sys_avail_mem = virtual_memory().available / 1e9 if is_percentage: return (max_mem / 100.0) * sys_avail_mem else: return min(max_mem, sys_avail_mem) def get_threads_per_node(max_threads, is_percentage=True): """ Get the available memory per node in GB. Parameters ---------- max_threads: float If is_percentage is False, then number is interpreted as an absolute number of threads. Otherwise, it should be a number between 0 and 100 and is interpreted as a percentage. is_percentage: bool A boolean indicating whether the parameter max_threads is to be interpreted as a percentage of available system memory. """ sys_n_threads = cpu_count(logical=True) if is_percentage: return (max_threads / 100.0) * sys_n_threads else: return min(max_threads, sys_n_threads) def extract_parameters(params_str, sep=";"): """ Extract the named parameters from a string. Example -------- The function can be used as follows: >>> extract_parameters("window_width=None; median_filt_shape=(3,3); padding_mode='wrap'") ... {'window_width': None, 'median_filt_shape': (3, 3), 'padding_mode': 'wrap'} """ if params_str in ("", None): return {} params_list = params_str.strip(sep).split(sep) res = {} for param_str in params_list: param_name, param_val_str = param_str.strip().split("=") param_name = param_name.strip() param_val_str = param_val_str.strip() param_val = literal_eval(param_val_str) res[param_name] = param_val return res def compact_parameters(params_dict, sep=";"): """ Compact the parameters from a dict into a string. This is the inverse of extract_parameters. It can be used for example in tomwer to convert parameters into a string, for example for cor_options, prior to calling a nabu method which is expecting an argument to be in the form of a string containing options. Example -------- The function can be used as follows: >>> compact_parameters( {"side":"near", "near_pos":300 } ) ... "side=near; nearpos= 300;" """ if params_dict in ({}, None): return "" res = "" for key, val in params_dict.items(): res = res + "{key} = {val} " + sep return res def is_hdf5_extension(ext): return ext.lower() in ["h5", "hdf5", "nx"] def get_quantities_and_units(string, sep=";"): """ Return a dictionary with quantities as keys, and values in SI. Example ------- get_quantities_and_units("pixel_size = 1.2 um ; distance = 1 m") Will return {"pixel_size": 1.2e-6, "distance": 1} """ result = {} quantities = string.split(sep) for quantity in quantities: quantity_name, value_and_unit = quantity.split("=") quantity_name = quantity_name.strip() value_and_unit = value_and_unit.strip() value, unit = value_and_unit.split() val = float(value) # Convert to SI conversion_factor = MetricSystem.from_str(unit).value result[quantity_name] = val * conversion_factor return result ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/stitching/0000755000175000017500000000000000000000000015145 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/stitching/__init__.py0000644000175000017500000000000000000000000017244 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/stitching/config.py0000644000175000017500000004375500000000000017002 0ustar00pierrepierre# coding: utf-8 # /*########################################################################## # # Copyright (c) 2016-2017 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # # ###########################################################################*/ __authors__ = ["H. Payno"] __license__ = "MIT" __date__ = "10/05/2022" from dataclasses import dataclass import numpy from tomoscan.identifier import VolumeIdentifier from tomoscan.identifier import ScanIdentifier from tomoscan.nexus.paths import nxtomo from silx.utils.enum import Enum as _Enum from typing import Optional, Union from nabu.pipeline.config_validators import ( integer_validator, list_of_shift_validator, list_of_tomoscan_identifier, optional_directory_location_validator, boolean_validator, optional_positive_integer_validator, output_file_format_validator, optional_tuple_of_floats_validator, optional_file_name_validator, ) from nabu.stitching.overlap import OverlapStichingStrategy from nabu.utils import concatenate_dict, convert_str_to_tuple from nabu.io.utils import get_output_volume from tomoscan.factory import Factory KEY_X_CROSS_CORRELATION_FUNC = "x_cross_correlation_function" KEY_Y_CROSS_CORRELATION_FUNC = "y_cross_correlation_function" KEY_CROSS_CORRELATION_SLICE = "do_slice_index_correlation_from_slice" _DEFAULT_AUTO_REL_SHIFT_PARAMS = ( f"{KEY_CROSS_CORRELATION_SLICE}=middle;{KEY_X_CROSS_CORRELATION_FUNC}=skimage;{KEY_Y_CROSS_CORRELATION_FUNC}=None" ) _OUTPUT_SECTION = "output" _INPUTS_SECTION = "inputs" _INPUT_DATASETS_FIELD = "input_dataset" _STITCHING_SECTION = "stitching" _STITCHING_STRATEGY_FIELD = "stitching_strategy" _STITCHING_TYPE_FIELD = "type" _DATA_FILE_FIELD = "location" _FILE_PREFIX_FIELD = "file_prefix" _FILE_FORMAT_FIELD = "file_format" _OVERWRITE_RESULTS_FIELD = "overwrite_results" _DATA_PATH_FIELD = "data_path" _X_RELATIVE_SHIFTS_FIELD = "x_relative_shifts_in_px" _VERTICAL_OVERLAP_FIELD = "vertical_overlap_area_in_px" _STITCHING_HEIGTH_FIELD = "stitching_height_in_px" _AUTO_RELATIVE_SHIFT_PARAMS_FIELD = "auto_relative_shifts_params" _NEXUS_VERSION_FIELD = "nexus_version" _OUTPUT_DTYPE = "data_type" _OUTPUT_VOLUME = "output_volume" def _str_to_dict(my_str): """convert a string as key_1=value_2;key_2=value_2 to a dict""" res = {} for key_value in my_str.split(";"): key, value = key_value.split("=") res[key] = value return res def _valid_relative_shift_params(my_dict): valid_keys = ( KEY_CROSS_CORRELATION_SLICE, KEY_X_CROSS_CORRELATION_FUNC, KEY_Y_CROSS_CORRELATION_FUNC, ) for key in my_dict.keys(): if not key in valid_keys: raise KeyError(f"{key} is a unrecognized key") return my_dict def _str_to_int_or_auto(my_str): ids = my_str.replace(" ", "").split(",") try: res = tuple([int(val) if val not in ("auto", "'auto'", '"auto"') else "auto" for val in ids]) except ValueError: raise ValueError(f"Fail to convert {my_str} to a list of int or to 'auto'") else: if len(res) == 1: return res[0] @dataclass class _StitchingConfiguration: """ bass class to define stitching configuration """ overlap_height: Union[int, tuple] # overlap area in pixel between each scan stitching_height: Union[None, int, tuple] # height to take in the overlap to apply stitching stitching_strategy: OverlapStichingStrategy output_dtype: numpy.dtype overwrite_results: bool x_shifts: Union[str, tuple] def to_dict(self): """dump configuration to a dict. Must be serializable because might be dump to HDF5 file""" return { _STITCHING_SECTION: { _VERTICAL_OVERLAP_FIELD: self.overlap_height, _STITCHING_HEIGTH_FIELD: self.stitching_height, _STITCHING_STRATEGY_FIELD: OverlapStichingStrategy.from_value(self.stitching_strategy).value, _X_RELATIVE_SHIFTS_FIELD: self.x_shifts, }, _OUTPUT_SECTION: { _OUTPUT_DTYPE: str(self.output_dtype), _OVERWRITE_RESULTS_FIELD: self.overwrite_results, }, } class StitchingType(_Enum): Z_PREPROC = "z-preproc" Z_POSTPROC = "z-postproc" @dataclass class ZStitchingConfiguration(_StitchingConfiguration): """ base class to define z-stitching parameters """ auto_relative_shift_params: dict def to_dict(self): return concatenate_dict( super().to_dict(), { _STITCHING_SECTION: { _AUTO_RELATIVE_SHIFT_PARAMS_FIELD: ";".join( [f"{key}={value}" for key, value in self.auto_relative_shift_params.items()] ), }, }, ) @dataclass class PreProcessedZStitchingConfiguration(ZStitchingConfiguration): """ base class to define z-stitching parameters """ input_scans: tuple # tuple of ScanBase output_file_path: str output_data_path: str output_nexus_version: Optional[float] def to_dict(self): return concatenate_dict( super().to_dict(), { _INPUTS_SECTION: { _INPUT_DATASETS_FIELD: [str(scan.get_identifier()) for scan in self.input_scans], }, _OUTPUT_SECTION: { _DATA_FILE_FIELD: self.output_file_path, _DATA_PATH_FIELD: self.output_data_path, _NEXUS_VERSION_FIELD: self.output_nexus_version, }, }, ) @staticmethod def from_dict(config: dict): if not isinstance(config, dict): raise TypeError(f"config is expected to be a dict and not {type(config)}") inputs_scans_str = config.get(_INPUTS_SECTION, {}).get(_INPUT_DATASETS_FIELD, None) if inputs_scans_str in (None, ""): input_scans = [] else: input_scans = _get_identifiers(inputs_scans_str) output_file_path = config.get(_OUTPUT_SECTION, {}).get(_DATA_FILE_FIELD, None) if output_file_path is None: raise ValueError("output location not provided") nexus_version = config.get(_OUTPUT_SECTION, {}).get(_NEXUS_VERSION_FIELD, None) if nexus_version in (None, ""): nexus_version = nxtomo.LATEST_VERSION else: nexus_version = float(nexus_version) return PreProcessedZStitchingConfiguration( overlap_height=_str_to_int_or_auto(config[_STITCHING_SECTION][_VERTICAL_OVERLAP_FIELD]), stitching_height=_str_to_int_or_auto(config[_STITCHING_SECTION].get(_STITCHING_HEIGTH_FIELD, "auto")), stitching_strategy=OverlapStichingStrategy.from_value( config[_STITCHING_SECTION].get( _STITCHING_STRATEGY_FIELD, OverlapStichingStrategy.COSINUS_WEIGHTS, ), ), x_shifts=_str_to_int_or_auto(config[_STITCHING_SECTION][_X_RELATIVE_SHIFTS_FIELD]), auto_relative_shift_params=_valid_relative_shift_params( _str_to_dict( config[_STITCHING_SECTION].get( _AUTO_RELATIVE_SHIFT_PARAMS_FIELD, _DEFAULT_AUTO_REL_SHIFT_PARAMS, ), ), ), input_scans=input_scans, output_file_path=output_file_path, output_data_path=config.get(_OUTPUT_SECTION, {}).get(_DATA_PATH_FIELD, "entry_from_stitchig"), overwrite_results=config[_STITCHING_SECTION].get(_OVERWRITE_RESULTS_FIELD, True), output_nexus_version=nexus_version, output_dtype=config[_OUTPUT_SECTION].get(_OUTPUT_DTYPE, numpy.float32), ) @dataclass class PostProcessedZStitchingConfiguration(ZStitchingConfiguration): """ base class to define z-stitching parameters """ input_volumes: tuple # tuple of VolumeBase output_volume: VolumeIdentifier @staticmethod def from_dict(config: dict): if not isinstance(config, dict): raise TypeError(f"config is expected to be a dict and not {type(config)}") inputs_volumes_str = config.get(_INPUTS_SECTION, {}).get(_INPUT_DATASETS_FIELD, None) if inputs_volumes_str in (None, ""): input_volumes = [] else: input_volumes = _get_identifiers(inputs_volumes_str) output_volume = get_output_volume( location=config.get(_OUTPUT_SECTION, {}).get(_DATA_FILE_FIELD, None), file_prefix=config.get(_OUTPUT_SECTION, {}).get(_FILE_PREFIX_FIELD, None), file_format=config.get(_OUTPUT_SECTION, {}).get(_FILE_FORMAT_FIELD, "hdf5"), ) # on the next section the one with a default value qre the optionnal one return PostProcessedZStitchingConfiguration( overlap_height=config[_STITCHING_SECTION][_VERTICAL_OVERLAP_FIELD], stitching_height=config[_STITCHING_SECTION].get(_STITCHING_HEIGTH_FIELD, "auto"), stitching_strategy=OverlapStichingStrategy.from_value( config[_STITCHING_SECTION].get( _STITCHING_STRATEGY_FIELD, OverlapStichingStrategy.COSINUS_WEIGHTS, ), ), x_shifts=config[_STITCHING_SECTION][_X_RELATIVE_SHIFTS_FIELD], auto_relative_shift_params=_valid_relative_shift_params( _str_to_dict( config[_STITCHING_SECTION].get( _AUTO_RELATIVE_SHIFT_PARAMS_FIELD, _DEFAULT_AUTO_REL_SHIFT_PARAMS, ), ), ), input_volumes=input_volumes, output_volume=output_volume, overwrite_results=config[_STITCHING_SECTION].get(_OVERWRITE_RESULTS_FIELD, True), output_dtype=config[_OUTPUT_SECTION].get(_OUTPUT_DTYPE, numpy.float32), ) def to_dict(self): return concatenate_dict( super().to_dict(), { _INPUTS_SECTION: { _INPUT_DATASETS_FIELD: [str(volume.get_identifier()) for volume in self.input_volumes], }, _OUTPUT_SECTION: { _OUTPUT_VOLUME: self.output_volume, }, }, ) def _get_identifiers(list_identifiers_as_str: str) -> tuple: # convert str to a list of str that should represent identifiers identifiers_as_str = convert_str_to_tuple(list_identifiers_as_str) # convert identifiers as string to IdentifierType instances return [Factory.create_tomo_object_from_identifier(identifier_as_str) for identifier_as_str in identifiers_as_str] def dict_to_config_obj(config: dict): if not isinstance(config, dict): raise TypeError stitching_type = config.get(_STITCHING_SECTION, {}).get(_STITCHING_TYPE_FIELD, None) if stitching_type is None: raise ValueError("Unagle to find stitching type from config dict") else: stitching_type = StitchingType.from_value(stitching_type) if stitching_type is StitchingType.Z_POSTPROC: return PostProcessedZStitchingConfiguration.from_dict(config) elif stitching_type is StitchingType.Z_PREPROC: return PreProcessedZStitchingConfiguration.from_dict(config) else: raise NotImplementedError(f"stitching type {stitching_type.value} not handled yet") def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, str]]) -> tuple: """ Return a default configuration for doing stitching. :param stitching_type: if None then return a configuration were use can provide inputs for any of the stitching. Else return config dict dedicated to a particular stitching :return: (config, section comments) """ if stitching_type is None: return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config) stitching_type = StitchingType.from_value(stitching_type) if stitching_type is StitchingType.Z_POSTPROC: return z_postproc_stitching_config elif stitching_type is StitchingType.Z_PREPROC: return z_preproc_stitching_config else: raise NotImplementedError SECTIONS_COMMENTS = { _STITCHING_SECTION: "section dedicated to stich parameters\n", _OUTPUT_SECTION: "section dedicated to output parameters\n", _INPUTS_SECTION: "section dedicated to inputs\n", } _stitching_config = { _STITCHING_SECTION: { _VERTICAL_OVERLAP_FIELD: { "default": "auto", "help": "Overlap area between two scans in pixel. Can be an int or a list of int. If 'auto' will try to deduce it from the magnification and z_translations value", "type": "required", "validator": list_of_shift_validator, }, _STITCHING_HEIGTH_FIELD: { "default": "auto", "help": "Height of the stich to apply on the overlap region. If set to 'auto' then will take the largest one possible (equal overlap height)", "type": "advanced", "validator": list_of_shift_validator, }, _STITCHING_STRATEGY_FIELD: { "default": "cosinus weights", "help": f"Policy to apply to compute the overlap area. Must be in {OverlapStichingStrategy.values()}.", "type": "required", }, }, _OUTPUT_SECTION: { _OVERWRITE_RESULTS_FIELD: { "default": "1", "help": "What to do in the case where the output file exists.\nBy default, the output data is never overwritten and the process is interrupted if the file already exists.\nSet this option to 1 if you want to overwrite the output files.", "validator": boolean_validator, "type": "required", }, }, _INPUTS_SECTION: { _INPUT_DATASETS_FIELD: { "default": "", "help": f"Dataset to stitch together. Must be volume for {StitchingType.Z_PREPROC.value} and NXtomo for {StitchingType.Z_POSTPROC.value}", "type": "required", }, }, } _z_stitching_config = concatenate_dict( _stitching_config, { _STITCHING_SECTION: { _X_RELATIVE_SHIFTS_FIELD: { "default": "auto", "help": "relative shift between two set of frames or volumes.", "type": "required", "validator": list_of_shift_validator, }, _AUTO_RELATIVE_SHIFT_PARAMS_FIELD: { "default": _DEFAULT_AUTO_REL_SHIFT_PARAMS, "help": "options to find shift automatically", "type": "advanced", }, }, }, ) z_preproc_stitching_config = concatenate_dict( { _STITCHING_SECTION: { _STITCHING_TYPE_FIELD: { "default": StitchingType.Z_PREPROC.value, "help": f"Which type of stitching to do. Must be in {StitchingType.values}", "type": "required", }, }, _OUTPUT_SECTION: { _DATA_FILE_FIELD: { "default": "", "help": "HDF5 file to save the generated NXtomo (.nx extension recommanded).", "validator": optional_directory_location_validator, "type": "required", }, _FILE_FORMAT_FIELD: { "default": "hdf5", "help": "Output file format. Available are: hdf5, tiff, jp2, edf", "validator": output_file_format_validator, "type": "optional", }, _NEXUS_VERSION_FIELD: { "default": "", "help": "output NXtomo version to use for saving stitched NXtomo. If not provided will take the latest version available.", "type": "optional", }, }, }, _z_stitching_config, ) z_postproc_stitching_config = concatenate_dict( { _STITCHING_SECTION: { _STITCHING_TYPE_FIELD: { "default": StitchingType.Z_POSTPROC.value, "help": f"Which type of stitching to do. Must be in {StitchingType.values}", "type": "required", }, }, _OUTPUT_SECTION: { _DATA_FILE_FIELD: { "default": "", "help": "Directory where the output reconstruction is stored.", "validator": optional_directory_location_validator, "type": "required", }, _FILE_PREFIX_FIELD: { "default": "", "help": "File prefix. Optional, by default it is inferred from the scanned dataset.", "validator": optional_file_name_validator, "type": "optional", }, _FILE_FORMAT_FIELD: { "default": "hdf5", "help": "Output file format. Available are: hdf5, tiff, jp2, edf", "validator": output_file_format_validator, "type": "optional", }, }, }, _z_stitching_config, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1677956536.0 nabu-2023.1.1/nabu/stitching/frame_composition.py0000644000175000017500000001633100000000000021240 0ustar00pierrepierrefrom copy import copy from dataclasses import dataclass import numpy from nabu.stitching.overlap import ZStichOverlapKernel @dataclass class _FrameCompositionBase: def compose(self, output_frame: numpy.ndarray, input_frames: tuple): raise NotImplementedError("Base class") @dataclass class ZFrameComposition(_FrameCompositionBase): """ class used to define intervals to know where to dump raw data or stitched data according to requested policy. The idea is to create this once for all for one stitching operation and reuse it for each frame. """ local_start_y: tuple local_end_y: tuple global_start_y: tuple global_end_y: tuple def browse(self): for i in range(len(self.local_start_y)): yield ( self.local_start_y[i], self.local_end_y[i], self.global_start_y[i], self.global_end_y[i], ) def compose(self, output_frame: numpy.ndarray, input_frames: tuple): if not output_frame.ndim == 2: raise TypeError(f"output_frame is expected to be 2D and not {output_frame.ndim}") for ( global_start_y, global_end_y, local_start_y, local_end_y, input_frame, ) in zip( self.global_start_y, self.global_end_y, self.local_start_y, self.local_end_y, input_frames, ): if input_frame is not None: output_frame[global_start_y:global_end_y] = input_frame[local_start_y:local_end_y] @staticmethod def compute_raw_frame_compositions(frames: tuple, y_shifts: tuple, overlap_kernels: tuple): """ compute frame composition for raw data """ assert len(frames) == len(overlap_kernels) + 1 == len(y_shifts) + 1 global_start_y = [] local_start_y = [] global_end_y = [] local_end_y = [] frame_height_sum = 0 # extend shifts and kernels to have a first shift of 0 and two overlaps values at 0 to # generalize processing lower_shifts = [0] lower_shifts.extend(y_shifts) upper_shifts = list(copy(y_shifts)) upper_shifts.append(0) overlaps = [kernel.overlap_height for kernel in overlap_kernels] overlaps.append(0) overlaps.insert(0, 0) for ( frame, lower_shift, upper_shift, lower_overlap_kernel, upper_overlap_kernel, ) in zip(frames, lower_shifts, upper_shifts, overlaps[:-1], overlaps[1:]): if lower_shift > 0: raise ValueError( "Unexpected shift value found; TODO: handle positive shift (this mean no overlap between frames" ) lower_remaining = abs(lower_shift) - lower_overlap_kernel upper_o_s_diff = abs(upper_shift) - upper_overlap_kernel # policy with overlap that needs to take one more line on one side ((lower_remaining) % 2 == 1) # take this line on the lower frame side. new_local_start_y = lower_overlap_kernel + (lower_remaining) // 2 - (lower_remaining) % 2 new_local_end_y = frame.shape[0] - (upper_overlap_kernel + (upper_o_s_diff) // 2) new_global_start_y = frame_height_sum + new_local_start_y new_global_end_y = frame_height_sum + new_local_end_y frame_height_sum += frame.shape[0] - abs(upper_shift) # check values are coherent if new_local_start_y < 0 or new_local_end_y < 0 or new_global_start_y < 0 or new_global_end_y < 0: raise ValueError( "Incoherence found on the computing raw frame composition. Are you sure overlap height is no larger than frame height" ) global_start_y.append(int(new_global_start_y)) global_end_y.append(int(new_global_end_y)) local_start_y.append(int(new_local_start_y)) local_end_y.append(int(new_local_end_y)) return ZFrameComposition( local_start_y=tuple(local_start_y), local_end_y=tuple(local_end_y), global_start_y=tuple(global_start_y), global_end_y=tuple(global_end_y), ) @staticmethod def compute_stitch_frame_composition(frames, y_shifts: tuple, overlap_kernels: tuple): """ compute frame composition for stiching. """ assert len(frames) == len(overlap_kernels) + 1 global_start_y = [] global_end_y = [] local_start_y = [0] * len(y_shifts) # stiched is expected to be at the expected size already local_end_y = [kernel.overlap_height for kernel in overlap_kernels] frame_height_sum = 0 for frame, kernel, y_shift in zip(frames[:-1], overlap_kernels, y_shifts): if y_shift > 0: raise ValueError( "Unexpected shift value found; TODO: handle positive shift (this mean no overlap between frames" ) assert isinstance(kernel, ZStichOverlapKernel) new_global_start_y = ( frame_height_sum + frame.shape[0] - (kernel.overlap_height + (abs(y_shift) - kernel.overlap_height) // 2) ) new_global_end_y = frame_height_sum + frame.shape[0] - (abs(y_shift) - kernel.overlap_height) // 2 if new_global_start_y < 0 or new_global_end_y < 0: raise ValueError( "Incoherence found on the computing raw frame composition. Are you sure overlap height is no larger than frame height" ) global_start_y.append(int(new_global_start_y)) global_end_y.append(int(new_global_end_y)) frame_height_sum += frame.shape[0] + y_shift return ZFrameComposition( local_start_y=tuple(local_start_y), local_end_y=tuple(local_end_y), global_start_y=tuple(global_start_y), global_end_y=tuple(global_end_y), ) @staticmethod def pprint_z_stitching(raw_composition, stitch_composition): """ util to display what the output of the z stitch will looks like from composition """ for i_frame, (raw_comp, stitch_comp) in enumerate(zip(raw_composition.browse(), stitch_composition.browse())): raw_local_start, raw_local_end, raw_global_start, raw_global_end = raw_comp print( f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]" ) ( stitch_local_start, stitch_local_end, stitch_global_start, stitch_global_end, ) = stitch_comp print( f"stitch_frame[{stitch_global_start}:{stitch_global_end}] = stitched_frame_{i_frame}[{stitch_local_start}:{stitch_local_end}]" ) else: i_frame += 1 raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1] print( f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]" ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/stitching/overlap.py0000644000175000017500000001404200000000000017170 0ustar00pierrepierre# coding: utf-8 # /*########################################################################## # # Copyright (c) 2016-2017 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # # ###########################################################################*/ __authors__ = ["H. Payno"] __license__ = "MIT" __date__ = "10/05/2022" from typing import Optional from silx.utils.enum import Enum as _Enum import numpy class OverlapStichingStrategy(_Enum): MEAN = "mean" COSINUS_WEIGHTS = "cosinus weights" LINEAR_WEIGHTS = "linear weights" CLOSEST = "closest" DEFAULT_OVERLAP_STRATEGY = OverlapStichingStrategy.COSINUS_WEIGHTS DEFAULT_OVERLAP_HEIGHT = 400 class OverlapKernelBase: pass class ZStichOverlapKernel(OverlapKernelBase): """ class used to define overlap between two scans and create stitch between frames (`stitch` function) """ def __init__( self, frame_width: int, stitching_strategy: OverlapStichingStrategy = DEFAULT_OVERLAP_STRATEGY, overlap_height: int = DEFAULT_OVERLAP_HEIGHT, ) -> None: """ """ if not isinstance(overlap_height, int) or (overlap_height != -1 and not overlap_height > 0): raise TypeError( f"overlap_height is expected to be a positive int, {overlap_height} - not {overlap_height} ({type(overlap_height)})" ) if not isinstance(frame_width, int) or not frame_width > 0: raise TypeError( f"frame_width is expected to be a positive int, {frame_width} - not {frame_width} ({type(frame_width)})" ) self._overlap_height = overlap_height self._frame_width = frame_width self._stitching_strategy = OverlapStichingStrategy.from_value(stitching_strategy) self._weights_img_1 = None self._weights_img_2 = None @staticmethod def __check_img(img, name): if not isinstance(img, numpy.ndarray) and img.ndim == 2: raise ValueError(f"{name} is expected to be 2D numpy array") @property def overlap_height(self) -> int: return self._overlap_height @overlap_height.setter def overlap_height(self, height: int): if not isinstance(height, int): raise TypeError(f"height expects a int ({type(height)} provided instead)") if not height >= 0: raise ValueError(f"height is expected to be positive") self._overlap_height = height # update weights if needed if self._weights_img_1 is not None or self._weights_img_2 is not None: self.compute_weights() @property def img_2(self) -> numpy.ndarray: return self._img_2 @property def weights_img_1(self) -> Optional[numpy.ndarray]: return self._weights_img_1 @property def weights_img_2(self) -> Optional[numpy.ndarray]: return self._weights_img_2 @property def stitching_strategy(self) -> OverlapStichingStrategy: return self._stitching_strategy def compute_weights(self): if self.stitching_strategy is OverlapStichingStrategy.MEAN: weights_img_1 = numpy.ones(self._overlap_height) * 0.5 weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStichingStrategy.CLOSEST: n_item = self._overlap_height // 2 + self._overlap_height % 2 weights_img_1 = numpy.concatenate( [ numpy.ones(n_item), numpy.zeros(self._overlap_height - n_item), ] ) weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStichingStrategy.LINEAR_WEIGHTS: weights_img_1 = numpy.linspace(1.0, 0.0, self._overlap_height) weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStichingStrategy.COSINUS_WEIGHTS: angles = numpy.linspace(0.0, numpy.pi / 2.0, self._overlap_height) weights_img_1 = numpy.cos(angles) ** 2 weights_img_2 = numpy.sin(angles) ** 2 else: raise NotImplementedError(f"{self.stitching_strategy} not implemented") self._weights_img_1 = weights_img_1.reshape(-1, 1) * numpy.ones(self._frame_width).reshape(1, -1) self._weights_img_2 = weights_img_2.reshape(-1, 1) * numpy.ones(self._frame_width).reshape(1, -1) def stitch(self, img_1, img_2, check_input=True) -> tuple: """Compute overlap region from the defined strategy""" if check_input: self.__check_img(img_1, "img_1") self.__check_img(img_2, "img_2") if img_1.shape != img_2.shape: raise ValueError( f"both images are expected to be of the same shape to apply stitch ({img_1.shape} vs {img_2.shape})" ) if self.weights_img_1 is None or self.weights_img_2 is None: self.compute_weights() return ( img_1 * self.weights_img_1 + img_2 * self.weights_img_2, self.weights_img_1, self.weights_img_2, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/stitching/utils.py0000644000175000017500000001734500000000000016671 0ustar00pierrepierrefrom typing import Optional, Union import numpy from nabu.stitching.overlap import OverlapStichingStrategy, ZStichOverlapKernel from tomoscan.scanbase import TomoScanBase import logging from scipy.ndimage import shift as scipy_shift _logger = logging.getLogger(__name__) try: from skimage.registration import phase_cross_correlation except ImportError: _logger.warning( "Unable to load skimage. Please install it if you want to use it for finding shifts from `find_relative_shifts`" ) __has_sk_phase_correlation__ = False else: __has_sk_phase_correlation__ = True def test_overlap_stitching_strategy(overlap_1, overlap_2, stitching_strategies): """ stitch the two ovrelap with all the requested strategies. Return a dictionary with stitching strategy as key and a result dict as value. result dict keys are: 'weights_overlap_1', 'weights_overlap_2', 'stiching' """ res = {} for strategy in stitching_strategies: s = OverlapStichingStrategy.from_value(strategy) stitcher = ZStichOverlapKernel( stitching_strategy=s, overlap_height=overlap_1.shape[0], frame_width=overlap_1.shape[1], ) stiched_overlap, w1, w2 = stitcher.stitch(overlap_1, overlap_2, check_input=True) res[s.value] = { "stitching": stiched_overlap, "weights_overlap_1": w1, "weights_overlap_2": w2, } return res def find_relative_shifts( scan_0: TomoScanBase, scan_1: TomoScanBase, projection_for_shift: Union[int, str] = "middle", invert_order: bool = False, x_cross_correlation_function=None, y_cross_correlation_function=None, auto_flip: bool = True, ) -> tuple: """ deduce the relative shift between the two scans. Expected behavior: * compute expected overlap area from z_translations and (sample) pixel size * call a cross correlation function from the overlap area to compute the x shift and polish the y shift from `projection_for_shift` :param TomoScanBase scan_0: :param TomoScanBase scan_1: :param Union[int,str] projection_for_shift: index fo the projection to use (in projection space or in scan space ?. For now in projection) or str. If str must be in (`middle`, `first`, `last`) :param str x_cross_correlation_function: optional method to refine x shift from computing cross correlation. For now valid values are: ("skimage", "skimage-fourier") :param str y_cross_correlation_function: optional method to refine y shift from computing cross correlation. For now valid values are: ("skimage", "skimage-fourier") :param int minimal_overlap_area_for_cross_correlation: if first approximated overlap shift found from z_translation is lower than this value will fall back on taking the full image for the cross correlation and log a warning :param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted) :param bool auto_flip: if True then will automatically flip frames to get a "homogeneous" result based on unflipped frames :return: relative shift of scan_1 with scan_0 as reference: (y_shift, x_shift) :rtype: tuple :warning: this function will flip left-right and up-down frames by default. So it will return shift according to this information """ def get_flat_fielded_proj(scan: TomoScanBase, proj_index: int, reverse: bool, revert_x: bool, revert_y): first_proj_idx = sorted(scan_1.projections.keys(), reverse=reverse)[proj_index] ff = scan.flat_field_correction( (scan.projections[first_proj_idx],), (first_proj_idx,), )[0] if auto_flip and revert_x: ff = numpy.fliplr(ff) if auto_flip and revert_y: ff = numpy.flipud(ff) return ff if isinstance(projection_for_shift, str): if projection_for_shift.lower() == "first": projection_for_shift = 0 elif projection_for_shift.lower() == "last": projection_for_shift = -1 elif projection_for_shift.lower() == "middle": projection_for_shift = len(scan_0.projections) // 2 else: try: projection_for_shift = int(projection_for_shift) except ValueError: raise ValueError( f"{projection_for_shift} cannot be cast to an int and is not one of the possible ('first', 'last', 'middle')" ) elif not isinstance(projection_for_shift, (int, numpy.number)): raise TypeError( f"projection_for_shift is expected to be an int. Not {type(projection_for_shift)} - {projection_for_shift}" ) proj_0 = get_flat_fielded_proj( scan_0, projection_for_shift, reverse=False, revert_x=scan_0.get_x_flipped(default=False), revert_y=scan_0.get_y_flipped(default=False), ) proj_1 = get_flat_fielded_proj( scan_1, projection_for_shift, reverse=invert_order, revert_x=scan_1.get_x_flipped(default=False), revert_y=scan_1.get_y_flipped(default=False), ) # get overlap area from z scan_0_y_bb = scan_0.get_bounding_box(axis="z") scan_1_y_bb = scan_1.get_bounding_box(axis="z") scan_0_scan_1_overlap = scan_0_y_bb.get_overlap(scan_1_y_bb) if scan_0_scan_1_overlap is not None: overlap = scan_0_scan_1_overlap.max - scan_0_scan_1_overlap.min overlap_percentage = (overlap) / (scan_0_y_bb.max - scan_0_y_bb.min) y_overlap_frm_position_in_pixel = int(overlap_percentage * scan_0.dim_2) overlap_1 = proj_0[-y_overlap_frm_position_in_pixel:] overlap_2 = proj_1[:y_overlap_frm_position_in_pixel:] else: _logger.warning( "no overlap founds from scan metadata. Take the full image to try to find an overlap. Automatic shift deduction has an higher probability to fail" ) x_found_shift = 0 y_found_shift = -y_overlap_frm_position_in_pixel if x_cross_correlation_function in ("skimage", "skimage-fourier"): if not __has_sk_phase_correlation__: raise ValueError("scikit-image not installed. Cannot do phase correlation from it") if x_cross_correlation_function == "skimage-fourier": overlap_1 = numpy.fft.fftn(overlap_1) overlap_2 = numpy.fft.fftn(overlap_2) space = "fourier" else: space = "real" found_shift, _, _ = phase_cross_correlation(reference_image=overlap_1, moving_image=overlap_2, space=space) x_found_shift = found_shift[1] elif x_cross_correlation_function is not None: raise ValueError(f"requested cross correlation function not handled ({x_cross_correlation_function})") if y_cross_correlation_function in ("skimage", "skimage-fourier"): if not __has_sk_phase_correlation__: raise ValueError("scikit-image not installed. Cannot do phase correlation from it") if y_cross_correlation_function == "skimage-fourier": overlap_1 = numpy.fft.fftn(overlap_1) overlap_2 = numpy.fft.fftn(overlap_2) space = "fourier" else: space = "real" found_shift, _, _ = phase_cross_correlation(reference_image=overlap_1, moving_image=overlap_2, space=space) y_found_shift = found_shift[0] - y_overlap_frm_position_in_pixel elif y_cross_correlation_function is not None: raise ValueError(f"requested cross correlation function not handled ({y_cross_correlation_function})") # if y_found_shift > 0: _logger.warning( f"found a positive shift ({found_shift[0]}) when a negative one is expected. Are you sure about the scan z ordering. This is likely z stitching will fails" ) return tuple([int(y_found_shift), int(x_found_shift)]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/stitching/z_stitching.py0000644000175000017500000012407600000000000020056 0ustar00pierrepierre# coding: utf-8 # /*########################################################################## # # Copyright (c) 2016-2017 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # # ###########################################################################*/ __authors__ = ["H. Payno"] __license__ = "MIT" __date__ = "10/05/2022" from copy import copy from datetime import datetime from multiprocessing.sharedctypes import Value import os from silx.io.utils import get_data from typing import Optional, Union from silx.io.url import DataUrl from silx.io.dictdump import dicttonx import numpy from tomoscan.esrf import HDF5TomoScan, EDFTomoScan from tomoscan.serie import Serie from tomoscan.esrf.scan.hdf5scan import ImageKey from tomoscan.nexus.paths.nxtomo import get_paths as _get_nexus_paths from tomoscan.scanbase import TomoScanBase from scipy.ndimage import shift as shift_scipy from nxtomomill.nexus import NXtomo from nabu.io.utils import DatasetReader from tomoscan.esrf.scan.utils import ( get_compacted_dataslices, ) # this version has a 'return_url_set' needed here. At one point they should be merged together from nabu.stitching.frame_composition import ZFrameComposition from nabu.stitching.utils import find_relative_shifts from nabu.stitching.config import ( KEY_CROSS_CORRELATION_SLICE, KEY_X_CROSS_CORRELATION_FUNC, KEY_Y_CROSS_CORRELATION_FUNC, PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration, ZStitchingConfiguration, ) from nabu.utils import Progress from nabu import version as nabu_version from nabu.io.writer import get_datetime from .overlap import ( DEFAULT_OVERLAP_STRATEGY, OverlapKernelBase, OverlapStichingStrategy, ZStichOverlapKernel, ) from tomoscan.io import HDF5File import h5py import logging _logger = logging.getLogger(__name__) def z_stitching(configuration: ZStitchingConfiguration, progress=None) -> DataUrl: """ Apply stitching from provided configuration. Return a DataUrl with the created NXtomo or Volume """ if isinstance(configuration, PreProcessedZStitchingConfiguration): stitcher = PreProcessZStitcher(configuration=configuration, progress=progress) elif isinstance(configuration, PostProcessedZStitchingConfiguration): raise NotImplementedError else: raise TypeError( f"configuration is expected to be in {(PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration)}. {type(configuration)} provided" ) return stitcher.stitch() class ZStitcher: def __init__(self, configuration, progress: Progress = None) -> None: if not isinstance(configuration, ZStitchingConfiguration): raise TypeError self._configuration = copy(configuration) # copy configuration because we will edit it self._progress = progress # z serie must be defined from daughter class assert hasattr(self, "_z_serie") def is_auto(param): return param in ("auto", ("auto",)) # 'expend' auto shift request if only set once for all if is_auto(self.configuration.x_shifts): self.configuration.x_shifts = [ "auto", ] * (len(self.z_serie) - 1) elif numpy.isscalar(self.configuration.x_shifts): self.configuration.x_shifts = [ self.configuration.x_shifts, ] * (len(self.z_serie) - 1) # 'expend' overlaph height and if is_auto(self.configuration.overlap_height): self.configuration.overlap_height = [ "auto", ] * (len(self.z_serie) - 1) elif numpy.isscalar(self.configuration.overlap_height): self.configuration.overlap_height = [self.configuration.overlap_height] * (len(self.z_serie) - 1) # 'expend' stitching height if is_auto(self.configuration.stitching_height): self.configuration.stitching_height = [ "auto", ] * (len(self.z_serie) - 1) elif numpy.isscalar(self._configuration.stitching_height): self.configuration.stitching_height = [self.configuration.stitching_height] * (len(self.z_serie) - 1) def stitch(self) -> DataUrl: """ Apply expected stitch from configuration and return the DataUrl of the object created """ raise NotImplementedError("base class") def _compute_shifts(self): raise NotImplementedError("base class") @property def z_serie(self) -> Serie: return self._z_serie @property def configuration(self) -> ZStitchingConfiguration: return self._configuration @property def progress(self) -> Optional[Progress]: return self._progress @staticmethod def get_overlap_areas( lower_frame: numpy.ndarray, upper_frame: numpy.ndarray, real_overlap: int, stitching_height: int, ): """ return the requested area from lower_frame and upper_frame. Lower_frame contains at the end of it the 'real overlap' with the upper_frame. Upper_frame contains the 'real overlap' at the end of it. For some reason the user can ask the stitching height to be smaller than the `real overlap`. Here are some drawing to have a better of view of those regions: .. image:: images/stitching/z_stitch_real_overlap.png :width: 600 .. image:: z_stitch_stitch_height.png :width: 600 """ assert real_overlap >= 0 assert stitching_height >= 0 if stitching_height > real_overlap: raise ValueError(f"stitching height ({stitching_height}) larger than existing overlap ({real_overlap}).") real_overlap_0 = lower_frame[-real_overlap:] real_overlap_1 = upper_frame[:real_overlap] if not real_overlap_0.shape == real_overlap_1.shape: raise RuntimeError( f"lower and upper frame have different overlap size ({real_overlap_0.shape} vs {real_overlap_1.shape})" ) low_pos = int(real_overlap // 2 - stitching_height // 2) hight_pos = int(real_overlap // 2 + stitching_height // 2) + (stitching_height) % 2 # if there is one more line to take on one side take it on the lower_frame side assert real_overlap_0[low_pos:hight_pos].shape == real_overlap_1[low_pos:hight_pos].shape return real_overlap_0[low_pos:hight_pos], real_overlap_1[low_pos:hight_pos] @staticmethod def stitch_frames( frames: tuple, x_relative_shifts: tuple, y_overlap_heights: tuple, output_dtype: numpy.ndarray, output_dataset: Optional[Union[h5py.Dataset, numpy.ndarray]] = None, check_inputs=True, shift_mode="nearest", overlap_kernels=None, i_frame=None, ) -> numpy.ndarray: """ shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and save them to output_dataset. :param tuple frames: element must be a DataUrl or a 2D numpy array """ if check_inputs: if len(frames) < 2: raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)") if len(frames) != len(x_relative_shifts) + 1: raise ValueError( f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected" ) if len(x_relative_shifts) != len(y_overlap_heights): raise ValueError( f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(y_overlap_heights)})" ) def check_frame_is_2d(frame): if frame.ndim != 2: raise ValueError(f"2D frame expected when {frame.ndim}D provided") # step_0 load data if from url data = [] for frame in frames: if isinstance(frame, DataUrl): data_frame = get_data(frame) if check_inputs: check_frame_is_2d(data_frame) data.append(data_frame) elif isinstance(frame, numpy.ndarray): if check_inputs: check_frame_is_2d(frame) data.append(frame) else: raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}") # step 1: shift each frames (except the first one) x_shifted_data = [data[0]] for frame, x_relative_shift in zip(data[1:], x_relative_shifts): # note: for now we only shift data in x. the y shift is handled in the FrameComposition x_relative_shift = numpy.asarray(x_relative_shift).astype(numpy.int8) if x_relative_shift == 0: shifted_frame = frame else: # TO speed up: should use the Fourier transform shifted_frame = shift_scipy( frame, mode=shift_mode, shift=[0, -x_relative_shift], order=1, ) x_shifted_data.append(shifted_frame) # step 2: create stitched frame if overlap_kernels is None: overlap_kernels = ZStichOverlapKernel(frame_width=data[0].shape[1]) stitched_frame = z_stitch_raw_frames( frames=x_shifted_data, y_shifts=[-abs(y_overlap_height) for y_overlap_height in y_overlap_heights], overlap_kernels=overlap_kernels, check_inputs=check_inputs, output_dtype=output_dtype, ) # step 3: dump stitched frame if output_dataset is not None and i_frame is not None: output_dataset[i_frame] = stitched_frame return stitched_frame class PreProcessZStitcher(ZStitcher): def __init__(self, configuration, progress=None) -> None: # z serie must be defined first self._z_serie = Serie("z-serie", iterable=configuration.input_scans, use_identifiers=False) self._reading_orders = [] self._x_flips = [] self._y_flips = [] # some scan can have been taken in the opposite order (so must be read on the opposite order one from the other) super().__init__(configuration, progress) @property def reading_order(self): """ as scan can be take on one direction or the order (rotation goes from X to Y then from Y to X) we might need to read data from one direction or another """ return self._reading_orders @property def x_flips(self) -> list: return self._x_flips @property def y_flips(self) -> list: return self._y_flips def stitch(self): if self.progress is not None: self.progress.set_name("order scans") self._order_scans() if self.progress is not None: self.progress.set_name("check inputs") self._check_inputs() if self.progress is not None: self.progress.set_name("compute flat field") self._compute_reduced_flats_and_darks() if self.progress is not None: self.progress.set_name("compute shift") self._compute_shifts() if self.progress is not None: self.progress.set_name("stitch projections, save them and create NXtomo") self._create_nx_tomo() if self.progress is not None: self.progress.set_name("dump configuration") self._dump_stitching_configuration() return DataUrl( file_path=self.configuration.output_file_path, data_path=self.configuration.output_data_path, scheme="h5py", ) def _order_scans(self): """ ensure scans are in z increasing order """ def get_min_z(scan): return scan.get_bounding_box(axis="z").min sorted_z_serie = Serie( self.z_serie.name, sorted(self.z_serie[:], key=get_min_z, reverse=True), use_identifiers=False, ) if sorted_z_serie != self.z_serie: if sorted_z_serie[:] != self.z_serie[::-1]: raise ValueError("Unable to get comprehensive input. Z (decreasing) ordering is not respected.") else: _logger.warning( f"z decreasing order haven't been respected. Need to reorder z serie ({[str(scan) for scan in sorted_z_serie[:]]}). Will also reorder overlap height, stitching height and invert shifts" ) self.configuration.overlap_height = self.configuration.overlap_height[::-1] self.configuration.x_shifts = [ -x_shift if x_shift != "auto" else x_shift for x_shift in self.configuration.x_shifts ] self.configuration.stitching_height = self.configuration.stitching_height[::-1] self._z_serie = sorted_z_serie def _check_inputs(self): """ insure input data is coherent """ n_scans = len(self.z_serie) if n_scans == 0: raise ValueError("no scan to stich together") # check number of shift provided if len(self.configuration.x_shifts) != (n_scans - 1): raise ValueError(f"expect {n_scans -1} shift defined. Get {len(self.configuration.x_shifts)}") if len(self.configuration.overlap_height) != (n_scans - 1): raise ValueError(f"expect {n_scans - 1} overlap defined. Get {len(self.configuration.overlap_height)}") if len(self.configuration.stitching_height) != (n_scans - 1): raise ValueError( f"expect {n_scans - 1} stitching height defined. Get {len(self.configuration.overlap_height)}" ) for scan in self.z_serie: if scan.x_flipped is None or scan.y_flipped is None: _logger.warning( f"Found at least one scan with no frame flips information ({scan}). Will consider those are unflipped. Might end up with some inverted frame errors." ) break self._reading_orders = [] # the first scan will define the expected reading orderd, and expected flip. # if all scan are flipped then we will keep it this way self._reading_orders.append(1) # check scans are coherent (nb projections, rotation angle, energy...) for scan_0, scan_1 in zip(self.z_serie[0:-1], self.z_serie[1:]): if len(scan_0.projections) != len(scan_1.projections): raise ValueError(f"{scan_0} and {scan_1} have a different number of projections") if isinstance(scan_0, HDF5TomoScan) and isinstance(scan_1, HDF5TomoScan): # check rotation (only of is an HDF5TomoScan) scan_0_angles = numpy.asarray(scan_0.rotation_angle) scan_0_projections_angles = scan_0_angles[ numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value ] scan_1_angles = numpy.asarray(scan_1.rotation_angle) scan_1_projections_angles = scan_1_angles[ numpy.asarray(scan_1.image_key_control) == ImageKey.PROJECTION.value ] if not numpy.allclose(scan_0_projections_angles, scan_1_projections_angles, atol=10e-1): if numpy.allclose( scan_0_projections_angles, scan_1_projections_angles[::-1], atol=10e-1, ): reading_order = -1 * self._reading_orders[-1] else: raise ValueError(f"Angles from {scan_0} and {scan_1} are different") else: reading_order = 1 * self._reading_orders[-1] self._reading_orders.append(reading_order) # check energy if scan_0.energy is None: _logger.warning(f"no energy found for {scan_0}") elif not numpy.isclose(scan_0.energy, scan_1.energy, rtol=1e-03): _logger.warning( f"different energy found between {scan_0} ({scan_0.energy}) and {scan_1} ({scan_1.energy})" ) # check FOV if not scan_0.field_of_view == scan_1.field_of_view: raise ValueError(f"{scan_0} and {scan_1} have different field of view") # check distance if scan_0.distance is None: _logger.warning(f"no distance found for {scan_0}") elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3): raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance") # check pixel size if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size): raise ValueError( f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}" ) if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size): raise ValueError( f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}" ) # check magnification (only if is HDF5TomoScan) if isinstance(scan_0, HDF5TomoScan) and isinstance(scan_1, HDF5TomoScan): if not numpy.isclose(scan_0.magnification, scan_1.magnification): raise ValueError( f"{scan_0} and {scan_1} have different magnification. {scan_0.magnification} vs {scan_1.magnification}" ) if scan_0.dim_1 != scan_1.dim_1: raise ValueError( f"projections width are expected to be the same. Not the canse for {scan_0} ({scan_0.dim_1} and {scan_1} ({scan_1.dim_1}))" ) for scan in self.z_serie: # check x, y and z translation are constant (only if is an HDF5TomoScan) if isinstance(scan_0, HDF5TomoScan) and isinstance(scan_1, HDF5TomoScan): if scan.x_translation is not None and not numpy.isclose( min(scan.x_translation), max(scan.x_translation) ): _logger.warning( "x translations appears to be evolving over time. Might end up with wrong stitching" ) if scan.y_translation is not None and not numpy.isclose( min(scan.y_translation), max(scan.y_translation) ): _logger.warning( "y translations appears to be evolving over time. Might end up with wrong stitching" ) if scan.z_translation is not None and not numpy.isclose( min(scan.z_translation), max(scan.z_translation) ): _logger.warning( "z translations appears to be evolving over time. Might end up with wrong stitching" ) def _compute_reduced_flats_and_darks(self): """ TODO: should be done with nabu stuff !!! """ for scan in self.z_serie: try: reduced_darks = scan.load_reduced_darks() except: _logger.info("no reduced dark found. Try to compute them.") if reduced_darks in (None, {}): reduced_darks = scan.compute_reduced_darks() try: # if we don't have write in the folder containing the .nx for example scan.save_reduced_darks(reduced_darks) except: pass scan.set_reduced_darks(reduced_darks) try: reduced_flats = scan.load_reduced_flats() except: _logger.info("no reduced flats found. Try to compute them.") if reduced_flats in (None, {}): reduced_flats = scan.compute_reduced_flats() try: # if we don't have write in the folder containing the .nx for example scan.save_reduced_flats(reduced_flats) except: pass scan.set_reduced_flats(reduced_flats) def _compute_shifts(self): """ compute all shift requested (set to 'auto' in the configuration) """ n_scans = len(self.configuration.input_scans) if n_scans == 0: raise ValueError("no scan to stich provided") # get shift final_shifts = [] projection_for_shift = self.configuration.auto_relative_shift_params.get(KEY_CROSS_CORRELATION_SLICE, "middle") x_cross_correlation_function = self.configuration.auto_relative_shift_params.get( KEY_X_CROSS_CORRELATION_FUNC, None ) y_cross_correlation_function = self.configuration.auto_relative_shift_params.get( KEY_Y_CROSS_CORRELATION_FUNC, None ) for scan_0, scan_1, order_s0, order_s1, x_relative_shift, y_overlap in zip( self.z_serie[:-1], self.z_serie[1:], self.reading_order[:-1], self.reading_order[1:], self.configuration.x_shifts, self.configuration.overlap_height, ): # compute relative shift if x_relative_shift == "auto" or y_overlap == "auto": found_y, found_x = find_relative_shifts( scan_0=scan_0, scan_1=scan_1, projection_for_shift=projection_for_shift, x_cross_correlation_function=x_cross_correlation_function, y_cross_correlation_function=y_cross_correlation_function, invert_order=order_s1 != order_s0, auto_flip=True, ) final_shift = ( found_y if y_overlap == "auto" else y_overlap, found_x if x_relative_shift == "auto" else x_relative_shift, ) _logger.info( f"between {scan_0} and {scan_1} found a shift of {final_shift}. cross_correlation function used are: {x_cross_correlation_function} for x and {y_cross_correlation_function} for y" ) final_shifts.append(final_shift) # set back values self.configuration.x_shifts = [final_shift[1] for final_shift in final_shifts] self.configuration.overlap_height = [abs(final_shift[0]) for final_shift in final_shifts] @staticmethod def _data_bunch_iterator(n_projections, bunch_size): proj_i = 0 while n_projections - proj_i > bunch_size: yield (proj_i - bunch_size, proj_i) proj_i += bunch_size else: yield (proj_i, n_projections - 1) @staticmethod def _get_bunch_of_data(bunch_start: int, bunch_end: int, scans: tuple, scans_projections_indexes: tuple): """ goal is to load contiguous projections as much as possible... :param scans_with_proj_indexes: tuple with scans and scan projection indexes to be loaded :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together """ assert len(scans) == len(scans_projections_indexes) scans_proj_urls = [] # for each scan store the real indices and the data url for scan, scan_projection_indexes in zip(scans, scans_projections_indexes): scan_proj_urls = {} # for each scan get the list of url to be loaded for i_proj in range(bunch_start, bunch_end): # for scan, scan_projections_indexes in zip( # scans, scans_projections_indexes # ): proj_index_in_full_scan = scan_projection_indexes[i_proj] scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan] scans_proj_urls.append(scan_proj_urls) # then load data all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object) for i_scan, scan_urls in enumerate(scans_proj_urls): i_frame = 0 _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True) for _, url in set_of_compacted_slices.items(): url = DataUrl( file_path=url.file_path(), data_path=url.data_path(), scheme="silx", data_slice=url.data_slice(), ) loaded_slices = get_data(url) if loaded_slices.ndim == 3: n_slice = loaded_slices.shape[0] else: n_slice = 1 loaded_slices = [ loaded_slices, ] scan_indexes = list(scan_urls.keys()) data = scan.flat_field_correction( loaded_slices, range( scan_indexes[i_frame], scan_indexes[i_frame] + n_slice, ), ) flip_lr = scans[i_scan].get_x_flipped(default=False) flip_ud = scans[i_scan].get_y_flipped(default=False) for frame in data: f_frame = frame if flip_lr: f_frame = numpy.fliplr(f_frame) if flip_ud: f_frame = numpy.flipud(f_frame) all_scan_final_data[i_frame, i_scan] = f_frame i_frame += 1 return all_scan_final_data def _create_nx_tomo(self): """ create final NXtomo with stitched frames. Policy: save all projections flat fielded. So this NXtomo will only contain projections (no dark and no flat). But nabu will be able to reconstruct it with field `flatfield` set to False """ if "auto" in self.configuration.x_shifts: raise RuntimeError("Looks like some shift haven't been computed") nx_tomo = NXtomo() nx_tomo.energy = self.z_serie[0].energy start_times = list(filter(None, [scan.start_time for scan in self.z_serie])) end_times = list(filter(None, [scan.end_time for scan in self.z_serie])) if len(start_times) > 0: nx_tomo.start_time = ( numpy.asarray([numpy.datetime64(start_time) for start_time in start_times]).min().astype(datetime) ) else: _logger.warning("Unable to find any start_time from input") if len(end_times) > 0: nx_tomo.end_time = ( numpy.asarray([numpy.datetime64(end_time) for end_time in end_times]).max().astype(datetime) ) else: _logger.warning("Unable to find any end_time from input") title = ",".join([scan.sequence_name or "" for scan in self.z_serie]) nx_tomo.title = f"stitch done from {title}" # handle detector (without frames) nx_tomo.instrument.detector.field_of_view = self.z_serie[0].field_of_view nx_tomo.instrument.detector.distance = self.z_serie[0].distance nx_tomo.instrument.detector.x_pixel_size = self.z_serie[0].x_pixel_size nx_tomo.instrument.detector.y_pixel_size = self.z_serie[0].y_pixel_size nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * len(self.z_serie[0].projections) nx_tomo.instrument.detector.tomo_n = len(self.z_serie[0].projections) if isinstance(self.z_serie[0], HDF5TomoScan): nx_tomo.instrument.detector.magnification = self.z_serie[0].magnification # note: stitching process insure unflipping of frames nx_tomo.instrument.detector.x_flipped = False nx_tomo.instrument.detector.y_flipped = False if isinstance(self.z_serie[0], HDF5TomoScan): # note: first scan is always the reference as order to read data (so no rotation_angle inversion here) rotation_angle = numpy.asarray(self.z_serie[0].rotation_angle) nx_tomo.sample.rotation_angle = rotation_angle[ numpy.asarray(self.z_serie[0].image_key_control) == ImageKey.PROJECTION.value ] elif isinstance(self.z_serie[0], EDFTomoScan): nx_tomo.sample.rotation_angle = numpy.linspace( start=0, stop=self.z_serie[0].scan_range, num=self.z_serie[0].tomo_n ) else: raise NotImplementedError( f"scan type ({type(self.z_serie[0])} is not handled)", HDF5TomoScan, isinstance(self.z_serie[0], HDF5TomoScan), ) # handle sample n_frames = len(nx_tomo.sample.rotation_angle) if False not in [isinstance(scan, HDF5TomoScan) for scan in self.z_serie]: # we consider the new x, y and z position to be at the center of the one created x_translation = [scan.x_translation for scan in self.z_serie if scan.x_translation is not None] nx_tomo.sample.x_translation = [numpy.asarray(x_translation).mean()] * n_frames y_translation = [scan.y_translation for scan in self.z_serie if scan.y_translation is not None] nx_tomo.sample.y_translation = [numpy.asarray(y_translation).mean()] * n_frames z_translation = [scan.z_translation for scan in self.z_serie if scan.z_translation is not None] nx_tomo.sample.z_translation = [numpy.asarray(z_translation).mean()] * n_frames nx_tomo.sample.name = self.z_serie[0].sample_name # compute stiched frame shape n_proj = len(self.z_serie[0].projections) y_overlaps = self.configuration.overlap_height stitched_frame_shape = ( n_proj, int( numpy.asarray([scan.dim_2 for scan in self.z_serie]).sum() - numpy.asarray([abs(overlap) for overlap in y_overlaps]).sum() ), self.z_serie[0].dim_1, ) # get expected output dataset first (just in case output and input files are the same) first_proj_idx = sorted(self.z_serie[0].projections.keys())[0] first_proj_url = self.z_serie[0].projections[first_proj_idx] if h5py.is_hdf5(first_proj_url.file_path()): first_proj_url = DataUrl( file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), scheme="h5py", ) # first save the NXtomo entry without the frame # dicttonx will fail if the folder does not exists os.makedirs(os.path.dirname(self.configuration.output_file_path), exist_ok=True) nx_tomo.save( file_path=self.configuration.output_file_path, data_path=self.configuration.output_data_path, nexus_path_version=self.configuration.output_nexus_version, overwrite=self.configuration.overwrite_results, ) _logger.info( f"reading order is {self.reading_order}", ) # append frames ("instrument/detactor/data" dataset) with HDF5File(filename=self.configuration.output_file_path, mode="a") as h5f: # note: nx_tomo.save already handles the possible overwrite conflict by removing # self.configuration.output_file_path or raising an error stitched_frame_path = "/".join( [ self.configuration.output_data_path, _get_nexus_paths(self.configuration.output_nexus_version).PROJ_PATH, ] ) projection_dataset = h5f.create_dataset( name=stitched_frame_path, shape=stitched_frame_shape, dtype=self.configuration.output_dtype, ) # TODO: we could also create in several time and create a virtual dataset from it. scans_projections_indexes = [] for scan, reverse in zip(self.z_serie, self.reading_order): scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1))) if self.progress: self.progress.set_max_advancement(len(scan.projections.keys())) # for each indexes create a value which is the list of url to stitch together # for now only try to do the first two overlap_kernels = [] for overlap_height in self.configuration.overlap_height: overlap_kernels.append( ZStichOverlapKernel( frame_width=self.z_serie[0].dim_1, stitching_strategy=self.configuration.stitching_strategy, overlap_height=-1 if overlap_height == "auto" else overlap_height, ) ) i_proj = 0 for bunch_start, bunch_end in self._data_bunch_iterator(len(scan.projections), bunch_size=50): for data_frames in self._get_bunch_of_data( bunch_start, bunch_end, scans=self.z_serie, scans_projections_indexes=scans_projections_indexes, ): # TODO: try to do this in parallel or at least dump then in one go. but not sure this last one would speed up. # should be handled by the flushing mecanism ZStitcher.stitch_frames( frames=data_frames, x_relative_shifts=self.configuration.x_shifts, y_overlap_heights=self.configuration.overlap_height, output_dataset=projection_dataset, overlap_kernels=overlap_kernels, i_frame=i_proj, output_dtype=self.configuration.output_dtype, ) if self.progress is not None: self.progress.increase_advancement() i_proj += 1 # create link to this dataset that can be missing # "data/data" link if "data" in h5f[self.configuration.output_data_path]: data_group = h5f[self.configuration.output_data_path]["data"] if not stitched_frame_path.startswith("/"): stitched_frame_path = "/" + stitched_frame_path data_group["data"] = h5py.SoftLink(stitched_frame_path) if "default" not in h5f[self.configuration.output_data_path].attrs: h5f[self.configuration.output_data_path].attrs["default"] = "data" for attr_name, attr_value in zip( ("NX_class", "SILX_style/axis_scale_types", "signal"), ("NXdata", ["linear", "linear"], "data"), ): if attr_name not in data_group.attrs: data_group.attrs[attr_name] = attr_value return nx_tomo def _dump_stitching_configuration(self): """dump configuration used for stitching at the NXtomo entry""" process_name = "stitching_configuration" config_dict = self.configuration.to_dict() # adding nabu specific information nabu_process_info = { "@NX_class": "NXentry", f"{process_name}@NX_class": "NXprocess", f"{process_name}/program": "nabu-stitching", f"{process_name}/version": nabu_version, f"{process_name}/date": get_datetime(), f"{process_name}/configuration": config_dict, } dicttonx( nabu_process_info, h5file=self.configuration.output_file_path, h5path=self.configuration.output_data_path, update_mode="replace", mode="a", ) def z_stitch_raw_frames( frames: tuple, y_shifts: tuple, output_dtype: numpy.dtype = numpy.float32, check_inputs=True, overlap_kernels: Optional[Union[ZStichOverlapKernel, tuple]] = None, raw_frames_compositions: Optional[ZFrameComposition] = None, overlap_frames_compositions: Optional[ZFrameComposition] = None, ) -> numpy.ndarray: """ stitches raw frames (already shifted and flat fielded !!!) together using raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int) returns stitched_projection, raw_img_1, raw_img_2, computed_overlap proj_0 and pro_1 are already expected to be in a row. Having stitching_height_in_px in common. At top of proj_0 and at bottom of proj_1 :param tuple frames: tuple of 2D numpy array :param stitching_heights_in_px: scalar value of the stitching to apply or a list of len(frames) - 1 size with each stitching_height to apply between each couple of frame :param raw_frames_compositions: pre computed raw frame composition. If not provided will compute them. allow providing it to speed up calculation :param overlap_frames_compositions: pre computed stitched frame composition. If not provided will compute them. allow providing it to speed up calculation :param overlap_kernels: ZStichOverlapKernel overlap kernel to be used or a list of kernel (one per overlap). Define startegy and overlap heights :param numpy.dtype output_dtype: dataset dtype. For now must be provided because flat field corrcetion change data type (numpy.float32 for now) """ if overlap_kernels is None: # handle overlap area if overlap_kernels is None and len(frames) > 0: # FIXME ! # pylint: disable= E1123,E1120 proj_0 = frames[0] overlap_kernels = ZStichOverlapKernel( stitching_strategy=DEFAULT_OVERLAP_STRATEGY, overlap_height=proj_0.shape[0], frame_width=proj_0.shape[1], ) if isinstance(overlap_kernels, OverlapKernelBase): overlap_kernels = [copy(overlap_kernels) for _ in (len(frames) - 1)] if check_inputs: def check_proj(proj): if not isinstance(proj, numpy.ndarray) and proj.ndim == 2: raise ValueError(f"frames are expected to be 2D numpy array") [check_proj(frame) for frame in frames] for proj_0, proj_1 in zip(frames[:-1], frames[1:]): if proj_0.shape[1] != proj_1.shape[1]: raise ValueError("Both projections are expected to have the same width") for proj_0, proj_1, kernel in zip(frames[:-1], frames[1:], overlap_kernels): if proj_0.shape[0] <= kernel.overlap_height: raise ValueError( f"proj_0 height ({proj_0.shape[0]}) is less than kernel overlap ({kernel.overlap_height})" ) if proj_1.shape[0] <= kernel.overlap_height: raise ValueError( f"proj_1 height ({proj_1.shape[0]}) is less than kernel overlap ({kernel.overlap_height})" ) # cast shift in int: for now only case handled y_shifts = [int(y_shift) for y_shift in y_shifts] # step 0: create numpy array that will contain stitching stitched_projection_shape = ( # here we only handle frames because shift are already done # + because shift are expected to be negative int( numpy.asarray([frame.shape[0] for frame in frames]).sum() + numpy.asarray(y_shifts).sum(), ), frames[0].shape[1], ) stitch_array = numpy.empty(stitched_projection_shape, dtype=output_dtype) # step 1: set kernel overlap height if undefined for y_shift, kernel in zip(y_shifts, overlap_kernels): if kernel.overlap_height in (-1, None): kernel.overlap_height = abs(y_shift) # step 2: set raw data # fill stitch array with raw data raw data if raw_frames_compositions is None: raw_frames_compositions = ZFrameComposition.compute_raw_frame_compositions( frames=frames, overlap_kernels=overlap_kernels, y_shifts=y_shifts, ) raw_frames_compositions.compose( output_frame=stitch_array, input_frames=frames, ) # step 3 set stitched data # 3.1 create stitched overlaps stitched_overlap = [] for frame_0, frame_1, kernel, y_shift in zip(frames[:-1], frames[1:], overlap_kernels, y_shifts): if y_shift >= 0: raise ValueError("No overlap found. Unagle to do stitching on it") frame_0_overlap, frame_1_overlap = ZStitcher.get_overlap_areas( frame_0, frame_1, real_overlap=abs(y_shift), stitching_height=kernel.overlap_height, ) assert ( frame_0_overlap.shape[0] == frame_1_overlap.shape[0] == kernel.overlap_height ), f"{frame_0_overlap.shape[0]} == {frame_1_overlap.shape[0]} == {kernel.overlap_height}" stitched_overlap.append( kernel.stitch( frame_0_overlap, frame_1_overlap, )[0] ) # 3.2 fill stitched overlap on output array if overlap_frames_compositions is None: overlap_frames_compositions = ZFrameComposition.compute_stitch_frame_composition( frames=frames, overlap_kernels=overlap_kernels, y_shifts=y_shifts, ) overlap_frames_compositions.compose( output_frame=stitch_array, input_frames=stitched_overlap, ) return stitch_array ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1675761825.0 nabu-2023.1.1/nabu/tests.py0000644000175000017500000000257000000000000014671 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- import sys import os import pytest from nabu.utils import get_folder_path from nabu import __nabu_modules__ as nabu_modules def get_modules_to_test(mods): sep = os.sep modules = [] extra_args = [] for mod in mods: if mod.startswith("-"): extra_args.append(mod) continue # Test a whole module if mod.lower() in nabu_modules: mod_abspath = os.path.join(get_folder_path(mod), "tests") # Test an individual file else: mod_path = mod.replace(".", sep) + ".py" mod_abspath = get_folder_path(mod_path) # test only one file mod_split = mod_abspath.split(sep) mod_split.insert(-1, "tests") mod_abspath = sep.join(mod_split) if not (os.path.exists(mod_abspath)): print("Error: no such file or directory: %s" % mod_abspath) exit(1) modules.append(mod_abspath) return modules, extra_args def nabu_test(): nabu_folder = get_folder_path() args = sys.argv[1:] modules_to_test, extra_args = get_modules_to_test(args) if len(modules_to_test) == 0: modules_to_test = [get_folder_path()] pytest_args = extra_args + modules_to_test return pytest.main(pytest_args) if __name__ == "__main__": ret = nabu_test() exit(ret) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1679642111.0 nabu-2023.1.1/nabu/testutils.py0000644000175000017500000002515100000000000015567 0ustar00pierrepierrefrom itertools import product import posixpath import tarfile import os import numpy as np from silx.resources import ExternalResources from silx.io.dictdump import dicttoh5, nxtodict, dicttonx from silx.io.url import DataUrl from tomoscan.io import HDF5File from .io.utils import get_compacted_dataslices utilstest = ExternalResources( project="nabu", url_base="http://www.silx.org/pub/nabu/data/", env_key="NABU_DATA", timeout=60 ) __big_testdata_dir__ = os.environ.get("NABU_BIGDATA_DIR") if __big_testdata_dir__ is None or not (os.path.isdir(__big_testdata_dir__)): __big_testdata_dir__ = None __do_long_tests__ = os.environ.get("NABU_LONG_TESTS", False) if __do_long_tests__: try: __do_long_tests__ = bool(int(__do_long_tests__)) except: __do_long_tests__ = False __do_large_mem_tests__ = os.environ.get("NABU_LARGE_MEM_TESTS", False) if __do_large_mem_tests__: try: __do_large_mem_tests__ = bool(int(__do_large_mem_tests__)) except: __do_large_mem_tests__ = False def generate_tests_scenarios(configurations): """ Generate "scenarios" of tests. The parameter is a dictionary where: - the key is the name of a parameter - the value is a list of possible parameters This function returns a list of dictionary where: - the key is the name of a parameter - the value is one value of this parameter """ scenarios = [{key: val for key, val in zip(configurations.keys(), p_)} for p_ in product(*configurations.values())] return scenarios def get_data(*dataset_path): """ Get a dataset file from silx.org/pub/nabu/data dataset_args is a list describing a nested folder structures, ex. ["path", "to", "my", "dataset.h5"] """ dataset_relpath = os.path.join(*dataset_path) dataset_downloaded_path = utilstest.getfile(dataset_relpath) return np.load(dataset_downloaded_path) def get_big_data(filename): if __big_testdata_dir__ is None: return None return np.load(os.path.join(__big_testdata_dir__, filename)) def uncompress_file(compressed_file_path, target_directory): with tarfile.open(compressed_file_path) as f: f.extractall(path=target_directory) def get_file(fname): downloaded_file = dataset_downloaded_path = utilstest.getfile(fname) if ".tar" in fname: uncompress_file(downloaded_file, os.path.dirname(downloaded_file)) downloaded_file = downloaded_file.split(".tar")[0] return downloaded_file def compare_arrays(arr1, arr2, tol, diff=None, absolute_value=True, percent=None, method="max", return_residual=False): """ Utility to compare two arrays. Parameters ---------- arr1: numpy.ndarray First array to compare arr2: numpy.ndarray Second array to compare tol: float Tolerance indicating whether arrays are close to eachother. diff: numpy.ndarray, optional Difference `arr1 - arr2`. If provided, this array is taken instead of `arr1` and `arr2`. absolute_value: bool, optional Whether to take absolute value of the difference. percent: float If set, a "relative" comparison is performed instead of a subtraction: `red(|arr1 - arr2|) / (red(|arr1|) * percent) < tol` where "red" is the reduction method (mean, max or median). method: Reduction method. Can be "max", "mean", or "median". Returns -------- (is_close, residual) if return_residual is set to True is_close otherwise Examples -------- When using method="mean" and absolute_value=True, this function computes the Mean Absolute Difference (MAD) metric. When also using percent=1.0, this computes the Relative Mean Absolute Difference (RMD) metric. """ reductions = { "max": np.max, "mean": np.mean, "median": np.median, } if method not in reductions: raise ValueError("reduction method should be in %s" % str(list(reductions.keys()))) if diff is None: diff = arr1 - arr2 if absolute_value is not None: diff = np.abs(diff) residual = reductions[method](diff) if percent is not None: a1 = np.abs(arr1) if absolute_value else arr1 residual /= reductions[method](a1) res = residual < tol if return_residual: res = res, residual return res class SimpleHDF5TomoScanMock: def __init__(self, image_key): self._image_key = image_key @property def image_key(self): return self._image_key @image_key.setter def image_key(self, image_key): self._image_key = image_key def save_reduced_flats(self, *args, **kwargs): pass def save_reduced_darks(self, *args, **kwargs): pass class NXDatasetMock: """ An alternative to tomoscan.esrf.mock.MockHDF5, with a different interface. Attributes are not supported ! """ def __init__(self, data_volume, image_keys, rotation_angles=None, incident_energy=19.0, other_params=None): self.data_volume = data_volume self.n_proj = data_volume.shape[0] self.image_key = image_keys if rotation_angles is None: rotation_angles = np.linspace(0, 180, self.n_proj, False) self.rotation_angle = rotation_angles self.incident_energy = incident_energy assert image_keys.size == self.n_proj self._finalize_init(other_params) self.dataset_dict = None self.fname = None # Mocks more attributes self.dataset_scanner = SimpleHDF5TomoScanMock(image_key=self.image_key) self.kind = "hdf5" def _finalize_init(self, other_params): if other_params is None: other_params = {} default_params = { "detector": { "count_time": 0.05 * np.ones(self.n_proj, dtype="f"), "distance": 0.5, "field_of_view": "Full", "image_key_control": np.copy(self.image_key), "x_pixel_size": 6.5e-6, "y_pixel_size": 6.5e-6, "x_magnified_pixel_size": 6.5e-5, "y_magnified_pixel_size": 6.5e-5, }, "sample": { "name": "dummy sample", "x_translation": 5e-4 * np.ones(self.n_proj, dtype="f"), "y_translation": 5e-4 * np.ones(self.n_proj, dtype="f"), "z_translation": 5e-4 * np.ones(self.n_proj, dtype="f"), }, } default_params.update(other_params) self.other_params = default_params def generate_dict(self): beam_group = { "incident_energy": self.incident_energy, } detector_other_params = self.other_params["detector"] detector_group = { "count_time": detector_other_params["count_time"], "data": self.data_volume, "distance": detector_other_params["distance"], "field_of_view": detector_other_params["field_of_view"], "image_key": self.image_key, "image_key_control": detector_other_params["image_key_control"], "x_pixel_size": detector_other_params["x_pixel_size"], "y_pixel_size": detector_other_params["y_pixel_size"], "x_magnified_pixel_size": detector_other_params["x_magnified_pixel_size"], "y_magnified_pixel_size": detector_other_params["y_magnified_pixel_size"], } sample_other_params = self.other_params["sample"] sample_group = { "name": sample_other_params["name"], "rotation_angle": self.rotation_angle, "x_translation": sample_other_params["x_translation"], "y_translation": sample_other_params["y_translation"], "z_translation": sample_other_params["z_translation"], } self.dataset_dict = { "beam": beam_group, "instrument": { "detector": detector_group, }, "sample": sample_group, } def generate_hdf5_file(self, fname, h5path=None): self.fname = fname h5path = h5path or "/entry" if self.dataset_dict is None: self.generate_dict() dicttoh5(self.dataset_dict, fname, h5path=h5path, mode="a") # Patch the "data" field which is exported as string by dicttoh5 (?!) self.dataset_path = os.path.join(h5path, "instrument/detector/data") with HDF5File(fname, "a") as fid: del fid[self.dataset_path] fid[self.dataset_path] = self.dataset_dict["instrument"]["detector"]["data"] # Mock some of the HDF5DatasetAnalyzer attributes @property def dataset_hdf5_url(self): if self.fname is None: raise ValueError("generate_hdf5_file() was not called") return DataUrl(file_path=self.fname, data_path=self.dataset_path, scheme="silx") def _get_images_with_key(self, key): indices = np.arange(self.image_key.size)[self.image_key == key] urls = [ DataUrl( file_path=self.fname, data_path=self.dataset_path, data_slice=slice(img_idx, img_idx + 1), scheme="silx", ) for img_idx in indices ] return dict(zip(indices, urls)) @property def flats(self): return self._get_images_with_key(1) @property def darks(self): return self._get_images_with_key(2) def get_data_slices(self, what): images = getattr(self, what) # we can't directly use set() on slice() object (unhashable). Use tuples tuples_list = list( set((du.data_slice().start, du.data_slice().stop) for du in get_compacted_dataslices(images).values()) ) slices_list = [slice(item[0], item[1]) for item in tuples_list] return slices_list # To be improved def generate_nx_dataset(out_fname, image_key, data_volume=None, rotation_angle=None): nx_template_file = get_file("dummy.nx.tar.gz") nx_dict = nxtodict(nx_template_file) nx_entry = nx_dict["entry"] def _get_field(dict_, path): if path.startswith("/"): path = path[1:] if path.endswith("/"): path = path[:-1] split_path = path.split("/") if len(split_path) == 1: return dict_[split_path[0]] return _get_field(dict_[split_path[0]], "/".join(split_path[1:])) for name in ["image_key", "image_key_control"]: nx_entry["data"][name] = image_key nx_entry["instrument"]["detector"][name] = image_key if rotation_angle is not None: nx_entry["data"]["rotation_angle"] = rotation_angle nx_entry["sample"]["rotation_angle"] = rotation_angle dicttonx(nx_dict, out_fname) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4687333 nabu-2023.1.1/nabu/thirdparty/0000755000175000017500000000000000000000000015343 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/thirdparty/__init__.py0000644000175000017500000000000000000000000017442 0ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/thirdparty/pore3d_deringer_munch.py0000644000175000017500000001062400000000000022165 0ustar00pierrepierre""" The following de-striping method is adapted from the pore3d software. :Organization: Elettra - Sincrotrone Trieste S.C.p.A. :Version: 2013.05.01 References ---------- [1] F. Brun, A. Accardo, G. Kourousias, D. Dreossi, R. Pugliese. Effective implementation of ring artifacts removal filters for synchrotron radiation microtomographic images. Proc. of the 8th International Symposium on Image and Signal Processing (ISPA), pp. 672-676, Sept. 4-6, Trieste (Italy), 2013. The license follows. """ # Copyright (c) 2013, Elettra - Sincrotrone Trieste S.C.p.A. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of the copyright holders nor the names of any # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. import numpy as np try: import pywt __has_pywt__ = True except ImportError: __has_pywt__ = False def munchetal_filter(im, wlevel, sigma, wname='db15'): """Process a sinogram image with the Munch et al. de-striping algorithm. Parameters ---------- im : array_like Image data as numpy array. wname : {'haar', 'db1'-'db20', 'sym2'-'sym20', 'coif1'-'coif5', 'dmey'} The wavelet transform to use. wlevel : int Levels of the wavelet decomposition. sigma : float Cutoff frequency of the Butterworth low-pass filtering. Example (using tiffile.py) -------------------------- >>> im = imread('original.tif') >>> im = munchetal_filter(im, 'db15', 4, 1.0) >>> imsave('filtered.tif', im) References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ # Wavelet decomposition: coeffs = pywt.wavedec2(im.astype(np.float32), wname, level=wlevel, mode="periodization") coeffsFlt = [coeffs[0]] # FFT transform of horizontal frequency bands: for i in range(1, wlevel + 1): # FFT: fcV = np.fft.fftshift(np.fft.fft(coeffs[i][1], axis=0)) my, mx = fcV.shape # Damping of vertical stripes: damp = 1 - np.exp(-(np.arange(-np.floor(my / 2.), -np.floor(my / 2.) + my) ** 2) / (2 * (sigma ** 2))) dampprime = np.kron(np.ones((1, mx)), damp.reshape((damp.shape[0], 1))) # np.tile(damp[:, np.newaxis], (1, mx)) fcV = fcV * dampprime # Inverse FFT: fcVflt = np.real(np.fft.ifft(np.fft.ifftshift(fcV), axis=0)) cVHDtup = (coeffs[i][0], fcVflt, coeffs[i][2]) coeffsFlt.append(cVHDtup) # Get wavelet reconstruction: im_f = np.real(pywt.waverec2(coeffsFlt, wname, mode="periodization")) # Return image according to input type: if (im.dtype == 'uint16'): # Check extrema for uint16 images: im_f[im_f < np.iinfo(np.uint16).min] = np.iinfo(np.uint16).min im_f[im_f > np.iinfo(np.uint16).max] = np.iinfo(np.uint16).max # Return filtered image (an additional row and/or column might be present): return im_f[0:im.shape[0], 0:im.shape[1]].astype(np.uint16) else: return im_f[0:im.shape[0], 0:im.shape[1]] if not(__has_pywt__): munchetal_filter = None ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1581878491.0 nabu-2023.1.1/nabu/thirdparty/tomopy_phase.py0000644000175000017500000002067200000000000020433 0ustar00pierrepierre#!/usr/bin/env python # -*- coding: utf-8 -*- # ######################################################################### # Copyright (c) 2015-2019, UChicago Argonne, LLC. All rights reserved. # # # # Copyright 2015-2019. UChicago Argonne, LLC. This software was produced # # under U.S. Government contract DE-AC02-06CH11357 for Argonne National # # Laboratory (ANL), which is operated by UChicago Argonne, LLC for the # # U.S. Department of Energy. The U.S. Government has rights to use, # # reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR # # UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR # # ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is # # modified to produce derivative works, such modified software should # # be clearly marked, so as not to confuse it with the version available # # from ANL. # # # # Additionally, redistribution and use in source and binary forms, with # # or without modification, are permitted provided that the following # # conditions are met: # # # # * Redistributions of source code must retain the above copyright # # notice, this list of conditions and the following disclaimer. # # # # * Redistributions in binary form must reproduce the above copyright # # notice, this list of conditions and the following disclaimer in # # the documentation and/or other materials provided with the # # distribution. # # # # * Neither the name of UChicago Argonne, LLC, Argonne National # # Laboratory, ANL, the U.S. Government, nor the names of its # # contributors may be used to endorse or promote products derived # # from this software without specific prior written permission. # # # # THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS # # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS # # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago # # Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, # # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN # # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # # POSSIBILITY OF SUCH DAMAGE. # # ######################################################################### """ Module for phase retrieval. This code is part of tomopy: https://github.com/tomopy/tomopy It was adapted for being stand-alone, without job distribution. See the license above for more information. """ import numpy as np fft2 = np.fft.fft2 ifft2 = np.fft.ifft2 __author__ = "Doga Gursoy" __credits__ = "Mark Rivers, Xianghui Xiao" __copyright__ = "Copyright (c) 2015, UChicago Argonne, LLC." __docformat__ = 'restructuredtext en' __all__ = ['retrieve_phase'] BOLTZMANN_CONSTANT = 1.3806488e-16 # [erg/k] SPEED_OF_LIGHT = 299792458e+2 # [cm/s] PI = 3.14159265359 PLANCK_CONSTANT = 6.58211928e-19 # [keV*s] def _wavelength(energy): return 2 * PI * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy def retrieve_phase( tomo, pixel_size=1e-4, dist=50, energy=20, alpha=1e-3, pad=True, ncore=None, nchunk=None): """ Perform single-step phase retrieval from phase-contrast measurements. Parameters ---------- tomo : ndarray 3D tomographic data. pixel_size : float, optional Detector pixel size in cm. dist : float, optional Propagation distance of the wavefront in cm. energy : float, optional Energy of incident wave in keV. alpha : float, optional Regularization parameter. pad : bool, optional If True, extend the size of the projections by padding with zeros. ncore : int, optional Number of cores that will be assigned to jobs. nchunk : int, optional Chunk size for each core. Returns ------- ndarray Approximated 3D tomographic phase data. """ # New dimensions and pad value after padding. py, pz, val = _calc_pad(tomo, pixel_size, dist, energy, pad) # Compute the reciprocal grid. dx, dy, dz = tomo.shape w2 = _reciprocal_grid(pixel_size, dy + 2 * py, dz + 2 * pz) # Filter in Fourier space. phase_filter = np.fft.fftshift( _paganin_filter_factor(energy, dist, alpha, w2)) prj = np.full((dy + 2 * py, dz + 2 * pz), val, dtype='float32') arr = _retrieve_phase(tomo, phase_filter, py, pz, prj, pad) return arr def _retrieve_phase(tomo, phase_filter, px, py, prj, pad): dx, dy, dz = tomo.shape num_jobs = tomo.shape[0] normalized_phase_filter = phase_filter / phase_filter.max() for m in range(num_jobs): # Padding "constant" with border value # prj is initially filled with "val" prj[px:dy + px, py:dz + py] = tomo[m] prj[:px] = prj[px] prj[-px:] = prj[-px-1] prj[:, :py] = prj[:, py][:, np.newaxis] prj[:, -py:] = prj[:, -py-1][:, np.newaxis] fproj = fft2(prj) fproj *= normalized_phase_filter proj = np.real(ifft2(fproj)) if pad: proj = proj[px:dy + px, py:dz + py] tomo[m] = proj return tomo def _calc_pad(tomo, pixel_size, dist, energy, pad): """ Calculate new dimensions and pad value after padding. Parameters ---------- tomo : ndarray 3D tomographic data. pixel_size : float Detector pixel size in cm. dist : float Propagation distance of the wavefront in cm. energy : float Energy of incident wave in keV. pad : bool If True, extend the size of the projections by padding with zeros. Returns ------- int Pad amount in projection axis. int Pad amount in sinogram axis. float Pad value. """ dx, dy, dz = tomo.shape wavelength = _wavelength(energy) py, pz, val = 0, 0, 0 if pad: val = _calc_pad_val(tomo) py = _calc_pad_width(dy, pixel_size, wavelength, dist) pz = _calc_pad_width(dz, pixel_size, wavelength, dist) return py, pz, val def _paganin_filter_factor(energy, dist, alpha, w2): return 1 / (_wavelength(energy) * dist * w2 / (4 * PI) + alpha) def _calc_pad_width(dim, pixel_size, wavelength, dist): pad_pix = np.ceil(PI * wavelength * dist / pixel_size ** 2) return int((pow(2, np.ceil(np.log2(dim + pad_pix))) - dim) * 0.5) def _calc_pad_val(tomo): # mean of [(column 0 of radio) + (column -1 of radio)]/2. return np.mean((tomo[..., 0] + tomo[..., -1]) * 0.5) def _reciprocal_grid(pixel_size, nx, ny): """ Calculate reciprocal grid. Parameters ---------- pixel_size : float Detector pixel size in cm. nx, ny : int Size of the reciprocal grid along x and y axes. Returns ------- ndarray Grid coordinates. """ # Sampling in reciprocal space. indx = _reciprocal_coord(pixel_size, nx) indy = _reciprocal_coord(pixel_size, ny) np.square(indx, out=indx) np.square(indy, out=indy) return np.add.outer(indx, indy) def _reciprocal_coord(pixel_size, num_grid): """ Calculate reciprocal grid coordinates for a given pixel size and discretization. Parameters ---------- pixel_size : float Detector pixel size in cm. num_grid : int Size of the reciprocal grid. Returns ------- ndarray Grid coordinates. """ n = num_grid - 1 rc = np.arange(-n, num_grid, 2, dtype = np.float32) rc *= 0.5 / (n * pixel_size) return rc ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1621525307.0 nabu-2023.1.1/nabu/thirdparty/tomwer_load_flats_darks.py0000644000175000017500000001275300000000000022616 0ustar00pierrepierre""" script embedding all function needed to read flats and darks. For information calculation method is stored in the results/configuration dictionary """ """ IMPORTANT: this script is used as long as flat-fielding with "raw" flats/daks is not implemented in nabu. For now we load results from tomwer. """ import typing import h5py from silx.io.url import DataUrl from tomoscan.io import HDF5File from silx.io.utils import h5py_read_dataset import logging MAX_DEPTH = 2 logger = logging.getLogger(__name__) def get_process_entries(root_node: h5py.Group, depth: int) -> tuple: """ return the list of 'Nxtomo' entries at the root level :param str file_path: :return: list of valid Nxtomo node (ordered alphabetically) :rtype: tuple ..note: entries are sorted to insure consistency """ def _get_entries(node, depth_): if isinstance(node, h5py.Dataset): return {} res_buf = {} if is_process_node(node) is True: res_buf[node.name] = int(node['sequence_index'][()]) assert isinstance(node, h5py.Group) if depth_ >= 1: for sub_node in node.values(): res_buf[node.name] = _get_entries(node=sub_node, depth_=depth_-1) return res_buf res = {} for node in root_node.values(): res.update(_get_entries(node=node, depth_=depth-1)) return res def is_process_node(node): return (node.name.split('/')[-1].startswith('tomwer_process_') and 'NX_class' in node.attrs and node.attrs['NX_class'] == "NXprocess" and 'program' in node and h5py_read_dataset(node['program']) == 'tomwer_dark_refs' and 'version' in node and 'sequence_index' in node) def get_darks_frm_process_file(process_file, entry) -> typing.Union[None, dict]: """ :param process_file: :return: """ if entry is None: with HDF5File(process_file, 'r', swmr=True) as h5f: entries = get_process_entries(root_node=h5f, depth=MAX_DEPTH) if len(entries) == 0: logger.info( 'unable to find a DarkRef process in %s' % process_file) return None elif len(entries) > 0: raise ValueError('several entry found, entry should be ' 'specify') else: entry = list(entries.keys())[0] logger.info('take %s as default entry' % entry) with HDF5File(process_file, 'r', swmr=True) as h5f: dark_nodes = get_process_entries(root_node=h5f[entry], depth=MAX_DEPTH-1) index_to_path = {} for key, value in dark_nodes.items(): index_to_path[key] = key if len(dark_nodes) == 0: return {} # take the last processed dark ref last_process_index = sorted(dark_nodes.keys())[-1] last_process_dark = index_to_path[last_process_index] if(len(index_to_path)) > 1: logger.warning('several processing found for dark-ref,' 'take the last one: %s' % last_process_dark) res = {} if 'results' in h5f[last_process_dark].keys(): results_node = h5f[last_process_dark]['results'] if 'darks' in results_node.keys(): darks = results_node['darks'] for index in darks: res[int(index)] = DataUrl( file_path=process_file, data_path=darks[index].name, scheme="silx" ) return res def get_flats_frm_process_file(process_file, entry) -> typing.Union[None, dict]: """ :param process_file: :return: """ if entry is None: with HDF5File(process_file, 'r', swmr=True) as h5f: entries = get_process_entries(root_node=h5f, depth=MAX_DEPTH) if len(entries) == 0: logger.info( 'unable to find a DarkRef process in %s' % process_file) return None elif len(entries) > 0: raise ValueError('several entry found, entry should be ' 'specify') else: entry = list(entries.keys())[0] logger.info('take %s as default entry' % entry) with HDF5File(process_file, 'r', swmr=True) as h5f: dkref_nodes = get_process_entries(root_node=h5f[entry], depth=MAX_DEPTH-1) if len(dkref_nodes) == 0: return {} index_to_path = {} for key, value in dkref_nodes.items(): index_to_path[key] = key # take the last processed dark ref last_process_index = sorted(dkref_nodes.keys())[-1] last_process_dkrf = index_to_path[last_process_index] if(len(index_to_path)) > 1: logger.warning('several processing found for dark-ref,' 'take the last one: %s' % last_process_dkrf) res = {} if 'results' in h5f[last_process_dkrf].keys(): results_node = h5f[last_process_dkrf]['results'] if 'flats' in results_node.keys(): flats = results_node['flats'] for index in flats: res[int(index)] = DataUrl( file_path=process_file, data_path=flats[index].name, scheme="silx" ) return res ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/nabu/utils.py0000644000175000017500000005270400000000000014673 0ustar00pierrepierreimport os import sys from typing import Union import typing import warnings from time import time import posixpath from itertools import product from functools import lru_cache import numpy as np def nextpow2(N, dtype=np.int32): return 2 ** np.ceil(np.log2(N)).astype(dtype) def previouspow2(N, dtype=np.int32): return 2 ** np.floor(np.log2(N)).astype(dtype) def updiv(a, b): return (a + (b - 1)) // b def convert_index(idx, idx_max, default_val): """ Convert an index (possibly negative or None) to a non-negative integer. Parameters ---------- idx: int or None Index idx_max: int Maximum value (upper bound) for the index. default_val: int Default value if idx is None Examples --------- Given an integer `M`, `J = convert_index(i, M, XX)` returns an integer in the mathematical range [0, M] (or Python `range(0, M)`). `J` can then be used to define an upper bound of a range. """ if idx is None: return default_val if idx > idx_max: return idx_max if idx < 0: return idx % idx_max return idx def get_folder_path(foldername=""): _file_dir = os.path.dirname(os.path.realpath(__file__)) package_dir = _file_dir return os.path.join(package_dir, foldername) def get_cuda_srcfile(filename): src_relpath = os.path.join("cuda", "src") cuda_src_folder = get_folder_path(foldername=src_relpath) return os.path.join(cuda_src_folder, filename) def get_resource_file(filename, subfolder=None): subfolder = subfolder or [] relpath = os.path.join("resources", *subfolder) abspath = get_folder_path(foldername=relpath) return os.path.join(abspath, filename) def get_available_threads(): return len(os.sched_getaffinity(0)) def is_writeable(location): """ Return True if a file/location is writeable. """ return os.access(location, os.W_OK) def is_int(num, eps=1e-7): return abs(num - int(num)) < eps def _sizeof(Type): """ return the size (in bytes) of a scalar type, like the C behavior """ return np.dtype(Type).itemsize def get_ftype(url): """ return supposed filetype of an url """ if hasattr(url, "file_path"): return os.path.splitext(url.file_path())[-1].replace(".", "") else: return os.path.splitext(url)[-1].replace(".", "") def get_2D_3D_shape(shape): if len(shape) == 2: return (1,) + shape return shape def get_subregion(sub_region): ymin, ymax = None, None if sub_region is None: xmin, xmax, ymin, ymax = None, None, None, None elif len(sub_region) == 2: first_part, second_part = sub_region if np.iterable(first_part) and np.iterable(second_part): xmin, xmax = first_part ymin, ymax = second_part else: xmin, xmax = sub_region elif len(sub_region) == 4: xmin, xmax, ymin, ymax = sub_region else: raise ValueError("Expected parameter in the form (a, b, c, d) or ((a, b), (c, d))") return xmin, xmax, ymin, ymax def get_3D_subregion(sub_region): if sub_region is None: xmin, xmax, ymin, ymax, zmin, zmax = None, None, None, None, None, None elif len(sub_region) == 3: first_part, second_part, third_part = sub_region xmin, xmax = first_part ymin, ymax = second_part zmin, zmax = third_part elif len(sub_region) == 6: xmin, xmax, ymin, ymax, zmin, zmax = sub_region else: raise ValueError( "Expected parameter in the form (xmin, xmax, ymin, ymax, zmin, zmax) or ((xmin, xmax), (ymin, ymax), (zmin, zmax))" ) return xmin, xmax, ymin, ymax, zmin, zmax def to_3D_array(arr): """ Turn an array to a (C-Contiguous) 3D array with the layout (n_arrays, n_y, n_x). """ if arr.ndim == 3: return arr return np.tile(arr, (1, 1, 1)) def view_as_images_stack(img): """ View an image (2D array) as a stack of one image (3D array). No data is duplicated. """ return img.reshape((1,) + img.shape) def rescale_integers(items, new_tot): """ " From a given sequence of integers, create a new sequence where the sum of all items must be equal to "new_tot". The relative contribution of each item to the new total is approximately kept. Parameters ---------- items: array-like Sequence of integers new_tot: int Integer indicating that the sum of the new array must be equal to this value """ cur_items = np.array(items) new_items = np.ceil(cur_items / cur_items.sum() * new_tot).astype(np.int64) excess = new_items.sum() - new_tot i = 0 while excess > 0: ind = i % new_items.size if cur_items[ind] > 0: new_items[ind] -= 1 excess -= 1 i += 1 return new_items.tolist() def merged_shape(shapes, axis=0): n_img = sum(shape[axis] for shape in shapes) if axis == 0: return (n_img,) + shapes[0][1:] elif axis == 1: return (shapes[0][0], n_img, shapes[0][2]) elif axis == 2: return shapes[0][:2] + (n_img,) def is_device_backend(backend): return backend.lower() in ["cuda", "opencl"] def get_decay(curve, cutoff=1e3, vmax=None): """ Assuming a decreasing curve, get the first point below a certain threshold. Parameters ---------- curve: numpy.ndarray A 1D array cutoff: float, optional Threshold. Default is 1000. vmax: float, optional Curve maximum value """ if vmax is None: vmax = curve.max() return np.argmax(np.abs(curve) < vmax / cutoff) @lru_cache(maxsize=1) def generate_powers(): """ Generate a list of powers of [2, 3, 5, 7], up to (2**15)*(3**9)*(5**6)*(7**5). """ primes = [2, 3, 5, 7] maxpow = {2: 15, 3: 9, 5: 6, 7: 5} valuations = [] for prime in primes: # disallow any odd number (for R2C transform), and any number # not multiple of 4 (Ram-Lak filter behaves strangely when # dwidth_padded/2 is not even) minval = 2 if prime == 2 else 0 valuations.append(range(minval, maxpow[prime] + 1)) powers = product(*valuations) res = [] for pw in powers: res.append(np.prod(list(map(lambda x: x[0] ** x[1], zip(primes, pw))))) return np.unique(res) def calc_padding_lengths1D(length, length_padded): """ Compute the padding lengths at both side along one dimension. Parameters ---------- length: int Number of elements along one dimension of the original array length_padded: tuple Number of elements along one dimension of the padded array Returns ------- pad_lengths: tuple A tuple under the form (padding_left, padding_right). These are the lengths needed to pad the original array. """ pad_left = (length_padded - length) // 2 pad_right = length_padded - length - pad_left return (pad_left, pad_right) def calc_padding_lengths(shape, shape_padded): """ Multi-dimensional version of calc_padding_lengths1D. Please refer to the documentation of calc_padding_lengths1D. """ assert len(shape) == len(shape_padded) padding_lengths = [] for dim_len, dim_len_padded in zip(shape, shape_padded): pad0, pad1 = calc_padding_lengths1D(dim_len, dim_len_padded) padding_lengths.append((pad0, pad1)) return tuple(padding_lengths) def partition_dict(dict_, n_partitions): keys = np.sort(list(dict_.keys())) res = [] for keys_arr in np.array_split(keys, n_partitions): d = {} for key in keys_arr: d[key] = dict_[key] res.append(d) return res def subsample_dict(dic, subsampling_factor): """ Subsample a dict where keys are integers. """ res = {} indices = sorted(dic.keys()) for i in indices[::subsampling_factor]: res[i] = dic[i] return res def compare_dicts(dic1, dic2): """ Compare two dictionaries. Return None if and only iff the dictionaries are the same. Parameters ---------- dic1: dict First dictionary dic2: dict Second dictionary Returns ------- res: result which can be the following: - None: it means that dictionaries are the same - empty string (""): the dictionaries do not have the same keys - nonempty string: path to the first differing items """ if set(dic1.keys()) != set(dic2.keys()): return "" for key, val1 in dic1.items(): val2 = dic2[key] if isinstance(val1, dict): res = compare_dicts(val1, val2) if res is not None: return posixpath.join(key, res) # str elif isinstance(val1, str): if val1 != val2: return key # Scalars elif np.isscalar(val1): if not np.isclose(val1, val2): return key # NoneType elif val1 is None: if val2 is not None: return key # Array-like elif np.iterable(val1): arr1 = np.array(val1) arr2 = np.array(val2) if arr1.ndim != arr2.ndim or arr1.dtype != arr2.dtype or not np.allclose(arr1, arr2): return key else: raise ValueError("Don't know what to do with type %s" % str(type(val1))) return None def remove_items_from_list(list_, items_to_remove): """ Remove items from a list and return the removed elements. Parameters ---------- list_: list List containing items to remove items_to_remove: list List of items to remove Returns -------- reduced_list: list List with removed items removed_items: dict Dictionary where the keys are the indices of removed items, and values are the items """ removed_items = {} res = [] for i, val in enumerate(list_): if val in items_to_remove: removed_items[i] = val else: res.append(val) return res, removed_items def restore_items_in_list(list_, removed_items): """ Undo the effect of the function `remove_items_from_list` Parameters ---------- list_: list List where items were removed removed_items: dict Dictionary where the keys are the indices of removed items, and values are the items """ for idx, val in removed_items.items(): list_.insert(idx, val) def check_supported(param_value, available, param_desc): if param_value not in available: raise ValueError("Unsupported %s '%s'. Available are: %s" % (param_desc, param_value, str(available))) def check_supported_enum(param_value, enum_cls, param_desc): available = enum_cls.values() return check_supported(param_value, available, param_desc) def check_shape(shape, expected_shape, name): if shape != expected_shape: raise ValueError("Expected %s shape %s but got %s" % (name, str(expected_shape), str(shape))) def copy_dict_items(dict_, keys): """ Perform a shallow copy of a subset of a dictionary. The subset if done by a list of keys. """ res = {key: dict_[key] for key in keys} return res def recursive_copy_dict(dict_): """ Perform a shallow copy of a dictionary of dictionaries. This is NOT a deep copy ! Only reference to objects are kept. """ if not (isinstance(dict_, dict)): res = dict_ else: res = {k: recursive_copy_dict(v) for k, v in dict_.items()} return res def subdivide_into_overlapping_segment(N, window_width, half_overlap): """ Divide a segment into a number of possibly-overlapping sub-segments. Parameters ---------- N: int Total segment length window_width: int Length of each segment half_overlap: int Half-length of the overlap between each sub-segment. Returns ------- segments: list A list where each item is in the form (left_margin_start, inner_segment_start, inner_segment_end, right_margin_end) """ if half_overlap > 0 and half_overlap >= window_width // 2: raise ValueError("overlap must be smaller than window_width") w_in = window_width - 2 * half_overlap n_segments = N // w_in inner_start = w_in * np.arange(n_segments) inner_end = w_in * (np.arange(n_segments) + 1) margin_left_start = np.maximum(inner_start - half_overlap, 0) margin_right_end = np.minimum(inner_end + half_overlap, N) segments = [ (left_start, i_start, i_end, right_end) for left_start, i_start, i_end, right_end in zip( margin_left_start.tolist(), inner_start.tolist(), inner_end.tolist(), margin_right_end.tolist() ) ] if N % w_in: # additional sub-segment new_margin_left_start = inner_end[-1] - half_overlap new_inner_start = inner_end[-1] new_inner_end = N segments.append((new_margin_left_start, new_inner_start, new_inner_end, new_inner_end)) return segments def get_num_threads(n=None): """ Get a number of threads (ex. to be used by fftw). If the argument is None, returns the total number of CPU threads. If the argument is negative, the total number of available threads plus this number is returned. Parameters ----------- n: int, optional - If an positive integer `n` is provided, then `n` threads are used - If a negative integer `n` is provided, then `n_avail + n` threads are used (so -1 to use all available threads minus one) """ n_avail = get_available_threads() if n is None or n == 0: return n_avail if n < 0: return max(1, n_avail + n) else: return min(n_avail, n) class DictToObj(object): """utility class to transform a dictionary into an object with dictionary items as members. Example: >>> a=DictToObj( dict(i=1,j=1)) ... a.j+a.i """ def __init__(self, dictio): self.__dict__ = dictio def remove_parenthesis_or_brackets(input_str): """ clear string from left and or roght parenthesis / braquets """ if input_str.startswith("(") and input_str.endswith(")") or input_str.startswith("[") and input_str.endswith("]"): input_str = input_str[1:-1] return input_str def filter_str_def(elmt): """clean elemt if is a string defined from a text file. Remove some character that could have be put on left or right and some empty spaces""" if elmt is None: return None assert isinstance(elmt, str) elmt = elmt.lstrip(" ").rstrip(" ") for character in ("'", '"'): if elmt.startswith(character) and elmt.endswith(character): elmt = elmt[1:-1] return elmt def convert_str_to_tuple(input_str: str, none_if_empty: bool = False) -> Union[None, tuple]: """ :param str input_str: string to convert :param bool none_if_empty: if true and the conversion is an empty tuple return None instead of an empty tuple """ if isinstance(input_str, tuple): return input_str if not isinstance(input_str, str): raise TypeError("input_str should be a string not {}, {}".format(type(input_str), input_str)) input_str = input_str.lstrip(" ").lstrip("(").lstrip("[").lstrip(" ").rstrip(" ") input_str = remove_parenthesis_or_brackets(input_str) input_str = input_str.replace("\n", ",") elmts = input_str.split(",") elmts = [filter_str_def(elmt) for elmt in elmts] rm_empty_str = lambda a: a != "" elmts = list(filter(rm_empty_str, elmts)) if none_if_empty and len(elmts) == 0: return None else: return tuple(elmts) class Progress: """Simple interface for defining advancement on a 100 percentage base""" def __init__(self, name: str): self._name = name self.set_name(name) def set_name(self, name): self._name = name self.reset() def reset(self, max_: typing.Union[None, int] = None) -> None: """ reset the advancement to n and max advancement to max_ :param int max_: """ self._n_processed = 0 self._max_processed = max_ def start_process(self) -> None: self.set_advancement(0) def set_advancement(self, value: int) -> None: """ :param int value: set advancement to value """ length = 20 # modify this to change the length block = int(round(length * value / 100)) blocks_str = "#" * block + "-" * (length - block) msg = "\r{0}: [{1}] {2}%".format(self._name, blocks_str, round(value, 2)) if value >= 100: msg += " DONE\r\n" sys.stdout.write(msg) sys.stdout.flush() def end_process(self) -> None: """Set advancement to 100 %""" self.set_advancement(100) def set_max_advancement(self, n: int) -> None: """ :param int n: number of steps contained by the advancement. When advancement reach this value, advancement will be 100 % """ self._max_processed = n def increase_advancement(self, i: int = 1) -> None: """ :param int i: increase the advancement of n step """ self._n_processed += i advancement = int(float(self._n_processed / self._max_processed) * 100) self.set_advancement(advancement) def concatenate_dict(dict_1, dict_2) -> dict: """update dict which has dict as values. And we want concatenate those values to""" res = dict_1.copy() for key in dict_2: if key in dict_1 and key in dict_2: res[key].update(dict_2[key]) else: res[key] = dict_2[key] return res # ------------------------------------------------------------------------------ # ------------------------ Image (move elsewhere ?) ---------------------------- # ------------------------------------------------------------------------------ def generate_coords(img_shp, center=None): l_r, l_c = float(img_shp[0]), float(img_shp[1]) R, C = np.mgrid[:l_r, :l_c] # np.indices is faster if center is None: center0, center1 = l_r / 2.0, l_c / 2.0 else: center0, center1 = center R += 0.5 - center0 C += 0.5 - center1 return R, C def clip_circle(img, center=None, radius=None): R, C = generate_coords(img.shape, center) M = R**2 + C**2 res = np.zeros_like(img) res[M < radius**2] = img[M < radius**2] return res def extend_image_onepixel(img): # extend of one pixel img2 = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=img.dtype) img2[0, 1:-1] = img[0] img2[-1, 1:-1] = img[-1] img2[1:-1, 0] = img[:, 0] img2[1:-1, -1] = img[:, -1] # middle img2[1:-1, 1:-1] = img # corners img2[0, 0] = img[0, 0] img2[-1, 0] = img[-1, 0] img2[0, -1] = img[0, -1] img2[-1, -1] = img[-1, -1] return img2 def median2(img): """ 3x3 median filter for 2D arrays, with "reflect" boundary mode. Roughly same speed as scipy median filter, but more memory demanding. """ img2 = extend_image_onepixel(img) I = np.array( [ img2[0:-2, 0:-2], img2[0:-2, 1:-1], img2[0:-2, 2:], img2[1:-1, 0:-2], img2[1:-1, 1:-1], img2[1:-1, 2:], img2[2:, 0:-2], img2[2:, 1:-1], img2[2:, 2:], ] ) return np.median(I, axis=0) # ------------------------------------------------------------------------------ # ---------------------------- Decorators -------------------------------------- # ------------------------------------------------------------------------------ _warnings = {} def measure_time(func): def wrapper(*args, **kwargs): t0 = time() res = func(*args, **kwargs) el = time() - t0 return el, res return wrapper def wip(func): def wrapper(*args, **kwargs): func_name = func.__name__ if func_name not in _warnings: _warnings[func_name] = 1 print("Warning: function %s is a work in progress, it is likely to change in the future") return func(*args, **kwargs) return wrapper def warning(msg): def decorator(func): def wrapper(*args, **kwargs): func_name = func.__name__ if func_name not in _warnings: _warnings[func_name] = 1 print(msg) res = func(*args, **kwargs) return res return wrapper return decorator def deprecated(msg, do_print=False): def decorator(func): def wrapper(*args, **kwargs): deprecation_warning(msg, do_print=do_print, func_name=func.__name__) res = func(*args, **kwargs) return res return wrapper return decorator def deprecated_class(msg, do_print=False): def decorator(cls): class wrapper: def __init__(self, *args, **kwargs): deprecation_warning(msg, do_print=do_print, func_name=cls.__name__) self.wrapped = cls(*args, **kwargs) # This is so ugly :-( def __getattr__(self, name): return getattr(self.wrapped, name) return wrapper return decorator def deprecation_warning(message, do_print=True, func_name=None): func_name_msg = str("%s: " % func_name) if func_name is not None else "" func_name = func_name or "None" if _warnings.get(func_name, False): return warnings.warn(message, DeprecationWarning) if do_print: print("Deprecation warning: %s%s" % (func_name_msg, message)) _warnings[func_name] = 1 ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4527333 nabu-2023.1.1/nabu.egg-info/0000755000175000017500000000000000000000000014643 5ustar00pierrepierre././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/PKG-INFO0000644000175000017500000000041600000000000015741 0ustar00pierrepierreMetadata-Version: 2.1 Name: nabu Version: 2023.1.1 Summary: Nabu - Tomography software Author: Pierre Paleo Author-email: pierre.paleo@esrf.fr Maintainer: Pierre Paleo Maintainer-email: pierre.paleo@esrf.fr Provides-Extra: full Provides-Extra: doc License-File: LICENSE ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/SOURCES.txt0000644000175000017500000001462600000000000016540 0ustar00pierrepierreLICENSE README.md setup.py nabu/__init__.py nabu/tests.py nabu/testutils.py nabu/utils.py nabu.egg-info/PKG-INFO nabu.egg-info/SOURCES.txt nabu.egg-info/dependency_links.txt nabu.egg-info/entry_points.txt nabu.egg-info/requires.txt nabu.egg-info/top_level.txt nabu.egg-info/zip-safe nabu/app/__init__.py nabu/app/bootstrap.py nabu/app/bootstrap_stitching.py nabu/app/cast_volume.py nabu/app/cli_configs.py nabu/app/compare_volumes.py nabu/app/composite_cor.py nabu/app/create_distortion_map_from_poly.py nabu/app/double_flatfield.py nabu/app/generate_header.py nabu/app/histogram.py nabu/app/nx_z_splitter.py nabu/app/prepare_weights_double.py nabu/app/reconstruct.py nabu/app/reconstruct_helical.py nabu/app/rotate.py nabu/app/shrink_dataset.py nabu/app/stitching.py nabu/app/utils.py nabu/app/validator.py nabu/cuda/__init__.py nabu/cuda/convolution.py nabu/cuda/kernel.py nabu/cuda/medfilt.py nabu/cuda/padding.py nabu/cuda/processing.py nabu/cuda/utils.py nabu/cuda/src/ElementOp.cu nabu/cuda/src/backproj.cu nabu/cuda/src/backproj_polar.cu nabu/cuda/src/boundary.h nabu/cuda/src/convolution.cu nabu/cuda/src/fftshift.cu nabu/cuda/src/flatfield.cu nabu/cuda/src/fourier_wavelets.cu nabu/cuda/src/halftomo.cu nabu/cuda/src/helical_padding.cu nabu/cuda/src/histogram.cu nabu/cuda/src/hst_backproj.cu nabu/cuda/src/interpolation.cu nabu/cuda/src/medfilt.cu nabu/cuda/src/normalization.cu nabu/cuda/src/padding.cu nabu/cuda/src/rotation.cu nabu/cuda/tests/__init__.py nabu/cuda/tests/test_medfilt.py nabu/cuda/tests/test_padding.py nabu/estimation/__init__.py nabu/estimation/alignment.py nabu/estimation/cor.py nabu/estimation/cor_sino.py nabu/estimation/distortion.py nabu/estimation/focus.py nabu/estimation/tilt.py nabu/estimation/translation.py nabu/estimation/utils.py nabu/estimation/tests/__init__.py nabu/estimation/tests/test_alignment.py nabu/estimation/tests/test_cor.py nabu/estimation/tests/test_focus.py nabu/estimation/tests/test_tilt.py nabu/estimation/tests/test_translation.py nabu/io/__init__.py nabu/io/cast_volume.py nabu/io/detector_distortion.py nabu/io/reader.py nabu/io/reader_helical.py nabu/io/tiffwriter_zmm.py nabu/io/utils.py nabu/io/writer.py nabu/io/tests/__init__.py nabu/io/tests/test_cast_volume.py nabu/io/tests/test_detector_distortion.py nabu/io/tests/test_writers.py nabu/misc/__init__.py nabu/misc/binning.py nabu/misc/filters.py nabu/misc/fourier_filters.py nabu/misc/histogram.py nabu/misc/histogram_cuda.py nabu/misc/padding.py nabu/misc/rotation.py nabu/misc/rotation_cuda.py nabu/misc/unsharp.py nabu/misc/unsharp_cuda.py nabu/misc/unsharp_opencl.py nabu/misc/utils.py nabu/misc/tests/__init__.py nabu/misc/tests/test_binning.py nabu/misc/tests/test_histogram.py nabu/misc/tests/test_interpolation.py nabu/misc/tests/test_rotation.py nabu/misc/tests/test_unsharp.py nabu/opencl/__init__.py nabu/opencl/utils.py nabu/pipeline/__init__.py nabu/pipeline/config.py nabu/pipeline/config_validators.py nabu/pipeline/datadump.py nabu/pipeline/dataset_validator.py nabu/pipeline/detector_distortion_provider.py nabu/pipeline/estimators.py nabu/pipeline/fallback_utils.py nabu/pipeline/params.py nabu/pipeline/processconfig.py nabu/pipeline/utils.py nabu/pipeline/writer.py nabu/pipeline/fullfield/__init__.py nabu/pipeline/fullfield/chunked.py nabu/pipeline/fullfield/chunked_cuda.py nabu/pipeline/fullfield/computations.py nabu/pipeline/fullfield/dataset_validator.py nabu/pipeline/fullfield/nabu_config.py nabu/pipeline/fullfield/processconfig.py nabu/pipeline/fullfield/reconstruction.py nabu/pipeline/helical/__init__.py nabu/pipeline/helical/dataset_validator.py nabu/pipeline/helical/fbp.py nabu/pipeline/helical/filtering.py nabu/pipeline/helical/gridded_accumulator.py nabu/pipeline/helical/helical_chunked_regridded.py nabu/pipeline/helical/helical_chunked_regridded_cuda.py nabu/pipeline/helical/helical_reconstruction.py nabu/pipeline/helical/helical_utils.py nabu/pipeline/helical/nabu_config.py nabu/pipeline/helical/processconfig.py nabu/pipeline/helical/span_strategy.py nabu/pipeline/helical/utils.py nabu/pipeline/helical/weight_balancer.py nabu/pipeline/helical/tests/__init__.py nabu/pipeline/helical/tests/test_accumulator.py nabu/pipeline/helical/tests/test_pipeline_elements_full.py nabu/pipeline/helical/tests/test_strategy.py nabu/pipeline/xrdct/__init__.py nabu/preproc/__init__.py nabu/preproc/alignment.py nabu/preproc/ccd.py nabu/preproc/ccd_cuda.py nabu/preproc/ctf.py nabu/preproc/ctf_cuda.py nabu/preproc/distortion.py nabu/preproc/double_flatfield.py nabu/preproc/double_flatfield_cuda.py nabu/preproc/double_flatfield_variable_region.py nabu/preproc/flatfield.py nabu/preproc/flatfield_cuda.py nabu/preproc/flatfield_variable_region.py nabu/preproc/phase.py nabu/preproc/phase_cuda.py nabu/preproc/shift.py nabu/preproc/shift_cuda.py nabu/preproc/tests/__init__.py nabu/preproc/tests/test_ccd_corr.py nabu/preproc/tests/test_ctf.py nabu/preproc/tests/test_double_flatfield.py nabu/preproc/tests/test_flatfield.py nabu/preproc/tests/test_paganin.py nabu/preproc/tests/test_vshift.py nabu/reconstruction/__init__.py nabu/reconstruction/cone.py nabu/reconstruction/fbp.py nabu/reconstruction/fbp_opencl.py nabu/reconstruction/filtering.py nabu/reconstruction/reconstructor.py nabu/reconstruction/reconstructor_cuda.py nabu/reconstruction/rings.py nabu/reconstruction/rings_cuda.py nabu/reconstruction/sinogram.py nabu/reconstruction/sinogram_cuda.py nabu/reconstruction/tests/__init__.py nabu/reconstruction/tests/test_cone.py nabu/reconstruction/tests/test_deringer.py nabu/reconstruction/tests/test_fbp.py nabu/reconstruction/tests/test_halftomo.py nabu/reconstruction/tests/test_reconstructor.py nabu/reconstruction/tests/test_sino_normalization.py nabu/resources/__init__.py nabu/resources/cor.py nabu/resources/dataset_analyzer.py nabu/resources/gpu.py nabu/resources/logger.py nabu/resources/nxflatfield.py nabu/resources/utils.py nabu/resources/cli/__init__.py nabu/resources/templates/__init__.py nabu/resources/templates/bm05_pag.conf nabu/resources/templates/id16_ctf.conf nabu/resources/templates/id16_holo.conf nabu/resources/templates/id19_pag.conf nabu/resources/tests/__init__.py nabu/resources/tests/test_nxflatfield.py nabu/resources/tests/test_units.py nabu/stitching/__init__.py nabu/stitching/config.py nabu/stitching/frame_composition.py nabu/stitching/overlap.py nabu/stitching/utils.py nabu/stitching/z_stitching.py nabu/thirdparty/__init__.py nabu/thirdparty/pore3d_deringer_munch.py nabu/thirdparty/tomopy_phase.py nabu/thirdparty/tomwer_load_flats_darks.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/dependency_links.txt0000644000175000017500000000000100000000000020711 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/entry_points.txt0000644000175000017500000000170500000000000020144 0ustar00pierrepierre[console_scripts] nabu = nabu.app.reconstruct:main nabu-cast = nabu.app.cast_volume:main nabu-compare-volumes = nabu.app.compare_volumes:compare_volumes_cli nabu-composite-cor = nabu.app.composite_cor:main nabu-config = nabu.app.bootstrap:bootstrap nabu-double-flatfield = nabu.app.double_flatfield:dff_cli nabu-generate-info = nabu.app.generate_header:generate_merged_info_file nabu-helical = nabu.app.reconstruct_helical:main_helical nabu-helical-prepare-weights-double = nabu.app.prepare_weights_double:main nabu-histogram = nabu.app.histogram:histogram_cli nabu-poly2map = nabu.app.create_distortion_map_from_poly:horizontal_match nabu-rotate = nabu.app.rotate:rotate_cli nabu-shrink-dataset = nabu.app.shrink_dataset:shrink_cli nabu-stitching = nabu.app.stitching:main nabu-stitching-config = nabu.app.bootstrap_stitching:bootstrap_stitching nabu-test = nabu.tests:nabu_test nabu-validator = nabu.app.validator:main nabu-zsplit = nabu.app.nx_z_splitter:zsplit ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/requires.txt0000644000175000017500000000031200000000000017237 0ustar00pierrepierrepsutil pytest numpy>1.9.0 scipy silx>=0.15.0 h5py>=3.0 tomoscan>=1.2.1 tifffile [doc] sphinx cloud_sptheme myst-parser nbsphinx [full] pyfftw scikit-image PyWavelets glymur pycuda scikit-cuda pycudwt ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682596037.0 nabu-2023.1.1/nabu.egg-info/top_level.txt0000644000175000017500000000000500000000000017370 0ustar00pierrepierrenabu ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1587030670.0 nabu-2023.1.1/nabu.egg-info/zip-safe0000644000175000017500000000000100000000000016273 0ustar00pierrepierre ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1682596037.4727333 nabu-2023.1.1/setup.cfg0000644000175000017500000000004600000000000014045 0ustar00pierrepierre[egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1682589911.0 nabu-2023.1.1/setup.py0000644000175000017500000000526700000000000013750 0ustar00pierrepierre# coding: utf-8 from setuptools import setup, find_packages import os from nabu import version def setup_package(): doc_requires = [ "sphinx", "cloud_sptheme", "myst-parser", "nbsphinx", ] setup( name="nabu", author="Pierre Paleo", version=version, author_email="pierre.paleo@esrf.fr", maintainer="Pierre Paleo", maintainer_email="pierre.paleo@esrf.fr", packages=find_packages(), package_data={ "nabu.cuda": [ "src/*.cu", "src/*.h", ], "nabu.resources": [ "templates/*.conf", ], }, include_package_data=True, install_requires=[ "psutil", "pytest", "numpy > 1.9.0", "scipy", "silx >= 0.15.0", "h5py>=3.0", "tomoscan >= 1.2.1", "tifffile", ], extras_require={ "full": [ "pyfftw", "scikit-image", "PyWavelets", "glymur", "pycuda", "scikit-cuda", "pycudwt", ], "doc": doc_requires, }, description="Nabu - Tomography software", entry_points={ "console_scripts": [ "nabu=nabu.app.reconstruct:main", "nabu-config=nabu.app.bootstrap:bootstrap", "nabu-test=nabu.tests:nabu_test", "nabu-histogram=nabu.app.histogram:histogram_cli", "nabu-zsplit=nabu.app.nx_z_splitter:zsplit", "nabu-rotate=nabu.app.rotate:rotate_cli", "nabu-double-flatfield=nabu.app.double_flatfield:dff_cli", "nabu-generate-info=nabu.app.generate_header:generate_merged_info_file", "nabu-validator=nabu.app.validator:main", "nabu-helical=nabu.app.reconstruct_helical:main_helical", "nabu-helical-prepare-weights-double=nabu.app.prepare_weights_double:main", "nabu-stitching-config=nabu.app.bootstrap_stitching:bootstrap_stitching", "nabu-stitching=nabu.app.stitching:main", "nabu-cast=nabu.app.cast_volume:main", "nabu-compare-volumes=nabu.app.compare_volumes:compare_volumes_cli", "nabu-shrink-dataset=nabu.app.shrink_dataset:shrink_cli", "nabu-composite-cor=nabu.app.composite_cor:main", "nabu-poly2map=nabu.app.create_distortion_map_from_poly:horizontal_match", ], }, zip_safe=True, ) if __name__ == "__main__": setup_package()